ceylon_next/memory/vector/
local_store.rs1use super::{SearchResult, VectorEntry, VectorStore};
4use crate::memory::vector::utils::cosine_similarity;
5use async_trait::async_trait;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9pub struct LocalVectorStore {
52 dimension: usize,
54 vectors: Arc<RwLock<Vec<VectorEntry>>>,
56}
57
58impl LocalVectorStore {
59 pub fn new(dimension: usize) -> Self {
73 Self {
74 dimension,
75 vectors: Arc::new(RwLock::new(Vec::new())),
76 }
77 }
78
79 pub fn with_capacity(dimension: usize, capacity: usize) -> Self {
86 Self {
87 dimension,
88 vectors: Arc::new(RwLock::new(Vec::with_capacity(capacity))),
89 }
90 }
91}
92
93#[async_trait]
94impl VectorStore for LocalVectorStore {
95 async fn store(&self, entry: VectorEntry) -> Result<String, String> {
96 if entry.vector.len() != self.dimension {
98 return Err(format!(
99 "Vector dimension mismatch: expected {}, got {}",
100 self.dimension,
101 entry.vector.len()
102 ));
103 }
104
105 let id = entry.id.clone();
106 let mut vectors = self.vectors.write().await;
107 vectors.push(entry);
108
109 Ok(id)
110 }
111
112 async fn store_batch(&self, entries: Vec<VectorEntry>) -> Result<Vec<String>, String> {
113 for entry in &entries {
115 if entry.vector.len() != self.dimension {
116 return Err(format!(
117 "Vector dimension mismatch: expected {}, got {}",
118 self.dimension,
119 entry.vector.len()
120 ));
121 }
122 }
123
124 let ids: Vec<String> = entries.iter().map(|e| e.id.clone()).collect();
125 let mut vectors = self.vectors.write().await;
126 vectors.extend(entries);
127
128 Ok(ids)
129 }
130
131 async fn get(&self, id: &str) -> Result<Option<VectorEntry>, String> {
132 let vectors = self.vectors.read().await;
133 Ok(vectors.iter().find(|v| v.id == id).cloned())
134 }
135
136 async fn search(
137 &self,
138 query_vector: &[f32],
139 agent_id: Option<&str>,
140 limit: usize,
141 threshold: Option<f32>,
142 ) -> Result<Vec<SearchResult>, String> {
143 if query_vector.len() != self.dimension {
145 return Err(format!(
146 "Query vector dimension mismatch: expected {}, got {}",
147 self.dimension,
148 query_vector.len()
149 ));
150 }
151
152 let vectors = self.vectors.read().await;
153
154 let filtered: Vec<&VectorEntry> = if let Some(aid) = agent_id {
156 vectors.iter().filter(|v| v.agent_id == aid).collect()
157 } else {
158 vectors.iter().collect()
159 };
160
161 let mut results: Vec<SearchResult> = Vec::new();
163
164 for entry in filtered {
165 match cosine_similarity(query_vector, &entry.vector) {
166 Ok(score) => {
167 if let Some(min_score) = threshold {
169 if score < min_score {
170 continue;
171 }
172 }
173
174 results.push(SearchResult {
175 entry: entry.clone(),
176 score,
177 });
178 }
179 Err(e) => {
180 log::warn!("Failed to compute similarity for vector {}: {}", entry.id, e);
181 continue;
182 }
183 }
184 }
185
186 results.sort_by(|a, b| {
188 b.score
189 .partial_cmp(&a.score)
190 .unwrap_or(std::cmp::Ordering::Equal)
191 });
192
193 results.truncate(limit);
195
196 Ok(results)
197 }
198
199 async fn clear_agent_vectors(&self, agent_id: &str) -> Result<(), String> {
200 let mut vectors = self.vectors.write().await;
201 vectors.retain(|v| v.agent_id != agent_id);
202 Ok(())
203 }
204
205 async fn count(&self) -> Result<usize, String> {
206 let vectors = self.vectors.read().await;
207 Ok(vectors.len())
208 }
209
210 fn dimension(&self) -> usize {
211 self.dimension
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[tokio::test]
220 async fn test_store_and_get() {
221 let store = LocalVectorStore::new(3);
222 let entry = VectorEntry::new(
223 "mem-1".to_string(),
224 "agent-1".to_string(),
225 "test".to_string(),
226 vec![1.0, 2.0, 3.0],
227 None,
228 );
229
230 let id = entry.id.clone();
231 store.store(entry).await.unwrap();
232
233 let retrieved = store.get(&id).await.unwrap();
234 assert!(retrieved.is_some());
235 assert_eq!(retrieved.unwrap().text, "test");
236 }
237
238 #[tokio::test]
239 async fn test_dimension_validation() {
240 let store = LocalVectorStore::new(3);
241 let entry = VectorEntry::new(
242 "mem-1".to_string(),
243 "agent-1".to_string(),
244 "test".to_string(),
245 vec![1.0, 2.0], None,
247 );
248
249 let result = store.store(entry).await;
250 assert!(result.is_err());
251 }
252
253 #[tokio::test]
254 async fn test_search() {
255 let store = LocalVectorStore::new(3);
256
257 let entries = vec![
259 VectorEntry::new(
260 "mem-1".to_string(),
261 "agent-1".to_string(),
262 "cat".to_string(),
263 vec![1.0, 0.0, 0.0],
264 None,
265 ),
266 VectorEntry::new(
267 "mem-2".to_string(),
268 "agent-1".to_string(),
269 "dog".to_string(),
270 vec![0.9, 0.1, 0.0],
271 None,
272 ),
273 VectorEntry::new(
274 "mem-3".to_string(),
275 "agent-1".to_string(),
276 "car".to_string(),
277 vec![0.0, 1.0, 0.0],
278 None,
279 ),
280 ];
281
282 for entry in entries {
283 store.store(entry).await.unwrap();
284 }
285
286 let query = vec![1.0, 0.0, 0.0];
288 let results = store.search(&query, Some("agent-1"), 2, None).await.unwrap();
289
290 assert_eq!(results.len(), 2);
291 assert_eq!(results[0].entry.text, "cat"); assert_eq!(results[1].entry.text, "dog"); assert!(results[0].score > results[1].score);
294 }
295
296 #[tokio::test]
297 async fn test_search_with_threshold() {
298 let store = LocalVectorStore::new(2);
299
300 store
301 .store(VectorEntry::new(
302 "mem-1".to_string(),
303 "agent-1".to_string(),
304 "similar".to_string(),
305 vec![1.0, 0.0],
306 None,
307 ))
308 .await
309 .unwrap();
310
311 store
312 .store(VectorEntry::new(
313 "mem-2".to_string(),
314 "agent-1".to_string(),
315 "different".to_string(),
316 vec![0.0, 1.0],
317 None,
318 ))
319 .await
320 .unwrap();
321
322 let query = vec![1.0, 0.0];
323 let results = store
324 .search(&query, Some("agent-1"), 10, Some(0.5))
325 .await
326 .unwrap();
327
328 assert_eq!(results.len(), 1);
330 assert_eq!(results[0].entry.text, "similar");
331 }
332
333 #[tokio::test]
334 async fn test_agent_filtering() {
335 let store = LocalVectorStore::new(2);
336
337 store
338 .store(VectorEntry::new(
339 "mem-1".to_string(),
340 "agent-1".to_string(),
341 "agent1".to_string(),
342 vec![1.0, 0.0],
343 None,
344 ))
345 .await
346 .unwrap();
347
348 store
349 .store(VectorEntry::new(
350 "mem-2".to_string(),
351 "agent-2".to_string(),
352 "agent2".to_string(),
353 vec![1.0, 0.0],
354 None,
355 ))
356 .await
357 .unwrap();
358
359 let query = vec![1.0, 0.0];
360 let results = store.search(&query, Some("agent-1"), 10, None).await.unwrap();
361
362 assert_eq!(results.len(), 1);
363 assert_eq!(results[0].entry.agent_id, "agent-1");
364 }
365
366 #[tokio::test]
367 async fn test_clear_agent_vectors() {
368 let store = LocalVectorStore::new(2);
369
370 store
371 .store(VectorEntry::new(
372 "mem-1".to_string(),
373 "agent-1".to_string(),
374 "test".to_string(),
375 vec![1.0, 0.0],
376 None,
377 ))
378 .await
379 .unwrap();
380
381 store
382 .store(VectorEntry::new(
383 "mem-2".to_string(),
384 "agent-2".to_string(),
385 "test".to_string(),
386 vec![1.0, 0.0],
387 None,
388 ))
389 .await
390 .unwrap();
391
392 assert_eq!(store.count().await.unwrap(), 2);
393
394 store.clear_agent_vectors("agent-1").await.unwrap();
395
396 assert_eq!(store.count().await.unwrap(), 1);
397
398 let query = vec![1.0, 0.0];
399 let results = store.search(&query, None, 10, None).await.unwrap();
400 assert_eq!(results.len(), 1);
401 assert_eq!(results[0].entry.agent_id, "agent-2");
402 }
403
404 #[tokio::test]
405 async fn test_store_batch() {
406 let store = LocalVectorStore::new(2);
407
408 let entries = vec![
409 VectorEntry::new(
410 "mem-1".to_string(),
411 "agent-1".to_string(),
412 "test1".to_string(),
413 vec![1.0, 0.0],
414 None,
415 ),
416 VectorEntry::new(
417 "mem-2".to_string(),
418 "agent-1".to_string(),
419 "test2".to_string(),
420 vec![0.0, 1.0],
421 None,
422 ),
423 ];
424
425 let ids = store.store_batch(entries).await.unwrap();
426 assert_eq!(ids.len(), 2);
427 assert_eq!(store.count().await.unwrap(), 2);
428 }
429}