1use std::collections::HashMap;
22use std::sync::Arc;
23
24use parking_lot::RwLock;
25use serde::{Deserialize, Serialize};
26use thiserror::Error;
27
28use crate::distance::Distance;
29use crate::encoding::{Codec, EncodedVector, EncodingError};
30use crate::index::{HnswIndex, HnswParams, IndexError, NodeId, SearchResult};
31use crate::turbo_index::TurboTable;
32
33pub type RowKey = Vec<u8>;
35
36#[derive(Clone, Debug, Serialize, Deserialize)]
38pub struct TableSchema {
39 pub name: String,
41 pub dim: u16,
43 pub codec: Codec,
45 pub distance: Distance,
47 pub hnsw: HnswParams,
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct VectorRow {
54 pub key: RowKey,
56 pub vector: EncodedVector,
58 pub metadata: HashMap<String, serde_json::Value>,
61 pub created_at: u64,
63 pub updated_at: u64,
66}
67
68#[derive(Debug, Error)]
70#[non_exhaustive]
71pub enum StoreError {
72 #[error("table not found: {0}")]
74 UnknownTable(String),
75 #[error("table already exists: {0}")]
77 TableExists(String),
78 #[error("dimension mismatch: table {table} expects {expected}, got {got}")]
80 DimensionMismatch {
81 table: String,
83 expected: u16,
85 got: u16,
87 },
88 #[error("row not found in table {table}: {key:?}")]
90 RowNotFound {
91 table: String,
93 key: RowKey,
95 },
96 #[error("encoding: {0}")]
98 Encoding(#[from] EncodingError),
99 #[error("index: {0}")]
101 Index(#[from] IndexError),
102 #[error("backend: {0}")]
104 Backend(String),
105}
106
107pub trait Backend: Send + Sync {
114 fn put_row(&self, table: &str, key: &[u8], row: &VectorRow) -> Result<(), StoreError>;
117
118 fn get_row(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError>;
121
122 fn delete_row(&self, table: &str, key: &[u8]) -> Result<bool, StoreError>;
125
126 fn for_each_row(&self, table: &str, f: &mut RowVisitor<'_>) -> Result<(), StoreError>;
129
130 fn put_schema(&self, schema: &TableSchema) -> Result<(), StoreError>;
132
133 fn list_schemas(&self) -> Result<Vec<TableSchema>, StoreError>;
135}
136
137pub type RowVisitor<'a> = dyn FnMut(&[u8], &VectorRow) -> Result<(), StoreError> + 'a;
139
140#[derive(Default)]
144pub struct MemoryBackend {
145 rows: RwLock<HashMap<String, HashMap<Vec<u8>, VectorRow>>>,
146 schemas: RwLock<HashMap<String, TableSchema>>,
147}
148
149impl MemoryBackend {
150 #[must_use]
152 pub fn new() -> Self {
153 Self::default()
154 }
155}
156
157impl Backend for MemoryBackend {
158 fn put_row(&self, table: &str, key: &[u8], row: &VectorRow) -> Result<(), StoreError> {
159 let mut rows = self.rows.write();
160 let entry = rows.entry(table.to_string()).or_default();
161 entry.insert(key.to_vec(), row.clone());
162 Ok(())
163 }
164
165 fn get_row(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError> {
166 let rows = self.rows.read();
167 Ok(rows.get(table).and_then(|m| m.get(key).cloned()))
168 }
169
170 fn delete_row(&self, table: &str, key: &[u8]) -> Result<bool, StoreError> {
171 let mut rows = self.rows.write();
172 Ok(rows.get_mut(table).is_some_and(|m| m.remove(key).is_some()))
173 }
174
175 fn for_each_row(&self, table: &str, f: &mut RowVisitor<'_>) -> Result<(), StoreError> {
176 let rows = self.rows.read();
177 if let Some(m) = rows.get(table) {
178 for (k, v) in m {
179 f(k, v)?;
180 }
181 }
182 Ok(())
183 }
184
185 fn put_schema(&self, schema: &TableSchema) -> Result<(), StoreError> {
186 self.schemas
187 .write()
188 .insert(schema.name.clone(), schema.clone());
189 Ok(())
190 }
191
192 fn list_schemas(&self) -> Result<Vec<TableSchema>, StoreError> {
193 Ok(self.schemas.read().values().cloned().collect())
194 }
195}
196
197struct TableState {
201 schema: TableSchema,
202 ann: AnnContainer,
203 key_to_node: HashMap<RowKey, NodeId>,
205 node_to_key: HashMap<NodeId, RowKey>,
208 next_node_id: NodeId,
212}
213
214enum AnnContainer {
221 Hnsw(HnswIndex),
222 Turbo(TurboTable),
223}
224
225impl AnnContainer {
226 fn new(schema: &TableSchema) -> Result<Self, StoreError> {
227 if let Some(bits) = schema.codec.turbovec_bits() {
228 let table = TurboTable::new(schema.distance, schema.dim, bits)?;
229 Ok(Self::Turbo(table))
230 } else {
231 Ok(Self::Hnsw(HnswIndex::new(schema.distance, schema.hnsw)))
232 }
233 }
234
235 fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
236 match self {
237 Self::Hnsw(idx) => idx.insert(id, vector),
238 Self::Turbo(t) => t.insert(id, vector),
239 }
240 }
241
242 fn delete(&mut self, id: NodeId) -> bool {
243 match self {
244 Self::Hnsw(idx) => idx.delete(id),
245 Self::Turbo(t) => t.delete(id),
246 }
247 }
248
249 fn search(
250 &self,
251 query: &[f32],
252 k: usize,
253 ef: Option<usize>,
254 ) -> Result<Vec<SearchResult>, IndexError> {
255 match self {
256 Self::Hnsw(idx) => idx.search(query, k, ef),
257 Self::Turbo(t) => t.search(query, k, ef),
258 }
259 }
260
261 fn len(&self) -> usize {
262 match self {
263 Self::Hnsw(idx) => idx.len(),
264 Self::Turbo(t) => t.len(),
265 }
266 }
267}
268
269pub struct VectorStore {
271 backend: Arc<dyn Backend>,
272 tables: RwLock<HashMap<String, Arc<parking_lot::Mutex<TableState>>>>,
273}
274
275impl VectorStore {
276 pub fn open(backend: Arc<dyn Backend>) -> Result<Self, StoreError> {
289 let tables = RwLock::new(HashMap::new());
290 let store = Self { backend, tables };
291 let schemas = store.backend.list_schemas()?;
292 for schema in schemas {
293 store.rehydrate_table(&schema)?;
294 }
295 Ok(store)
296 }
297
298 #[must_use]
301 pub fn in_memory() -> Self {
302 Self {
303 backend: Arc::new(MemoryBackend::new()),
304 tables: RwLock::new(HashMap::new()),
305 }
306 }
307
308 pub fn create_table(&self, schema: TableSchema) -> Result<(), StoreError> {
315 let mut tables = self.tables.write();
316 if tables.contains_key(&schema.name) {
317 return Err(StoreError::TableExists(schema.name));
318 }
319 let state = TableState {
320 schema: schema.clone(),
321 ann: AnnContainer::new(&schema)?,
322 key_to_node: HashMap::new(),
323 node_to_key: HashMap::new(),
324 next_node_id: 1,
325 };
326 self.backend.put_schema(&schema)?;
327 tables.insert(
328 schema.name.clone(),
329 Arc::new(parking_lot::Mutex::new(state)),
330 );
331 Ok(())
332 }
333
334 pub fn tables(&self) -> Vec<TableSchema> {
336 self.tables
337 .read()
338 .values()
339 .map(|s| s.lock().schema.clone())
340 .collect()
341 }
342
343 pub fn upsert(
358 &self,
359 table: &str,
360 key: RowKey,
361 vector: &[f32],
362 metadata: HashMap<String, serde_json::Value>,
363 ) -> Result<(), StoreError> {
364 let state = self.table_state(table)?;
365 let mut state = state.lock();
366 let dim = u16::try_from(vector.len()).unwrap_or(u16::MAX);
367 if dim != state.schema.dim {
368 return Err(StoreError::DimensionMismatch {
369 table: table.to_string(),
370 expected: state.schema.dim,
371 got: dim,
372 });
373 }
374 let codec_encoder = state.schema.codec.encoder();
375 let encoded = codec_encoder.encode(vector)?;
376 let now = now_millis();
377 let prior = self.backend.get_row(table, &key)?;
378 let row = VectorRow {
379 key: key.clone(),
380 vector: encoded,
381 metadata,
382 created_at: prior.as_ref().map_or(now, |r| r.created_at),
383 updated_at: now,
384 };
385 self.backend.put_row(table, &key, &row)?;
386 if let Some(&old_node) = state.key_to_node.get(&key) {
387 state.ann.delete(old_node);
388 state.node_to_key.remove(&old_node);
389 }
390 let node_id = state.next_node_id;
391 state.next_node_id += 1;
392 state.ann.insert(node_id, vector.to_vec())?;
393 state.key_to_node.insert(key.clone(), node_id);
394 state.node_to_key.insert(node_id, key);
395 Ok(())
396 }
397
398 pub fn get(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError> {
405 let _ = self.table_state(table)?;
406 self.backend.get_row(table, key)
407 }
408
409 pub fn delete(&self, table: &str, key: &[u8]) -> Result<bool, StoreError> {
417 let state = self.table_state(table)?;
418 let mut state = state.lock();
419 let removed = self.backend.delete_row(table, key)?;
420 if let Some(node_id) = state.key_to_node.remove(key) {
421 state.ann.delete(node_id);
422 state.node_to_key.remove(&node_id);
423 }
424 Ok(removed)
425 }
426
427 pub fn search(
437 &self,
438 table: &str,
439 query: &[f32],
440 k: usize,
441 ef: Option<usize>,
442 ) -> Result<Vec<(VectorRow, f32)>, StoreError> {
443 let state = self.table_state(table)?;
444 let state = state.lock();
445 let dim = u16::try_from(query.len()).unwrap_or(u16::MAX);
446 if dim != state.schema.dim {
447 return Err(StoreError::DimensionMismatch {
448 table: table.to_string(),
449 expected: state.schema.dim,
450 got: dim,
451 });
452 }
453 let hits: Vec<SearchResult> = state.ann.search(query, k, ef)?;
454 let mut out = Vec::with_capacity(hits.len());
455 for hit in hits {
456 if let Some(key) = state.node_to_key.get(&hit.id) {
457 if let Some(row) = self.backend.get_row(table, key)? {
458 out.push((row, hit.score));
459 }
460 }
461 }
462 Ok(out)
463 }
464
465 pub fn stats(&self, table: &str) -> Result<TableStats, StoreError> {
472 let state = self.table_state(table)?;
473 let state = state.lock();
474 Ok(TableStats {
475 name: state.schema.name.clone(),
476 dim: state.schema.dim,
477 codec: state.schema.codec,
478 distance: state.schema.distance,
479 live_rows: state.ann.len(),
480 tracked_rows: state.key_to_node.len(),
481 })
482 }
483
484 fn table_state(&self, table: &str) -> Result<Arc<parking_lot::Mutex<TableState>>, StoreError> {
485 self.tables
486 .read()
487 .get(table)
488 .cloned()
489 .ok_or_else(|| StoreError::UnknownTable(table.to_string()))
490 }
491
492 fn rehydrate_table(&self, schema: &TableSchema) -> Result<(), StoreError> {
493 let state = TableState {
494 schema: schema.clone(),
495 ann: AnnContainer::new(schema)?,
496 key_to_node: HashMap::new(),
497 node_to_key: HashMap::new(),
498 next_node_id: 1,
499 };
500 let cell = Arc::new(parking_lot::Mutex::new(state));
501 self.tables
502 .write()
503 .insert(schema.name.clone(), cell.clone());
504 let mut guard = cell.lock();
505 let encoder = guard.schema.codec.encoder();
506 let mut to_insert: Vec<(NodeId, RowKey, Vec<f32>)> = Vec::new();
507 let table_name = schema.name.clone();
508 let mut next = 1u64;
509 self.backend.for_each_row(&table_name, &mut |k, row| {
510 let v = encoder.decode(&row.vector)?;
511 to_insert.push((next, k.to_vec(), v));
512 next += 1;
513 Ok(())
514 })?;
515 for (node, key, v) in to_insert {
516 guard.ann.insert(node, v)?;
517 guard.key_to_node.insert(key.clone(), node);
518 guard.node_to_key.insert(node, key);
519 guard.next_node_id = node + 1;
520 }
521 Ok(())
522 }
523}
524
525#[derive(Clone, Debug, Serialize, Deserialize)]
527pub struct TableStats {
528 pub name: String,
530 pub dim: u16,
532 pub codec: Codec,
534 pub distance: Distance,
536 pub live_rows: usize,
538 pub tracked_rows: usize,
540}
541
542fn now_millis() -> u64 {
543 use std::time::{SystemTime, UNIX_EPOCH};
544 SystemTime::now()
545 .duration_since(UNIX_EPOCH)
546 .map(|d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
547 .unwrap_or(0)
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553 use crate::index::HnswParams;
554
555 fn schema(name: &str, dim: u16) -> TableSchema {
556 TableSchema {
557 name: name.to_string(),
558 dim,
559 codec: Codec::Int8Quantized,
560 distance: Distance::Euclidean,
561 hnsw: HnswParams::default(),
562 }
563 }
564
565 #[test]
566 fn create_and_list_tables() {
567 let store = VectorStore::in_memory();
568 store.create_table(schema("t", 4)).unwrap();
569 let tables = store.tables();
570 assert_eq!(tables.len(), 1);
571 assert_eq!(tables[0].name, "t");
572 assert_eq!(tables[0].dim, 4);
573 }
574
575 #[test]
576 fn duplicate_table_rejected() {
577 let store = VectorStore::in_memory();
578 store.create_table(schema("t", 4)).unwrap();
579 assert!(matches!(
580 store.create_table(schema("t", 4)),
581 Err(StoreError::TableExists(_))
582 ));
583 }
584
585 #[test]
586 fn upsert_get_delete_round_trip() {
587 let store = VectorStore::in_memory();
588 store.create_table(schema("t", 3)).unwrap();
589 store
590 .upsert("t", b"a".to_vec(), &[1.0, 2.0, 3.0], HashMap::new())
591 .unwrap();
592 let row = store.get("t", b"a").unwrap().expect("row present");
593 assert_eq!(row.key, b"a");
594 assert_eq!(row.vector.dim, 3);
595 assert!(store.delete("t", b"a").unwrap());
596 assert!(store.get("t", b"a").unwrap().is_none());
597 assert!(!store.delete("t", b"a").unwrap());
598 }
599
600 #[test]
601 fn dimension_mismatch_rejected() {
602 let store = VectorStore::in_memory();
603 store.create_table(schema("t", 3)).unwrap();
604 assert!(matches!(
605 store.upsert("t", b"a".to_vec(), &[1.0, 2.0], HashMap::new()),
606 Err(StoreError::DimensionMismatch { .. })
607 ));
608 }
609
610 #[test]
611 fn search_returns_nearest_first() {
612 let store = VectorStore::in_memory();
613 store.create_table(schema("t", 2)).unwrap();
614 for (k, v) in [
615 (&b"origin"[..], [0.0_f32, 0.0]),
616 (&b"unit_x"[..], [1.0, 0.0]),
617 (&b"unit_y"[..], [0.0, 1.0]),
618 (&b"diag"[..], [1.0, 1.0]),
619 ] {
620 store.upsert("t", k.to_vec(), &v, HashMap::new()).unwrap();
621 }
622 let res = store.search("t", &[0.05, 0.05], 1, None).unwrap();
623 assert_eq!(res.len(), 1);
624 assert_eq!(res[0].0.key, b"origin");
625 }
626
627 #[test]
628 fn rehydrate_rebuilds_index() {
629 let backend = Arc::new(MemoryBackend::new());
630 let store = VectorStore::open(backend.clone()).unwrap();
631 store.create_table(schema("t", 2)).unwrap();
632 for i in 0..10_u8 {
633 let k = format!("k{i}").into_bytes();
634 let v = [f32::from(i), f32::from(i) * 2.0];
635 store.upsert("t", k, &v, HashMap::new()).unwrap();
636 }
637 drop(store);
639 let reopened = VectorStore::open(backend).unwrap();
640 let stats = reopened.stats("t").unwrap();
641 assert_eq!(stats.live_rows, 10);
642 let res = reopened.search("t", &[3.0, 6.0], 1, None).unwrap();
643 assert_eq!(res[0].0.key, b"k3");
644 }
645
646 #[test]
647 fn stats_reports_live_rows() {
648 let store = VectorStore::in_memory();
649 store.create_table(schema("t", 2)).unwrap();
650 store
651 .upsert("t", b"a".to_vec(), &[1.0, 2.0], HashMap::new())
652 .unwrap();
653 store
654 .upsert("t", b"b".to_vec(), &[3.0, 4.0], HashMap::new())
655 .unwrap();
656 let s = store.stats("t").unwrap();
657 assert_eq!(s.live_rows, 2);
658 assert_eq!(s.tracked_rows, 2);
659 }
660}