Skip to main content

claw_vector/grpc/
server.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, Mutex},
4    time::Instant,
5};
6
7use sha2::{Digest, Sha256};
8use tokio_stream::wrappers::ReceiverStream;
9use tonic::{
10    metadata::MetadataValue, service::interceptor::InterceptedService, transport::Server, Request,
11    Response, Status,
12};
13
14use crate::{
15    config::VectorConfig,
16    engine::VectorEngine,
17    grpc::proto::{
18        embedding_service_server::{EmbeddingService, EmbeddingServiceServer},
19        vector_service_server::{VectorService, VectorServiceServer},
20        CollectionInfo, CollectionStatsResponse, CreateCollectionRequest, DeleteCollectionRequest,
21        DeleteResult, EmbedRequest, EmbedResponse, HealthRequest, HealthResponse,
22        ListCollectionsResponse, ListRequest, ModelInfoRequest, ModelInfoResponse,
23        SearchMetricsProto, SearchRequest, SearchResponseProto, StatsRequest, UpsertResult,
24        UpsertVectorRequest,
25    },
26    types::{DistanceMetric, SearchQuery},
27    VectorError,
28};
29
30const WORKSPACE_HEADER: &str = "x-claw-workspace-id";
31const API_KEY_HEADER: &str = "x-claw-api-key";
32const TRACE_HEADER: &str = "x-trace-id";
33
34#[derive(Clone)]
35struct WorkspaceId(String);
36
37#[derive(Clone)]
38struct TraceId(String);
39
40/// Minimal pass-through embedding service stub for local Rust gRPC server mode.
41pub struct EmbeddingServiceImpl;
42
43#[tonic::async_trait]
44impl EmbeddingService for EmbeddingServiceImpl {
45    async fn embed(
46        &self,
47        _request: Request<EmbedRequest>,
48    ) -> Result<Response<EmbedResponse>, Status> {
49        Err(Status::unimplemented(
50            "Embed is handled by the Python embedding service",
51        ))
52    }
53
54    async fn health(
55        &self,
56        _request: Request<HealthRequest>,
57    ) -> Result<Response<HealthResponse>, Status> {
58        Ok(Response::new(HealthResponse {
59            ready: false,
60            model_name: String::new(),
61            model_load_time_ms: 0,
62        }))
63    }
64
65    async fn model_info(
66        &self,
67        _request: Request<ModelInfoRequest>,
68    ) -> Result<Response<ModelInfoResponse>, Status> {
69        Err(Status::unimplemented(
70            "ModelInfo is handled by the Python embedding service",
71        ))
72    }
73
74    type EmbedStreamStream = ReceiverStream<Result<EmbedResponse, Status>>;
75
76    async fn embed_stream(
77        &self,
78        _request: Request<tonic::Streaming<EmbedRequest>>,
79    ) -> Result<Response<Self::EmbedStreamStream>, Status> {
80        Err(Status::unimplemented(
81            "EmbedStream is handled by the Python embedding service",
82        ))
83    }
84}
85
86#[derive(Clone)]
87struct ServerState {
88    default_workspace_id: String,
89    require_auth: bool,
90    default_rate_limit_rps: u32,
91    api_keys: Arc<HashMap<String, String>>,
92    workspace_rate_limits: Arc<HashMap<String, u32>>,
93    buckets: Arc<Mutex<HashMap<String, TokenBucket>>>,
94}
95
96#[derive(Clone)]
97struct AuthRateTraceInterceptor {
98    state: Arc<ServerState>,
99}
100
101impl tonic::service::Interceptor for AuthRateTraceInterceptor {
102    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
103        let workspace_id = request
104            .metadata()
105            .get(WORKSPACE_HEADER)
106            .and_then(|value| value.to_str().ok())
107            .map(ToOwned::to_owned)
108            .unwrap_or_else(|| self.state.default_workspace_id.clone());
109
110        if self.state.require_auth {
111            let api_key = request
112                .metadata()
113                .get(API_KEY_HEADER)
114                .and_then(|value| value.to_str().ok())
115                .ok_or_else(|| Status::unauthenticated("missing x-claw-api-key"))?;
116
117            let hashed = hash_api_key(api_key);
118            let valid_workspace = self
119                .state
120                .api_keys
121                .get(&hashed)
122                .map(|ws| ws == &workspace_id)
123                .unwrap_or(false);
124            if !valid_workspace {
125                return Err(Status::unauthenticated("invalid API key"));
126            }
127        }
128
129        let rate_limit = self
130            .state
131            .workspace_rate_limits
132            .get(&workspace_id)
133            .copied()
134            .unwrap_or(self.state.default_rate_limit_rps)
135            .max(1);
136        {
137            let mut buckets = self
138                .state
139                .buckets
140                .lock()
141                .map_err(|_| Status::internal("rate limiter lock poisoned"))?;
142            let bucket = buckets
143                .entry(workspace_id.clone())
144                .or_insert_with(|| TokenBucket::new(rate_limit));
145            bucket.rate_limit_rps = rate_limit;
146            if !bucket.try_consume(1.0) {
147                return Err(Status::resource_exhausted("rate limit exceeded"));
148            }
149        }
150
151        request.extensions_mut().insert(WorkspaceId(workspace_id));
152        request
153            .extensions_mut()
154            .insert(TraceId(format!("trace-{}", uuid::Uuid::new_v4())));
155        Ok(request)
156    }
157}
158
159#[derive(Debug)]
160struct TokenBucket {
161    tokens: f64,
162    last_refill: Instant,
163    rate_limit_rps: u32,
164}
165
166impl TokenBucket {
167    fn new(rate_limit_rps: u32) -> Self {
168        let rate = rate_limit_rps.max(1) as f64;
169        Self {
170            tokens: rate,
171            last_refill: Instant::now(),
172            rate_limit_rps,
173        }
174    }
175
176    fn try_consume(&mut self, cost: f64) -> bool {
177        let now = Instant::now();
178        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
179        let rate = self.rate_limit_rps.max(1) as f64;
180        self.tokens = (self.tokens + elapsed * rate).min(rate);
181        self.last_refill = now;
182        if self.tokens >= cost {
183            self.tokens -= cost;
184            true
185        } else {
186            false
187        }
188    }
189}
190
191/// VectorService gRPC implementation backed by [`VectorEngine`].
192pub struct VectorServiceImpl {
193    engine: Arc<VectorEngine>,
194}
195
196impl VectorServiceImpl {
197    fn workspace_from_request<T>(&self, request: &Request<T>) -> String {
198        request
199            .extensions()
200            .get::<WorkspaceId>()
201            .map(|value| value.0.clone())
202            .unwrap_or_else(|| self.engine.config.default_workspace_id.clone())
203    }
204
205    fn trace_from_request<T>(
206        &self,
207        request: &Request<T>,
208    ) -> Option<MetadataValue<tonic::metadata::Ascii>> {
209        request
210            .extensions()
211            .get::<TraceId>()
212            .and_then(|value| MetadataValue::try_from(value.0.as_str()).ok())
213    }
214}
215
216#[tonic::async_trait]
217impl VectorService for VectorServiceImpl {
218    async fn create_collection(
219        &self,
220        request: Request<CreateCollectionRequest>,
221    ) -> Result<Response<CollectionInfo>, Status> {
222        let trace = self.trace_from_request(&request);
223        let workspace_id = self.workspace_from_request(&request);
224        let req = request.into_inner();
225        let metric = parse_distance_metric(&req.distance_metric)?;
226        let created = self
227            .engine
228            .create_collection_in_workspace(
229                &workspace_id,
230                &req.name,
231                req.dimensions as usize,
232                metric,
233            )
234            .await
235            .map_err(Status::from)?;
236        let info = collection_to_proto(&created);
237        let mut response = Response::new(info);
238        if let Some(trace_id) = trace {
239            response.metadata_mut().insert(TRACE_HEADER, trace_id);
240        }
241        Ok(response)
242    }
243
244    async fn delete_collection(
245        &self,
246        request: Request<DeleteCollectionRequest>,
247    ) -> Result<Response<DeleteResult>, Status> {
248        let trace = self.trace_from_request(&request);
249        let workspace_id = self.workspace_from_request(&request);
250        let req = request.into_inner();
251        let stats = self
252            .engine
253            .collections
254            .store
255            .collection_stats(&workspace_id, &req.name)
256            .await
257            .ok();
258        self.engine
259            .delete_collection_in_workspace(&workspace_id, &req.name)
260            .await
261            .map_err(Status::from)?;
262        let mut response = Response::new(DeleteResult {
263            records_removed: stats.map(|value| value.vector_count).unwrap_or(0),
264        });
265        if let Some(trace_id) = trace {
266            response.metadata_mut().insert(TRACE_HEADER, trace_id);
267        }
268        Ok(response)
269    }
270
271    async fn upsert_vector(
272        &self,
273        request: Request<UpsertVectorRequest>,
274    ) -> Result<Response<UpsertResult>, Status> {
275        let trace = self.trace_from_request(&request);
276        let workspace_id = self.workspace_from_request(&request);
277        let req = request.into_inner();
278        let metadata = if req.metadata_json.trim().is_empty() {
279            serde_json::json!({})
280        } else {
281            serde_json::from_str(&req.metadata_json)
282                .map_err(|err| Status::invalid_argument(format!("invalid metadata_json: {err}")))?
283        };
284
285        let id = if !req.vector.is_empty() {
286            self.engine
287                .upsert_vector_in_workspace(&workspace_id, &req.collection, req.vector, metadata)
288                .await
289                .map_err(Status::from)?
290        } else if !req.text.trim().is_empty() {
291            self.engine
292                .upsert_in_workspace(&workspace_id, &req.collection, &req.text, metadata)
293                .await
294                .map_err(Status::from)?
295        } else {
296            return Err(Status::invalid_argument(
297                "either vector or text must be provided",
298            ));
299        };
300
301        let mut response = Response::new(UpsertResult { id: id.to_string() });
302        if let Some(trace_id) = trace {
303            response.metadata_mut().insert(TRACE_HEADER, trace_id);
304        }
305        Ok(response)
306    }
307
308    async fn search_vectors(
309        &self,
310        request: Request<SearchRequest>,
311    ) -> Result<Response<SearchResponseProto>, Status> {
312        let trace = self.trace_from_request(&request);
313        let workspace_id = self.workspace_from_request(&request);
314        let req = request.into_inner();
315        let filter =
316            if req.filter_json.trim().is_empty() {
317                None
318            } else {
319                Some(serde_json::from_str(&req.filter_json).map_err(|err| {
320                    Status::invalid_argument(format!("invalid filter_json: {err}"))
321                })?)
322            };
323
324        let query = if !req.vector.is_empty() {
325            SearchQuery {
326                collection: req.collection,
327                vector: req.vector,
328                top_k: req.top_k.max(1) as usize,
329                filter,
330                include_vectors: req.include_vectors,
331                include_metadata: req.include_metadata,
332                ef_search: None,
333                reranker: None,
334            }
335        } else if !req.text.trim().is_empty() {
336            let embedded = self
337                .engine
338                .embedding_client
339                .embed_one(&req.text)
340                .await
341                .map_err(Status::from)?;
342            SearchQuery {
343                collection: req.collection,
344                vector: embedded,
345                top_k: req.top_k.max(1) as usize,
346                filter,
347                include_vectors: req.include_vectors,
348                include_metadata: req.include_metadata,
349                ef_search: None,
350                reranker: None,
351            }
352        } else {
353            return Err(Status::invalid_argument(
354                "either vector or text must be provided",
355            ));
356        };
357
358        let response = self
359            .engine
360            .search_in_workspace(&workspace_id, query)
361            .await
362            .map_err(Status::from)?;
363
364        let proto = SearchResponseProto {
365            results: response
366                .results
367                .into_iter()
368                .map(|result| crate::grpc::proto::SearchHit {
369                    id: result.id.to_string(),
370                    score: result.score,
371                    vector: result.vector.unwrap_or_default(),
372                    metadata_json: if result.metadata.is_null() {
373                        "{}".to_string()
374                    } else {
375                        serde_json::to_string(&result.metadata).unwrap_or_else(|_| "{}".to_string())
376                    },
377                    text: result.text.unwrap_or_default(),
378                })
379                .collect(),
380            metrics: Some(SearchMetricsProto {
381                query_vector_dims: response.metrics.query_vector_dims as u32,
382                candidates_evaluated: response.metrics.candidates_evaluated as u32,
383                post_filter_count: response.metrics.post_filter_count as u32,
384                latency_us: response.metrics.latency_us,
385            }),
386        };
387
388        let mut response = Response::new(proto);
389        if let Some(trace_id) = trace {
390            response.metadata_mut().insert(TRACE_HEADER, trace_id);
391        }
392        Ok(response)
393    }
394
395    async fn get_collection_stats(
396        &self,
397        request: Request<StatsRequest>,
398    ) -> Result<Response<CollectionStatsResponse>, Status> {
399        let trace = self.trace_from_request(&request);
400        let workspace_id = self.workspace_from_request(&request);
401        let req = request.into_inner();
402        let collection = self
403            .engine
404            .collections
405            .get_collection(&workspace_id, &req.collection)
406            .await
407            .map_err(Status::from)?;
408
409        let response = CollectionStatsResponse {
410            vector_count: collection.vector_count,
411            index_type: format!("{:?}", collection.index_type).to_lowercase(),
412            dimensions: collection.dimensions as u32,
413            last_modified_at: collection.created_at.timestamp_millis(),
414        };
415
416        let mut response = Response::new(response);
417        if let Some(trace_id) = trace {
418            response.metadata_mut().insert(TRACE_HEADER, trace_id);
419        }
420        Ok(response)
421    }
422
423    async fn list_collections(
424        &self,
425        request: Request<ListRequest>,
426    ) -> Result<Response<ListCollectionsResponse>, Status> {
427        let trace = self.trace_from_request(&request);
428        let workspace_id = self.workspace_from_request(&request);
429        let req = request.into_inner();
430        let page_size = req.page_size.clamp(1, 500) as usize;
431        let page = req.page.max(1) as usize;
432
433        let collections = self
434            .engine
435            .list_collections_in_workspace(&workspace_id)
436            .await
437            .map_err(Status::from)?;
438        let total = collections.len();
439        let start = page_size.saturating_mul(page - 1);
440        let page_items = collections
441            .into_iter()
442            .skip(start)
443            .take(page_size)
444            .map(|collection| collection_to_proto(&collection))
445            .collect::<Vec<_>>();
446
447        let mut response = Response::new(ListCollectionsResponse {
448            collections: page_items,
449            page: page as u32,
450            page_size: page_size as u32,
451            total: total as u32,
452        });
453        if let Some(trace_id) = trace {
454            response.metadata_mut().insert(TRACE_HEADER, trace_id);
455        }
456        Ok(response)
457    }
458}
459
460fn collection_to_proto(collection: &crate::types::Collection) -> CollectionInfo {
461    CollectionInfo {
462        id: format!("{}:{}", collection.workspace_id, collection.name),
463        name: collection.name.clone(),
464        dimensions: collection.dimensions as u32,
465        distance_metric: format!("{:?}", collection.distance).to_lowercase(),
466        index_type: format!("{:?}", collection.index_type).to_lowercase(),
467        vector_count: collection.vector_count,
468        last_modified_at: collection.created_at.timestamp_millis(),
469    }
470}
471
472#[allow(clippy::result_large_err)]
473fn parse_distance_metric(raw: &str) -> Result<DistanceMetric, Status> {
474    match raw.trim().to_ascii_lowercase().as_str() {
475        "cosine" => Ok(DistanceMetric::Cosine),
476        "euclidean" => Ok(DistanceMetric::Euclidean),
477        "dot" | "dot_product" => Ok(DistanceMetric::DotProduct),
478        _ => Err(Status::invalid_argument("unsupported distance metric")),
479    }
480}
481
482fn hash_api_key(api_key: &str) -> String {
483    let mut hasher = Sha256::new();
484    hasher.update(api_key.as_bytes());
485    hex::encode(hasher.finalize())
486}
487
488async fn load_server_state(config: &VectorConfig) -> Result<Arc<ServerState>, VectorError> {
489    let store = crate::store::sqlite::VectorStore::new(&config.api_key_store_path).await?;
490    sqlx::query(
491        "CREATE TABLE IF NOT EXISTS api_keys (key_hash TEXT PRIMARY KEY, workspace_id TEXT NOT NULL, created_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0)",
492    )
493    .execute(store.pool())
494    .await?;
495
496    let key_rows = sqlx::query_as::<_, (String, String)>(
497        "SELECT key_hash, workspace_id FROM api_keys WHERE revoked = 0",
498    )
499    .fetch_all(store.pool())
500    .await?;
501    let api_keys = key_rows.into_iter().collect::<HashMap<_, _>>();
502
503    let mut workspace_rate_limits = HashMap::new();
504    let has_rate_table = sqlx::query_scalar::<_, i64>(
505        "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='workspace_rate_limits'",
506    )
507    .fetch_one(store.pool())
508    .await
509    .unwrap_or(0)
510        > 0;
511    if has_rate_table {
512        let rows = sqlx::query_as::<_, (String, i64)>(
513            "SELECT workspace_id, rate_limit_rps FROM workspace_rate_limits",
514        )
515        .fetch_all(store.pool())
516        .await
517        .unwrap_or_default();
518        for (workspace_id, rps) in rows {
519            workspace_rate_limits.insert(workspace_id, (rps as u32).max(1));
520        }
521    }
522
523    Ok(Arc::new(ServerState {
524        default_workspace_id: config.default_workspace_id.clone(),
525        require_auth: config.require_auth,
526        default_rate_limit_rps: config.rate_limit_rps.max(1),
527        api_keys: Arc::new(api_keys),
528        workspace_rate_limits: Arc::new(workspace_rate_limits),
529        buckets: Arc::new(Mutex::new(HashMap::new())),
530    }))
531}
532
533/// Start the Rust gRPC server with auth, rate limiting, and trace-id interception.
534pub async fn serve(addr: std::net::SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
535    let config = VectorConfig::from_env();
536    let state = load_server_state(&config).await?;
537    let engine = Arc::new(VectorEngine::new(config.clone()).await?);
538
539    let interceptor = AuthRateTraceInterceptor { state };
540    let embedding = EmbeddingServiceServer::new(EmbeddingServiceImpl);
541    let vector = VectorServiceServer::new(VectorServiceImpl { engine });
542
543    Server::builder()
544        .add_service(InterceptedService::new(embedding, interceptor.clone()))
545        .add_service(InterceptedService::new(vector, interceptor))
546        .serve(addr)
547        .await?;
548    Ok(())
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554    use tonic::service::Interceptor;
555
556    fn interceptor_for_test(require_auth: bool, rate_limit: u32) -> AuthRateTraceInterceptor {
557        let mut api_keys = HashMap::new();
558        api_keys.insert(hash_api_key("valid-key"), "ws-test".to_string());
559        AuthRateTraceInterceptor {
560            state: Arc::new(ServerState {
561                default_workspace_id: "default".to_string(),
562                require_auth,
563                default_rate_limit_rps: rate_limit,
564                api_keys: Arc::new(api_keys),
565                workspace_rate_limits: Arc::new(HashMap::new()),
566                buckets: Arc::new(Mutex::new(HashMap::new())),
567            }),
568        }
569    }
570
571    #[test]
572    fn interceptor_rejects_invalid_api_key() {
573        let mut interceptor = interceptor_for_test(true, 100);
574        let mut request = Request::new(());
575        request.metadata_mut().insert(
576            WORKSPACE_HEADER,
577            MetadataValue::try_from("ws-test").unwrap(),
578        );
579        request.metadata_mut().insert(
580            API_KEY_HEADER,
581            MetadataValue::try_from("wrong-key").unwrap(),
582        );
583
584        let result = interceptor.call(request);
585        assert!(matches!(result, Err(status) if status.code() == tonic::Code::Unauthenticated));
586    }
587
588    #[test]
589    fn interceptor_applies_workspace_rate_limit() {
590        let mut interceptor = interceptor_for_test(false, 100);
591        let mut last: Result<Request<()>, Status> = Ok(Request::new(()));
592        for _ in 0..101 {
593            let mut request = Request::new(());
594            request.metadata_mut().insert(
595                WORKSPACE_HEADER,
596                MetadataValue::try_from("ws-test").unwrap(),
597            );
598            last = interceptor.call(request);
599        }
600
601        assert!(matches!(last, Err(status) if status.code() == tonic::Code::ResourceExhausted));
602    }
603}