1use std::collections::HashMap;
13use std::sync::Arc;
14
15use serde_json::Value;
16
17use crate::storage::{RowKey, StoreError, TableSchema, TableStats, VectorRow, VectorStore};
18
19#[derive(Clone)]
26pub struct Engine {
27 store: Arc<VectorStore>,
28 table: String,
29}
30
31impl std::fmt::Debug for Engine {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("Engine")
34 .field("table", &self.table)
35 .finish_non_exhaustive()
36 }
37}
38
39impl Engine {
40 pub fn in_memory(schema: TableSchema) -> Result<Self, StoreError> {
52 let store = Arc::new(VectorStore::in_memory());
53 let table = schema.name.clone();
54 store.create_table(schema)?;
55 Ok(Self { store, table })
56 }
57
58 #[must_use]
62 pub fn with_store(store: Arc<VectorStore>, table: String) -> Self {
63 Self { store, table }
64 }
65
66 #[must_use]
68 pub fn table_name(&self) -> &str {
69 &self.table
70 }
71
72 #[must_use]
74 pub fn store(&self) -> &Arc<VectorStore> {
75 &self.store
76 }
77
78 pub fn upsert(
85 &self,
86 key: RowKey,
87 vector: &[f32],
88 metadata: HashMap<String, Value>,
89 ) -> Result<(), StoreError> {
90 self.store.upsert(&self.table, key, vector, metadata)
91 }
92
93 pub fn get(&self, key: &[u8]) -> Result<Option<VectorRow>, StoreError> {
99 self.store.get(&self.table, key)
100 }
101
102 pub fn delete(&self, key: &[u8]) -> Result<bool, StoreError> {
109 self.store.delete(&self.table, key)
110 }
111
112 pub fn search(
119 &self,
120 query: &[f32],
121 k: usize,
122 ef: Option<usize>,
123 ) -> Result<Vec<(VectorRow, f32)>, StoreError> {
124 self.store.search(&self.table, query, k, ef)
125 }
126
127 pub fn stats(&self) -> Result<TableStats, StoreError> {
134 self.store.stats(&self.table)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::distance::Distance;
142 use crate::encoding::Codec;
143 use crate::index::HnswParams;
144
145 fn schema(name: &str, dim: u16) -> TableSchema {
146 TableSchema {
147 name: name.to_string(),
148 dim,
149 codec: Codec::Int8Quantized,
150 distance: Distance::Euclidean,
151 hnsw: HnswParams::default(),
152 }
153 }
154
155 #[test]
156 fn engine_round_trips_a_row() {
157 let engine = Engine::in_memory(schema("t", 3)).unwrap();
158 engine
159 .upsert(b"a".to_vec(), &[1.0, 2.0, 3.0], HashMap::new())
160 .unwrap();
161 let row = engine.get(b"a").unwrap().expect("row present");
162 assert_eq!(row.key, b"a");
163 assert_eq!(row.vector.dim, 3);
164 assert_eq!(engine.table_name(), "t");
165 let stats = engine.stats().unwrap();
166 assert_eq!(stats.live_rows, 1);
167 }
168
169 #[test]
170 fn engine_search_returns_nearest_first() {
171 let engine = Engine::in_memory(schema("t", 2)).unwrap();
172 for (k, v) in [
173 (&b"origin"[..], [0.0_f32, 0.0]),
174 (&b"unit_x"[..], [1.0, 0.0]),
175 (&b"unit_y"[..], [0.0, 1.0]),
176 ] {
177 engine.upsert(k.to_vec(), &v, HashMap::new()).unwrap();
178 }
179 let res = engine.search(&[0.05, 0.05], 1, None).unwrap();
180 assert_eq!(res.len(), 1);
181 assert_eq!(res[0].0.key, b"origin");
182 }
183}