ceylon_next/memory/vector/
local_store.rs

1//! In-memory vector storage with cosine similarity search.
2
3use super::{SearchResult, VectorEntry, VectorStore};
4use crate::memory::vector::utils::cosine_similarity;
5use async_trait::async_trait;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// In-memory vector store using brute-force cosine similarity search.
10///
11/// This store keeps all vectors in memory and performs exact nearest neighbor
12/// search using cosine similarity. While not as efficient as approximate methods
13/// like HNSW for large datasets, it's simple, fast for small to medium datasets,
14/// and provides exact results.
15///
16/// # Features
17///
18/// - Fast in-memory storage with RwLock for thread-safety
19/// - Exact nearest neighbor search using cosine similarity
20/// - Optional similarity threshold filtering
21/// - Agent-specific filtering
22/// - No external dependencies
23///
24/// # Example
25///
26/// ```rust
27/// use ceylon_next::memory::vector::{LocalVectorStore, VectorEntry, VectorStore};
28/// use std::sync::Arc;
29///
30/// #[tokio::main]
31/// async fn main() {
32///     let store = LocalVectorStore::new(384);
33///
34///     let entry = VectorEntry::new(
35///         "memory-1".to_string(),
36///         "agent-1".to_string(),
37///         "Hello world".to_string(),
38///         vec![0.1; 384],
39///         None,
40///     );
41///
42///     let id = store.store(entry).await.unwrap();
43///     println!("Stored vector with ID: {}", id);
44///
45///     // Search for similar vectors
46///     let query = vec![0.1; 384];
47///     let results = store.search(&query, Some("agent-1"), 5, None).await.unwrap();
48///     println!("Found {} similar vectors", results.len());
49/// }
50/// ```
51pub struct LocalVectorStore {
52    /// The dimensionality of vectors in this store
53    dimension: usize,
54    /// The stored vectors
55    vectors: Arc<RwLock<Vec<VectorEntry>>>,
56}
57
58impl LocalVectorStore {
59    /// Creates a new local vector store.
60    ///
61    /// # Arguments
62    ///
63    /// * `dimension` - The dimensionality of vectors to store
64    ///
65    /// # Example
66    ///
67    /// ```rust
68    /// use ceylon_next::memory::vector::LocalVectorStore;
69    ///
70    /// let store = LocalVectorStore::new(384);
71    /// ```
72    pub fn new(dimension: usize) -> Self {
73        Self {
74            dimension,
75            vectors: Arc::new(RwLock::new(Vec::new())),
76        }
77    }
78
79    /// Creates a new local vector store with pre-allocated capacity.
80    ///
81    /// # Arguments
82    ///
83    /// * `dimension` - The dimensionality of vectors to store
84    /// * `capacity` - Initial capacity for the vector storage
85    pub fn with_capacity(dimension: usize, capacity: usize) -> Self {
86        Self {
87            dimension,
88            vectors: Arc::new(RwLock::new(Vec::with_capacity(capacity))),
89        }
90    }
91}
92
93#[async_trait]
94impl VectorStore for LocalVectorStore {
95    async fn store(&self, entry: VectorEntry) -> Result<String, String> {
96        // Validate vector dimension
97        if entry.vector.len() != self.dimension {
98            return Err(format!(
99                "Vector dimension mismatch: expected {}, got {}",
100                self.dimension,
101                entry.vector.len()
102            ));
103        }
104
105        let id = entry.id.clone();
106        let mut vectors = self.vectors.write().await;
107        vectors.push(entry);
108
109        Ok(id)
110    }
111
112    async fn store_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<String>, String> {
113        // Validate all vectors first
114        for entry in &entries {
115            if entry.vector.len() != self.dimension {
116                return Err(format!(
117                    "Vector dimension mismatch: expected {}, got {}",
118                    self.dimension,
119                    entry.vector.len()
120                ));
121            }
122        }
123
124        let ids: Vec<String> = entries.iter().map(|e| e.id.clone()).collect();
125        let mut vectors = self.vectors.write().await;
126        vectors.extend(entries);
127
128        Ok(ids)
129    }
130
131    async fn get(&self, id: &str) -> Result<Option<VectorEntry>, String> {
132        let vectors = self.vectors.read().await;
133        Ok(vectors.iter().find(|v| v.id == id).cloned())
134    }
135
136    async fn search(
137        &self,
138        query_vector: &[f32],
139        agent_id: Option<&str>,
140        limit: usize,
141        threshold: Option<f32>,
142    ) -> Result<Vec<SearchResult>, String> {
143        // Validate query vector dimension
144        if query_vector.len() != self.dimension {
145            return Err(format!(
146                "Query vector dimension mismatch: expected {}, got {}",
147                self.dimension,
148                query_vector.len()
149            ));
150        }
151
152        let vectors = self.vectors.read().await;
153
154        // Filter by agent if specified
155        let filtered: Vec<&VectorEntry> = if let Some(aid) = agent_id {
156            vectors.iter().filter(|v| v.agent_id == aid).collect()
157        } else {
158            vectors.iter().collect()
159        };
160
161        // Compute similarities for all vectors
162        let mut results: Vec<SearchResult> = Vec::new();
163
164        for entry in filtered {
165            match cosine_similarity(query_vector, &entry.vector) {
166                Ok(score) => {
167                    // Apply threshold if specified
168                    if let Some(min_score) = threshold {
169                        if score < min_score {
170                            continue;
171                        }
172                    }
173
174                    results.push(SearchResult {
175                        entry: entry.clone(),
176                        score,
177                    });
178                }
179                Err(e) => {
180                    log::warn!("Failed to compute similarity for vector {}: {}", entry.id, e);
181                    continue;
182                }
183            }
184        }
185
186        // Sort by score descending (highest similarity first)
187        results.sort_by(|a, b| {
188            b.score
189                .partial_cmp(&a.score)
190                .unwrap_or(std::cmp::Ordering::Equal)
191        });
192
193        // Limit results
194        results.truncate(limit);
195
196        Ok(results)
197    }
198
199    async fn clear_agent_vectors(&self, agent_id: &str) -> Result<(), String> {
200        let mut vectors = self.vectors.write().await;
201        vectors.retain(|v| v.agent_id != agent_id);
202        Ok(())
203    }
204
205    async fn count(&self) -> Result<usize, String> {
206        let vectors = self.vectors.read().await;
207        Ok(vectors.len())
208    }
209
210    fn dimension(&self) -> usize {
211        self.dimension
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[tokio::test]
220    async fn test_store_and_get() {
221        let store = LocalVectorStore::new(3);
222        let entry = VectorEntry::new(
223            "mem-1".to_string(),
224            "agent-1".to_string(),
225            "test".to_string(),
226            vec![1.0, 2.0, 3.0],
227            None,
228        );
229
230        let id = entry.id.clone();
231        store.store(entry).await.unwrap();
232
233        let retrieved = store.get(&id).await.unwrap();
234        assert!(retrieved.is_some());
235        assert_eq!(retrieved.unwrap().text, "test");
236    }
237
238    #[tokio::test]
239    async fn test_dimension_validation() {
240        let store = LocalVectorStore::new(3);
241        let entry = VectorEntry::new(
242            "mem-1".to_string(),
243            "agent-1".to_string(),
244            "test".to_string(),
245            vec![1.0, 2.0], // Wrong dimension
246            None,
247        );
248
249        let result = store.store(entry).await;
250        assert!(result.is_err());
251    }
252
253    #[tokio::test]
254    async fn test_search() {
255        let store = LocalVectorStore::new(3);
256
257        // Store some vectors
258        let entries = vec![
259            VectorEntry::new(
260                "mem-1".to_string(),
261                "agent-1".to_string(),
262                "cat".to_string(),
263                vec![1.0, 0.0, 0.0],
264                None,
265            ),
266            VectorEntry::new(
267                "mem-2".to_string(),
268                "agent-1".to_string(),
269                "dog".to_string(),
270                vec![0.9, 0.1, 0.0],
271                None,
272            ),
273            VectorEntry::new(
274                "mem-3".to_string(),
275                "agent-1".to_string(),
276                "car".to_string(),
277                vec![0.0, 1.0, 0.0],
278                None,
279            ),
280        ];
281
282        for entry in entries {
283            store.store(entry).await.unwrap();
284        }
285
286        // Search for similar to "cat"
287        let query = vec![1.0, 0.0, 0.0];
288        let results = store.search(&query, Some("agent-1"), 2, None).await.unwrap();
289
290        assert_eq!(results.len(), 2);
291        assert_eq!(results[0].entry.text, "cat"); // Most similar
292        assert_eq!(results[1].entry.text, "dog"); // Second most similar
293        assert!(results[0].score > results[1].score);
294    }
295
296    #[tokio::test]
297    async fn test_search_with_threshold() {
298        let store = LocalVectorStore::new(2);
299
300        store
301            .store(VectorEntry::new(
302                "mem-1".to_string(),
303                "agent-1".to_string(),
304                "similar".to_string(),
305                vec![1.0, 0.0],
306                None,
307            ))
308            .await
309            .unwrap();
310
311        store
312            .store(VectorEntry::new(
313                "mem-2".to_string(),
314                "agent-1".to_string(),
315                "different".to_string(),
316                vec![0.0, 1.0],
317                None,
318            ))
319            .await
320            .unwrap();
321
322        let query = vec![1.0, 0.0];
323        let results = store
324            .search(&query, Some("agent-1"), 10, Some(0.5))
325            .await
326            .unwrap();
327
328        // Should only return the similar vector
329        assert_eq!(results.len(), 1);
330        assert_eq!(results[0].entry.text, "similar");
331    }
332
333    #[tokio::test]
334    async fn test_agent_filtering() {
335        let store = LocalVectorStore::new(2);
336
337        store
338            .store(VectorEntry::new(
339                "mem-1".to_string(),
340                "agent-1".to_string(),
341                "agent1".to_string(),
342                vec![1.0, 0.0],
343                None,
344            ))
345            .await
346            .unwrap();
347
348        store
349            .store(VectorEntry::new(
350                "mem-2".to_string(),
351                "agent-2".to_string(),
352                "agent2".to_string(),
353                vec![1.0, 0.0],
354                None,
355            ))
356            .await
357            .unwrap();
358
359        let query = vec![1.0, 0.0];
360        let results = store.search(&query, Some("agent-1"), 10, None).await.unwrap();
361
362        assert_eq!(results.len(), 1);
363        assert_eq!(results[0].entry.agent_id, "agent-1");
364    }
365
366    #[tokio::test]
367    async fn test_clear_agent_vectors() {
368        let store = LocalVectorStore::new(2);
369
370        store
371            .store(VectorEntry::new(
372                "mem-1".to_string(),
373                "agent-1".to_string(),
374                "test".to_string(),
375                vec![1.0, 0.0],
376                None,
377            ))
378            .await
379            .unwrap();
380
381        store
382            .store(VectorEntry::new(
383                "mem-2".to_string(),
384                "agent-2".to_string(),
385                "test".to_string(),
386                vec![1.0, 0.0],
387                None,
388            ))
389            .await
390            .unwrap();
391
392        assert_eq!(store.count().await.unwrap(), 2);
393
394        store.clear_agent_vectors("agent-1").await.unwrap();
395
396        assert_eq!(store.count().await.unwrap(), 1);
397
398        let query = vec![1.0, 0.0];
399        let results = store.search(&query, None, 10, None).await.unwrap();
400        assert_eq!(results.len(), 1);
401        assert_eq!(results[0].entry.agent_id, "agent-2");
402    }
403
404    #[tokio::test]
405    async fn test_store_batch() {
406        let store = LocalVectorStore::new(2);
407
408        let entries = vec![
409            VectorEntry::new(
410                "mem-1".to_string(),
411                "agent-1".to_string(),
412                "test1".to_string(),
413                vec![1.0, 0.0],
414                None,
415            ),
416            VectorEntry::new(
417                "mem-2".to_string(),
418                "agent-1".to_string(),
419                "test2".to_string(),
420                vec![0.0, 1.0],
421                None,
422            ),
423        ];
424
425        let ids = store.store_batch(entries).await.unwrap();
426        assert_eq!(ids.len(), 2);
427        assert_eq!(store.count().await.unwrap(), 2);
428    }
429}