1use crate::{
7 DistanceMetric, SearchResult, VectorError, VectorRecord, VectorResult, VectorStore,
8};
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::RwLock;
12
13pub struct InMemoryStore {
44 records: RwLock<HashMap<String, VectorRecord>>,
45 dimension: usize,
46 metric: DistanceMetric,
47}
48
49impl InMemoryStore {
50 pub fn new(dimension: usize) -> Self {
54 Self {
55 records: RwLock::new(HashMap::new()),
56 dimension,
57 metric: DistanceMetric::Cosine,
58 }
59 }
60
61 pub fn with_config(dimension: usize, metric: DistanceMetric) -> Self {
63 Self {
64 records: RwLock::new(HashMap::new()),
65 dimension,
66 metric,
67 }
68 }
69
70 fn compute_score(&self, a: &[f32], b: &[f32]) -> f32 {
72 match self.metric {
73 DistanceMetric::Cosine => crate::util::cosine_similarity(a, b),
74 DistanceMetric::Euclidean => {
75 let dist = crate::util::euclidean_distance(a, b);
77 1.0 / (1.0 + dist)
78 }
79 DistanceMetric::DotProduct => crate::util::dot_product(a, b),
80 }
81 }
82}
83
84#[async_trait]
85impl VectorStore for InMemoryStore {
86 fn name(&self) -> &str {
87 "in-memory"
88 }
89
90 fn dimension(&self) -> usize {
91 self.dimension
92 }
93
94 fn metric(&self) -> DistanceMetric {
95 self.metric
96 }
97
98 async fn upsert(&self, records: Vec<VectorRecord>) -> VectorResult<()> {
99 let mut store = self.records.write().map_err(|e| {
100 VectorError::Connection(format!("Failed to acquire write lock: {}", e))
101 })?;
102
103 for record in records {
104 if record.vector.len() != self.dimension {
105 return Err(VectorError::DimensionMismatch {
106 expected: self.dimension,
107 actual: record.vector.len(),
108 });
109 }
110 store.insert(record.id.clone(), record);
111 }
112
113 Ok(())
114 }
115
116 async fn search(&self, vector: &[f32], k: usize) -> VectorResult<Vec<SearchResult>> {
117 if vector.len() != self.dimension {
118 return Err(VectorError::DimensionMismatch {
119 expected: self.dimension,
120 actual: vector.len(),
121 });
122 }
123
124 let store = self.records.read().map_err(|e| {
125 VectorError::Connection(format!("Failed to acquire read lock: {}", e))
126 })?;
127
128 let mut scored: Vec<_> = store
130 .values()
131 .map(|record| {
132 let score = self.compute_score(vector, &record.vector);
133 (record, score)
134 })
135 .collect();
136
137 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
139
140 let results = scored
142 .into_iter()
143 .take(k)
144 .map(|(record, score)| SearchResult {
145 id: record.id.clone(),
146 score,
147 vector: Some(record.vector.clone()),
148 metadata: record.metadata.clone(),
149 })
150 .collect();
151
152 Ok(results)
153 }
154
155 async fn search_with_filter(
156 &self,
157 vector: &[f32],
158 k: usize,
159 filter: &HashMap<String, serde_json::Value>,
160 ) -> VectorResult<Vec<SearchResult>> {
161 if vector.len() != self.dimension {
162 return Err(VectorError::DimensionMismatch {
163 expected: self.dimension,
164 actual: vector.len(),
165 });
166 }
167
168 let store = self.records.read().map_err(|e| {
169 VectorError::Connection(format!("Failed to acquire read lock: {}", e))
170 })?;
171
172 let mut scored: Vec<_> = store
174 .values()
175 .filter(|record| {
176 filter.iter().all(|(key, value)| {
178 record.metadata.get(key).map_or(false, |v| v == value)
179 })
180 })
181 .map(|record| {
182 let score = self.compute_score(vector, &record.vector);
183 (record, score)
184 })
185 .collect();
186
187 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
189
190 let results = scored
192 .into_iter()
193 .take(k)
194 .map(|(record, score)| SearchResult {
195 id: record.id.clone(),
196 score,
197 vector: Some(record.vector.clone()),
198 metadata: record.metadata.clone(),
199 })
200 .collect();
201
202 Ok(results)
203 }
204
205 async fn get(&self, id: &str) -> VectorResult<Option<VectorRecord>> {
206 let store = self.records.read().map_err(|e| {
207 VectorError::Connection(format!("Failed to acquire read lock: {}", e))
208 })?;
209
210 Ok(store.get(id).cloned())
211 }
212
213 async fn get_batch(&self, ids: &[&str]) -> VectorResult<Vec<VectorRecord>> {
214 let store = self.records.read().map_err(|e| {
215 VectorError::Connection(format!("Failed to acquire read lock: {}", e))
216 })?;
217
218 let records: Vec<_> = ids
219 .iter()
220 .filter_map(|id| store.get(*id).cloned())
221 .collect();
222
223 Ok(records)
224 }
225
226 async fn delete(&self, id: &str) -> VectorResult<()> {
227 let mut store = self.records.write().map_err(|e| {
228 VectorError::Connection(format!("Failed to acquire write lock: {}", e))
229 })?;
230
231 store.remove(id);
232 Ok(())
233 }
234
235 async fn delete_batch(&self, ids: &[&str]) -> VectorResult<()> {
236 let mut store = self.records.write().map_err(|e| {
237 VectorError::Connection(format!("Failed to acquire write lock: {}", e))
238 })?;
239
240 for id in ids {
241 store.remove(*id);
242 }
243
244 Ok(())
245 }
246
247 async fn count(&self) -> VectorResult<usize> {
248 let store = self.records.read().map_err(|e| {
249 VectorError::Connection(format!("Failed to acquire read lock: {}", e))
250 })?;
251
252 Ok(store.len())
253 }
254
255 async fn clear(&self) -> VectorResult<()> {
256 let mut store = self.records.write().map_err(|e| {
257 VectorError::Connection(format!("Failed to acquire write lock: {}", e))
258 })?;
259
260 store.clear();
261 Ok(())
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[tokio::test]
270 async fn test_upsert_and_search() {
271 let store = InMemoryStore::new(3);
272
273 store.upsert(vec![
275 VectorRecord::new("a", vec![1.0, 0.0, 0.0]),
276 VectorRecord::new("b", vec![0.0, 1.0, 0.0]),
277 VectorRecord::new("c", vec![0.7, 0.7, 0.0]),
278 ]).await.unwrap();
279
280 let results = store.search(&[1.0, 0.0, 0.0], 2).await.unwrap();
282 assert_eq!(results.len(), 2);
283 assert_eq!(results[0].id, "a");
284 assert!((results[0].score - 1.0).abs() < 1e-6);
285 }
286
287 #[tokio::test]
288 async fn test_get_and_delete() {
289 let store = InMemoryStore::new(2);
290
291 store.upsert(vec![
292 VectorRecord::new("x", vec![1.0, 0.0]),
293 ]).await.unwrap();
294
295 let record = store.get("x").await.unwrap();
297 assert!(record.is_some());
298 assert_eq!(record.unwrap().id, "x");
299
300 store.delete("x").await.unwrap();
302
303 let record = store.get("x").await.unwrap();
305 assert!(record.is_none());
306 }
307
308 #[tokio::test]
309 async fn test_search_with_filter() {
310 let store = InMemoryStore::new(2);
311
312 store.upsert(vec![
313 VectorRecord::new("a", vec![1.0, 0.0]).with_metadata("type", "doc"),
314 VectorRecord::new("b", vec![1.0, 0.0]).with_metadata("type", "query"),
315 VectorRecord::new("c", vec![1.0, 0.0]).with_metadata("type", "doc"),
316 ]).await.unwrap();
317
318 let mut filter = HashMap::new();
319 filter.insert("type".to_string(), serde_json::json!("doc"));
320
321 let results = store.search_with_filter(&[1.0, 0.0], 10, &filter).await.unwrap();
322 assert_eq!(results.len(), 2);
323 assert!(results.iter().all(|r| r.id != "b"));
324 }
325
326 #[tokio::test]
327 async fn test_count_and_clear() {
328 let store = InMemoryStore::new(2);
329
330 store.upsert(vec![
331 VectorRecord::new("a", vec![1.0, 0.0]),
332 VectorRecord::new("b", vec![0.0, 1.0]),
333 ]).await.unwrap();
334
335 assert_eq!(store.count().await.unwrap(), 2);
336
337 store.clear().await.unwrap();
338 assert_eq!(store.count().await.unwrap(), 0);
339 }
340
341 #[tokio::test]
342 async fn test_dimension_mismatch() {
343 let store = InMemoryStore::new(3);
344
345 let result = store.upsert(vec![
346 VectorRecord::new("a", vec![1.0, 0.0]), ]).await;
348
349 assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
350 }
351
352 #[tokio::test]
353 async fn test_euclidean_metric() {
354 let store = InMemoryStore::with_config(2, DistanceMetric::Euclidean);
355
356 store.upsert(vec![
357 VectorRecord::new("close", vec![0.1, 0.0]),
358 VectorRecord::new("far", vec![10.0, 0.0]),
359 ]).await.unwrap();
360
361 let results = store.search(&[0.0, 0.0], 2).await.unwrap();
362 assert_eq!(results[0].id, "close"); }
364}