use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::distance::Distance;
use crate::encoding::{Codec, EncodedVector, EncodingError};
use crate::index::{HnswIndex, HnswParams, IndexError, NodeId, SearchResult};
use crate::turbo_hnsw::TurboHnswIndex;
use crate::turbo_index::TurboTable;
pub type RowKey = Vec<u8>;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum IndexAlgorithm {
Hnsw,
Flat,
}
impl Default for IndexAlgorithm {
fn default() -> Self {
Self::Hnsw
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TableSchema {
pub name: String,
pub dim: u16,
pub codec: Codec,
pub distance: Distance,
pub hnsw: HnswParams,
#[serde(default)]
pub algorithm: IndexAlgorithm,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VectorRow {
pub key: RowKey,
pub vector: EncodedVector,
pub metadata: HashMap<String, serde_json::Value>,
pub created_at: u64,
pub updated_at: u64,
}
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum StoreError {
#[error("table not found: {0}")]
UnknownTable(String),
#[error("table already exists: {0}")]
TableExists(String),
#[error("dimension mismatch: table {table} expects {expected}, got {got}")]
DimensionMismatch {
table: String,
expected: u16,
got: u16,
},
#[error("row not found in table {table}: {key:?}")]
RowNotFound {
table: String,
key: RowKey,
},
#[error("encoding: {0}")]
Encoding(#[from] EncodingError),
#[error("index: {0}")]
Index(#[from] IndexError),
#[error("backend: {0}")]
Backend(String),
}
pub trait Backend: Send + Sync {
fn put_row(&self, table: &str, key: &[u8], row: &VectorRow) -> Result<(), StoreError>;
fn get_row(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError>;
fn delete_row(&self, table: &str, key: &[u8]) -> Result<bool, StoreError>;
fn for_each_row(&self, table: &str, f: &mut RowVisitor<'_>) -> Result<(), StoreError>;
fn put_schema(&self, schema: &TableSchema) -> Result<(), StoreError>;
fn list_schemas(&self) -> Result<Vec<TableSchema>, StoreError>;
}
pub type RowVisitor<'a> = dyn FnMut(&[u8], &VectorRow) -> Result<(), StoreError> + 'a;
#[derive(Default)]
pub struct MemoryBackend {
rows: RwLock<HashMap<String, HashMap<Vec<u8>, VectorRow>>>,
schemas: RwLock<HashMap<String, TableSchema>>,
}
impl MemoryBackend {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl Backend for MemoryBackend {
fn put_row(&self, table: &str, key: &[u8], row: &VectorRow) -> Result<(), StoreError> {
let mut rows = self.rows.write();
let entry = rows.entry(table.to_string()).or_default();
entry.insert(key.to_vec(), row.clone());
Ok(())
}
fn get_row(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError> {
let rows = self.rows.read();
Ok(rows.get(table).and_then(|m| m.get(key).cloned()))
}
fn delete_row(&self, table: &str, key: &[u8]) -> Result<bool, StoreError> {
let mut rows = self.rows.write();
Ok(rows.get_mut(table).is_some_and(|m| m.remove(key).is_some()))
}
fn for_each_row(&self, table: &str, f: &mut RowVisitor<'_>) -> Result<(), StoreError> {
let rows = self.rows.read();
if let Some(m) = rows.get(table) {
for (k, v) in m {
f(k, v)?;
}
}
Ok(())
}
fn put_schema(&self, schema: &TableSchema) -> Result<(), StoreError> {
self.schemas
.write()
.insert(schema.name.clone(), schema.clone());
Ok(())
}
fn list_schemas(&self) -> Result<Vec<TableSchema>, StoreError> {
Ok(self.schemas.read().values().cloned().collect())
}
}
struct TableState {
schema: TableSchema,
ann: AnnContainer,
key_to_node: HashMap<RowKey, NodeId>,
node_to_key: HashMap<NodeId, RowKey>,
next_node_id: NodeId,
}
enum AnnContainer {
Hnsw(HnswIndex),
TurboFlat(TurboTable),
TurboHnsw2(TurboHnswIndex<2>),
TurboHnsw3(TurboHnswIndex<3>),
TurboHnsw4(TurboHnswIndex<4>),
}
impl AnnContainer {
fn new(schema: &TableSchema) -> Result<Self, StoreError> {
if let Some(bits) = schema.codec.turbovec_bits() {
match schema.algorithm {
IndexAlgorithm::Flat => {
let table = TurboTable::new(schema.distance, schema.dim, bits)?;
Ok(Self::TurboFlat(table))
}
IndexAlgorithm::Hnsw => match bits {
2 => Ok(Self::TurboHnsw2(TurboHnswIndex::<2>::new(
schema.distance,
schema.dim,
schema.hnsw,
)?)),
3 => Ok(Self::TurboHnsw3(TurboHnswIndex::<3>::new(
schema.distance,
schema.dim,
schema.hnsw,
)?)),
4 => Ok(Self::TurboHnsw4(TurboHnswIndex::<4>::new(
schema.distance,
schema.dim,
schema.hnsw,
)?)),
_ => Err(StoreError::Index(IndexError::Empty)),
},
}
} else {
Ok(Self::Hnsw(HnswIndex::new(schema.distance, schema.hnsw)))
}
}
fn insert(&mut self, id: NodeId, vector: Vec<f32>) -> Result<(), IndexError> {
match self {
Self::Hnsw(idx) => idx.insert(id, vector),
Self::TurboFlat(t) => t.insert(id, vector),
Self::TurboHnsw2(t) => t.insert(id, vector),
Self::TurboHnsw3(t) => t.insert(id, vector),
Self::TurboHnsw4(t) => t.insert(id, vector),
}
}
fn delete(&mut self, id: NodeId) -> bool {
match self {
Self::Hnsw(idx) => idx.delete(id),
Self::TurboFlat(t) => t.delete(id),
Self::TurboHnsw2(t) => t.delete(id),
Self::TurboHnsw3(t) => t.delete(id),
Self::TurboHnsw4(t) => t.delete(id),
}
}
fn search(
&self,
query: &[f32],
k: usize,
ef: Option<usize>,
) -> Result<Vec<SearchResult>, IndexError> {
match self {
Self::Hnsw(idx) => idx.search(query, k, ef),
Self::TurboFlat(t) => t.search(query, k, ef),
Self::TurboHnsw2(t) => t.search(query, k, ef),
Self::TurboHnsw3(t) => t.search(query, k, ef),
Self::TurboHnsw4(t) => t.search(query, k, ef),
}
}
fn len(&self) -> usize {
match self {
Self::Hnsw(idx) => idx.len(),
Self::TurboFlat(t) => t.len(),
Self::TurboHnsw2(t) => t.len(),
Self::TurboHnsw3(t) => t.len(),
Self::TurboHnsw4(t) => t.len(),
}
}
}
pub struct VectorStore {
backend: Arc<dyn Backend>,
tables: RwLock<HashMap<String, Arc<parking_lot::Mutex<TableState>>>>,
}
impl VectorStore {
pub fn open(backend: Arc<dyn Backend>) -> Result<Self, StoreError> {
let tables = RwLock::new(HashMap::new());
let store = Self { backend, tables };
let schemas = store.backend.list_schemas()?;
for schema in schemas {
store.rehydrate_table(&schema)?;
}
Ok(store)
}
#[must_use]
pub fn in_memory() -> Self {
Self {
backend: Arc::new(MemoryBackend::new()),
tables: RwLock::new(HashMap::new()),
}
}
pub fn create_table(&self, schema: TableSchema) -> Result<(), StoreError> {
let mut tables = self.tables.write();
if tables.contains_key(&schema.name) {
return Err(StoreError::TableExists(schema.name));
}
let state = TableState {
schema: schema.clone(),
ann: AnnContainer::new(&schema)?,
key_to_node: HashMap::new(),
node_to_key: HashMap::new(),
next_node_id: 1,
};
self.backend.put_schema(&schema)?;
tables.insert(
schema.name.clone(),
Arc::new(parking_lot::Mutex::new(state)),
);
Ok(())
}
pub fn tables(&self) -> Vec<TableSchema> {
self.tables
.read()
.values()
.map(|s| s.lock().schema.clone())
.collect()
}
pub fn upsert(
&self,
table: &str,
key: RowKey,
vector: &[f32],
metadata: HashMap<String, serde_json::Value>,
) -> Result<(), StoreError> {
let state = self.table_state(table)?;
let mut state = state.lock();
let dim = u16::try_from(vector.len()).unwrap_or(u16::MAX);
if dim != state.schema.dim {
return Err(StoreError::DimensionMismatch {
table: table.to_string(),
expected: state.schema.dim,
got: dim,
});
}
let codec_encoder = state.schema.codec.encoder();
let encoded = codec_encoder.encode(vector)?;
let now = now_millis();
let prior = self.backend.get_row(table, &key)?;
let row = VectorRow {
key: key.clone(),
vector: encoded,
metadata,
created_at: prior.as_ref().map_or(now, |r| r.created_at),
updated_at: now,
};
self.backend.put_row(table, &key, &row)?;
if let Some(&old_node) = state.key_to_node.get(&key) {
state.ann.delete(old_node);
state.node_to_key.remove(&old_node);
}
let node_id = state.next_node_id;
state.next_node_id += 1;
state.ann.insert(node_id, vector.to_vec())?;
state.key_to_node.insert(key.clone(), node_id);
state.node_to_key.insert(node_id, key);
Ok(())
}
pub fn get(&self, table: &str, key: &[u8]) -> Result<Option<VectorRow>, StoreError> {
let _ = self.table_state(table)?;
self.backend.get_row(table, key)
}
pub fn delete(&self, table: &str, key: &[u8]) -> Result<bool, StoreError> {
let state = self.table_state(table)?;
let mut state = state.lock();
let removed = self.backend.delete_row(table, key)?;
if let Some(node_id) = state.key_to_node.remove(key) {
state.ann.delete(node_id);
state.node_to_key.remove(&node_id);
}
Ok(removed)
}
pub fn search(
&self,
table: &str,
query: &[f32],
k: usize,
ef: Option<usize>,
) -> Result<Vec<(VectorRow, f32)>, StoreError> {
let state = self.table_state(table)?;
let state = state.lock();
let dim = u16::try_from(query.len()).unwrap_or(u16::MAX);
if dim != state.schema.dim {
return Err(StoreError::DimensionMismatch {
table: table.to_string(),
expected: state.schema.dim,
got: dim,
});
}
let hits: Vec<SearchResult> = state.ann.search(query, k, ef)?;
let mut out = Vec::with_capacity(hits.len());
for hit in hits {
if let Some(key) = state.node_to_key.get(&hit.id) {
if let Some(row) = self.backend.get_row(table, key)? {
out.push((row, hit.score));
}
}
}
Ok(out)
}
pub fn stats(&self, table: &str) -> Result<TableStats, StoreError> {
let state = self.table_state(table)?;
let state = state.lock();
Ok(TableStats {
name: state.schema.name.clone(),
dim: state.schema.dim,
codec: state.schema.codec,
distance: state.schema.distance,
live_rows: state.ann.len(),
tracked_rows: state.key_to_node.len(),
})
}
fn table_state(&self, table: &str) -> Result<Arc<parking_lot::Mutex<TableState>>, StoreError> {
self.tables
.read()
.get(table)
.cloned()
.ok_or_else(|| StoreError::UnknownTable(table.to_string()))
}
fn rehydrate_table(&self, schema: &TableSchema) -> Result<(), StoreError> {
let state = TableState {
schema: schema.clone(),
ann: AnnContainer::new(schema)?,
key_to_node: HashMap::new(),
node_to_key: HashMap::new(),
next_node_id: 1,
};
let cell = Arc::new(parking_lot::Mutex::new(state));
self.tables
.write()
.insert(schema.name.clone(), cell.clone());
let mut guard = cell.lock();
let encoder = guard.schema.codec.encoder();
let mut to_insert: Vec<(NodeId, RowKey, Vec<f32>)> = Vec::new();
let table_name = schema.name.clone();
let mut next = 1u64;
self.backend.for_each_row(&table_name, &mut |k, row| {
let v = encoder.decode(&row.vector)?;
to_insert.push((next, k.to_vec(), v));
next += 1;
Ok(())
})?;
for (node, key, v) in to_insert {
guard.ann.insert(node, v)?;
guard.key_to_node.insert(key.clone(), node);
guard.node_to_key.insert(node, key);
guard.next_node_id = node + 1;
}
Ok(())
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TableStats {
pub name: String,
pub dim: u16,
pub codec: Codec,
pub distance: Distance,
pub live_rows: usize,
pub tracked_rows: usize,
}
fn now_millis() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::HnswParams;
fn schema(name: &str, dim: u16) -> TableSchema {
TableSchema {
name: name.to_string(),
dim,
codec: Codec::Int8Quantized,
distance: Distance::Euclidean,
hnsw: HnswParams::default(),
algorithm: IndexAlgorithm::Hnsw,
}
}
#[test]
fn create_and_list_tables() {
let store = VectorStore::in_memory();
store.create_table(schema("t", 4)).unwrap();
let tables = store.tables();
assert_eq!(tables.len(), 1);
assert_eq!(tables[0].name, "t");
assert_eq!(tables[0].dim, 4);
}
#[test]
fn duplicate_table_rejected() {
let store = VectorStore::in_memory();
store.create_table(schema("t", 4)).unwrap();
assert!(matches!(
store.create_table(schema("t", 4)),
Err(StoreError::TableExists(_))
));
}
#[test]
fn upsert_get_delete_round_trip() {
let store = VectorStore::in_memory();
store.create_table(schema("t", 3)).unwrap();
store
.upsert("t", b"a".to_vec(), &[1.0, 2.0, 3.0], HashMap::new())
.unwrap();
let row = store.get("t", b"a").unwrap().expect("row present");
assert_eq!(row.key, b"a");
assert_eq!(row.vector.dim, 3);
assert!(store.delete("t", b"a").unwrap());
assert!(store.get("t", b"a").unwrap().is_none());
assert!(!store.delete("t", b"a").unwrap());
}
#[test]
fn dimension_mismatch_rejected() {
let store = VectorStore::in_memory();
store.create_table(schema("t", 3)).unwrap();
assert!(matches!(
store.upsert("t", b"a".to_vec(), &[1.0, 2.0], HashMap::new()),
Err(StoreError::DimensionMismatch { .. })
));
}
#[test]
fn search_returns_nearest_first() {
let store = VectorStore::in_memory();
store.create_table(schema("t", 2)).unwrap();
for (k, v) in [
(&b"origin"[..], [0.0_f32, 0.0]),
(&b"unit_x"[..], [1.0, 0.0]),
(&b"unit_y"[..], [0.0, 1.0]),
(&b"diag"[..], [1.0, 1.0]),
] {
store.upsert("t", k.to_vec(), &v, HashMap::new()).unwrap();
}
let res = store.search("t", &[0.05, 0.05], 1, None).unwrap();
assert_eq!(res.len(), 1);
assert_eq!(res[0].0.key, b"origin");
}
#[test]
fn rehydrate_rebuilds_index() {
let backend = Arc::new(MemoryBackend::new());
let store = VectorStore::open(backend.clone()).unwrap();
store.create_table(schema("t", 2)).unwrap();
for i in 0..10_u8 {
let k = format!("k{i}").into_bytes();
let v = [f32::from(i), f32::from(i) * 2.0];
store.upsert("t", k, &v, HashMap::new()).unwrap();
}
drop(store);
let reopened = VectorStore::open(backend).unwrap();
let stats = reopened.stats("t").unwrap();
assert_eq!(stats.live_rows, 10);
let res = reopened.search("t", &[3.0, 6.0], 1, None).unwrap();
assert_eq!(res[0].0.key, b"k3");
}
#[test]
fn stats_reports_live_rows() {
let store = VectorStore::in_memory();
store.create_table(schema("t", 2)).unwrap();
store
.upsert("t", b"a".to_vec(), &[1.0, 2.0], HashMap::new())
.unwrap();
store
.upsert("t", b"b".to_vec(), &[3.0, 4.0], HashMap::new())
.unwrap();
let s = store.stats("t").unwrap();
assert_eq!(s.live_rows, 2);
assert_eq!(s.tracked_rows, 2);
}
}