Skip to main content

phago_vectors/
memory.rs

1//! In-memory vector store implementation.
2//!
3//! This module provides a simple in-memory vector store that uses brute-force
4//! search. It's useful for testing and small-scale applications.
5
6use crate::{
7    DistanceMetric, SearchResult, VectorError, VectorRecord, VectorResult, VectorStore,
8};
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::sync::RwLock;
12
13/// In-memory vector store using brute-force search.
14///
15/// This is the simplest implementation, suitable for:
16/// - Testing and development
17/// - Small datasets (< 10,000 vectors)
18/// - Prototyping before moving to a production database
19///
20/// # Example
21///
22/// ```rust
23/// use phago_vectors::{InMemoryStore, VectorStore, VectorRecord};
24///
25/// #[tokio::main]
26/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
27///     let store = InMemoryStore::new(3);
28///
29///     // Insert records
30///     store.upsert(vec![
31///         VectorRecord::new("a", vec![1.0, 0.0, 0.0]),
32///         VectorRecord::new("b", vec![0.0, 1.0, 0.0]),
33///         VectorRecord::new("c", vec![0.7, 0.7, 0.0]),
34///     ]).await?;
35///
36///     // Search
37///     let results = store.search(&[1.0, 0.0, 0.0], 2).await?;
38///     assert_eq!(results[0].id, "a");
39///
40///     Ok(())
41/// }
42/// ```
43pub struct InMemoryStore {
44    records: RwLock<HashMap<String, VectorRecord>>,
45    dimension: usize,
46    metric: DistanceMetric,
47}
48
49impl InMemoryStore {
50    /// Create a new in-memory store with the specified dimension.
51    ///
52    /// Uses cosine similarity by default.
53    pub fn new(dimension: usize) -> Self {
54        Self {
55            records: RwLock::new(HashMap::new()),
56            dimension,
57            metric: DistanceMetric::Cosine,
58        }
59    }
60
61    /// Create a new in-memory store with a specific distance metric.
62    pub fn with_config(dimension: usize, metric: DistanceMetric) -> Self {
63        Self {
64            records: RwLock::new(HashMap::new()),
65            dimension,
66            metric,
67        }
68    }
69
70    /// Compute similarity/distance between two vectors.
71    fn compute_score(&self, a: &[f32], b: &[f32]) -> f32 {
72        match self.metric {
73            DistanceMetric::Cosine => crate::util::cosine_similarity(a, b),
74            DistanceMetric::Euclidean => {
75                // Convert distance to similarity (higher is better)
76                let dist = crate::util::euclidean_distance(a, b);
77                1.0 / (1.0 + dist)
78            }
79            DistanceMetric::DotProduct => crate::util::dot_product(a, b),
80        }
81    }
82}
83
84#[async_trait]
85impl VectorStore for InMemoryStore {
86    fn name(&self) -> &str {
87        "in-memory"
88    }
89
90    fn dimension(&self) -> usize {
91        self.dimension
92    }
93
94    fn metric(&self) -> DistanceMetric {
95        self.metric
96    }
97
98    async fn upsert(&self, records: Vec<VectorRecord>) -> VectorResult<()> {
99        let mut store = self.records.write().map_err(|e| {
100            VectorError::Connection(format!("Failed to acquire write lock: {}", e))
101        })?;
102
103        for record in records {
104            if record.vector.len() != self.dimension {
105                return Err(VectorError::DimensionMismatch {
106                    expected: self.dimension,
107                    actual: record.vector.len(),
108                });
109            }
110            store.insert(record.id.clone(), record);
111        }
112
113        Ok(())
114    }
115
116    async fn search(&self, vector: &[f32], k: usize) -> VectorResult<Vec<SearchResult>> {
117        if vector.len() != self.dimension {
118            return Err(VectorError::DimensionMismatch {
119                expected: self.dimension,
120                actual: vector.len(),
121            });
122        }
123
124        let store = self.records.read().map_err(|e| {
125            VectorError::Connection(format!("Failed to acquire read lock: {}", e))
126        })?;
127
128        // Compute scores for all records
129        let mut scored: Vec<_> = store
130            .values()
131            .map(|record| {
132                let score = self.compute_score(vector, &record.vector);
133                (record, score)
134            })
135            .collect();
136
137        // Sort by score (descending)
138        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
139
140        // Take top k
141        let results = scored
142            .into_iter()
143            .take(k)
144            .map(|(record, score)| SearchResult {
145                id: record.id.clone(),
146                score,
147                vector: Some(record.vector.clone()),
148                metadata: record.metadata.clone(),
149            })
150            .collect();
151
152        Ok(results)
153    }
154
155    async fn search_with_filter(
156        &self,
157        vector: &[f32],
158        k: usize,
159        filter: &HashMap<String, serde_json::Value>,
160    ) -> VectorResult<Vec<SearchResult>> {
161        if vector.len() != self.dimension {
162            return Err(VectorError::DimensionMismatch {
163                expected: self.dimension,
164                actual: vector.len(),
165            });
166        }
167
168        let store = self.records.read().map_err(|e| {
169            VectorError::Connection(format!("Failed to acquire read lock: {}", e))
170        })?;
171
172        // Filter and compute scores
173        let mut scored: Vec<_> = store
174            .values()
175            .filter(|record| {
176                // Check if all filter conditions match
177                filter.iter().all(|(key, value)| {
178                    record.metadata.get(key).map_or(false, |v| v == value)
179                })
180            })
181            .map(|record| {
182                let score = self.compute_score(vector, &record.vector);
183                (record, score)
184            })
185            .collect();
186
187        // Sort by score (descending)
188        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
189
190        // Take top k
191        let results = scored
192            .into_iter()
193            .take(k)
194            .map(|(record, score)| SearchResult {
195                id: record.id.clone(),
196                score,
197                vector: Some(record.vector.clone()),
198                metadata: record.metadata.clone(),
199            })
200            .collect();
201
202        Ok(results)
203    }
204
205    async fn get(&self, id: &str) -> VectorResult<Option<VectorRecord>> {
206        let store = self.records.read().map_err(|e| {
207            VectorError::Connection(format!("Failed to acquire read lock: {}", e))
208        })?;
209
210        Ok(store.get(id).cloned())
211    }
212
213    async fn get_batch(&self, ids: &[&str]) -> VectorResult<Vec<VectorRecord>> {
214        let store = self.records.read().map_err(|e| {
215            VectorError::Connection(format!("Failed to acquire read lock: {}", e))
216        })?;
217
218        let records: Vec<_> = ids
219            .iter()
220            .filter_map(|id| store.get(*id).cloned())
221            .collect();
222
223        Ok(records)
224    }
225
226    async fn delete(&self, id: &str) -> VectorResult<()> {
227        let mut store = self.records.write().map_err(|e| {
228            VectorError::Connection(format!("Failed to acquire write lock: {}", e))
229        })?;
230
231        store.remove(id);
232        Ok(())
233    }
234
235    async fn delete_batch(&self, ids: &[&str]) -> VectorResult<()> {
236        let mut store = self.records.write().map_err(|e| {
237            VectorError::Connection(format!("Failed to acquire write lock: {}", e))
238        })?;
239
240        for id in ids {
241            store.remove(*id);
242        }
243
244        Ok(())
245    }
246
247    async fn count(&self) -> VectorResult<usize> {
248        let store = self.records.read().map_err(|e| {
249            VectorError::Connection(format!("Failed to acquire read lock: {}", e))
250        })?;
251
252        Ok(store.len())
253    }
254
255    async fn clear(&self) -> VectorResult<()> {
256        let mut store = self.records.write().map_err(|e| {
257            VectorError::Connection(format!("Failed to acquire write lock: {}", e))
258        })?;
259
260        store.clear();
261        Ok(())
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[tokio::test]
270    async fn test_upsert_and_search() {
271        let store = InMemoryStore::new(3);
272
273        // Insert records
274        store.upsert(vec![
275            VectorRecord::new("a", vec![1.0, 0.0, 0.0]),
276            VectorRecord::new("b", vec![0.0, 1.0, 0.0]),
277            VectorRecord::new("c", vec![0.7, 0.7, 0.0]),
278        ]).await.unwrap();
279
280        // Search for vector close to 'a'
281        let results = store.search(&[1.0, 0.0, 0.0], 2).await.unwrap();
282        assert_eq!(results.len(), 2);
283        assert_eq!(results[0].id, "a");
284        assert!((results[0].score - 1.0).abs() < 1e-6);
285    }
286
287    #[tokio::test]
288    async fn test_get_and_delete() {
289        let store = InMemoryStore::new(2);
290
291        store.upsert(vec![
292            VectorRecord::new("x", vec![1.0, 0.0]),
293        ]).await.unwrap();
294
295        // Get
296        let record = store.get("x").await.unwrap();
297        assert!(record.is_some());
298        assert_eq!(record.unwrap().id, "x");
299
300        // Delete
301        store.delete("x").await.unwrap();
302
303        // Verify deleted
304        let record = store.get("x").await.unwrap();
305        assert!(record.is_none());
306    }
307
308    #[tokio::test]
309    async fn test_search_with_filter() {
310        let store = InMemoryStore::new(2);
311
312        store.upsert(vec![
313            VectorRecord::new("a", vec![1.0, 0.0]).with_metadata("type", "doc"),
314            VectorRecord::new("b", vec![1.0, 0.0]).with_metadata("type", "query"),
315            VectorRecord::new("c", vec![1.0, 0.0]).with_metadata("type", "doc"),
316        ]).await.unwrap();
317
318        let mut filter = HashMap::new();
319        filter.insert("type".to_string(), serde_json::json!("doc"));
320
321        let results = store.search_with_filter(&[1.0, 0.0], 10, &filter).await.unwrap();
322        assert_eq!(results.len(), 2);
323        assert!(results.iter().all(|r| r.id != "b"));
324    }
325
326    #[tokio::test]
327    async fn test_count_and_clear() {
328        let store = InMemoryStore::new(2);
329
330        store.upsert(vec![
331            VectorRecord::new("a", vec![1.0, 0.0]),
332            VectorRecord::new("b", vec![0.0, 1.0]),
333        ]).await.unwrap();
334
335        assert_eq!(store.count().await.unwrap(), 2);
336
337        store.clear().await.unwrap();
338        assert_eq!(store.count().await.unwrap(), 0);
339    }
340
341    #[tokio::test]
342    async fn test_dimension_mismatch() {
343        let store = InMemoryStore::new(3);
344
345        let result = store.upsert(vec![
346            VectorRecord::new("a", vec![1.0, 0.0]), // Wrong dimension
347        ]).await;
348
349        assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
350    }
351
352    #[tokio::test]
353    async fn test_euclidean_metric() {
354        let store = InMemoryStore::with_config(2, DistanceMetric::Euclidean);
355
356        store.upsert(vec![
357            VectorRecord::new("close", vec![0.1, 0.0]),
358            VectorRecord::new("far", vec![10.0, 0.0]),
359        ]).await.unwrap();
360
361        let results = store.search(&[0.0, 0.0], 2).await.unwrap();
362        assert_eq!(results[0].id, "close"); // Closer vector should rank first
363    }
364}