1use std::{collections::HashMap, num::NonZeroU32, sync::Arc};
2
3use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter};
4use tokio_stream::wrappers::ReceiverStream;
5use tonic::{
6 metadata::MetadataValue, service::interceptor::InterceptedService, transport::Server, Request,
7 Response, Status,
8};
9
10use crate::{
11 config::VectorConfig,
12 engine::VectorEngine,
13 grpc::proto::{
14 embedding_service_server::{EmbeddingService, EmbeddingServiceServer},
15 vector_service_server::{VectorService, VectorServiceServer},
16 CollectionInfo, CollectionStatsResponse, CreateCollectionRequest, DeleteCollectionRequest,
17 DeleteResult, EmbedRequest, EmbedResponse, HealthRequest, HealthResponse,
18 ListCollectionsResponse, ListRequest, ModelInfoRequest, ModelInfoResponse,
19 SearchMetricsProto, SearchRequest, SearchResponseProto, StatsRequest, UpsertResult,
20 UpsertVectorRequest,
21 },
22 types::{DistanceMetric, SearchQuery},
23 VectorError,
24};
25
26const WORKSPACE_HEADER: &str = "x-claw-workspace-id";
27const API_KEY_HEADER: &str = "x-claw-api-key";
28const REQUEST_ID_HEADER: &str = "x-request-id";
29
30#[derive(Clone)]
31struct WorkspaceId(String);
32
33#[derive(Clone)]
34struct TraceId(String);
35
36pub struct EmbeddingServiceImpl;
38
39#[tonic::async_trait]
40impl EmbeddingService for EmbeddingServiceImpl {
41 async fn embed(
42 &self,
43 _request: Request<EmbedRequest>,
44 ) -> Result<Response<EmbedResponse>, Status> {
45 Err(Status::unimplemented(
46 "Embed is handled by the Python embedding service",
47 ))
48 }
49
50 async fn health(
51 &self,
52 _request: Request<HealthRequest>,
53 ) -> Result<Response<HealthResponse>, Status> {
54 Ok(Response::new(HealthResponse {
55 ready: false,
56 model_name: String::new(),
57 model_load_time_ms: 0,
58 }))
59 }
60
61 async fn model_info(
62 &self,
63 _request: Request<ModelInfoRequest>,
64 ) -> Result<Response<ModelInfoResponse>, Status> {
65 Err(Status::unimplemented(
66 "ModelInfo is handled by the Python embedding service",
67 ))
68 }
69
70 type EmbedStreamStream = ReceiverStream<Result<EmbedResponse, Status>>;
71
72 async fn embed_stream(
73 &self,
74 _request: Request<tonic::Streaming<EmbedRequest>>,
75 ) -> Result<Response<Self::EmbedStreamStream>, Status> {
76 Err(Status::unimplemented(
77 "EmbedStream is handled by the Python embedding service",
78 ))
79 }
80}
81
82#[derive(Clone)]
83struct ServerState {
84 default_workspace_id: String,
85 require_workspace_id: bool,
86 require_auth: bool,
87 api_keys: Arc<HashMap<String, String>>,
88 limiter: Arc<RateLimiter<String, DashMapStateStore<String>, DefaultClock>>,
89}
90
91#[derive(Clone)]
92struct AuthRateTraceInterceptor {
93 state: Arc<ServerState>,
94}
95
96impl tonic::service::Interceptor for AuthRateTraceInterceptor {
97 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
98 let workspace_id = request
99 .metadata()
100 .get(WORKSPACE_HEADER)
101 .and_then(|value| value.to_str().ok())
102 .map(ToOwned::to_owned)
103 .or_else(|| {
104 if self.state.require_workspace_id {
105 None
106 } else {
107 Some(self.state.default_workspace_id.clone())
108 }
109 })
110 .ok_or_else(|| Status::invalid_argument("missing x-claw-workspace-id"))?;
111
112 if self.state.require_auth {
113 let api_key = request
114 .metadata()
115 .get(API_KEY_HEADER)
116 .and_then(|value| value.to_str().ok())
117 .ok_or_else(|| Status::unauthenticated("missing x-claw-api-key"))?;
118
119 let hashed = hash_api_key(api_key);
120 let valid_workspace = self
121 .state
122 .api_keys
123 .get(&hashed)
124 .map(|ws| ws == &workspace_id)
125 .unwrap_or(false);
126 if !valid_workspace {
127 return Err(Status::unauthenticated("invalid API key"));
128 }
129 }
130
131 if self.state.limiter.check_key(&workspace_id).is_err() {
132 return Err(Status::resource_exhausted("rate limit exceeded"));
133 }
134
135 let request_id = request
136 .metadata()
137 .get(REQUEST_ID_HEADER)
138 .and_then(|value| value.to_str().ok())
139 .map(ToOwned::to_owned)
140 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
141
142 request.extensions_mut().insert(WorkspaceId(workspace_id));
143 request.extensions_mut().insert(TraceId(request_id));
144 Ok(request)
145 }
146}
147
148pub struct VectorServiceImpl {
150 engine: Arc<VectorEngine>,
151}
152
153impl VectorServiceImpl {
154 fn workspace_from_request<T>(&self, request: &Request<T>) -> String {
155 request
156 .extensions()
157 .get::<WorkspaceId>()
158 .map(|value| value.0.clone())
159 .unwrap_or_else(|| self.engine.config.default_workspace_id.clone())
160 }
161
162 fn trace_from_request<T>(
163 &self,
164 request: &Request<T>,
165 ) -> Option<MetadataValue<tonic::metadata::Ascii>> {
166 request
167 .extensions()
168 .get::<TraceId>()
169 .and_then(|value| MetadataValue::try_from(value.0.as_str()).ok())
170 }
171}
172
173#[tonic::async_trait]
174impl VectorService for VectorServiceImpl {
175 async fn create_collection(
176 &self,
177 request: Request<CreateCollectionRequest>,
178 ) -> Result<Response<CollectionInfo>, Status> {
179 let trace = self.trace_from_request(&request);
180 let workspace_id = self.workspace_from_request(&request);
181 let req = request.into_inner();
182 let metric = parse_distance_metric(&req.distance_metric)?;
183 let created = self
184 .engine
185 .create_collection_in_workspace(
186 &workspace_id,
187 &req.name,
188 req.dimensions as usize,
189 metric,
190 )
191 .await
192 .map_err(Status::from)?;
193 let info = collection_to_proto(&created);
194 let mut response = Response::new(info);
195 if let Some(trace_id) = trace {
196 response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
197 }
198 Ok(response)
199 }
200
201 async fn delete_collection(
202 &self,
203 request: Request<DeleteCollectionRequest>,
204 ) -> Result<Response<DeleteResult>, Status> {
205 let trace = self.trace_from_request(&request);
206 let workspace_id = self.workspace_from_request(&request);
207 let req = request.into_inner();
208 let stats = self
209 .engine
210 .collections
211 .store
212 .collection_stats(&workspace_id, &req.name)
213 .await
214 .ok();
215 self.engine
216 .delete_collection_in_workspace(&workspace_id, &req.name)
217 .await
218 .map_err(Status::from)?;
219 let mut response = Response::new(DeleteResult {
220 records_removed: stats.map(|value| value.vector_count).unwrap_or(0),
221 });
222 if let Some(trace_id) = trace {
223 response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
224 }
225 Ok(response)
226 }
227
228 async fn upsert_vector(
229 &self,
230 request: Request<UpsertVectorRequest>,
231 ) -> Result<Response<UpsertResult>, Status> {
232 let trace = self.trace_from_request(&request);
233 let workspace_id = self.workspace_from_request(&request);
234 let req = request.into_inner();
235 let metadata = if req.metadata_json.trim().is_empty() {
236 serde_json::json!({})
237 } else {
238 serde_json::from_str(&req.metadata_json)
239 .map_err(|err| Status::invalid_argument(format!("invalid metadata_json: {err}")))?
240 };
241
242 let id = if !req.vector.is_empty() {
243 self.engine
244 .upsert_vector_in_workspace(&workspace_id, &req.collection, req.vector, metadata)
245 .await
246 .map_err(Status::from)?
247 } else if !req.text.trim().is_empty() {
248 self.engine
249 .upsert_in_workspace(&workspace_id, &req.collection, &req.text, metadata)
250 .await
251 .map_err(Status::from)?
252 } else {
253 return Err(Status::invalid_argument(
254 "either vector or text must be provided",
255 ));
256 };
257
258 let mut response = Response::new(UpsertResult { id: id.to_string() });
259 if let Some(trace_id) = trace {
260 response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
261 }
262 Ok(response)
263 }
264
265 async fn search_vectors(
266 &self,
267 request: Request<SearchRequest>,
268 ) -> Result<Response<SearchResponseProto>, Status> {
269 let trace = self.trace_from_request(&request);
270 let workspace_id = self.workspace_from_request(&request);
271 let req = request.into_inner();
272 let filter =
273 if req.filter_json.trim().is_empty() {
274 None
275 } else {
276 Some(serde_json::from_str(&req.filter_json).map_err(|err| {
277 Status::invalid_argument(format!("invalid filter_json: {err}"))
278 })?)
279 };
280
281 let query = if !req.vector.is_empty() {
282 SearchQuery {
283 collection: req.collection,
284 vector: req.vector,
285 top_k: req.top_k.max(1) as usize,
286 filter,
287 include_vectors: req.include_vectors,
288 include_metadata: req.include_metadata,
289 ef_search: None,
290 reranker: None,
291 }
292 } else if !req.text.trim().is_empty() {
293 let embedded = self
294 .engine
295 .embedding_client
296 .embed_one(&req.text)
297 .await
298 .map_err(Status::from)?;
299 SearchQuery {
300 collection: req.collection,
301 vector: embedded,
302 top_k: req.top_k.max(1) as usize,
303 filter,
304 include_vectors: req.include_vectors,
305 include_metadata: req.include_metadata,
306 ef_search: None,
307 reranker: None,
308 }
309 } else {
310 return Err(Status::invalid_argument(
311 "either vector or text must be provided",
312 ));
313 };
314
315 let response = self
316 .engine
317 .search_in_workspace(&workspace_id, query)
318 .await
319 .map_err(Status::from)?;
320
321 let proto = SearchResponseProto {
322 results: response
323 .results
324 .into_iter()
325 .map(|result| crate::grpc::proto::SearchHit {
326 id: result.id.to_string(),
327 score: result.score,
328 vector: result.vector.unwrap_or_default(),
329 metadata_json: if result.metadata.is_null() {
330 "{}".to_string()
331 } else {
332 serde_json::to_string(&result.metadata).unwrap_or_else(|_| "{}".to_string())
333 },
334 text: result.text.unwrap_or_default(),
335 })
336 .collect(),
337 metrics: Some(SearchMetricsProto {
338 query_vector_dims: response.metrics.query_vector_dims as u32,
339 candidates_evaluated: response.metrics.candidates_evaluated as u32,
340 post_filter_count: response.metrics.post_filter_count as u32,
341 latency_us: response.metrics.latency_us,
342 }),
343 };
344
345 let mut response = Response::new(proto);
346 if let Some(trace_id) = trace {
347 response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
348 }
349 Ok(response)
350 }
351
352 async fn get_collection_stats(
353 &self,
354 request: Request<StatsRequest>,
355 ) -> Result<Response<CollectionStatsResponse>, Status> {
356 let trace = self.trace_from_request(&request);
357 let workspace_id = self.workspace_from_request(&request);
358 let req = request.into_inner();
359 let collection = self
360 .engine
361 .collections
362 .get_collection(&workspace_id, &req.collection)
363 .await
364 .map_err(Status::from)?;
365
366 let response = CollectionStatsResponse {
367 vector_count: collection.vector_count,
368 index_type: format!("{:?}", collection.index_type).to_lowercase(),
369 dimensions: collection.dimensions as u32,
370 last_modified_at: collection.created_at.timestamp_millis(),
371 };
372
373 let mut response = Response::new(response);
374 if let Some(trace_id) = trace {
375 response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
376 }
377 Ok(response)
378 }
379
380 async fn list_collections(
381 &self,
382 request: Request<ListRequest>,
383 ) -> Result<Response<ListCollectionsResponse>, Status> {
384 let trace = self.trace_from_request(&request);
385 let workspace_id = self.workspace_from_request(&request);
386 let req = request.into_inner();
387 let page_size = req.page_size.clamp(1, 500) as usize;
388 let page = req.page.max(1) as usize;
389
390 let collections = self
391 .engine
392 .list_collections_in_workspace(&workspace_id)
393 .await
394 .map_err(Status::from)?;
395 let total = collections.len();
396 let start = page_size.saturating_mul(page - 1);
397 let page_items = collections
398 .into_iter()
399 .skip(start)
400 .take(page_size)
401 .map(|collection| collection_to_proto(&collection))
402 .collect::<Vec<_>>();
403
404 let mut response = Response::new(ListCollectionsResponse {
405 collections: page_items,
406 page: page as u32,
407 page_size: page_size as u32,
408 total: total as u32,
409 });
410 if let Some(trace_id) = trace {
411 response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
412 }
413 Ok(response)
414 }
415}
416
417fn collection_to_proto(collection: &crate::types::Collection) -> CollectionInfo {
418 CollectionInfo {
419 id: format!("{}:{}", collection.workspace_id, collection.name),
420 name: collection.name.clone(),
421 dimensions: collection.dimensions as u32,
422 distance_metric: format!("{:?}", collection.distance).to_lowercase(),
423 index_type: format!("{:?}", collection.index_type).to_lowercase(),
424 vector_count: collection.vector_count,
425 last_modified_at: collection.created_at.timestamp_millis(),
426 }
427}
428
429#[allow(clippy::result_large_err)]
430fn parse_distance_metric(raw: &str) -> Result<DistanceMetric, Status> {
431 match raw.trim().to_ascii_lowercase().as_str() {
432 "cosine" => Ok(DistanceMetric::Cosine),
433 "euclidean" => Ok(DistanceMetric::Euclidean),
434 "dot" | "dot_product" => Ok(DistanceMetric::DotProduct),
435 _ => Err(Status::invalid_argument("unsupported distance metric")),
436 }
437}
438
439fn hash_api_key(api_key: &str) -> String {
440 blake3::hash(api_key.as_bytes()).to_hex().to_string()
441}
442
443async fn load_server_state(config: &VectorConfig) -> Result<Arc<ServerState>, VectorError> {
444 let store = crate::store::sqlite::VectorStore::new(&config.api_key_store_path).await?;
445 sqlx::query(
446 "CREATE TABLE IF NOT EXISTS api_keys (key_hash TEXT PRIMARY KEY, workspace_id TEXT NOT NULL, created_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0)",
447 )
448 .execute(store.pool())
449 .await?;
450
451 let key_rows = sqlx::query_as::<_, (String, String)>(
452 "SELECT key_hash, workspace_id FROM api_keys WHERE revoked = 0",
453 )
454 .fetch_all(store.pool())
455 .await?;
456 let api_keys = key_rows.into_iter().collect::<HashMap<_, _>>();
457
458 let rate_limit =
459 NonZeroU32::new(config.rate_limit_rps.max(1)).unwrap_or(NonZeroU32::new(1).unwrap());
460 let limiter = RateLimiter::keyed(Quota::per_second(rate_limit));
461
462 Ok(Arc::new(ServerState {
463 default_workspace_id: config.default_workspace_id.clone(),
464 require_workspace_id: config.require_workspace_id,
465 require_auth: config.require_auth,
466 api_keys: Arc::new(api_keys),
467 limiter: Arc::new(limiter),
468 }))
469}
470
471pub async fn serve(addr: std::net::SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
473 let config = VectorConfig::from_env();
474 let state = load_server_state(&config).await?;
475 let engine = Arc::new(VectorEngine::new(config.clone()).await?);
476
477 let interceptor = AuthRateTraceInterceptor { state };
478 let embedding = EmbeddingServiceServer::new(EmbeddingServiceImpl);
479 let vector = VectorServiceServer::new(VectorServiceImpl { engine });
480
481 Server::builder()
482 .add_service(InterceptedService::new(embedding, interceptor.clone()))
483 .add_service(InterceptedService::new(vector, interceptor))
484 .serve(addr)
485 .await?;
486 Ok(())
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use tonic::service::Interceptor;
493
494 fn interceptor_for_test(require_auth: bool, rate_limit: u32) -> AuthRateTraceInterceptor {
495 let mut api_keys = HashMap::new();
496 api_keys.insert(hash_api_key("valid-key"), "ws-test".to_string());
497 AuthRateTraceInterceptor {
498 state: Arc::new(ServerState {
499 default_workspace_id: "default".to_string(),
500 require_workspace_id: false,
501 require_auth,
502 api_keys: Arc::new(api_keys),
503 limiter: Arc::new(RateLimiter::keyed(Quota::per_second(
504 NonZeroU32::new(rate_limit.max(1)).unwrap(),
505 ))),
506 }),
507 }
508 }
509
510 #[test]
511 fn interceptor_rejects_invalid_api_key() {
512 let mut interceptor = interceptor_for_test(true, 100);
513 let mut request = Request::new(());
514 request.metadata_mut().insert(
515 WORKSPACE_HEADER,
516 MetadataValue::try_from("ws-test").unwrap(),
517 );
518 request.metadata_mut().insert(
519 API_KEY_HEADER,
520 MetadataValue::try_from("wrong-key").unwrap(),
521 );
522
523 let result = interceptor.call(request);
524 assert!(matches!(result, Err(status) if status.code() == tonic::Code::Unauthenticated));
525 }
526
527 #[test]
528 fn interceptor_applies_workspace_rate_limit() {
529 let mut interceptor = interceptor_for_test(false, 100);
530 let mut last: Result<Request<()>, Status> = Ok(Request::new(()));
531 for _ in 0..101 {
532 let mut request = Request::new(());
533 request.metadata_mut().insert(
534 WORKSPACE_HEADER,
535 MetadataValue::try_from("ws-test").unwrap(),
536 );
537 last = interceptor.call(request);
538 }
539
540 assert!(matches!(last, Err(status) if status.code() == tonic::Code::ResourceExhausted));
541 }
542}