Skip to main content

claw_vector/grpc/
server.rs

1use std::{collections::HashMap, num::NonZeroU32, sync::Arc};
2
3use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter};
4use tokio_stream::wrappers::ReceiverStream;
5use tonic::{
6    metadata::MetadataValue, service::interceptor::InterceptedService, transport::Server, Request,
7    Response, Status,
8};
9
10use crate::{
11    config::VectorConfig,
12    engine::VectorEngine,
13    grpc::proto::{
14        embedding_service_server::{EmbeddingService, EmbeddingServiceServer},
15        vector_service_server::{VectorService, VectorServiceServer},
16        CollectionInfo, CollectionStatsResponse, CreateCollectionRequest, DeleteCollectionRequest,
17        DeleteResult, EmbedRequest, EmbedResponse, HealthRequest, HealthResponse,
18        ListCollectionsResponse, ListRequest, ModelInfoRequest, ModelInfoResponse,
19        SearchMetricsProto, SearchRequest, SearchResponseProto, StatsRequest, UpsertResult,
20        UpsertVectorRequest,
21    },
22    types::{DistanceMetric, SearchQuery},
23    VectorError,
24};
25
26const WORKSPACE_HEADER: &str = "x-claw-workspace-id";
27const API_KEY_HEADER: &str = "x-claw-api-key";
28const REQUEST_ID_HEADER: &str = "x-request-id";
29
30#[derive(Clone)]
31struct WorkspaceId(String);
32
33#[derive(Clone)]
34struct TraceId(String);
35
36/// Minimal pass-through embedding service stub for local Rust gRPC server mode.
37pub struct EmbeddingServiceImpl;
38
39#[tonic::async_trait]
40impl EmbeddingService for EmbeddingServiceImpl {
41    async fn embed(
42        &self,
43        _request: Request<EmbedRequest>,
44    ) -> Result<Response<EmbedResponse>, Status> {
45        Err(Status::unimplemented(
46            "Embed is handled by the Python embedding service",
47        ))
48    }
49
50    async fn health(
51        &self,
52        _request: Request<HealthRequest>,
53    ) -> Result<Response<HealthResponse>, Status> {
54        Ok(Response::new(HealthResponse {
55            ready: false,
56            model_name: String::new(),
57            model_load_time_ms: 0,
58        }))
59    }
60
61    async fn model_info(
62        &self,
63        _request: Request<ModelInfoRequest>,
64    ) -> Result<Response<ModelInfoResponse>, Status> {
65        Err(Status::unimplemented(
66            "ModelInfo is handled by the Python embedding service",
67        ))
68    }
69
70    type EmbedStreamStream = ReceiverStream<Result<EmbedResponse, Status>>;
71
72    async fn embed_stream(
73        &self,
74        _request: Request<tonic::Streaming<EmbedRequest>>,
75    ) -> Result<Response<Self::EmbedStreamStream>, Status> {
76        Err(Status::unimplemented(
77            "EmbedStream is handled by the Python embedding service",
78        ))
79    }
80}
81
82#[derive(Clone)]
83struct ServerState {
84    default_workspace_id: String,
85    require_workspace_id: bool,
86    require_auth: bool,
87    api_keys: Arc<HashMap<String, String>>,
88    limiter: Arc<RateLimiter<String, DashMapStateStore<String>, DefaultClock>>,
89}
90
91#[derive(Clone)]
92struct AuthRateTraceInterceptor {
93    state: Arc<ServerState>,
94}
95
96impl tonic::service::Interceptor for AuthRateTraceInterceptor {
97    fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
98        let workspace_id = request
99            .metadata()
100            .get(WORKSPACE_HEADER)
101            .and_then(|value| value.to_str().ok())
102            .map(ToOwned::to_owned)
103            .or_else(|| {
104                if self.state.require_workspace_id {
105                    None
106                } else {
107                    Some(self.state.default_workspace_id.clone())
108                }
109            })
110            .ok_or_else(|| Status::invalid_argument("missing x-claw-workspace-id"))?;
111
112        if self.state.require_auth {
113            let api_key = request
114                .metadata()
115                .get(API_KEY_HEADER)
116                .and_then(|value| value.to_str().ok())
117                .ok_or_else(|| Status::unauthenticated("missing x-claw-api-key"))?;
118
119            let hashed = hash_api_key(api_key);
120            let valid_workspace = self
121                .state
122                .api_keys
123                .get(&hashed)
124                .map(|ws| ws == &workspace_id)
125                .unwrap_or(false);
126            if !valid_workspace {
127                return Err(Status::unauthenticated("invalid API key"));
128            }
129        }
130
131        if self.state.limiter.check_key(&workspace_id).is_err() {
132            return Err(Status::resource_exhausted("rate limit exceeded"));
133        }
134
135        let request_id = request
136            .metadata()
137            .get(REQUEST_ID_HEADER)
138            .and_then(|value| value.to_str().ok())
139            .map(ToOwned::to_owned)
140            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
141
142        request.extensions_mut().insert(WorkspaceId(workspace_id));
143        request.extensions_mut().insert(TraceId(request_id));
144        Ok(request)
145    }
146}
147
148/// VectorService gRPC implementation backed by [`VectorEngine`].
149pub struct VectorServiceImpl {
150    engine: Arc<VectorEngine>,
151}
152
153impl VectorServiceImpl {
154    fn workspace_from_request<T>(&self, request: &Request<T>) -> String {
155        request
156            .extensions()
157            .get::<WorkspaceId>()
158            .map(|value| value.0.clone())
159            .unwrap_or_else(|| self.engine.config.default_workspace_id.clone())
160    }
161
162    fn trace_from_request<T>(
163        &self,
164        request: &Request<T>,
165    ) -> Option<MetadataValue<tonic::metadata::Ascii>> {
166        request
167            .extensions()
168            .get::<TraceId>()
169            .and_then(|value| MetadataValue::try_from(value.0.as_str()).ok())
170    }
171}
172
173#[tonic::async_trait]
174impl VectorService for VectorServiceImpl {
175    async fn create_collection(
176        &self,
177        request: Request<CreateCollectionRequest>,
178    ) -> Result<Response<CollectionInfo>, Status> {
179        let trace = self.trace_from_request(&request);
180        let workspace_id = self.workspace_from_request(&request);
181        let req = request.into_inner();
182        let metric = parse_distance_metric(&req.distance_metric)?;
183        let created = self
184            .engine
185            .create_collection_in_workspace(
186                &workspace_id,
187                &req.name,
188                req.dimensions as usize,
189                metric,
190            )
191            .await
192            .map_err(Status::from)?;
193        let info = collection_to_proto(&created);
194        let mut response = Response::new(info);
195        if let Some(trace_id) = trace {
196            response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
197        }
198        Ok(response)
199    }
200
201    async fn delete_collection(
202        &self,
203        request: Request<DeleteCollectionRequest>,
204    ) -> Result<Response<DeleteResult>, Status> {
205        let trace = self.trace_from_request(&request);
206        let workspace_id = self.workspace_from_request(&request);
207        let req = request.into_inner();
208        let stats = self
209            .engine
210            .collections
211            .store
212            .collection_stats(&workspace_id, &req.name)
213            .await
214            .ok();
215        self.engine
216            .delete_collection_in_workspace(&workspace_id, &req.name)
217            .await
218            .map_err(Status::from)?;
219        let mut response = Response::new(DeleteResult {
220            records_removed: stats.map(|value| value.vector_count).unwrap_or(0),
221        });
222        if let Some(trace_id) = trace {
223            response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
224        }
225        Ok(response)
226    }
227
228    async fn upsert_vector(
229        &self,
230        request: Request<UpsertVectorRequest>,
231    ) -> Result<Response<UpsertResult>, Status> {
232        let trace = self.trace_from_request(&request);
233        let workspace_id = self.workspace_from_request(&request);
234        let req = request.into_inner();
235        let metadata = if req.metadata_json.trim().is_empty() {
236            serde_json::json!({})
237        } else {
238            serde_json::from_str(&req.metadata_json)
239                .map_err(|err| Status::invalid_argument(format!("invalid metadata_json: {err}")))?
240        };
241
242        let id = if !req.vector.is_empty() {
243            self.engine
244                .upsert_vector_in_workspace(&workspace_id, &req.collection, req.vector, metadata)
245                .await
246                .map_err(Status::from)?
247        } else if !req.text.trim().is_empty() {
248            self.engine
249                .upsert_in_workspace(&workspace_id, &req.collection, &req.text, metadata)
250                .await
251                .map_err(Status::from)?
252        } else {
253            return Err(Status::invalid_argument(
254                "either vector or text must be provided",
255            ));
256        };
257
258        let mut response = Response::new(UpsertResult { id: id.to_string() });
259        if let Some(trace_id) = trace {
260            response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
261        }
262        Ok(response)
263    }
264
265    async fn search_vectors(
266        &self,
267        request: Request<SearchRequest>,
268    ) -> Result<Response<SearchResponseProto>, Status> {
269        let trace = self.trace_from_request(&request);
270        let workspace_id = self.workspace_from_request(&request);
271        let req = request.into_inner();
272        let filter =
273            if req.filter_json.trim().is_empty() {
274                None
275            } else {
276                Some(serde_json::from_str(&req.filter_json).map_err(|err| {
277                    Status::invalid_argument(format!("invalid filter_json: {err}"))
278                })?)
279            };
280
281        let query = if !req.vector.is_empty() {
282            SearchQuery {
283                collection: req.collection,
284                vector: req.vector,
285                top_k: req.top_k.max(1) as usize,
286                filter,
287                include_vectors: req.include_vectors,
288                include_metadata: req.include_metadata,
289                ef_search: None,
290                reranker: None,
291            }
292        } else if !req.text.trim().is_empty() {
293            let embedded = self
294                .engine
295                .embedding_client
296                .embed_one(&req.text)
297                .await
298                .map_err(Status::from)?;
299            SearchQuery {
300                collection: req.collection,
301                vector: embedded,
302                top_k: req.top_k.max(1) as usize,
303                filter,
304                include_vectors: req.include_vectors,
305                include_metadata: req.include_metadata,
306                ef_search: None,
307                reranker: None,
308            }
309        } else {
310            return Err(Status::invalid_argument(
311                "either vector or text must be provided",
312            ));
313        };
314
315        let response = self
316            .engine
317            .search_in_workspace(&workspace_id, query)
318            .await
319            .map_err(Status::from)?;
320
321        let proto = SearchResponseProto {
322            results: response
323                .results
324                .into_iter()
325                .map(|result| crate::grpc::proto::SearchHit {
326                    id: result.id.to_string(),
327                    score: result.score,
328                    vector: result.vector.unwrap_or_default(),
329                    metadata_json: if result.metadata.is_null() {
330                        "{}".to_string()
331                    } else {
332                        serde_json::to_string(&result.metadata).unwrap_or_else(|_| "{}".to_string())
333                    },
334                    text: result.text.unwrap_or_default(),
335                })
336                .collect(),
337            metrics: Some(SearchMetricsProto {
338                query_vector_dims: response.metrics.query_vector_dims as u32,
339                candidates_evaluated: response.metrics.candidates_evaluated as u32,
340                post_filter_count: response.metrics.post_filter_count as u32,
341                latency_us: response.metrics.latency_us,
342            }),
343        };
344
345        let mut response = Response::new(proto);
346        if let Some(trace_id) = trace {
347            response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
348        }
349        Ok(response)
350    }
351
352    async fn get_collection_stats(
353        &self,
354        request: Request<StatsRequest>,
355    ) -> Result<Response<CollectionStatsResponse>, Status> {
356        let trace = self.trace_from_request(&request);
357        let workspace_id = self.workspace_from_request(&request);
358        let req = request.into_inner();
359        let collection = self
360            .engine
361            .collections
362            .get_collection(&workspace_id, &req.collection)
363            .await
364            .map_err(Status::from)?;
365
366        let response = CollectionStatsResponse {
367            vector_count: collection.vector_count,
368            index_type: format!("{:?}", collection.index_type).to_lowercase(),
369            dimensions: collection.dimensions as u32,
370            last_modified_at: collection.created_at.timestamp_millis(),
371        };
372
373        let mut response = Response::new(response);
374        if let Some(trace_id) = trace {
375            response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
376        }
377        Ok(response)
378    }
379
380    async fn list_collections(
381        &self,
382        request: Request<ListRequest>,
383    ) -> Result<Response<ListCollectionsResponse>, Status> {
384        let trace = self.trace_from_request(&request);
385        let workspace_id = self.workspace_from_request(&request);
386        let req = request.into_inner();
387        let page_size = req.page_size.clamp(1, 500) as usize;
388        let page = req.page.max(1) as usize;
389
390        let collections = self
391            .engine
392            .list_collections_in_workspace(&workspace_id)
393            .await
394            .map_err(Status::from)?;
395        let total = collections.len();
396        let start = page_size.saturating_mul(page - 1);
397        let page_items = collections
398            .into_iter()
399            .skip(start)
400            .take(page_size)
401            .map(|collection| collection_to_proto(&collection))
402            .collect::<Vec<_>>();
403
404        let mut response = Response::new(ListCollectionsResponse {
405            collections: page_items,
406            page: page as u32,
407            page_size: page_size as u32,
408            total: total as u32,
409        });
410        if let Some(trace_id) = trace {
411            response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
412        }
413        Ok(response)
414    }
415}
416
417fn collection_to_proto(collection: &crate::types::Collection) -> CollectionInfo {
418    CollectionInfo {
419        id: format!("{}:{}", collection.workspace_id, collection.name),
420        name: collection.name.clone(),
421        dimensions: collection.dimensions as u32,
422        distance_metric: format!("{:?}", collection.distance).to_lowercase(),
423        index_type: format!("{:?}", collection.index_type).to_lowercase(),
424        vector_count: collection.vector_count,
425        last_modified_at: collection.created_at.timestamp_millis(),
426    }
427}
428
429#[allow(clippy::result_large_err)]
430fn parse_distance_metric(raw: &str) -> Result<DistanceMetric, Status> {
431    match raw.trim().to_ascii_lowercase().as_str() {
432        "cosine" => Ok(DistanceMetric::Cosine),
433        "euclidean" => Ok(DistanceMetric::Euclidean),
434        "dot" | "dot_product" => Ok(DistanceMetric::DotProduct),
435        _ => Err(Status::invalid_argument("unsupported distance metric")),
436    }
437}
438
439fn hash_api_key(api_key: &str) -> String {
440    blake3::hash(api_key.as_bytes()).to_hex().to_string()
441}
442
443async fn load_server_state(config: &VectorConfig) -> Result<Arc<ServerState>, VectorError> {
444    let store = crate::store::sqlite::VectorStore::new(&config.api_key_store_path).await?;
445    sqlx::query(
446        "CREATE TABLE IF NOT EXISTS api_keys (key_hash TEXT PRIMARY KEY, workspace_id TEXT NOT NULL, created_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0)",
447    )
448    .execute(store.pool())
449    .await?;
450
451    let key_rows = sqlx::query_as::<_, (String, String)>(
452        "SELECT key_hash, workspace_id FROM api_keys WHERE revoked = 0",
453    )
454    .fetch_all(store.pool())
455    .await?;
456    let api_keys = key_rows.into_iter().collect::<HashMap<_, _>>();
457
458    let rate_limit =
459        NonZeroU32::new(config.rate_limit_rps.max(1)).unwrap_or(NonZeroU32::new(1).unwrap());
460    let limiter = RateLimiter::keyed(Quota::per_second(rate_limit));
461
462    Ok(Arc::new(ServerState {
463        default_workspace_id: config.default_workspace_id.clone(),
464        require_workspace_id: config.require_workspace_id,
465        require_auth: config.require_auth,
466        api_keys: Arc::new(api_keys),
467        limiter: Arc::new(limiter),
468    }))
469}
470
471/// Start the Rust gRPC server with auth, rate limiting, and trace-id interception.
472pub async fn serve(addr: std::net::SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
473    let config = VectorConfig::from_env();
474    let state = load_server_state(&config).await?;
475    let engine = Arc::new(VectorEngine::new(config.clone()).await?);
476
477    let interceptor = AuthRateTraceInterceptor { state };
478    let embedding = EmbeddingServiceServer::new(EmbeddingServiceImpl);
479    let vector = VectorServiceServer::new(VectorServiceImpl { engine });
480
481    Server::builder()
482        .add_service(InterceptedService::new(embedding, interceptor.clone()))
483        .add_service(InterceptedService::new(vector, interceptor))
484        .serve(addr)
485        .await?;
486    Ok(())
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use tonic::service::Interceptor;
493
494    fn interceptor_for_test(require_auth: bool, rate_limit: u32) -> AuthRateTraceInterceptor {
495        let mut api_keys = HashMap::new();
496        api_keys.insert(hash_api_key("valid-key"), "ws-test".to_string());
497        AuthRateTraceInterceptor {
498            state: Arc::new(ServerState {
499                default_workspace_id: "default".to_string(),
500                require_workspace_id: false,
501                require_auth,
502                api_keys: Arc::new(api_keys),
503                limiter: Arc::new(RateLimiter::keyed(Quota::per_second(
504                    NonZeroU32::new(rate_limit.max(1)).unwrap(),
505                ))),
506            }),
507        }
508    }
509
510    #[test]
511    fn interceptor_rejects_invalid_api_key() {
512        let mut interceptor = interceptor_for_test(true, 100);
513        let mut request = Request::new(());
514        request.metadata_mut().insert(
515            WORKSPACE_HEADER,
516            MetadataValue::try_from("ws-test").unwrap(),
517        );
518        request.metadata_mut().insert(
519            API_KEY_HEADER,
520            MetadataValue::try_from("wrong-key").unwrap(),
521        );
522
523        let result = interceptor.call(request);
524        assert!(matches!(result, Err(status) if status.code() == tonic::Code::Unauthenticated));
525    }
526
527    #[test]
528    fn interceptor_applies_workspace_rate_limit() {
529        let mut interceptor = interceptor_for_test(false, 100);
530        let mut last: Result<Request<()>, Status> = Ok(Request::new(()));
531        for _ in 0..101 {
532            let mut request = Request::new(());
533            request.metadata_mut().insert(
534                WORKSPACE_HEADER,
535                MetadataValue::try_from("ws-test").unwrap(),
536            );
537            last = interceptor.call(request);
538        }
539
540        assert!(matches!(last, Err(status) if status.code() == tonic::Code::ResourceExhausted));
541    }
542}