Skip to main content

claw_vector/grpc/
server.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, Mutex},
4    time::Instant,
5};
6
7use sha2::{Digest, Sha256};
8use tokio_stream::wrappers::ReceiverStream;
9use tonic::{
10    metadata::MetadataValue,
11    service::interceptor::InterceptedService,
12    transport::Server,
13    Request,
14    Response,
15    Status,
16};
17
18use crate::{
19    config::VectorConfig,
20    engine::VectorEngine,
21    grpc::proto::{
22        embedding_service_server::{EmbeddingService, EmbeddingServiceServer},
23        vector_service_server::{VectorService, VectorServiceServer},
24        CollectionInfo, CollectionStatsResponse, CreateCollectionRequest, DeleteCollectionRequest,
25        DeleteResult, EmbedRequest, EmbedResponse, HealthRequest, HealthResponse, ListCollectionsResponse,
26        ListRequest, ModelInfoRequest, ModelInfoResponse, SearchMetricsProto, SearchRequest,
27        SearchResponseProto, StatsRequest, UpsertResult, UpsertVectorRequest,
28    },
29    types::{DistanceMetric, SearchQuery},
30    VectorError,
31};
32
33const WORKSPACE_HEADER: &str = "x-claw-workspace-id";
34const API_KEY_HEADER: &str = "x-claw-api-key";
35const TRACE_HEADER: &str = "x-trace-id";
36
37#[derive(Clone)]
38struct WorkspaceId(String);
39
40#[derive(Clone)]
41struct TraceId(String);
42
43/// Minimal pass-through embedding service stub for local Rust gRPC server mode.
44pub struct EmbeddingServiceImpl;
45
46#[tonic::async_trait]
47impl EmbeddingService for EmbeddingServiceImpl {
48    async fn embed(
49        &self,
50        _request: Request<EmbedRequest>,
51    ) -> Result<Response<EmbedResponse>, Status> {
52        Err(Status::unimplemented(
53            "Embed is handled by the Python embedding service",
54        ))
55    }
56
57    async fn health(
58        &self,
59        _request: Request<HealthRequest>,
60    ) -> Result<Response<HealthResponse>, Status> {
61        Ok(Response::new(HealthResponse {
62            ready: false,
63            model_name: String::new(),
64            model_load_time_ms: 0,
65        }))
66    }
67
68    async fn model_info(
69        &self,
70        _request: Request<ModelInfoRequest>,
71    ) -> Result<Response<ModelInfoResponse>, Status> {
72        Err(Status::unimplemented(
73            "ModelInfo is handled by the Python embedding service",
74        ))
75    }
76
77    type EmbedStreamStream = ReceiverStream<Result<EmbedResponse, Status>>;
78
79    async fn embed_stream(
80        &self,
81        _request: Request<tonic::Streaming<EmbedRequest>>,
82    ) -> Result<Response<Self::EmbedStreamStream>, Status> {
83        Err(Status::unimplemented(
84            "EmbedStream is handled by the Python embedding service",
85        ))
86    }
87}
88
89#[derive(Clone)]
90struct ServerState {
91    default_workspace_id: String,
92    require_auth: bool,
93    default_rate_limit_rps: u32,
94    api_keys: Arc<HashMap<String, String>>,
95    workspace_rate_limits: Arc<HashMap<String, u32>>,
96    buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
97}
98
99#[derive(Clone)]
100struct AuthRateTraceInterceptor {
101    state: Arc<ServerState>,
102}
103
104impl tonic::service::Interceptor for AuthRateTraceInterceptor {
105    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
106        let workspace_id = request
107            .metadata()
108            .get(WORKSPACE_HEADER)
109            .and_then(|value| value.to_str().ok())
110            .map(ToOwned::to_owned)
111            .unwrap_or_else(|| self.state.default_workspace_id.clone());
112
113        if self.state.require_auth {
114            let api_key = request
115                .metadata()
116                .get(API_KEY_HEADER)
117                .and_then(|value| value.to_str().ok())
118                .ok_or_else(|| Status::unauthenticated("missing x-claw-api-key"))?;
119
120            let hashed = hash_api_key(api_key);
121            let valid_workspace = self
122                .state
123                .api_keys
124                .get(&hashed)
125                .map(|ws| ws == &workspace_id)
126                .unwrap_or(false);
127            if !valid_workspace {
128                return Err(Status::unauthenticated("invalid API key"));
129            }
130        }
131
132        let rate_limit = self
133            .state
134            .workspace_rate_limits
135            .get(&workspace_id)
136            .copied()
137            .unwrap_or(self.state.default_rate_limit_rps)
138            .max(1);
139        {
140            let mut buckets = self
141                .state
142                .buckets
143                .lock()
144                .map_err(|_| Status::internal("rate limiter lock poisoned"))?;
145            let bucket = buckets
146                .entry(workspace_id.clone())
147                .or_insert_with(|| TokenBucket::new(rate_limit));
148            bucket.rate_limit_rps = rate_limit;
149            if !bucket.try_consume(1.0) {
150                return Err(Status::resource_exhausted("rate limit exceeded"));
151            }
152        }
153
154        request.extensions_mut().insert(WorkspaceId(workspace_id));
155        request
156            .extensions_mut()
157            .insert(TraceId(format!("trace-{}", uuid::Uuid::new_v4())));
158        Ok(request)
159    }
160}
161
162#[derive(Debug)]
163struct TokenBucket {
164    tokens: f64,
165    last_refill: Instant,
166    rate_limit_rps: u32,
167}
168
169impl TokenBucket {
170    fn new(rate_limit_rps: u32) -> Self {
171        let rate = rate_limit_rps.max(1) as f64;
172        Self {
173            tokens: rate,
174            last_refill: Instant::now(),
175            rate_limit_rps,
176        }
177    }
178
179    fn try_consume(&mut self, cost: f64) -> bool {
180        let now = Instant::now();
181        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
182        let rate = self.rate_limit_rps.max(1) as f64;
183        self.tokens = (self.tokens + elapsed * rate).min(rate);
184        self.last_refill = now;
185        if self.tokens >= cost {
186            self.tokens -= cost;
187            true
188        } else {
189            false
190        }
191    }
192}
193
194/// VectorService gRPC implementation backed by [`VectorEngine`].
195pub struct VectorServiceImpl {
196    engine: Arc<VectorEngine>,
197}
198
199impl VectorServiceImpl {
200    fn workspace_from_request<T>(&self, request: &Request<T>) -> String {
201        request
202            .extensions()
203            .get::<WorkspaceId>()
204            .map(|value| value.0.clone())
205            .unwrap_or_else(|| self.engine.config.default_workspace_id.clone())
206    }
207
208    fn trace_from_request<T>(
209        &self,
210        request: &Request<T>,
211    ) -> Option<MetadataValue<tonic::metadata::Ascii>> {
212        request
213            .extensions()
214            .get::<TraceId>()
215            .and_then(|value| MetadataValue::try_from(value.0.as_str()).ok())
216    }
217}
218
219#[tonic::async_trait]
220impl VectorService for VectorServiceImpl {
221    async fn create_collection(
222        &self,
223        request: Request<CreateCollectionRequest>,
224    ) -> Result<Response<CollectionInfo>, Status> {
225        let trace = self.trace_from_request(&request);
226        let workspace_id = self.workspace_from_request(&request);
227        let req = request.into_inner();
228        let metric = parse_distance_metric(&req.distance_metric)?;
229        let created = self
230            .engine
231            .create_collection_in_workspace(&workspace_id, &req.name, req.dimensions as usize, metric)
232            .await
233            .map_err(Status::from)?;
234        let info = collection_to_proto(&created);
235        let mut response = Response::new(info);
236        if let Some(trace_id) = trace {
237            response.metadata_mut().insert(TRACE_HEADER, trace_id);
238        }
239        Ok(response)
240    }
241
242    async fn delete_collection(
243        &self,
244        request: Request<DeleteCollectionRequest>,
245    ) -> Result<Response<DeleteResult>, Status> {
246        let trace = self.trace_from_request(&request);
247        let workspace_id = self.workspace_from_request(&request);
248        let req = request.into_inner();
249        let stats = self
250            .engine
251            .collections
252            .store
253            .collection_stats(&workspace_id, &req.name)
254            .await
255            .ok();
256        self.engine
257            .delete_collection_in_workspace(&workspace_id, &req.name)
258            .await
259            .map_err(Status::from)?;
260        let mut response = Response::new(DeleteResult {
261            records_removed: stats.map(|value| value.vector_count).unwrap_or(0),
262        });
263        if let Some(trace_id) = trace {
264            response.metadata_mut().insert(TRACE_HEADER, trace_id);
265        }
266        Ok(response)
267    }
268
269    async fn upsert_vector(
270        &self,
271        request: Request<UpsertVectorRequest>,
272    ) -> Result<Response<UpsertResult>, Status> {
273        let trace = self.trace_from_request(&request);
274        let workspace_id = self.workspace_from_request(&request);
275        let req = request.into_inner();
276        let metadata = if req.metadata_json.trim().is_empty() {
277            serde_json::json!({})
278        } else {
279            serde_json::from_str(&req.metadata_json).map_err(|err| {
280                Status::invalid_argument(format!("invalid metadata_json: {err}"))
281            })?
282        };
283
284        let id = if !req.vector.is_empty() {
285            self.engine
286                .upsert_vector_in_workspace(&workspace_id, &req.collection, req.vector, metadata)
287                .await
288                .map_err(Status::from)?
289        } else if !req.text.trim().is_empty() {
290            self.engine
291                .upsert_in_workspace(&workspace_id, &req.collection, &req.text, metadata)
292                .await
293                .map_err(Status::from)?
294        } else {
295            return Err(Status::invalid_argument(
296                "either vector or text must be provided",
297            ));
298        };
299
300        let mut response = Response::new(UpsertResult { id: id.to_string() });
301        if let Some(trace_id) = trace {
302            response.metadata_mut().insert(TRACE_HEADER, trace_id);
303        }
304        Ok(response)
305    }
306
307    async fn search_vectors(
308        &self,
309        request: Request<SearchRequest>,
310    ) -> Result<Response<SearchResponseProto>, Status> {
311        let trace = self.trace_from_request(&request);
312        let workspace_id = self.workspace_from_request(&request);
313        let req = request.into_inner();
314        let filter = if req.filter_json.trim().is_empty() {
315            None
316        } else {
317            Some(serde_json::from_str(&req.filter_json).map_err(|err| {
318                Status::invalid_argument(format!("invalid filter_json: {err}"))
319            })?)
320        };
321
322        let query = if !req.vector.is_empty() {
323            SearchQuery {
324                collection: req.collection,
325                vector: req.vector,
326                top_k: req.top_k.max(1) as usize,
327                filter,
328                include_vectors: req.include_vectors,
329                include_metadata: req.include_metadata,
330                ef_search: None,
331                reranker: None,
332            }
333        } else if !req.text.trim().is_empty() {
334            let embedded = self
335                .engine
336                .embedding_client
337                .embed_one(&req.text)
338                .await
339                .map_err(Status::from)?;
340            SearchQuery {
341                collection: req.collection,
342                vector: embedded,
343                top_k: req.top_k.max(1) as usize,
344                filter,
345                include_vectors: req.include_vectors,
346                include_metadata: req.include_metadata,
347                ef_search: None,
348                reranker: None,
349            }
350        } else {
351            return Err(Status::invalid_argument(
352                "either vector or text must be provided",
353            ));
354        };
355
356        let response = self
357            .engine
358            .search_in_workspace(&workspace_id, query)
359            .await
360            .map_err(Status::from)?;
361
362        let proto = SearchResponseProto {
363            results: response
364                .results
365                .into_iter()
366                .map(|result| crate::grpc::proto::SearchHit {
367                    id: result.id.to_string(),
368                    score: result.score,
369                    vector: result.vector.unwrap_or_default(),
370                    metadata_json: if result.metadata.is_null() {
371                        "{}".to_string()
372                    } else {
373                        serde_json::to_string(&result.metadata).unwrap_or_else(|_| "{}".to_string())
374                    },
375                    text: result.text.unwrap_or_default(),
376                })
377                .collect(),
378            metrics: Some(SearchMetricsProto {
379                query_vector_dims: response.metrics.query_vector_dims as u32,
380                candidates_evaluated: response.metrics.candidates_evaluated as u32,
381                post_filter_count: response.metrics.post_filter_count as u32,
382                latency_us: response.metrics.latency_us,
383            }),
384        };
385
386        let mut response = Response::new(proto);
387        if let Some(trace_id) = trace {
388            response.metadata_mut().insert(TRACE_HEADER, trace_id);
389        }
390        Ok(response)
391    }
392
393    async fn get_collection_stats(
394        &self,
395        request: Request<StatsRequest>,
396    ) -> Result<Response<CollectionStatsResponse>, Status> {
397        let trace = self.trace_from_request(&request);
398        let workspace_id = self.workspace_from_request(&request);
399        let req = request.into_inner();
400        let collection = self
401            .engine
402            .collections
403            .get_collection(&workspace_id, &req.collection)
404            .await
405            .map_err(Status::from)?;
406
407        let response = CollectionStatsResponse {
408            vector_count: collection.vector_count,
409            index_type: format!("{:?}", collection.index_type).to_lowercase(),
410            dimensions: collection.dimensions as u32,
411            last_modified_at: collection.created_at.timestamp_millis(),
412        };
413
414        let mut response = Response::new(response);
415        if let Some(trace_id) = trace {
416            response.metadata_mut().insert(TRACE_HEADER, trace_id);
417        }
418        Ok(response)
419    }
420
421    async fn list_collections(
422        &self,
423        request: Request<ListRequest>,
424    ) -> Result<Response<ListCollectionsResponse>, Status> {
425        let trace = self.trace_from_request(&request);
426        let workspace_id = self.workspace_from_request(&request);
427        let req = request.into_inner();
428        let page_size = req.page_size.clamp(1, 500) as usize;
429        let page = req.page.max(1) as usize;
430
431        let collections = self
432            .engine
433            .list_collections_in_workspace(&workspace_id)
434            .await
435            .map_err(Status::from)?;
436        let total = collections.len();
437        let start = page_size.saturating_mul(page - 1);
438        let page_items = collections
439            .into_iter()
440            .skip(start)
441            .take(page_size)
442            .map(|collection| collection_to_proto(&collection))
443            .collect::<Vec<_>>();
444
445        let mut response = Response::new(ListCollectionsResponse {
446            collections: page_items,
447            page: page as u32,
448            page_size: page_size as u32,
449            total: total as u32,
450        });
451        if let Some(trace_id) = trace {
452            response.metadata_mut().insert(TRACE_HEADER, trace_id);
453        }
454        Ok(response)
455    }
456}
457
458fn collection_to_proto(collection: &crate::types::Collection) -> CollectionInfo {
459    CollectionInfo {
460        id: format!("{}:{}", collection.workspace_id, collection.name),
461        name: collection.name.clone(),
462        dimensions: collection.dimensions as u32,
463        distance_metric: format!("{:?}", collection.distance).to_lowercase(),
464        index_type: format!("{:?}", collection.index_type).to_lowercase(),
465        vector_count: collection.vector_count,
466        last_modified_at: collection.created_at.timestamp_millis(),
467    }
468}
469
470#[allow(clippy::result_large_err)]
471fn parse_distance_metric(raw: &str) -> Result<DistanceMetric, Status> {
472    match raw.trim().to_ascii_lowercase().as_str() {
473        "cosine" => Ok(DistanceMetric::Cosine),
474        "euclidean" => Ok(DistanceMetric::Euclidean),
475        "dot" | "dot_product" => Ok(DistanceMetric::DotProduct),
476        _ => Err(Status::invalid_argument("unsupported distance metric")),
477    }
478}
479
480fn hash_api_key(api_key: &str) -> String {
481    let mut hasher = Sha256::new();
482    hasher.update(api_key.as_bytes());
483    hex::encode(hasher.finalize())
484}
485
486async fn load_server_state(config: &VectorConfig) -> Result<Arc<ServerState>, VectorError> {
487    let store = crate::store::sqlite::VectorStore::new(&config.api_key_store_path).await?;
488    sqlx::query(
489        "CREATE TABLE IF NOT EXISTS api_keys (key_hash TEXT PRIMARY KEY, workspace_id TEXT NOT NULL, created_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0)",
490    )
491    .execute(store.pool())
492    .await?;
493
494    let key_rows = sqlx::query_as::<_, (String, String)>(
495        "SELECT key_hash, workspace_id FROM api_keys WHERE revoked = 0",
496    )
497    .fetch_all(store.pool())
498    .await?;
499    let api_keys = key_rows.into_iter().collect::<HashMap<_, _>>();
500
501    let mut workspace_rate_limits = HashMap::new();
502    let has_rate_table = sqlx::query_scalar::<_, i64>(
503        "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='workspace_rate_limits'",
504    )
505    .fetch_one(store.pool())
506    .await
507    .unwrap_or(0)
508        > 0;
509    if has_rate_table {
510        let rows = sqlx::query_as::<_, (String, i64)>(
511            "SELECT workspace_id, rate_limit_rps FROM workspace_rate_limits",
512        )
513        .fetch_all(store.pool())
514        .await
515        .unwrap_or_default();
516        for (workspace_id, rps) in rows {
517            workspace_rate_limits.insert(workspace_id, (rps as u32).max(1));
518        }
519    }
520
521    Ok(Arc::new(ServerState {
522        default_workspace_id: config.default_workspace_id.clone(),
523        require_auth: config.require_auth,
524        default_rate_limit_rps: config.rate_limit_rps.max(1),
525        api_keys: Arc::new(api_keys),
526        workspace_rate_limits: Arc::new(workspace_rate_limits),
527        buckets: Arc::new(Mutex::new(HashMap::new())),
528    }))
529}
530
531/// Start the Rust gRPC server with auth, rate limiting, and trace-id interception.
532pub async fn serve(addr: std::net::SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
533    let config = VectorConfig::from_env();
534    let state = load_server_state(&config).await?;
535    let engine = Arc::new(VectorEngine::new(config.clone()).await?);
536
537    let interceptor = AuthRateTraceInterceptor { state };
538    let embedding = EmbeddingServiceServer::new(EmbeddingServiceImpl);
539    let vector = VectorServiceServer::new(VectorServiceImpl { engine });
540
541    Server::builder()
542        .add_service(InterceptedService::new(embedding, interceptor.clone()))
543        .add_service(InterceptedService::new(vector, interceptor))
544        .serve(addr)
545        .await?;
546    Ok(())
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use tonic::service::Interceptor;
553
554    fn interceptor_for_test(require_auth: bool, rate_limit: u32) -> AuthRateTraceInterceptor {
555        let mut api_keys = HashMap::new();
556        api_keys.insert(hash_api_key("valid-key"), "ws-test".to_string());
557        AuthRateTraceInterceptor {
558            state: Arc::new(ServerState {
559                default_workspace_id: "default".to_string(),
560                require_auth,
561                default_rate_limit_rps: rate_limit,
562                api_keys: Arc::new(api_keys),
563                workspace_rate_limits: Arc::new(HashMap::new()),
564                buckets: Arc::new(Mutex::new(HashMap::new())),
565            }),
566        }
567    }
568
569    #[test]
570    fn interceptor_rejects_invalid_api_key() {
571        let mut interceptor = interceptor_for_test(true, 100);
572        let mut request = Request::new(());
573        request
574            .metadata_mut()
575            .insert(WORKSPACE_HEADER, MetadataValue::try_from("ws-test").unwrap());
576        request
577            .metadata_mut()
578            .insert(API_KEY_HEADER, MetadataValue::try_from("wrong-key").unwrap());
579
580        let result = interceptor.call(request);
581        assert!(matches!(result, Err(status) if status.code() == tonic::Code::Unauthenticated));
582    }
583
584    #[test]
585    fn interceptor_applies_workspace_rate_limit() {
586        let mut interceptor = interceptor_for_test(false, 100);
587        let mut last: Result<Request<()>, Status> = Ok(Request::new(()));
588        for _ in 0..101 {
589            let mut request = Request::new(());
590            request
591                .metadata_mut()
592                .insert(WORKSPACE_HEADER, MetadataValue::try_from("ws-test").unwrap());
593            last = interceptor.call(request);
594        }
595
596        assert!(matches!(last, Err(status) if status.code() == tonic::Code::ResourceExhausted));
597    }
598}