use std::{collections::HashMap, num::NonZeroU32, sync::Arc};
use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter};
use tokio_stream::wrappers::ReceiverStream;
use tonic::{
metadata::MetadataValue, service::interceptor::InterceptedService, transport::Server, Request,
Response, Status,
};
use crate::{
config::VectorConfig,
engine::VectorEngine,
grpc::proto::{
embedding_service_server::{EmbeddingService, EmbeddingServiceServer},
vector_service_server::{VectorService, VectorServiceServer},
CollectionInfo, CollectionStatsResponse, CreateCollectionRequest, DeleteCollectionRequest,
DeleteResult, EmbedRequest, EmbedResponse, HealthRequest, HealthResponse,
ListCollectionsResponse, ListRequest, ModelInfoRequest, ModelInfoResponse,
SearchMetricsProto, SearchRequest, SearchResponseProto, StatsRequest, UpsertResult,
UpsertVectorRequest,
},
types::{DistanceMetric, SearchQuery},
VectorError,
};
const WORKSPACE_HEADER: &str = "x-claw-workspace-id";
const API_KEY_HEADER: &str = "x-claw-api-key";
const REQUEST_ID_HEADER: &str = "x-request-id";
#[derive(Clone)]
struct WorkspaceId(String);
#[derive(Clone)]
struct TraceId(String);
pub struct EmbeddingServiceImpl;
#[tonic::async_trait]
impl EmbeddingService for EmbeddingServiceImpl {
async fn embed(
&self,
_request: Request<EmbedRequest>,
) -> Result<Response<EmbedResponse>, Status> {
Err(Status::unimplemented(
"Embed is handled by the Python embedding service",
))
}
async fn health(
&self,
_request: Request<HealthRequest>,
) -> Result<Response<HealthResponse>, Status> {
Ok(Response::new(HealthResponse {
ready: false,
model_name: String::new(),
model_load_time_ms: 0,
}))
}
async fn model_info(
&self,
_request: Request<ModelInfoRequest>,
) -> Result<Response<ModelInfoResponse>, Status> {
Err(Status::unimplemented(
"ModelInfo is handled by the Python embedding service",
))
}
type EmbedStreamStream = ReceiverStream<Result<EmbedResponse, Status>>;
async fn embed_stream(
&self,
_request: Request<tonic::Streaming<EmbedRequest>>,
) -> Result<Response<Self::EmbedStreamStream>, Status> {
Err(Status::unimplemented(
"EmbedStream is handled by the Python embedding service",
))
}
}
#[derive(Clone)]
struct ServerState {
default_workspace_id: String,
require_workspace_id: bool,
require_auth: bool,
api_keys: Arc<HashMap<String, String>>,
limiter: Arc<RateLimiter<String, DashMapStateStore<String>, DefaultClock>>,
}
#[derive(Clone)]
struct AuthRateTraceInterceptor {
state: Arc<ServerState>,
}
impl tonic::service::Interceptor for AuthRateTraceInterceptor {
fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
let workspace_id = request
.metadata()
.get(WORKSPACE_HEADER)
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned)
.or_else(|| {
if self.state.require_workspace_id {
None
} else {
Some(self.state.default_workspace_id.clone())
}
})
.ok_or_else(|| Status::invalid_argument("missing x-claw-workspace-id"))?;
if self.state.require_auth {
let api_key = request
.metadata()
.get(API_KEY_HEADER)
.and_then(|value| value.to_str().ok())
.ok_or_else(|| Status::unauthenticated("missing x-claw-api-key"))?;
let hashed = hash_api_key(api_key);
let valid_workspace = self
.state
.api_keys
.get(&hashed)
.map(|ws| ws == &workspace_id)
.unwrap_or(false);
if !valid_workspace {
return Err(Status::unauthenticated("invalid API key"));
}
}
if self.state.limiter.check_key(&workspace_id).is_err() {
return Err(Status::resource_exhausted("rate limit exceeded"));
}
let request_id = request
.metadata()
.get(REQUEST_ID_HEADER)
.and_then(|value| value.to_str().ok())
.map(ToOwned::to_owned)
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
request.extensions_mut().insert(WorkspaceId(workspace_id));
request.extensions_mut().insert(TraceId(request_id));
Ok(request)
}
}
pub struct VectorServiceImpl {
engine: Arc<VectorEngine>,
}
impl VectorServiceImpl {
fn workspace_from_request<T>(&self, request: &Request<T>) -> String {
request
.extensions()
.get::<WorkspaceId>()
.map(|value| value.0.clone())
.unwrap_or_else(|| self.engine.config.default_workspace_id.clone())
}
fn trace_from_request<T>(
&self,
request: &Request<T>,
) -> Option<MetadataValue<tonic::metadata::Ascii>> {
request
.extensions()
.get::<TraceId>()
.and_then(|value| MetadataValue::try_from(value.0.as_str()).ok())
}
}
#[tonic::async_trait]
impl VectorService for VectorServiceImpl {
async fn create_collection(
&self,
request: Request<CreateCollectionRequest>,
) -> Result<Response<CollectionInfo>, Status> {
let trace = self.trace_from_request(&request);
let workspace_id = self.workspace_from_request(&request);
let req = request.into_inner();
let metric = parse_distance_metric(&req.distance_metric)?;
let created = self
.engine
.create_collection_in_workspace(
&workspace_id,
&req.name,
req.dimensions as usize,
metric,
)
.await
.map_err(Status::from)?;
let info = collection_to_proto(&created);
let mut response = Response::new(info);
if let Some(trace_id) = trace {
response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
}
Ok(response)
}
async fn delete_collection(
&self,
request: Request<DeleteCollectionRequest>,
) -> Result<Response<DeleteResult>, Status> {
let trace = self.trace_from_request(&request);
let workspace_id = self.workspace_from_request(&request);
let req = request.into_inner();
let stats = self
.engine
.collections
.store
.collection_stats(&workspace_id, &req.name)
.await
.ok();
self.engine
.delete_collection_in_workspace(&workspace_id, &req.name)
.await
.map_err(Status::from)?;
let mut response = Response::new(DeleteResult {
records_removed: stats.map(|value| value.vector_count).unwrap_or(0),
});
if let Some(trace_id) = trace {
response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
}
Ok(response)
}
async fn upsert_vector(
&self,
request: Request<UpsertVectorRequest>,
) -> Result<Response<UpsertResult>, Status> {
let trace = self.trace_from_request(&request);
let workspace_id = self.workspace_from_request(&request);
let req = request.into_inner();
let metadata = if req.metadata_json.trim().is_empty() {
serde_json::json!({})
} else {
serde_json::from_str(&req.metadata_json)
.map_err(|err| Status::invalid_argument(format!("invalid metadata_json: {err}")))?
};
let id = if !req.vector.is_empty() {
self.engine
.upsert_vector_in_workspace(&workspace_id, &req.collection, req.vector, metadata)
.await
.map_err(Status::from)?
} else if !req.text.trim().is_empty() {
self.engine
.upsert_in_workspace(&workspace_id, &req.collection, &req.text, metadata)
.await
.map_err(Status::from)?
} else {
return Err(Status::invalid_argument(
"either vector or text must be provided",
));
};
let mut response = Response::new(UpsertResult { id: id.to_string() });
if let Some(trace_id) = trace {
response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
}
Ok(response)
}
async fn search_vectors(
&self,
request: Request<SearchRequest>,
) -> Result<Response<SearchResponseProto>, Status> {
let trace = self.trace_from_request(&request);
let workspace_id = self.workspace_from_request(&request);
let req = request.into_inner();
let filter =
if req.filter_json.trim().is_empty() {
None
} else {
Some(serde_json::from_str(&req.filter_json).map_err(|err| {
Status::invalid_argument(format!("invalid filter_json: {err}"))
})?)
};
let query = if !req.vector.is_empty() {
SearchQuery {
collection: req.collection,
vector: req.vector,
top_k: req.top_k.max(1) as usize,
filter,
include_vectors: req.include_vectors,
include_metadata: req.include_metadata,
ef_search: None,
reranker: None,
}
} else if !req.text.trim().is_empty() {
let embedded = self
.engine
.embedding_client
.embed_one(&req.text)
.await
.map_err(Status::from)?;
SearchQuery {
collection: req.collection,
vector: embedded,
top_k: req.top_k.max(1) as usize,
filter,
include_vectors: req.include_vectors,
include_metadata: req.include_metadata,
ef_search: None,
reranker: None,
}
} else {
return Err(Status::invalid_argument(
"either vector or text must be provided",
));
};
let response = self
.engine
.search_in_workspace(&workspace_id, query)
.await
.map_err(Status::from)?;
let proto = SearchResponseProto {
results: response
.results
.into_iter()
.map(|result| crate::grpc::proto::SearchHit {
id: result.id.to_string(),
score: result.score,
vector: result.vector.unwrap_or_default(),
metadata_json: if result.metadata.is_null() {
"{}".to_string()
} else {
serde_json::to_string(&result.metadata).unwrap_or_else(|_| "{}".to_string())
},
text: result.text.unwrap_or_default(),
})
.collect(),
metrics: Some(SearchMetricsProto {
query_vector_dims: response.metrics.query_vector_dims as u32,
candidates_evaluated: response.metrics.candidates_evaluated as u32,
post_filter_count: response.metrics.post_filter_count as u32,
latency_us: response.metrics.latency_us,
}),
};
let mut response = Response::new(proto);
if let Some(trace_id) = trace {
response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
}
Ok(response)
}
async fn get_collection_stats(
&self,
request: Request<StatsRequest>,
) -> Result<Response<CollectionStatsResponse>, Status> {
let trace = self.trace_from_request(&request);
let workspace_id = self.workspace_from_request(&request);
let req = request.into_inner();
let collection = self
.engine
.collections
.get_collection(&workspace_id, &req.collection)
.await
.map_err(Status::from)?;
let response = CollectionStatsResponse {
vector_count: collection.vector_count,
index_type: format!("{:?}", collection.index_type).to_lowercase(),
dimensions: collection.dimensions as u32,
last_modified_at: collection.created_at.timestamp_millis(),
};
let mut response = Response::new(response);
if let Some(trace_id) = trace {
response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
}
Ok(response)
}
async fn list_collections(
&self,
request: Request<ListRequest>,
) -> Result<Response<ListCollectionsResponse>, Status> {
let trace = self.trace_from_request(&request);
let workspace_id = self.workspace_from_request(&request);
let req = request.into_inner();
let page_size = req.page_size.clamp(1, 500) as usize;
let page = req.page.max(1) as usize;
let collections = self
.engine
.list_collections_in_workspace(&workspace_id)
.await
.map_err(Status::from)?;
let total = collections.len();
let start = page_size.saturating_mul(page - 1);
let page_items = collections
.into_iter()
.skip(start)
.take(page_size)
.map(|collection| collection_to_proto(&collection))
.collect::<Vec<_>>();
let mut response = Response::new(ListCollectionsResponse {
collections: page_items,
page: page as u32,
page_size: page_size as u32,
total: total as u32,
});
if let Some(trace_id) = trace {
response.metadata_mut().insert(REQUEST_ID_HEADER, trace_id);
}
Ok(response)
}
}
fn collection_to_proto(collection: &crate::types::Collection) -> CollectionInfo {
CollectionInfo {
id: format!("{}:{}", collection.workspace_id, collection.name),
name: collection.name.clone(),
dimensions: collection.dimensions as u32,
distance_metric: format!("{:?}", collection.distance).to_lowercase(),
index_type: format!("{:?}", collection.index_type).to_lowercase(),
vector_count: collection.vector_count,
last_modified_at: collection.created_at.timestamp_millis(),
}
}
#[allow(clippy::result_large_err)]
fn parse_distance_metric(raw: &str) -> Result<DistanceMetric, Status> {
match raw.trim().to_ascii_lowercase().as_str() {
"cosine" => Ok(DistanceMetric::Cosine),
"euclidean" => Ok(DistanceMetric::Euclidean),
"dot" | "dot_product" => Ok(DistanceMetric::DotProduct),
_ => Err(Status::invalid_argument("unsupported distance metric")),
}
}
fn hash_api_key(api_key: &str) -> String {
blake3::hash(api_key.as_bytes()).to_hex().to_string()
}
async fn load_server_state(config: &VectorConfig) -> Result<Arc<ServerState>, VectorError> {
let store = crate::store::sqlite::VectorStore::new(&config.api_key_store_path).await?;
sqlx::query(
"CREATE TABLE IF NOT EXISTS api_keys (key_hash TEXT PRIMARY KEY, workspace_id TEXT NOT NULL, created_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0)",
)
.execute(store.pool())
.await?;
let key_rows = sqlx::query_as::<_, (String, String)>(
"SELECT key_hash, workspace_id FROM api_keys WHERE revoked = 0",
)
.fetch_all(store.pool())
.await?;
let api_keys = key_rows.into_iter().collect::<HashMap<_, _>>();
let rate_limit =
NonZeroU32::new(config.rate_limit_rps.max(1)).unwrap_or(NonZeroU32::new(1).unwrap());
let limiter = RateLimiter::keyed(Quota::per_second(rate_limit));
Ok(Arc::new(ServerState {
default_workspace_id: config.default_workspace_id.clone(),
require_workspace_id: config.require_workspace_id,
require_auth: config.require_auth,
api_keys: Arc::new(api_keys),
limiter: Arc::new(limiter),
}))
}
pub async fn serve(addr: std::net::SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
let config = VectorConfig::from_env();
let state = load_server_state(&config).await?;
let engine = Arc::new(VectorEngine::new(config.clone()).await?);
let interceptor = AuthRateTraceInterceptor { state };
let embedding = EmbeddingServiceServer::new(EmbeddingServiceImpl);
let vector = VectorServiceServer::new(VectorServiceImpl { engine });
Server::builder()
.add_service(InterceptedService::new(embedding, interceptor.clone()))
.add_service(InterceptedService::new(vector, interceptor))
.serve(addr)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tonic::service::Interceptor;
fn interceptor_for_test(require_auth: bool, rate_limit: u32) -> AuthRateTraceInterceptor {
let mut api_keys = HashMap::new();
api_keys.insert(hash_api_key("valid-key"), "ws-test".to_string());
AuthRateTraceInterceptor {
state: Arc::new(ServerState {
default_workspace_id: "default".to_string(),
require_workspace_id: false,
require_auth,
api_keys: Arc::new(api_keys),
limiter: Arc::new(RateLimiter::keyed(Quota::per_second(
NonZeroU32::new(rate_limit.max(1)).unwrap(),
))),
}),
}
}
#[test]
fn interceptor_rejects_invalid_api_key() {
let mut interceptor = interceptor_for_test(true, 100);
let mut request = Request::new(());
request.metadata_mut().insert(
WORKSPACE_HEADER,
MetadataValue::try_from("ws-test").unwrap(),
);
request.metadata_mut().insert(
API_KEY_HEADER,
MetadataValue::try_from("wrong-key").unwrap(),
);
let result = interceptor.call(request);
assert!(matches!(result, Err(status) if status.code() == tonic::Code::Unauthenticated));
}
#[test]
fn interceptor_applies_workspace_rate_limit() {
let mut interceptor = interceptor_for_test(false, 100);
let mut last: Result<Request<()>, Status> = Ok(Request::new(()));
for _ in 0..101 {
let mut request = Request::new(());
request.metadata_mut().insert(
WORKSPACE_HEADER,
MetadataValue::try_from("ws-test").unwrap(),
);
last = interceptor.call(request);
}
assert!(matches!(last, Err(status) if status.code() == tonic::Code::ResourceExhausted));
}
}