1use std::collections::HashMap;
22use std::sync::Arc;
23
24use parking_lot::RwLock;
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27
28use crate::distance::Distance;
29use crate::encoding::{Codec, EncodedVector, EncodingError};
30use crate::index::{HnswIndex, HnswParams, IndexError, NodeId, SearchResult};
31use crate::turbo_hnsw::TurboHnswIndex;
32use crate::turbo_index::TurboTable;
33
34pub type RowKey = Vec<u8>;
36
37#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
50#[serde(rename_all = "snake_case")]
51#[non_exhaustive]
52pub enum IndexAlgorithm {
53 Hnsw,
55 Flat,
57}
58
59impl Default for IndexAlgorithm {
60 fn default() -> Self {
61 Self::Hnsw
62 }
63}
64
65#[derive(Clone, Debug, Serialize, Deserialize)]
67pub struct TableSchema {
68 pub name: String,
70 pub dim: u16,
72 pub codec: Codec,
74 pub distance: Distance,
76 pub hnsw: HnswParams,
80 #[serde(default)]
84 pub algorithm: IndexAlgorithm,
85}
86
87#[derive(Clone, Debug, Serialize, Deserialize)]
89pub struct VectorRow {
90 pub key: RowKey,
92 pub vector: EncodedVector,
94 pub metadata: HashMap<String, serde_json::Value>,
97 pub created_at: u64,
99 pub updated_at: u64,
102}
103
104#[derive(Debug, Error)]
106#[non_exhaustive]
107pub enum StoreError {
108 #[error("table not found: {0}")]
110 UnknownTable(String),
111 #[error("table already exists: {0}")]
113 TableExists(String),
114 #[error("dimension mismatch: table {table} expects {expected}, got {got}")]
116 DimensionMismatch {
117 table: String,
119 expected: u16,
121 got: u16,
123 },
124 #[error("row not found in table {table}: {key:?}")]
126 RowNotFound {
127 table: String,
129 key: RowKey,
131 },
132 #[error("encoding: {0}")]
134 Encoding(#[from] EncodingError),
135 #[error("index: {0}")]
137 Index(#[from] IndexError),
138 #[error("backend: {0}")]
140 Backend(String),
141}
142
143pub trait Backend: Send + Sync {
150 fn put_row(&self, table: &str, key: &[u8], row: &VectorRow) -> Result<(), StoreError>;
153
154 fn get_row(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError>;
157
158 fn delete_row(&self, table: &str, key: &[u8]) -> Result<bool, StoreError>;
161
162 fn for_each_row(&self, table: &str, f: &mut RowVisitor<'_>) -> Result<(), StoreError>;
165
166 fn put_schema(&self, schema: &TableSchema) -> Result<(), StoreError>;
168
169 fn list_schemas(&self) -> Result<Vec<TableSchema>, StoreError>;
171}
172
173pub type RowVisitor<'a> = dyn FnMut(&[u8], &VectorRow) -> Result<(), StoreError> + 'a;
175
176#[derive(Default)]
180pub struct MemoryBackend {
181 rows: RwLock<HashMap<String, HashMap<Vec<u8>, VectorRow>>>,
182 schemas: RwLock<HashMap<String, TableSchema>>,
183}
184
185impl MemoryBackend {
186 #[must_use]
188 pub fn new() -> Self {
189 Self::default()
190 }
191}
192
193impl Backend for MemoryBackend {
194 fn put_row(&self, table: &str, key: &[u8], row: &VectorRow) -> Result<(), StoreError> {
195 let mut rows = self.rows.write();
196 let entry = rows.entry(table.to_string()).or_default();
197 entry.insert(key.to_vec(), row.clone());
198 Ok(())
199 }
200
201 fn get_row(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError> {
202 let rows = self.rows.read();
203 Ok(rows.get(table).and_then(|m| m.get(key).cloned()))
204 }
205
206 fn delete_row(&self, table: &str, key: &[u8]) -> Result<bool, StoreError> {
207 let mut rows = self.rows.write();
208 Ok(rows.get_mut(table).is_some_and(|m| m.remove(key).is_some()))
209 }
210
211 fn for_each_row(&self, table: &str, f: &mut RowVisitor<'_>) -> Result<(), StoreError> {
212 let rows = self.rows.read();
213 if let Some(m) = rows.get(table) {
214 for (k, v) in m {
215 f(k, v)?;
216 }
217 }
218 Ok(())
219 }
220
221 fn put_schema(&self, schema: &TableSchema) -> Result<(), StoreError> {
222 self.schemas
223 .write()
224 .insert(schema.name.clone(), schema.clone());
225 Ok(())
226 }
227
228 fn list_schemas(&self) -> Result<Vec<TableSchema>, StoreError> {
229 Ok(self.schemas.read().values().cloned().collect())
230 }
231}
232
233struct TableState {
237 schema: TableSchema,
238 ann: AnnContainer,
239 key_to_node: HashMap<RowKey, NodeId>,
241 node_to_key: HashMap<NodeId, RowKey>,
244 next_node_id: NodeId,
248}
249
250enum AnnContainer {
266 Hnsw(HnswIndex),
267 TurboFlat(TurboTable),
268 TurboHnsw2(TurboHnswIndex<2>),
269 TurboHnsw3(TurboHnswIndex<3>),
270 TurboHnsw4(TurboHnswIndex<4>),
271}
272
273impl AnnContainer {
274 fn new(schema: &TableSchema) -> Result<Self, StoreError> {
275 if let Some(bits) = schema.codec.turbovec_bits() {
276 match schema.algorithm {
277 IndexAlgorithm::Flat => {
278 let table = TurboTable::new(schema.distance, schema.dim, bits)?;
279 Ok(Self::TurboFlat(table))
280 }
281 IndexAlgorithm::Hnsw => match bits {
282 2 => Ok(Self::TurboHnsw2(TurboHnswIndex::<2>::new(
283 schema.distance,
284 schema.dim,
285 schema.hnsw,
286 )?)),
287 3 => Ok(Self::TurboHnsw3(TurboHnswIndex::<3>::new(
288 schema.distance,
289 schema.dim,
290 schema.hnsw,
291 )?)),
292 4 => Ok(Self::TurboHnsw4(TurboHnswIndex::<4>::new(
293 schema.distance,
294 schema.dim,
295 schema.hnsw,
296 )?)),
297 _ => Err(StoreError::Index(IndexError::Empty)),
298 },
299 }
300 } else {
301 Ok(Self::Hnsw(HnswIndex::new(schema.distance, schema.hnsw)))
305 }
306 }
307
308 fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
309 match self {
310 Self::Hnsw(idx) => idx.insert(id, vector),
311 Self::TurboFlat(t) => t.insert(id, vector),
312 Self::TurboHnsw2(t) => t.insert(id, vector),
313 Self::TurboHnsw3(t) => t.insert(id, vector),
314 Self::TurboHnsw4(t) => t.insert(id, vector),
315 }
316 }
317
318 fn delete(&mut self, id: NodeId) -> bool {
319 match self {
320 Self::Hnsw(idx) => idx.delete(id),
321 Self::TurboFlat(t) => t.delete(id),
322 Self::TurboHnsw2(t) => t.delete(id),
323 Self::TurboHnsw3(t) => t.delete(id),
324 Self::TurboHnsw4(t) => t.delete(id),
325 }
326 }
327
328 fn search(
329 &self,
330 query: &[f32],
331 k: usize,
332 ef: Option<usize>,
333 ) -> Result<Vec<SearchResult>, IndexError> {
334 match self {
335 Self::Hnsw(idx) => idx.search(query, k, ef),
336 Self::TurboFlat(t) => t.search(query, k, ef),
337 Self::TurboHnsw2(t) => t.search(query, k, ef),
338 Self::TurboHnsw3(t) => t.search(query, k, ef),
339 Self::TurboHnsw4(t) => t.search(query, k, ef),
340 }
341 }
342
343 fn len(&self) -> usize {
344 match self {
345 Self::Hnsw(idx) => idx.len(),
346 Self::TurboFlat(t) => t.len(),
347 Self::TurboHnsw2(t) => t.len(),
348 Self::TurboHnsw3(t) => t.len(),
349 Self::TurboHnsw4(t) => t.len(),
350 }
351 }
352}
353
354pub struct VectorStore {
356 backend: Arc<dyn Backend>,
357 tables: RwLock<HashMap<String, Arc<parking_lot::Mutex<TableState>>>>,
358}
359
360impl VectorStore {
361 pub fn open(backend: Arc<dyn Backend>) -> Result<Self, StoreError> {
374 let tables = RwLock::new(HashMap::new());
375 let store = Self { backend, tables };
376 let schemas = store.backend.list_schemas()?;
377 for schema in schemas {
378 store.rehydrate_table(&schema)?;
379 }
380 Ok(store)
381 }
382
383 #[must_use]
386 pub fn in_memory() -> Self {
387 Self {
388 backend: Arc::new(MemoryBackend::new()),
389 tables: RwLock::new(HashMap::new()),
390 }
391 }
392
393 pub fn create_table(&self, schema: TableSchema) -> Result<(), StoreError> {
400 let mut tables = self.tables.write();
401 if tables.contains_key(&schema.name) {
402 return Err(StoreError::TableExists(schema.name));
403 }
404 let state = TableState {
405 schema: schema.clone(),
406 ann: AnnContainer::new(&schema)?,
407 key_to_node: HashMap::new(),
408 node_to_key: HashMap::new(),
409 next_node_id: 1,
410 };
411 self.backend.put_schema(&schema)?;
412 tables.insert(
413 schema.name.clone(),
414 Arc::new(parking_lot::Mutex::new(state)),
415 );
416 Ok(())
417 }
418
419 pub fn tables(&self) -> Vec<TableSchema> {
421 self.tables
422 .read()
423 .values()
424 .map(|s| s.lock().schema.clone())
425 .collect()
426 }
427
428 pub fn upsert(
443 &self,
444 table: &str,
445 key: RowKey,
446 vector: &[f32],
447 metadata: HashMap<String, serde_json::Value>,
448 ) -> Result<(), StoreError> {
449 let state = self.table_state(table)?;
450 let mut state = state.lock();
451 let dim = u16::try_from(vector.len()).unwrap_or(u16::MAX);
452 if dim != state.schema.dim {
453 return Err(StoreError::DimensionMismatch {
454 table: table.to_string(),
455 expected: state.schema.dim,
456 got: dim,
457 });
458 }
459 let codec_encoder = state.schema.codec.encoder();
460 let encoded = codec_encoder.encode(vector)?;
461 let now = now_millis();
462 let prior = self.backend.get_row(table, &key)?;
463 let row = VectorRow {
464 key: key.clone(),
465 vector: encoded,
466 metadata,
467 created_at: prior.as_ref().map_or(now, |r| r.created_at),
468 updated_at: now,
469 };
470 self.backend.put_row(table, &key, &row)?;
471 if let Some(&old_node) = state.key_to_node.get(&key) {
472 state.ann.delete(old_node);
473 state.node_to_key.remove(&old_node);
474 }
475 let node_id = state.next_node_id;
476 state.next_node_id += 1;
477 state.ann.insert(node_id, vector.to_vec())?;
478 state.key_to_node.insert(key.clone(), node_id);
479 state.node_to_key.insert(node_id, key);
480 Ok(())
481 }
482
483 pub fn get(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError> {
490 let _ = self.table_state(table)?;
491 self.backend.get_row(table, key)
492 }
493
494 pub fn delete(&self, table: &str, key: &[u8]) -> Result<bool, StoreError> {
502 let state = self.table_state(table)?;
503 let mut state = state.lock();
504 let removed = self.backend.delete_row(table, key)?;
505 if let Some(node_id) = state.key_to_node.remove(key) {
506 state.ann.delete(node_id);
507 state.node_to_key.remove(&node_id);
508 }
509 Ok(removed)
510 }
511
512 pub fn search(
522 &self,
523 table: &str,
524 query: &[f32],
525 k: usize,
526 ef: Option<usize>,
527 ) -> Result<Vec<(VectorRow, f32)>, StoreError> {
528 let state = self.table_state(table)?;
529 let state = state.lock();
530 let dim = u16::try_from(query.len()).unwrap_or(u16::MAX);
531 if dim != state.schema.dim {
532 return Err(StoreError::DimensionMismatch {
533 table: table.to_string(),
534 expected: state.schema.dim,
535 got: dim,
536 });
537 }
538 let hits: Vec<SearchResult> = state.ann.search(query, k, ef)?;
539 let mut out = Vec::with_capacity(hits.len());
540 for hit in hits {
541 if let Some(key) = state.node_to_key.get(&hit.id) {
542 if let Some(row) = self.backend.get_row(table, key)? {
543 out.push((row, hit.score));
544 }
545 }
546 }
547 Ok(out)
548 }
549
550 pub fn stats(&self, table: &str) -> Result<TableStats, StoreError> {
557 let state = self.table_state(table)?;
558 let state = state.lock();
559 Ok(TableStats {
560 name: state.schema.name.clone(),
561 dim: state.schema.dim,
562 codec: state.schema.codec,
563 distance: state.schema.distance,
564 live_rows: state.ann.len(),
565 tracked_rows: state.key_to_node.len(),
566 })
567 }
568
569 fn table_state(&self, table: &str) -> Result<Arc<parking_lot::Mutex<TableState>>, StoreError> {
570 self.tables
571 .read()
572 .get(table)
573 .cloned()
574 .ok_or_else(|| StoreError::UnknownTable(table.to_string()))
575 }
576
577 fn rehydrate_table(&self, schema: &TableSchema) -> Result<(), StoreError> {
578 let state = TableState {
579 schema: schema.clone(),
580 ann: AnnContainer::new(schema)?,
581 key_to_node: HashMap::new(),
582 node_to_key: HashMap::new(),
583 next_node_id: 1,
584 };
585 let cell = Arc::new(parking_lot::Mutex::new(state));
586 self.tables
587 .write()
588 .insert(schema.name.clone(), cell.clone());
589 let mut guard = cell.lock();
590 let encoder = guard.schema.codec.encoder();
591 let mut to_insert: Vec<(NodeId, RowKey, Vec<f32>)> = Vec::new();
592 let table_name = schema.name.clone();
593 let mut next = 1u64;
594 self.backend.for_each_row(&table_name, &mut |k, row| {
595 let v = encoder.decode(&row.vector)?;
596 to_insert.push((next, k.to_vec(), v));
597 next += 1;
598 Ok(())
599 })?;
600 for (node, key, v) in to_insert {
601 guard.ann.insert(node, v)?;
602 guard.key_to_node.insert(key.clone(), node);
603 guard.node_to_key.insert(node, key);
604 guard.next_node_id = node + 1;
605 }
606 Ok(())
607 }
608}
609
610#[derive(Clone, Debug, Serialize, Deserialize)]
612pub struct TableStats {
613 pub name: String,
615 pub dim: u16,
617 pub codec: Codec,
619 pub distance: Distance,
621 pub live_rows: usize,
623 pub tracked_rows: usize,
625}
626
627fn now_millis() -> u64 {
628 use std::time::{SystemTime, UNIX_EPOCH};
629 SystemTime::now()
630 .duration_since(UNIX_EPOCH)
631 .map(|d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
632 .unwrap_or(0)
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use crate::index::HnswParams;
639
640 fn schema(name: &str, dim: u16) -> TableSchema {
641 TableSchema {
642 name: name.to_string(),
643 dim,
644 codec: Codec::Int8Quantized,
645 distance: Distance::Euclidean,
646 hnsw: HnswParams::default(),
647 algorithm: IndexAlgorithm::Hnsw,
648 }
649 }
650
651 #[test]
652 fn create_and_list_tables() {
653 let store = VectorStore::in_memory();
654 store.create_table(schema("t", 4)).unwrap();
655 let tables = store.tables();
656 assert_eq!(tables.len(), 1);
657 assert_eq!(tables[0].name, "t");
658 assert_eq!(tables[0].dim, 4);
659 }
660
661 #[test]
662 fn duplicate_table_rejected() {
663 let store = VectorStore::in_memory();
664 store.create_table(schema("t", 4)).unwrap();
665 assert!(matches!(
666 store.create_table(schema("t", 4)),
667 Err(StoreError::TableExists(_))
668 ));
669 }
670
671 #[test]
672 fn upsert_get_delete_round_trip() {
673 let store = VectorStore::in_memory();
674 store.create_table(schema("t", 3)).unwrap();
675 store
676 .upsert("t", b"a".to_vec(), &[1.0, 2.0, 3.0], HashMap::new())
677 .unwrap();
678 let row = store.get("t", b"a").unwrap().expect("row present");
679 assert_eq!(row.key, b"a");
680 assert_eq!(row.vector.dim, 3);
681 assert!(store.delete("t", b"a").unwrap());
682 assert!(store.get("t", b"a").unwrap().is_none());
683 assert!(!store.delete("t", b"a").unwrap());
684 }
685
686 #[test]
687 fn dimension_mismatch_rejected() {
688 let store = VectorStore::in_memory();
689 store.create_table(schema("t", 3)).unwrap();
690 assert!(matches!(
691 store.upsert("t", b"a".to_vec(), &[1.0, 2.0], HashMap::new()),
692 Err(StoreError::DimensionMismatch { .. })
693 ));
694 }
695
696 #[test]
697 fn search_returns_nearest_first() {
698 let store = VectorStore::in_memory();
699 store.create_table(schema("t", 2)).unwrap();
700 for (k, v) in [
701 (&b"origin"[..], [0.0_f32, 0.0]),
702 (&b"unit_x"[..], [1.0, 0.0]),
703 (&b"unit_y"[..], [0.0, 1.0]),
704 (&b"diag"[..], [1.0, 1.0]),
705 ] {
706 store.upsert("t", k.to_vec(), &v, HashMap::new()).unwrap();
707 }
708 let res = store.search("t", &[0.05, 0.05], 1, None).unwrap();
709 assert_eq!(res.len(), 1);
710 assert_eq!(res[0].0.key, b"origin");
711 }
712
713 #[test]
714 fn rehydrate_rebuilds_index() {
715 let backend = Arc::new(MemoryBackend::new());
716 let store = VectorStore::open(backend.clone()).unwrap();
717 store.create_table(schema("t", 2)).unwrap();
718 for i in 0..10_u8 {
719 let k = format!("k{i}").into_bytes();
720 let v = [f32::from(i), f32::from(i) * 2.0];
721 store.upsert("t", k, &v, HashMap::new()).unwrap();
722 }
723 drop(store);
725 let reopened = VectorStore::open(backend).unwrap();
726 let stats = reopened.stats("t").unwrap();
727 assert_eq!(stats.live_rows, 10);
728 let res = reopened.search("t", &[3.0, 6.0], 1, None).unwrap();
729 assert_eq!(res[0].0.key, b"k3");
730 }
731
732 #[test]
733 fn stats_reports_live_rows() {
734 let store = VectorStore::in_memory();
735 store.create_table(schema("t", 2)).unwrap();
736 store
737 .upsert("t", b"a".to_vec(), &[1.0, 2.0], HashMap::new())
738 .unwrap();
739 store
740 .upsert("t", b"b".to_vec(), &[3.0, 4.0], HashMap::new())
741 .unwrap();
742 let s = store.stats("t").unwrap();
743 assert_eq!(s.live_rows, 2);
744 assert_eq!(s.tracked_rows, 2);
745 }
746}