1use std::{
2 collections::HashMap,
3 sync::{Arc, Mutex},
4 time::Instant,
5};
6
7use sha2::{Digest, Sha256};
8use tokio_stream::wrappers::ReceiverStream;
9use tonic::{
10 metadata::MetadataValue, service::interceptor::InterceptedService, transport::Server, Request,
11 Response, Status,
12};
13
14use crate::{
15 config::VectorConfig,
16 engine::VectorEngine,
17 grpc::proto::{
18 embedding_service_server::{EmbeddingService, EmbeddingServiceServer},
19 vector_service_server::{VectorService, VectorServiceServer},
20 CollectionInfo, CollectionStatsResponse, CreateCollectionRequest, DeleteCollectionRequest,
21 DeleteResult, EmbedRequest, EmbedResponse, HealthRequest, HealthResponse,
22 ListCollectionsResponse, ListRequest, ModelInfoRequest, ModelInfoResponse,
23 SearchMetricsProto, SearchRequest, SearchResponseProto, StatsRequest, UpsertResult,
24 UpsertVectorRequest,
25 },
26 types::{DistanceMetric, SearchQuery},
27 VectorError,
28};
29
30const WORKSPACE_HEADER: &str = "x-claw-workspace-id";
31const API_KEY_HEADER: &str = "x-claw-api-key";
32const TRACE_HEADER: &str = "x-trace-id";
33
34#[derive(Clone)]
35struct WorkspaceId(String);
36
37#[derive(Clone)]
38struct TraceId(String);
39
40pub struct EmbeddingServiceImpl;
42
43#[tonic::async_trait]
44impl EmbeddingService for EmbeddingServiceImpl {
45 async fn embed(
46 &self,
47 _request: Request<EmbedRequest>,
48 ) -> Result<Response<EmbedResponse>, Status> {
49 Err(Status::unimplemented(
50 "Embed is handled by the Python embedding service",
51 ))
52 }
53
54 async fn health(
55 &self,
56 _request: Request<HealthRequest>,
57 ) -> Result<Response<HealthResponse>, Status> {
58 Ok(Response::new(HealthResponse {
59 ready: false,
60 model_name: String::new(),
61 model_load_time_ms: 0,
62 }))
63 }
64
65 async fn model_info(
66 &self,
67 _request: Request<ModelInfoRequest>,
68 ) -> Result<Response<ModelInfoResponse>, Status> {
69 Err(Status::unimplemented(
70 "ModelInfo is handled by the Python embedding service",
71 ))
72 }
73
74 type EmbedStreamStream = ReceiverStream<Result<EmbedResponse, Status>>;
75
76 async fn embed_stream(
77 &self,
78 _request: Request<tonic::Streaming<EmbedRequest>>,
79 ) -> Result<Response<Self::EmbedStreamStream>, Status> {
80 Err(Status::unimplemented(
81 "EmbedStream is handled by the Python embedding service",
82 ))
83 }
84}
85
86#[derive(Clone)]
87struct ServerState {
88 default_workspace_id: String,
89 require_auth: bool,
90 default_rate_limit_rps: u32,
91 api_keys: Arc<HashMap<String, String>>,
92 workspace_rate_limits: Arc<HashMap<String, u32>>,
93 buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
94}
95
96#[derive(Clone)]
97struct AuthRateTraceInterceptor {
98 state: Arc<ServerState>,
99}
100
101impl tonic::service::Interceptor for AuthRateTraceInterceptor {
102 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
103 let workspace_id = request
104 .metadata()
105 .get(WORKSPACE_HEADER)
106 .and_then(|value| value.to_str().ok())
107 .map(ToOwned::to_owned)
108 .unwrap_or_else(|| self.state.default_workspace_id.clone());
109
110 if self.state.require_auth {
111 let api_key = request
112 .metadata()
113 .get(API_KEY_HEADER)
114 .and_then(|value| value.to_str().ok())
115 .ok_or_else(|| Status::unauthenticated("missing x-claw-api-key"))?;
116
117 let hashed = hash_api_key(api_key);
118 let valid_workspace = self
119 .state
120 .api_keys
121 .get(&hashed)
122 .map(|ws| ws == &workspace_id)
123 .unwrap_or(false);
124 if !valid_workspace {
125 return Err(Status::unauthenticated("invalid API key"));
126 }
127 }
128
129 let rate_limit = self
130 .state
131 .workspace_rate_limits
132 .get(&workspace_id)
133 .copied()
134 .unwrap_or(self.state.default_rate_limit_rps)
135 .max(1);
136 {
137 let mut buckets = self
138 .state
139 .buckets
140 .lock()
141 .map_err(|_| Status::internal("rate limiter lock poisoned"))?;
142 let bucket = buckets
143 .entry(workspace_id.clone())
144 .or_insert_with(|| TokenBucket::new(rate_limit));
145 bucket.rate_limit_rps = rate_limit;
146 if !bucket.try_consume(1.0) {
147 return Err(Status::resource_exhausted("rate limit exceeded"));
148 }
149 }
150
151 request.extensions_mut().insert(WorkspaceId(workspace_id));
152 request
153 .extensions_mut()
154 .insert(TraceId(format!("trace-{}", uuid::Uuid::new_v4())));
155 Ok(request)
156 }
157}
158
159#[derive(Debug)]
160struct TokenBucket {
161 tokens: f64,
162 last_refill: Instant,
163 rate_limit_rps: u32,
164}
165
166impl TokenBucket {
167 fn new(rate_limit_rps: u32) -> Self {
168 let rate = rate_limit_rps.max(1) as f64;
169 Self {
170 tokens: rate,
171 last_refill: Instant::now(),
172 rate_limit_rps,
173 }
174 }
175
176 fn try_consume(&mut self, cost: f64) -> bool {
177 let now = Instant::now();
178 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
179 let rate = self.rate_limit_rps.max(1) as f64;
180 self.tokens = (self.tokens + elapsed * rate).min(rate);
181 self.last_refill = now;
182 if self.tokens >= cost {
183 self.tokens -= cost;
184 true
185 } else {
186 false
187 }
188 }
189}
190
191pub struct VectorServiceImpl {
193 engine: Arc<VectorEngine>,
194}
195
196impl VectorServiceImpl {
197 fn workspace_from_request<T>(&self, request: &Request<T>) -> String {
198 request
199 .extensions()
200 .get::<WorkspaceId>()
201 .map(|value| value.0.clone())
202 .unwrap_or_else(|| self.engine.config.default_workspace_id.clone())
203 }
204
205 fn trace_from_request<T>(
206 &self,
207 request: &Request<T>,
208 ) -> Option<MetadataValue<tonic::metadata::Ascii>> {
209 request
210 .extensions()
211 .get::<TraceId>()
212 .and_then(|value| MetadataValue::try_from(value.0.as_str()).ok())
213 }
214}
215
216#[tonic::async_trait]
217impl VectorService for VectorServiceImpl {
218 async fn create_collection(
219 &self,
220 request: Request<CreateCollectionRequest>,
221 ) -> Result<Response<CollectionInfo>, Status> {
222 let trace = self.trace_from_request(&request);
223 let workspace_id = self.workspace_from_request(&request);
224 let req = request.into_inner();
225 let metric = parse_distance_metric(&req.distance_metric)?;
226 let created = self
227 .engine
228 .create_collection_in_workspace(
229 &workspace_id,
230 &req.name,
231 req.dimensions as usize,
232 metric,
233 )
234 .await
235 .map_err(Status::from)?;
236 let info = collection_to_proto(&created);
237 let mut response = Response::new(info);
238 if let Some(trace_id) = trace {
239 response.metadata_mut().insert(TRACE_HEADER, trace_id);
240 }
241 Ok(response)
242 }
243
244 async fn delete_collection(
245 &self,
246 request: Request<DeleteCollectionRequest>,
247 ) -> Result<Response<DeleteResult>, Status> {
248 let trace = self.trace_from_request(&request);
249 let workspace_id = self.workspace_from_request(&request);
250 let req = request.into_inner();
251 let stats = self
252 .engine
253 .collections
254 .store
255 .collection_stats(&workspace_id, &req.name)
256 .await
257 .ok();
258 self.engine
259 .delete_collection_in_workspace(&workspace_id, &req.name)
260 .await
261 .map_err(Status::from)?;
262 let mut response = Response::new(DeleteResult {
263 records_removed: stats.map(|value| value.vector_count).unwrap_or(0),
264 });
265 if let Some(trace_id) = trace {
266 response.metadata_mut().insert(TRACE_HEADER, trace_id);
267 }
268 Ok(response)
269 }
270
271 async fn upsert_vector(
272 &self,
273 request: Request<UpsertVectorRequest>,
274 ) -> Result<Response<UpsertResult>, Status> {
275 let trace = self.trace_from_request(&request);
276 let workspace_id = self.workspace_from_request(&request);
277 let req = request.into_inner();
278 let metadata = if req.metadata_json.trim().is_empty() {
279 serde_json::json!({})
280 } else {
281 serde_json::from_str(&req.metadata_json)
282 .map_err(|err| Status::invalid_argument(format!("invalid metadata_json: {err}")))?
283 };
284
285 let id = if !req.vector.is_empty() {
286 self.engine
287 .upsert_vector_in_workspace(&workspace_id, &req.collection, req.vector, metadata)
288 .await
289 .map_err(Status::from)?
290 } else if !req.text.trim().is_empty() {
291 self.engine
292 .upsert_in_workspace(&workspace_id, &req.collection, &req.text, metadata)
293 .await
294 .map_err(Status::from)?
295 } else {
296 return Err(Status::invalid_argument(
297 "either vector or text must be provided",
298 ));
299 };
300
301 let mut response = Response::new(UpsertResult { id: id.to_string() });
302 if let Some(trace_id) = trace {
303 response.metadata_mut().insert(TRACE_HEADER, trace_id);
304 }
305 Ok(response)
306 }
307
308 async fn search_vectors(
309 &self,
310 request: Request<SearchRequest>,
311 ) -> Result<Response<SearchResponseProto>, Status> {
312 let trace = self.trace_from_request(&request);
313 let workspace_id = self.workspace_from_request(&request);
314 let req = request.into_inner();
315 let filter =
316 if req.filter_json.trim().is_empty() {
317 None
318 } else {
319 Some(serde_json::from_str(&req.filter_json).map_err(|err| {
320 Status::invalid_argument(format!("invalid filter_json: {err}"))
321 })?)
322 };
323
324 let query = if !req.vector.is_empty() {
325 SearchQuery {
326 collection: req.collection,
327 vector: req.vector,
328 top_k: req.top_k.max(1) as usize,
329 filter,
330 include_vectors: req.include_vectors,
331 include_metadata: req.include_metadata,
332 ef_search: None,
333 reranker: None,
334 }
335 } else if !req.text.trim().is_empty() {
336 let embedded = self
337 .engine
338 .embedding_client
339 .embed_one(&req.text)
340 .await
341 .map_err(Status::from)?;
342 SearchQuery {
343 collection: req.collection,
344 vector: embedded,
345 top_k: req.top_k.max(1) as usize,
346 filter,
347 include_vectors: req.include_vectors,
348 include_metadata: req.include_metadata,
349 ef_search: None,
350 reranker: None,
351 }
352 } else {
353 return Err(Status::invalid_argument(
354 "either vector or text must be provided",
355 ));
356 };
357
358 let response = self
359 .engine
360 .search_in_workspace(&workspace_id, query)
361 .await
362 .map_err(Status::from)?;
363
364 let proto = SearchResponseProto {
365 results: response
366 .results
367 .into_iter()
368 .map(|result| crate::grpc::proto::SearchHit {
369 id: result.id.to_string(),
370 score: result.score,
371 vector: result.vector.unwrap_or_default(),
372 metadata_json: if result.metadata.is_null() {
373 "{}".to_string()
374 } else {
375 serde_json::to_string(&result.metadata).unwrap_or_else(|_| "{}".to_string())
376 },
377 text: result.text.unwrap_or_default(),
378 })
379 .collect(),
380 metrics: Some(SearchMetricsProto {
381 query_vector_dims: response.metrics.query_vector_dims as u32,
382 candidates_evaluated: response.metrics.candidates_evaluated as u32,
383 post_filter_count: response.metrics.post_filter_count as u32,
384 latency_us: response.metrics.latency_us,
385 }),
386 };
387
388 let mut response = Response::new(proto);
389 if let Some(trace_id) = trace {
390 response.metadata_mut().insert(TRACE_HEADER, trace_id);
391 }
392 Ok(response)
393 }
394
395 async fn get_collection_stats(
396 &self,
397 request: Request<StatsRequest>,
398 ) -> Result<Response<CollectionStatsResponse>, Status> {
399 let trace = self.trace_from_request(&request);
400 let workspace_id = self.workspace_from_request(&request);
401 let req = request.into_inner();
402 let collection = self
403 .engine
404 .collections
405 .get_collection(&workspace_id, &req.collection)
406 .await
407 .map_err(Status::from)?;
408
409 let response = CollectionStatsResponse {
410 vector_count: collection.vector_count,
411 index_type: format!("{:?}", collection.index_type).to_lowercase(),
412 dimensions: collection.dimensions as u32,
413 last_modified_at: collection.created_at.timestamp_millis(),
414 };
415
416 let mut response = Response::new(response);
417 if let Some(trace_id) = trace {
418 response.metadata_mut().insert(TRACE_HEADER, trace_id);
419 }
420 Ok(response)
421 }
422
423 async fn list_collections(
424 &self,
425 request: Request<ListRequest>,
426 ) -> Result<Response<ListCollectionsResponse>, Status> {
427 let trace = self.trace_from_request(&request);
428 let workspace_id = self.workspace_from_request(&request);
429 let req = request.into_inner();
430 let page_size = req.page_size.clamp(1, 500) as usize;
431 let page = req.page.max(1) as usize;
432
433 let collections = self
434 .engine
435 .list_collections_in_workspace(&workspace_id)
436 .await
437 .map_err(Status::from)?;
438 let total = collections.len();
439 let start = page_size.saturating_mul(page - 1);
440 let page_items = collections
441 .into_iter()
442 .skip(start)
443 .take(page_size)
444 .map(|collection| collection_to_proto(&collection))
445 .collect::<Vec<_>>();
446
447 let mut response = Response::new(ListCollectionsResponse {
448 collections: page_items,
449 page: page as u32,
450 page_size: page_size as u32,
451 total: total as u32,
452 });
453 if let Some(trace_id) = trace {
454 response.metadata_mut().insert(TRACE_HEADER, trace_id);
455 }
456 Ok(response)
457 }
458}
459
460fn collection_to_proto(collection: &crate::types::Collection) -> CollectionInfo {
461 CollectionInfo {
462 id: format!("{}:{}", collection.workspace_id, collection.name),
463 name: collection.name.clone(),
464 dimensions: collection.dimensions as u32,
465 distance_metric: format!("{:?}", collection.distance).to_lowercase(),
466 index_type: format!("{:?}", collection.index_type).to_lowercase(),
467 vector_count: collection.vector_count,
468 last_modified_at: collection.created_at.timestamp_millis(),
469 }
470}
471
472#[allow(clippy::result_large_err)]
473fn parse_distance_metric(raw: &str) -> Result<DistanceMetric, Status> {
474 match raw.trim().to_ascii_lowercase().as_str() {
475 "cosine" => Ok(DistanceMetric::Cosine),
476 "euclidean" => Ok(DistanceMetric::Euclidean),
477 "dot" | "dot_product" => Ok(DistanceMetric::DotProduct),
478 _ => Err(Status::invalid_argument("unsupported distance metric")),
479 }
480}
481
482fn hash_api_key(api_key: &str) -> String {
483 let mut hasher = Sha256::new();
484 hasher.update(api_key.as_bytes());
485 hex::encode(hasher.finalize())
486}
487
488async fn load_server_state(config: &VectorConfig) -> Result<Arc<ServerState>, VectorError> {
489 let store = crate::store::sqlite::VectorStore::new(&config.api_key_store_path).await?;
490 sqlx::query(
491 "CREATE TABLE IF NOT EXISTS api_keys (key_hash TEXT PRIMARY KEY, workspace_id TEXT NOT NULL, created_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0)",
492 )
493 .execute(store.pool())
494 .await?;
495
496 let key_rows = sqlx::query_as::<_, (String, String)>(
497 "SELECT key_hash, workspace_id FROM api_keys WHERE revoked = 0",
498 )
499 .fetch_all(store.pool())
500 .await?;
501 let api_keys = key_rows.into_iter().collect::<HashMap<_, _>>();
502
503 let mut workspace_rate_limits = HashMap::new();
504 let has_rate_table = sqlx::query_scalar::<_, i64>(
505 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='workspace_rate_limits'",
506 )
507 .fetch_one(store.pool())
508 .await
509 .unwrap_or(0)
510 > 0;
511 if has_rate_table {
512 let rows = sqlx::query_as::<_, (String, i64)>(
513 "SELECT workspace_id, rate_limit_rps FROM workspace_rate_limits",
514 )
515 .fetch_all(store.pool())
516 .await
517 .unwrap_or_default();
518 for (workspace_id, rps) in rows {
519 workspace_rate_limits.insert(workspace_id, (rps as u32).max(1));
520 }
521 }
522
523 Ok(Arc::new(ServerState {
524 default_workspace_id: config.default_workspace_id.clone(),
525 require_auth: config.require_auth,
526 default_rate_limit_rps: config.rate_limit_rps.max(1),
527 api_keys: Arc::new(api_keys),
528 workspace_rate_limits: Arc::new(workspace_rate_limits),
529 buckets: Arc::new(Mutex::new(HashMap::new())),
530 }))
531}
532
533pub async fn serve(addr: std::net::SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
535 let config = VectorConfig::from_env();
536 let state = load_server_state(&config).await?;
537 let engine = Arc::new(VectorEngine::new(config.clone()).await?);
538
539 let interceptor = AuthRateTraceInterceptor { state };
540 let embedding = EmbeddingServiceServer::new(EmbeddingServiceImpl);
541 let vector = VectorServiceServer::new(VectorServiceImpl { engine });
542
543 Server::builder()
544 .add_service(InterceptedService::new(embedding, interceptor.clone()))
545 .add_service(InterceptedService::new(vector, interceptor))
546 .serve(addr)
547 .await?;
548 Ok(())
549}
550
551#[cfg(test)]
552mod tests {
553 use super::*;
554 use tonic::service::Interceptor;
555
556 fn interceptor_for_test(require_auth: bool, rate_limit: u32) -> AuthRateTraceInterceptor {
557 let mut api_keys = HashMap::new();
558 api_keys.insert(hash_api_key("valid-key"), "ws-test".to_string());
559 AuthRateTraceInterceptor {
560 state: Arc::new(ServerState {
561 default_workspace_id: "default".to_string(),
562 require_auth,
563 default_rate_limit_rps: rate_limit,
564 api_keys: Arc::new(api_keys),
565 workspace_rate_limits: Arc::new(HashMap::new()),
566 buckets: Arc::new(Mutex::new(HashMap::new())),
567 }),
568 }
569 }
570
571 #[test]
572 fn interceptor_rejects_invalid_api_key() {
573 let mut interceptor = interceptor_for_test(true, 100);
574 let mut request = Request::new(());
575 request.metadata_mut().insert(
576 WORKSPACE_HEADER,
577 MetadataValue::try_from("ws-test").unwrap(),
578 );
579 request.metadata_mut().insert(
580 API_KEY_HEADER,
581 MetadataValue::try_from("wrong-key").unwrap(),
582 );
583
584 let result = interceptor.call(request);
585 assert!(matches!(result, Err(status) if status.code() == tonic::Code::Unauthenticated));
586 }
587
588 #[test]
589 fn interceptor_applies_workspace_rate_limit() {
590 let mut interceptor = interceptor_for_test(false, 100);
591 let mut last: Result<Request<()>, Status> = Ok(Request::new(()));
592 for _ in 0..101 {
593 let mut request = Request::new(());
594 request.metadata_mut().insert(
595 WORKSPACE_HEADER,
596 MetadataValue::try_from("ws-test").unwrap(),
597 );
598 last = interceptor.call(request);
599 }
600
601 assert!(matches!(last, Err(status) if status.code() == tonic::Code::ResourceExhausted));
602 }
603}