1use crate::error::GrpcError;
21use crate::proto::{
22 self,
23 vector_index_service_server::{VectorIndexService, VectorIndexServiceServer},
24 CreateIndexRequest, CreateIndexResponse, DropIndexRequest, DropIndexResponse,
25 GetStatsRequest, GetStatsResponse, HealthCheckRequest, HealthCheckResponse,
26 HnswConfig as ProtoHnswConfig, IndexInfo, IndexStats, InsertBatchRequest,
27 InsertBatchResponse, InsertStreamRequest, InsertStreamResponse, QueryResults,
28 SearchBatchRequest, SearchBatchResponse, SearchRequest, SearchResponse, SearchResult,
29};
30use dashmap::DashMap;
31use std::sync::Arc;
32use std::time::Instant;
33use tokio_stream::StreamExt;
34use tonic::{Request, Response, Status, Streaming};
35use sochdb_index::hnsw::{DistanceMetric, HnswConfig, HnswIndex};
36
37const VERSION: &str = env!("CARGO_PKG_VERSION");
39
40#[allow(dead_code)]
42struct IndexEntry {
43 index: Arc<HnswIndex>,
44 name: String,
45 dimension: usize,
46 metric: proto::DistanceMetric,
47 config: ProtoHnswConfig,
48 created_at: u64,
49}
50
51pub struct VectorIndexServer {
53 indexes: DashMap<String, IndexEntry>,
55}
56
57impl VectorIndexServer {
58 pub fn new() -> Self {
60 Self {
61 indexes: DashMap::new(),
62 }
63 }
64
65 pub fn into_service(self) -> VectorIndexServiceServer<Self> {
67 VectorIndexServiceServer::new(self)
68 }
69
70 fn get_index_with_dim(&self, name: &str) -> Result<(Arc<HnswIndex>, usize), GrpcError> {
72 self.indexes
73 .get(name)
74 .map(|entry| (entry.index.clone(), entry.dimension))
75 .ok_or_else(|| GrpcError::IndexNotFound(name.to_string()))
76 }
77
78 fn get_index(&self, name: &str) -> Result<Arc<HnswIndex>, GrpcError> {
80 self.indexes
81 .get(name)
82 .map(|entry| entry.index.clone())
83 .ok_or_else(|| GrpcError::IndexNotFound(name.to_string()))
84 }
85
86 fn convert_metric(metric: proto::DistanceMetric) -> DistanceMetric {
88 match metric {
89 proto::DistanceMetric::L2 => DistanceMetric::Euclidean,
90 proto::DistanceMetric::Cosine => DistanceMetric::Cosine,
91 proto::DistanceMetric::DotProduct => DistanceMetric::DotProduct,
92 _ => DistanceMetric::Cosine, }
94 }
95}
96
97impl Default for VectorIndexServer {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103#[tonic::async_trait]
104impl VectorIndexService for VectorIndexServer {
105 async fn create_index(
106 &self,
107 request: Request<CreateIndexRequest>,
108 ) -> Result<Response<CreateIndexResponse>, Status> {
109 let req = request.into_inner();
110 let name = req.name.clone();
111
112 if self.indexes.contains_key(&name) {
114 return Ok(Response::new(CreateIndexResponse {
115 success: false,
116 error: format!("Index '{}' already exists", name),
117 info: None,
118 }));
119 }
120
121 let proto_config = req.config.unwrap_or_default();
123 let config = HnswConfig {
124 max_connections: if proto_config.max_connections > 0 {
125 proto_config.max_connections as usize
126 } else {
127 16
128 },
129 max_connections_layer0: if proto_config.max_connections_layer0 > 0 {
130 proto_config.max_connections_layer0 as usize
131 } else {
132 32
133 },
134 ef_construction: if proto_config.ef_construction > 0 {
135 proto_config.ef_construction as usize
136 } else {
137 200
138 },
139 ef_search: if proto_config.ef_search > 0 {
140 proto_config.ef_search as usize
141 } else {
142 50
143 },
144 metric: Self::convert_metric(req.metric()),
145 ..Default::default()
146 };
147
148 let dimension = req.dimension as usize;
149 let index = HnswIndex::new(dimension, config.clone());
150 let created_at = std::time::SystemTime::now()
151 .duration_since(std::time::UNIX_EPOCH)
152 .unwrap()
153 .as_secs();
154
155 let entry = IndexEntry {
156 index: Arc::new(index),
157 name: name.clone(),
158 dimension,
159 metric: req.metric(),
160 config: proto_config.clone(),
161 created_at,
162 };
163
164 self.indexes.insert(name.clone(), entry);
165
166 tracing::info!("Created index '{}' with dimension {}", name, dimension);
167
168 Ok(Response::new(CreateIndexResponse {
169 success: true,
170 error: String::new(),
171 info: Some(IndexInfo {
172 name,
173 dimension: dimension as u32,
174 metric: req.metric.into(),
175 config: Some(proto_config),
176 created_at,
177 }),
178 }))
179 }
180
181 async fn drop_index(
182 &self,
183 request: Request<DropIndexRequest>,
184 ) -> Result<Response<DropIndexResponse>, Status> {
185 let name = request.into_inner().name;
186
187 match self.indexes.remove(&name) {
188 Some(_) => {
189 tracing::info!("Dropped index '{}'", name);
190 Ok(Response::new(DropIndexResponse {
191 success: true,
192 error: String::new(),
193 }))
194 }
195 None => Ok(Response::new(DropIndexResponse {
196 success: false,
197 error: format!("Index '{}' not found", name),
198 })),
199 }
200 }
201
202 async fn insert_batch(
203 &self,
204 request: Request<InsertBatchRequest>,
205 ) -> Result<Response<InsertBatchResponse>, Status> {
206 let start = Instant::now();
207 let req = request.into_inner();
208
209 let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
210
211 if req.vectors.len() != req.ids.len() * dimension {
213 return Err(Status::invalid_argument(format!(
214 "Vector data size mismatch: expected {} floats, got {}",
215 req.ids.len() * dimension,
216 req.vectors.len()
217 )));
218 }
219
220 let ids: Vec<u128> = req.ids.iter().map(|&id| id as u128).collect();
222
223 match index.insert_batch_flat(&ids, &req.vectors, dimension) {
225 Ok(count) => {
226 let duration_us = start.elapsed().as_micros() as u64;
227 tracing::debug!(
228 "Inserted {} vectors into '{}' in {}µs",
229 count,
230 req.index_name,
231 duration_us
232 );
233 Ok(Response::new(InsertBatchResponse {
234 inserted_count: count as u32,
235 error: String::new(),
236 duration_us,
237 }))
238 }
239 Err(e) => Ok(Response::new(InsertBatchResponse {
240 inserted_count: 0,
241 error: e,
242 duration_us: start.elapsed().as_micros() as u64,
243 })),
244 }
245 }
246
247 async fn insert_stream(
248 &self,
249 request: Request<Streaming<InsertStreamRequest>>,
250 ) -> Result<Response<InsertStreamResponse>, Status> {
251 let start = Instant::now();
252 let mut stream = request.into_inner();
253
254 let mut index_name: Option<String> = None;
255 let mut index: Option<Arc<HnswIndex>> = None;
256 let mut total_inserted = 0u32;
257 let mut errors = Vec::new();
258
259 while let Some(result) = stream.next().await {
260 match result {
261 Ok(req) => {
262 if index.is_none() {
264 if req.index_name.is_empty() {
265 errors.push("First message must include index_name".to_string());
266 continue;
267 }
268 index_name = Some(req.index_name.clone());
269 match self.get_index(&req.index_name) {
270 Ok(idx) => index = Some(idx),
271 Err(e) => {
272 errors.push(e.to_string());
273 break;
274 }
275 }
276 }
277
278 if let Some(ref idx) = index {
280 let vector: Vec<f32> = req.vector;
281 match idx.insert_one_from_slice(req.id as u128, &vector) {
282 Ok(()) => total_inserted += 1,
283 Err(e) => errors.push(format!("ID {}: {}", req.id, e)),
284 }
285 }
286 }
287 Err(e) => {
288 errors.push(format!("Stream error: {}", e));
289 break;
290 }
291 }
292 }
293
294 let duration_us = start.elapsed().as_micros() as u64;
295
296 if let Some(name) = &index_name {
297 tracing::debug!(
298 "Stream inserted {} vectors into '{}' in {}µs",
299 total_inserted,
300 name,
301 duration_us
302 );
303 }
304
305 Ok(Response::new(InsertStreamResponse {
306 total_inserted,
307 errors,
308 duration_us,
309 }))
310 }
311
312 async fn search(
313 &self,
314 request: Request<SearchRequest>,
315 ) -> Result<Response<SearchResponse>, Status> {
316 let start = Instant::now();
317 let req = request.into_inner();
318
319 let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
320
321 if req.query.len() != dimension {
323 return Err(Status::invalid_argument(format!(
324 "Query dimension mismatch: expected {}, got {}",
325 dimension,
326 req.query.len()
327 )));
328 }
329
330 let k = req.k.max(1) as usize;
331
332 let results = match index.search(&req.query, k) {
334 Ok(r) => r,
335 Err(e) => {
336 return Ok(Response::new(SearchResponse {
337 results: vec![],
338 duration_us: start.elapsed().as_micros() as u64,
339 error: e,
340 }));
341 }
342 };
343
344 let duration_us = start.elapsed().as_micros() as u64;
345
346 Ok(Response::new(SearchResponse {
347 results: results
348 .into_iter()
349 .map(|(id, distance)| SearchResult {
350 id: id as u64,
351 distance,
352 })
353 .collect(),
354 duration_us,
355 error: String::new(),
356 }))
357 }
358
359 async fn search_batch(
360 &self,
361 request: Request<SearchBatchRequest>,
362 ) -> Result<Response<SearchBatchResponse>, Status> {
363 let start = Instant::now();
364 let req = request.into_inner();
365
366 let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
367 let num_queries = req.num_queries as usize;
368 let k = req.k.max(1) as usize;
369
370 if req.queries.len() != num_queries * dimension {
372 return Err(Status::invalid_argument(format!(
373 "Query data size mismatch: expected {} floats, got {}",
374 num_queries * dimension,
375 req.queries.len()
376 )));
377 }
378
379 let mut all_results = Vec::with_capacity(num_queries);
381
382 for i in 0..num_queries {
383 let query = &req.queries[i * dimension..(i + 1) * dimension];
384 let results = match index.search(query, k) {
385 Ok(r) => r,
386 Err(_) => vec![],
387 };
388
389 all_results.push(QueryResults {
390 results: results
391 .into_iter()
392 .map(|(id, distance)| SearchResult {
393 id: id as u64,
394 distance,
395 })
396 .collect(),
397 });
398 }
399
400 let duration_us = start.elapsed().as_micros() as u64;
401
402 Ok(Response::new(SearchBatchResponse {
403 results: all_results,
404 duration_us,
405 }))
406 }
407
408 async fn get_stats(
409 &self,
410 request: Request<GetStatsRequest>,
411 ) -> Result<Response<GetStatsResponse>, Status> {
412 let name = request.into_inner().index_name;
413
414 match self.indexes.get(&name) {
415 Some(entry) => {
416 let stats = entry.index.stats();
417 Ok(Response::new(GetStatsResponse {
418 stats: Some(IndexStats {
419 num_vectors: stats.num_vectors as u64,
420 dimension: entry.dimension as u32,
421 max_layer: stats.max_layer as u32,
422 memory_bytes: 0, avg_connections: stats.avg_connections,
424 }),
425 error: String::new(),
426 }))
427 }
428 None => Ok(Response::new(GetStatsResponse {
429 stats: None,
430 error: format!("Index '{}' not found", name),
431 })),
432 }
433 }
434
435 async fn health_check(
436 &self,
437 _request: Request<HealthCheckRequest>,
438 ) -> Result<Response<HealthCheckResponse>, Status> {
439 let indexes: Vec<String> = self.indexes.iter().map(|e| e.name.clone()).collect();
440
441 Ok(Response::new(HealthCheckResponse {
442 status: proto::health_check_response::Status::Serving.into(),
443 version: VERSION.to_string(),
444 indexes,
445 }))
446 }
447}