1use std::{
2 collections::HashMap,
3 sync::{Arc, Mutex},
4 time::Instant,
5};
6
7use sha2::{Digest, Sha256};
8use tokio_stream::wrappers::ReceiverStream;
9use tonic::{
10 metadata::MetadataValue,
11 service::interceptor::InterceptedService,
12 transport::Server,
13 Request,
14 Response,
15 Status,
16};
17
18use crate::{
19 config::VectorConfig,
20 engine::VectorEngine,
21 grpc::proto::{
22 embedding_service_server::{EmbeddingService, EmbeddingServiceServer},
23 vector_service_server::{VectorService, VectorServiceServer},
24 CollectionInfo, CollectionStatsResponse, CreateCollectionRequest, DeleteCollectionRequest,
25 DeleteResult, EmbedRequest, EmbedResponse, HealthRequest, HealthResponse, ListCollectionsResponse,
26 ListRequest, ModelInfoRequest, ModelInfoResponse, SearchMetricsProto, SearchRequest,
27 SearchResponseProto, StatsRequest, UpsertResult, UpsertVectorRequest,
28 },
29 types::{DistanceMetric, SearchQuery},
30 VectorError,
31};
32
33const WORKSPACE_HEADER: &str = "x-claw-workspace-id";
34const API_KEY_HEADER: &str = "x-claw-api-key";
35const TRACE_HEADER: &str = "x-trace-id";
36
37#[derive(Clone)]
38struct WorkspaceId(String);
39
40#[derive(Clone)]
41struct TraceId(String);
42
43pub struct EmbeddingServiceImpl;
45
46#[tonic::async_trait]
47impl EmbeddingService for EmbeddingServiceImpl {
48 async fn embed(
49 &self,
50 _request: Request<EmbedRequest>,
51 ) -> Result<Response<EmbedResponse>, Status> {
52 Err(Status::unimplemented(
53 "Embed is handled by the Python embedding service",
54 ))
55 }
56
57 async fn health(
58 &self,
59 _request: Request<HealthRequest>,
60 ) -> Result<Response<HealthResponse>, Status> {
61 Ok(Response::new(HealthResponse {
62 ready: false,
63 model_name: String::new(),
64 model_load_time_ms: 0,
65 }))
66 }
67
68 async fn model_info(
69 &self,
70 _request: Request<ModelInfoRequest>,
71 ) -> Result<Response<ModelInfoResponse>, Status> {
72 Err(Status::unimplemented(
73 "ModelInfo is handled by the Python embedding service",
74 ))
75 }
76
77 type EmbedStreamStream = ReceiverStream<Result<EmbedResponse, Status>>;
78
79 async fn embed_stream(
80 &self,
81 _request: Request<tonic::Streaming<EmbedRequest>>,
82 ) -> Result<Response<Self::EmbedStreamStream>, Status> {
83 Err(Status::unimplemented(
84 "EmbedStream is handled by the Python embedding service",
85 ))
86 }
87}
88
89#[derive(Clone)]
90struct ServerState {
91 default_workspace_id: String,
92 require_auth: bool,
93 default_rate_limit_rps: u32,
94 api_keys: Arc<HashMap<String, String>>,
95 workspace_rate_limits: Arc<HashMap<String, u32>>,
96 buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
97}
98
99#[derive(Clone)]
100struct AuthRateTraceInterceptor {
101 state: Arc<ServerState>,
102}
103
104impl tonic::service::Interceptor for AuthRateTraceInterceptor {
105 fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
106 let workspace_id = request
107 .metadata()
108 .get(WORKSPACE_HEADER)
109 .and_then(|value| value.to_str().ok())
110 .map(ToOwned::to_owned)
111 .unwrap_or_else(|| self.state.default_workspace_id.clone());
112
113 if self.state.require_auth {
114 let api_key = request
115 .metadata()
116 .get(API_KEY_HEADER)
117 .and_then(|value| value.to_str().ok())
118 .ok_or_else(|| Status::unauthenticated("missing x-claw-api-key"))?;
119
120 let hashed = hash_api_key(api_key);
121 let valid_workspace = self
122 .state
123 .api_keys
124 .get(&hashed)
125 .map(|ws| ws == &workspace_id)
126 .unwrap_or(false);
127 if !valid_workspace {
128 return Err(Status::unauthenticated("invalid API key"));
129 }
130 }
131
132 let rate_limit = self
133 .state
134 .workspace_rate_limits
135 .get(&workspace_id)
136 .copied()
137 .unwrap_or(self.state.default_rate_limit_rps)
138 .max(1);
139 {
140 let mut buckets = self
141 .state
142 .buckets
143 .lock()
144 .map_err(|_| Status::internal("rate limiter lock poisoned"))?;
145 let bucket = buckets
146 .entry(workspace_id.clone())
147 .or_insert_with(|| TokenBucket::new(rate_limit));
148 bucket.rate_limit_rps = rate_limit;
149 if !bucket.try_consume(1.0) {
150 return Err(Status::resource_exhausted("rate limit exceeded"));
151 }
152 }
153
154 request.extensions_mut().insert(WorkspaceId(workspace_id));
155 request
156 .extensions_mut()
157 .insert(TraceId(format!("trace-{}", uuid::Uuid::new_v4())));
158 Ok(request)
159 }
160}
161
162#[derive(Debug)]
163struct TokenBucket {
164 tokens: f64,
165 last_refill: Instant,
166 rate_limit_rps: u32,
167}
168
169impl TokenBucket {
170 fn new(rate_limit_rps: u32) -> Self {
171 let rate = rate_limit_rps.max(1) as f64;
172 Self {
173 tokens: rate,
174 last_refill: Instant::now(),
175 rate_limit_rps,
176 }
177 }
178
179 fn try_consume(&mut self, cost: f64) -> bool {
180 let now = Instant::now();
181 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
182 let rate = self.rate_limit_rps.max(1) as f64;
183 self.tokens = (self.tokens + elapsed * rate).min(rate);
184 self.last_refill = now;
185 if self.tokens >= cost {
186 self.tokens -= cost;
187 true
188 } else {
189 false
190 }
191 }
192}
193
194pub struct VectorServiceImpl {
196 engine: Arc<VectorEngine>,
197}
198
199impl VectorServiceImpl {
200 fn workspace_from_request<T>(&self, request: &Request<T>) -> String {
201 request
202 .extensions()
203 .get::<WorkspaceId>()
204 .map(|value| value.0.clone())
205 .unwrap_or_else(|| self.engine.config.default_workspace_id.clone())
206 }
207
208 fn trace_from_request<T>(
209 &self,
210 request: &Request<T>,
211 ) -> Option<MetadataValue<tonic::metadata::Ascii>> {
212 request
213 .extensions()
214 .get::<TraceId>()
215 .and_then(|value| MetadataValue::try_from(value.0.as_str()).ok())
216 }
217}
218
219#[tonic::async_trait]
220impl VectorService for VectorServiceImpl {
221 async fn create_collection(
222 &self,
223 request: Request<CreateCollectionRequest>,
224 ) -> Result<Response<CollectionInfo>, Status> {
225 let trace = self.trace_from_request(&request);
226 let workspace_id = self.workspace_from_request(&request);
227 let req = request.into_inner();
228 let metric = parse_distance_metric(&req.distance_metric)?;
229 let created = self
230 .engine
231 .create_collection_in_workspace(&workspace_id, &req.name, req.dimensions as usize, metric)
232 .await
233 .map_err(Status::from)?;
234 let info = collection_to_proto(&created);
235 let mut response = Response::new(info);
236 if let Some(trace_id) = trace {
237 response.metadata_mut().insert(TRACE_HEADER, trace_id);
238 }
239 Ok(response)
240 }
241
242 async fn delete_collection(
243 &self,
244 request: Request<DeleteCollectionRequest>,
245 ) -> Result<Response<DeleteResult>, Status> {
246 let trace = self.trace_from_request(&request);
247 let workspace_id = self.workspace_from_request(&request);
248 let req = request.into_inner();
249 let stats = self
250 .engine
251 .collections
252 .store
253 .collection_stats(&workspace_id, &req.name)
254 .await
255 .ok();
256 self.engine
257 .delete_collection_in_workspace(&workspace_id, &req.name)
258 .await
259 .map_err(Status::from)?;
260 let mut response = Response::new(DeleteResult {
261 records_removed: stats.map(|value| value.vector_count).unwrap_or(0),
262 });
263 if let Some(trace_id) = trace {
264 response.metadata_mut().insert(TRACE_HEADER, trace_id);
265 }
266 Ok(response)
267 }
268
269 async fn upsert_vector(
270 &self,
271 request: Request<UpsertVectorRequest>,
272 ) -> Result<Response<UpsertResult>, Status> {
273 let trace = self.trace_from_request(&request);
274 let workspace_id = self.workspace_from_request(&request);
275 let req = request.into_inner();
276 let metadata = if req.metadata_json.trim().is_empty() {
277 serde_json::json!({})
278 } else {
279 serde_json::from_str(&req.metadata_json).map_err(|err| {
280 Status::invalid_argument(format!("invalid metadata_json: {err}"))
281 })?
282 };
283
284 let id = if !req.vector.is_empty() {
285 self.engine
286 .upsert_vector_in_workspace(&workspace_id, &req.collection, req.vector, metadata)
287 .await
288 .map_err(Status::from)?
289 } else if !req.text.trim().is_empty() {
290 self.engine
291 .upsert_in_workspace(&workspace_id, &req.collection, &req.text, metadata)
292 .await
293 .map_err(Status::from)?
294 } else {
295 return Err(Status::invalid_argument(
296 "either vector or text must be provided",
297 ));
298 };
299
300 let mut response = Response::new(UpsertResult { id: id.to_string() });
301 if let Some(trace_id) = trace {
302 response.metadata_mut().insert(TRACE_HEADER, trace_id);
303 }
304 Ok(response)
305 }
306
307 async fn search_vectors(
308 &self,
309 request: Request<SearchRequest>,
310 ) -> Result<Response<SearchResponseProto>, Status> {
311 let trace = self.trace_from_request(&request);
312 let workspace_id = self.workspace_from_request(&request);
313 let req = request.into_inner();
314 let filter = if req.filter_json.trim().is_empty() {
315 None
316 } else {
317 Some(serde_json::from_str(&req.filter_json).map_err(|err| {
318 Status::invalid_argument(format!("invalid filter_json: {err}"))
319 })?)
320 };
321
322 let query = if !req.vector.is_empty() {
323 SearchQuery {
324 collection: req.collection,
325 vector: req.vector,
326 top_k: req.top_k.max(1) as usize,
327 filter,
328 include_vectors: req.include_vectors,
329 include_metadata: req.include_metadata,
330 ef_search: None,
331 reranker: None,
332 }
333 } else if !req.text.trim().is_empty() {
334 let embedded = self
335 .engine
336 .embedding_client
337 .embed_one(&req.text)
338 .await
339 .map_err(Status::from)?;
340 SearchQuery {
341 collection: req.collection,
342 vector: embedded,
343 top_k: req.top_k.max(1) as usize,
344 filter,
345 include_vectors: req.include_vectors,
346 include_metadata: req.include_metadata,
347 ef_search: None,
348 reranker: None,
349 }
350 } else {
351 return Err(Status::invalid_argument(
352 "either vector or text must be provided",
353 ));
354 };
355
356 let response = self
357 .engine
358 .search_in_workspace(&workspace_id, query)
359 .await
360 .map_err(Status::from)?;
361
362 let proto = SearchResponseProto {
363 results: response
364 .results
365 .into_iter()
366 .map(|result| crate::grpc::proto::SearchHit {
367 id: result.id.to_string(),
368 score: result.score,
369 vector: result.vector.unwrap_or_default(),
370 metadata_json: if result.metadata.is_null() {
371 "{}".to_string()
372 } else {
373 serde_json::to_string(&result.metadata).unwrap_or_else(|_| "{}".to_string())
374 },
375 text: result.text.unwrap_or_default(),
376 })
377 .collect(),
378 metrics: Some(SearchMetricsProto {
379 query_vector_dims: response.metrics.query_vector_dims as u32,
380 candidates_evaluated: response.metrics.candidates_evaluated as u32,
381 post_filter_count: response.metrics.post_filter_count as u32,
382 latency_us: response.metrics.latency_us,
383 }),
384 };
385
386 let mut response = Response::new(proto);
387 if let Some(trace_id) = trace {
388 response.metadata_mut().insert(TRACE_HEADER, trace_id);
389 }
390 Ok(response)
391 }
392
393 async fn get_collection_stats(
394 &self,
395 request: Request<StatsRequest>,
396 ) -> Result<Response<CollectionStatsResponse>, Status> {
397 let trace = self.trace_from_request(&request);
398 let workspace_id = self.workspace_from_request(&request);
399 let req = request.into_inner();
400 let collection = self
401 .engine
402 .collections
403 .get_collection(&workspace_id, &req.collection)
404 .await
405 .map_err(Status::from)?;
406
407 let response = CollectionStatsResponse {
408 vector_count: collection.vector_count,
409 index_type: format!("{:?}", collection.index_type).to_lowercase(),
410 dimensions: collection.dimensions as u32,
411 last_modified_at: collection.created_at.timestamp_millis(),
412 };
413
414 let mut response = Response::new(response);
415 if let Some(trace_id) = trace {
416 response.metadata_mut().insert(TRACE_HEADER, trace_id);
417 }
418 Ok(response)
419 }
420
421 async fn list_collections(
422 &self,
423 request: Request<ListRequest>,
424 ) -> Result<Response<ListCollectionsResponse>, Status> {
425 let trace = self.trace_from_request(&request);
426 let workspace_id = self.workspace_from_request(&request);
427 let req = request.into_inner();
428 let page_size = req.page_size.clamp(1, 500) as usize;
429 let page = req.page.max(1) as usize;
430
431 let collections = self
432 .engine
433 .list_collections_in_workspace(&workspace_id)
434 .await
435 .map_err(Status::from)?;
436 let total = collections.len();
437 let start = page_size.saturating_mul(page - 1);
438 let page_items = collections
439 .into_iter()
440 .skip(start)
441 .take(page_size)
442 .map(|collection| collection_to_proto(&collection))
443 .collect::<Vec<_>>();
444
445 let mut response = Response::new(ListCollectionsResponse {
446 collections: page_items,
447 page: page as u32,
448 page_size: page_size as u32,
449 total: total as u32,
450 });
451 if let Some(trace_id) = trace {
452 response.metadata_mut().insert(TRACE_HEADER, trace_id);
453 }
454 Ok(response)
455 }
456}
457
458fn collection_to_proto(collection: &crate::types::Collection) -> CollectionInfo {
459 CollectionInfo {
460 id: format!("{}:{}", collection.workspace_id, collection.name),
461 name: collection.name.clone(),
462 dimensions: collection.dimensions as u32,
463 distance_metric: format!("{:?}", collection.distance).to_lowercase(),
464 index_type: format!("{:?}", collection.index_type).to_lowercase(),
465 vector_count: collection.vector_count,
466 last_modified_at: collection.created_at.timestamp_millis(),
467 }
468}
469
470#[allow(clippy::result_large_err)]
471fn parse_distance_metric(raw: &str) -> Result<DistanceMetric, Status> {
472 match raw.trim().to_ascii_lowercase().as_str() {
473 "cosine" => Ok(DistanceMetric::Cosine),
474 "euclidean" => Ok(DistanceMetric::Euclidean),
475 "dot" | "dot_product" => Ok(DistanceMetric::DotProduct),
476 _ => Err(Status::invalid_argument("unsupported distance metric")),
477 }
478}
479
480fn hash_api_key(api_key: &str) -> String {
481 let mut hasher = Sha256::new();
482 hasher.update(api_key.as_bytes());
483 hex::encode(hasher.finalize())
484}
485
486async fn load_server_state(config: &VectorConfig) -> Result<Arc<ServerState>, VectorError> {
487 let store = crate::store::sqlite::VectorStore::new(&config.api_key_store_path).await?;
488 sqlx::query(
489 "CREATE TABLE IF NOT EXISTS api_keys (key_hash TEXT PRIMARY KEY, workspace_id TEXT NOT NULL, created_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0)",
490 )
491 .execute(store.pool())
492 .await?;
493
494 let key_rows = sqlx::query_as::<_, (String, String)>(
495 "SELECT key_hash, workspace_id FROM api_keys WHERE revoked = 0",
496 )
497 .fetch_all(store.pool())
498 .await?;
499 let api_keys = key_rows.into_iter().collect::<HashMap<_, _>>();
500
501 let mut workspace_rate_limits = HashMap::new();
502 let has_rate_table = sqlx::query_scalar::<_, i64>(
503 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='workspace_rate_limits'",
504 )
505 .fetch_one(store.pool())
506 .await
507 .unwrap_or(0)
508 > 0;
509 if has_rate_table {
510 let rows = sqlx::query_as::<_, (String, i64)>(
511 "SELECT workspace_id, rate_limit_rps FROM workspace_rate_limits",
512 )
513 .fetch_all(store.pool())
514 .await
515 .unwrap_or_default();
516 for (workspace_id, rps) in rows {
517 workspace_rate_limits.insert(workspace_id, (rps as u32).max(1));
518 }
519 }
520
521 Ok(Arc::new(ServerState {
522 default_workspace_id: config.default_workspace_id.clone(),
523 require_auth: config.require_auth,
524 default_rate_limit_rps: config.rate_limit_rps.max(1),
525 api_keys: Arc::new(api_keys),
526 workspace_rate_limits: Arc::new(workspace_rate_limits),
527 buckets: Arc::new(Mutex::new(HashMap::new())),
528 }))
529}
530
531pub async fn serve(addr: std::net::SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
533 let config = VectorConfig::from_env();
534 let state = load_server_state(&config).await?;
535 let engine = Arc::new(VectorEngine::new(config.clone()).await?);
536
537 let interceptor = AuthRateTraceInterceptor { state };
538 let embedding = EmbeddingServiceServer::new(EmbeddingServiceImpl);
539 let vector = VectorServiceServer::new(VectorServiceImpl { engine });
540
541 Server::builder()
542 .add_service(InterceptedService::new(embedding, interceptor.clone()))
543 .add_service(InterceptedService::new(vector, interceptor))
544 .serve(addr)
545 .await?;
546 Ok(())
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use tonic::service::Interceptor;
553
554 fn interceptor_for_test(require_auth: bool, rate_limit: u32) -> AuthRateTraceInterceptor {
555 let mut api_keys = HashMap::new();
556 api_keys.insert(hash_api_key("valid-key"), "ws-test".to_string());
557 AuthRateTraceInterceptor {
558 state: Arc::new(ServerState {
559 default_workspace_id: "default".to_string(),
560 require_auth,
561 default_rate_limit_rps: rate_limit,
562 api_keys: Arc::new(api_keys),
563 workspace_rate_limits: Arc::new(HashMap::new()),
564 buckets: Arc::new(Mutex::new(HashMap::new())),
565 }),
566 }
567 }
568
569 #[test]
570 fn interceptor_rejects_invalid_api_key() {
571 let mut interceptor = interceptor_for_test(true, 100);
572 let mut request = Request::new(());
573 request
574 .metadata_mut()
575 .insert(WORKSPACE_HEADER, MetadataValue::try_from("ws-test").unwrap());
576 request
577 .metadata_mut()
578 .insert(API_KEY_HEADER, MetadataValue::try_from("wrong-key").unwrap());
579
580 let result = interceptor.call(request);
581 assert!(matches!(result, Err(status) if status.code() == tonic::Code::Unauthenticated));
582 }
583
584 #[test]
585 fn interceptor_applies_workspace_rate_limit() {
586 let mut interceptor = interceptor_for_test(false, 100);
587 let mut last: Result<Request<()>, Status> = Ok(Request::new(()));
588 for _ in 0..101 {
589 let mut request = Request::new(());
590 request
591 .metadata_mut()
592 .insert(WORKSPACE_HEADER, MetadataValue::try_from("ws-test").unwrap());
593 last = interceptor.call(request);
594 }
595
596 assert!(matches!(last, Err(status) if status.code() == tonic::Code::ResourceExhausted));
597 }
598}