1use crate::error::GrpcError;
18use crate::proto::{
19 self,
20 vector_index_service_server::{VectorIndexService, VectorIndexServiceServer},
21 CreateIndexRequest, CreateIndexResponse, DropIndexRequest, DropIndexResponse,
22 GetStatsRequest, GetStatsResponse, HealthCheckRequest, HealthCheckResponse,
23 HnswConfig as ProtoHnswConfig, IndexInfo, IndexStats, InsertBatchRequest,
24 InsertBatchResponse, InsertStreamRequest, InsertStreamResponse, QueryResults,
25 SearchBatchRequest, SearchBatchResponse, SearchRequest, SearchResponse, SearchResult,
26};
27use dashmap::DashMap;
28use std::sync::Arc;
29use std::time::Instant;
30use tokio_stream::StreamExt;
31use tonic::{Request, Response, Status, Streaming};
32use toondb_index::hnsw::{DistanceMetric, HnswConfig, HnswIndex};
33
34const VERSION: &str = env!("CARGO_PKG_VERSION");
36
37#[allow(dead_code)]
39struct IndexEntry {
40 index: Arc<HnswIndex>,
41 name: String,
42 dimension: usize,
43 metric: proto::DistanceMetric,
44 config: ProtoHnswConfig,
45 created_at: u64,
46}
47
48pub struct VectorIndexServer {
50 indexes: DashMap<String, IndexEntry>,
52}
53
54impl VectorIndexServer {
55 pub fn new() -> Self {
57 Self {
58 indexes: DashMap::new(),
59 }
60 }
61
62 pub fn into_service(self) -> VectorIndexServiceServer<Self> {
64 VectorIndexServiceServer::new(self)
65 }
66
67 fn get_index_with_dim(&self, name: &str) -> Result<(Arc<HnswIndex>, usize), GrpcError> {
69 self.indexes
70 .get(name)
71 .map(|entry| (entry.index.clone(), entry.dimension))
72 .ok_or_else(|| GrpcError::IndexNotFound(name.to_string()))
73 }
74
75 fn get_index(&self, name: &str) -> Result<Arc<HnswIndex>, GrpcError> {
77 self.indexes
78 .get(name)
79 .map(|entry| entry.index.clone())
80 .ok_or_else(|| GrpcError::IndexNotFound(name.to_string()))
81 }
82
83 fn convert_metric(metric: proto::DistanceMetric) -> DistanceMetric {
85 match metric {
86 proto::DistanceMetric::L2 => DistanceMetric::Euclidean,
87 proto::DistanceMetric::Cosine => DistanceMetric::Cosine,
88 proto::DistanceMetric::DotProduct => DistanceMetric::DotProduct,
89 _ => DistanceMetric::Cosine, }
91 }
92}
93
94impl Default for VectorIndexServer {
95 fn default() -> Self {
96 Self::new()
97 }
98}
99
100#[tonic::async_trait]
101impl VectorIndexService for VectorIndexServer {
102 async fn create_index(
103 &self,
104 request: Request<CreateIndexRequest>,
105 ) -> Result<Response<CreateIndexResponse>, Status> {
106 let req = request.into_inner();
107 let name = req.name.clone();
108
109 if self.indexes.contains_key(&name) {
111 return Ok(Response::new(CreateIndexResponse {
112 success: false,
113 error: format!("Index '{}' already exists", name),
114 info: None,
115 }));
116 }
117
118 let proto_config = req.config.unwrap_or_default();
120 let config = HnswConfig {
121 max_connections: if proto_config.max_connections > 0 {
122 proto_config.max_connections as usize
123 } else {
124 16
125 },
126 max_connections_layer0: if proto_config.max_connections_layer0 > 0 {
127 proto_config.max_connections_layer0 as usize
128 } else {
129 32
130 },
131 ef_construction: if proto_config.ef_construction > 0 {
132 proto_config.ef_construction as usize
133 } else {
134 200
135 },
136 ef_search: if proto_config.ef_search > 0 {
137 proto_config.ef_search as usize
138 } else {
139 50
140 },
141 metric: Self::convert_metric(req.metric()),
142 ..Default::default()
143 };
144
145 let dimension = req.dimension as usize;
146 let index = HnswIndex::new(dimension, config.clone());
147 let created_at = std::time::SystemTime::now()
148 .duration_since(std::time::UNIX_EPOCH)
149 .unwrap()
150 .as_secs();
151
152 let entry = IndexEntry {
153 index: Arc::new(index),
154 name: name.clone(),
155 dimension,
156 metric: req.metric(),
157 config: proto_config.clone(),
158 created_at,
159 };
160
161 self.indexes.insert(name.clone(), entry);
162
163 tracing::info!("Created index '{}' with dimension {}", name, dimension);
164
165 Ok(Response::new(CreateIndexResponse {
166 success: true,
167 error: String::new(),
168 info: Some(IndexInfo {
169 name,
170 dimension: dimension as u32,
171 metric: req.metric.into(),
172 config: Some(proto_config),
173 created_at,
174 }),
175 }))
176 }
177
178 async fn drop_index(
179 &self,
180 request: Request<DropIndexRequest>,
181 ) -> Result<Response<DropIndexResponse>, Status> {
182 let name = request.into_inner().name;
183
184 match self.indexes.remove(&name) {
185 Some(_) => {
186 tracing::info!("Dropped index '{}'", name);
187 Ok(Response::new(DropIndexResponse {
188 success: true,
189 error: String::new(),
190 }))
191 }
192 None => Ok(Response::new(DropIndexResponse {
193 success: false,
194 error: format!("Index '{}' not found", name),
195 })),
196 }
197 }
198
199 async fn insert_batch(
200 &self,
201 request: Request<InsertBatchRequest>,
202 ) -> Result<Response<InsertBatchResponse>, Status> {
203 let start = Instant::now();
204 let req = request.into_inner();
205
206 let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
207
208 if req.vectors.len() != req.ids.len() * dimension {
210 return Err(Status::invalid_argument(format!(
211 "Vector data size mismatch: expected {} floats, got {}",
212 req.ids.len() * dimension,
213 req.vectors.len()
214 )));
215 }
216
217 let ids: Vec<u128> = req.ids.iter().map(|&id| id as u128).collect();
219
220 match index.insert_batch_flat(&ids, &req.vectors, dimension) {
222 Ok(count) => {
223 let duration_us = start.elapsed().as_micros() as u64;
224 tracing::debug!(
225 "Inserted {} vectors into '{}' in {}µs",
226 count,
227 req.index_name,
228 duration_us
229 );
230 Ok(Response::new(InsertBatchResponse {
231 inserted_count: count as u32,
232 error: String::new(),
233 duration_us,
234 }))
235 }
236 Err(e) => Ok(Response::new(InsertBatchResponse {
237 inserted_count: 0,
238 error: e,
239 duration_us: start.elapsed().as_micros() as u64,
240 })),
241 }
242 }
243
244 async fn insert_stream(
245 &self,
246 request: Request<Streaming<InsertStreamRequest>>,
247 ) -> Result<Response<InsertStreamResponse>, Status> {
248 let start = Instant::now();
249 let mut stream = request.into_inner();
250
251 let mut index_name: Option<String> = None;
252 let mut index: Option<Arc<HnswIndex>> = None;
253 let mut total_inserted = 0u32;
254 let mut errors = Vec::new();
255
256 while let Some(result) = stream.next().await {
257 match result {
258 Ok(req) => {
259 if index.is_none() {
261 if req.index_name.is_empty() {
262 errors.push("First message must include index_name".to_string());
263 continue;
264 }
265 index_name = Some(req.index_name.clone());
266 match self.get_index(&req.index_name) {
267 Ok(idx) => index = Some(idx),
268 Err(e) => {
269 errors.push(e.to_string());
270 break;
271 }
272 }
273 }
274
275 if let Some(ref idx) = index {
277 let vector: Vec<f32> = req.vector;
278 match idx.insert_one_from_slice(req.id as u128, &vector) {
279 Ok(()) => total_inserted += 1,
280 Err(e) => errors.push(format!("ID {}: {}", req.id, e)),
281 }
282 }
283 }
284 Err(e) => {
285 errors.push(format!("Stream error: {}", e));
286 break;
287 }
288 }
289 }
290
291 let duration_us = start.elapsed().as_micros() as u64;
292
293 if let Some(name) = &index_name {
294 tracing::debug!(
295 "Stream inserted {} vectors into '{}' in {}µs",
296 total_inserted,
297 name,
298 duration_us
299 );
300 }
301
302 Ok(Response::new(InsertStreamResponse {
303 total_inserted,
304 errors,
305 duration_us,
306 }))
307 }
308
309 async fn search(
310 &self,
311 request: Request<SearchRequest>,
312 ) -> Result<Response<SearchResponse>, Status> {
313 let start = Instant::now();
314 let req = request.into_inner();
315
316 let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
317
318 if req.query.len() != dimension {
320 return Err(Status::invalid_argument(format!(
321 "Query dimension mismatch: expected {}, got {}",
322 dimension,
323 req.query.len()
324 )));
325 }
326
327 let k = req.k.max(1) as usize;
328
329 let results = match index.search(&req.query, k) {
331 Ok(r) => r,
332 Err(e) => {
333 return Ok(Response::new(SearchResponse {
334 results: vec![],
335 duration_us: start.elapsed().as_micros() as u64,
336 error: e,
337 }));
338 }
339 };
340
341 let duration_us = start.elapsed().as_micros() as u64;
342
343 Ok(Response::new(SearchResponse {
344 results: results
345 .into_iter()
346 .map(|(id, distance)| SearchResult {
347 id: id as u64,
348 distance,
349 })
350 .collect(),
351 duration_us,
352 error: String::new(),
353 }))
354 }
355
356 async fn search_batch(
357 &self,
358 request: Request<SearchBatchRequest>,
359 ) -> Result<Response<SearchBatchResponse>, Status> {
360 let start = Instant::now();
361 let req = request.into_inner();
362
363 let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
364 let num_queries = req.num_queries as usize;
365 let k = req.k.max(1) as usize;
366
367 if req.queries.len() != num_queries * dimension {
369 return Err(Status::invalid_argument(format!(
370 "Query data size mismatch: expected {} floats, got {}",
371 num_queries * dimension,
372 req.queries.len()
373 )));
374 }
375
376 let mut all_results = Vec::with_capacity(num_queries);
378
379 for i in 0..num_queries {
380 let query = &req.queries[i * dimension..(i + 1) * dimension];
381 let results = match index.search(query, k) {
382 Ok(r) => r,
383 Err(_) => vec![],
384 };
385
386 all_results.push(QueryResults {
387 results: results
388 .into_iter()
389 .map(|(id, distance)| SearchResult {
390 id: id as u64,
391 distance,
392 })
393 .collect(),
394 });
395 }
396
397 let duration_us = start.elapsed().as_micros() as u64;
398
399 Ok(Response::new(SearchBatchResponse {
400 results: all_results,
401 duration_us,
402 }))
403 }
404
405 async fn get_stats(
406 &self,
407 request: Request<GetStatsRequest>,
408 ) -> Result<Response<GetStatsResponse>, Status> {
409 let name = request.into_inner().index_name;
410
411 match self.indexes.get(&name) {
412 Some(entry) => {
413 let stats = entry.index.stats();
414 Ok(Response::new(GetStatsResponse {
415 stats: Some(IndexStats {
416 num_vectors: stats.num_vectors as u64,
417 dimension: entry.dimension as u32,
418 max_layer: stats.max_layer as u32,
419 memory_bytes: 0, avg_connections: stats.avg_connections,
421 }),
422 error: String::new(),
423 }))
424 }
425 None => Ok(Response::new(GetStatsResponse {
426 stats: None,
427 error: format!("Index '{}' not found", name),
428 })),
429 }
430 }
431
432 async fn health_check(
433 &self,
434 _request: Request<HealthCheckRequest>,
435 ) -> Result<Response<HealthCheckResponse>, Status> {
436 let indexes: Vec<String> = self.indexes.iter().map(|e| e.name.clone()).collect();
437
438 Ok(Response::new(HealthCheckResponse {
439 status: proto::health_check_response::Status::Serving.into(),
440 version: VERSION.to_string(),
441 indexes,
442 }))
443 }
444}