1use async_trait::async_trait;
13use fabryk_core::{Error, Result};
14use serde::{Deserialize, Serialize};
15
16use crate::embedding::EmbeddingProvider;
17use crate::types::{
18 EmbeddedDocument, VectorConfig, VectorSearchParams, VectorSearchResult, VectorSearchResults,
19};
20use std::path::Path;
21use std::sync::Arc;
22
23#[async_trait]
34pub trait VectorBackend: Send + Sync {
35 async fn search(&self, params: VectorSearchParams) -> Result<VectorSearchResults>;
41
42 fn name(&self) -> &str;
44
45 fn is_ready(&self) -> bool {
47 true
48 }
49
50 fn document_count(&self) -> Result<usize>;
52}
53
54pub fn create_vector_backend(
62 config: &VectorConfig,
63 provider: Arc<dyn EmbeddingProvider>,
64) -> Result<Box<dyn VectorBackend>> {
65 if !config.enabled {
66 return Ok(Box::new(SimpleVectorBackend::new(provider)));
67 }
68
69 match config.backend.as_str() {
70 #[cfg(feature = "vector-lancedb")]
71 "lancedb" => {
72 log::info!("LanceDB requested but requires async build(); returning simple backend");
75 Ok(Box::new(SimpleVectorBackend::new(provider)))
76 }
77 _ => Ok(Box::new(SimpleVectorBackend::new(provider))),
78 }
79}
80
81#[derive(Serialize, Deserialize)]
87struct VectorCache {
88 content_hash: String,
89 documents: Vec<EmbeddedDocument>,
90}
91
92#[derive(Deserialize)]
94struct VectorCacheHeader {
95 content_hash: String,
96}
97
98pub struct SimpleVectorBackend {
114 provider: Arc<dyn EmbeddingProvider>,
115 documents: Vec<EmbeddedDocument>,
116}
117
118impl SimpleVectorBackend {
119 pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
121 Self {
122 provider,
123 documents: Vec::new(),
124 }
125 }
126
127 pub fn add_documents(&mut self, documents: Vec<EmbeddedDocument>) {
129 self.documents.extend(documents);
130 }
131
132 pub fn save_cache(&self, path: &Path, content_hash: &str) -> Result<()> {
137 let cache = VectorCache {
138 content_hash: content_hash.to_string(),
139 documents: self.documents.clone(),
140 };
141
142 if let Some(parent) = path.parent() {
144 if !parent.exists() {
145 std::fs::create_dir_all(parent).map_err(|e| Error::io_with_path(e, parent))?;
146 }
147 }
148
149 let json = serde_json::to_string(&cache)
150 .map_err(|e| Error::operation(format!("Failed to serialize vector cache: {e}")))?;
151
152 std::fs::write(path, json).map_err(|e| Error::io_with_path(e, path))?;
153
154 log::info!(
155 "Saved vector cache: {} documents to {}",
156 self.documents.len(),
157 path.display()
158 );
159
160 Ok(())
161 }
162
163 pub fn load_cache(path: &Path, provider: Arc<dyn EmbeddingProvider>) -> Result<Option<Self>> {
168 if !path.exists() {
169 return Ok(None);
170 }
171
172 let json = std::fs::read_to_string(path).map_err(|e| Error::io_with_path(e, path))?;
173
174 let cache: VectorCache = serde_json::from_str(&json)
175 .map_err(|e| Error::parse(format!("Failed to parse vector cache: {e}")))?;
176
177 let mut backend = Self::new(provider);
178 backend.documents = cache.documents;
179
180 log::info!(
181 "Loaded vector cache: {} documents from {}",
182 backend.documents.len(),
183 path.display()
184 );
185
186 Ok(Some(backend))
187 }
188
189 pub fn is_cache_fresh(path: &Path, content_hash: &str) -> bool {
191 if !path.exists() {
192 return false;
193 }
194
195 if let Ok(json) = std::fs::read_to_string(path) {
197 if let Ok(cache) = serde_json::from_str::<VectorCacheHeader>(&json) {
198 return cache.content_hash == content_hash;
199 }
200 }
201
202 false
203 }
204
205 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
207 if a.len() != b.len() || a.is_empty() {
208 return 0.0;
209 }
210
211 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
212 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
213 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
214
215 if norm_a == 0.0 || norm_b == 0.0 {
216 return 0.0;
217 }
218
219 dot / (norm_a * norm_b)
220 }
221}
222
223#[async_trait]
224impl VectorBackend for SimpleVectorBackend {
225 async fn search(&self, params: VectorSearchParams) -> Result<VectorSearchResults> {
226 if self.documents.is_empty() {
227 return Ok(VectorSearchResults::empty(self.name()));
228 }
229
230 let query_embedding = self.provider.embed(¶ms.query).await?;
231 let limit = params.limit.unwrap_or(10);
232 let threshold = params.similarity_threshold.unwrap_or(0.0);
233
234 let mut scored: Vec<(usize, f32)> = self
235 .documents
236 .iter()
237 .enumerate()
238 .map(|(i, doc)| {
239 let sim = Self::cosine_similarity(&query_embedding, &doc.embedding);
240 (i, sim)
241 })
242 .filter(|(_, sim)| *sim >= threshold)
243 .collect();
244
245 if let Some(ref category) = params.category {
247 scored.retain(|(i, _)| {
248 self.documents[*i].document.category.as_deref() == Some(category.as_str())
249 });
250 }
251
252 for (key, value) in ¶ms.metadata_filters {
254 scored.retain(|(i, _)| {
255 self.documents[*i]
256 .document
257 .metadata
258 .get(key)
259 .map(|v| v == value)
260 .unwrap_or(false)
261 });
262 }
263
264 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
266 scored.truncate(limit);
267
268 let total = scored.len();
269 let items: Vec<VectorSearchResult> = scored
270 .into_iter()
271 .map(|(i, score)| {
272 let doc = &self.documents[i];
273 let distance = 1.0 - score; VectorSearchResult {
275 id: doc.document.id.clone(),
276 score,
277 distance,
278 metadata: doc.document.metadata.clone(),
279 }
280 })
281 .collect();
282
283 Ok(VectorSearchResults {
284 items,
285 total,
286 backend: self.name().to_string(),
287 })
288 }
289
290 fn name(&self) -> &str {
291 "simple"
292 }
293
294 fn document_count(&self) -> Result<usize> {
295 Ok(self.documents.len())
296 }
297}
298
299impl std::fmt::Debug for SimpleVectorBackend {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 f.debug_struct("SimpleVectorBackend")
302 .field("documents", &self.documents.len())
303 .finish()
304 }
305}
306
307#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::embedding::MockEmbeddingProvider;
315 use crate::types::VectorDocument;
316
317 fn mock_provider() -> Arc<dyn EmbeddingProvider> {
318 Arc::new(MockEmbeddingProvider::new(8))
319 }
320
321 #[test]
322 fn test_simple_backend_creation() {
323 let backend = SimpleVectorBackend::new(mock_provider());
324 assert_eq!(backend.name(), "simple");
325 assert!(backend.is_ready());
326 assert_eq!(backend.document_count().unwrap(), 0);
327 }
328
329 #[test]
330 fn test_simple_backend_add_documents() {
331 let provider = mock_provider();
332 let mut backend = SimpleVectorBackend::new(provider);
333
334 let docs = vec![
335 EmbeddedDocument::new(
336 VectorDocument::new("doc-1", "hello"),
337 vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
338 ),
339 EmbeddedDocument::new(
340 VectorDocument::new("doc-2", "world"),
341 vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
342 ),
343 ];
344
345 backend.add_documents(docs);
346 assert_eq!(backend.document_count().unwrap(), 2);
347 }
348
349 #[tokio::test]
350 async fn test_simple_backend_search_empty() {
351 let backend = SimpleVectorBackend::new(mock_provider());
352
353 let params = VectorSearchParams::new("test query");
354 let results = backend.search(params).await.unwrap();
355
356 assert!(results.items.is_empty());
357 assert_eq!(results.total, 0);
358 assert_eq!(results.backend, "simple");
359 }
360
361 #[tokio::test]
362 async fn test_simple_backend_search_with_results() {
363 let provider = Arc::new(MockEmbeddingProvider::new(4));
364 let mut backend = SimpleVectorBackend::new(provider.clone());
365
366 let docs = vec![
368 EmbeddedDocument::new(
369 VectorDocument::new("doc-close", "close match"),
370 vec![0.9, 0.1, 0.0, 0.0],
371 ),
372 EmbeddedDocument::new(
373 VectorDocument::new("doc-far", "far away"),
374 vec![0.0, 0.0, 0.1, 0.9],
375 ),
376 ];
377
378 backend.add_documents(docs);
379
380 let params = VectorSearchParams::new("test").with_limit(10);
381 let results = backend.search(params).await.unwrap();
382
383 assert_eq!(results.items.len(), 2);
384 assert!(results.items[0].score >= results.items[1].score);
386 }
387
388 #[tokio::test]
389 async fn test_simple_backend_search_with_threshold() {
390 let provider = Arc::new(MockEmbeddingProvider::new(4));
391 let mut backend = SimpleVectorBackend::new(provider.clone());
392
393 let docs = vec![
394 EmbeddedDocument::new(
395 VectorDocument::new("doc-1", "text"),
396 vec![1.0, 0.0, 0.0, 0.0],
397 ),
398 EmbeddedDocument::new(
399 VectorDocument::new("doc-2", "text"),
400 vec![0.0, 0.0, 0.0, 1.0],
401 ),
402 ];
403
404 backend.add_documents(docs);
405
406 let params = VectorSearchParams::new("test").with_threshold(0.99);
408 let results = backend.search(params).await.unwrap();
409
410 assert!(results.items.len() <= 2);
412 }
413
414 #[tokio::test]
415 async fn test_simple_backend_search_with_category() {
416 let provider = Arc::new(MockEmbeddingProvider::new(4));
417 let mut backend = SimpleVectorBackend::new(provider.clone());
418
419 let docs = vec![
420 EmbeddedDocument::new(
421 VectorDocument::new("doc-1", "harmony text").with_category("harmony"),
422 vec![0.5, 0.5, 0.0, 0.0],
423 ),
424 EmbeddedDocument::new(
425 VectorDocument::new("doc-2", "rhythm text").with_category("rhythm"),
426 vec![0.5, 0.0, 0.5, 0.0],
427 ),
428 ];
429
430 backend.add_documents(docs);
431
432 let params = VectorSearchParams::new("test").with_category("harmony");
433 let results = backend.search(params).await.unwrap();
434
435 assert_eq!(results.items.len(), 1);
436 assert_eq!(results.items[0].id, "doc-1");
437 }
438
439 #[tokio::test]
440 async fn test_simple_backend_search_with_metadata_filter() {
441 let provider = Arc::new(MockEmbeddingProvider::new(4));
442 let mut backend = SimpleVectorBackend::new(provider.clone());
443
444 let docs = vec![
445 EmbeddedDocument::new(
446 VectorDocument::new("doc-1", "text").with_metadata("tier", "beginner"),
447 vec![0.5, 0.5, 0.0, 0.0],
448 ),
449 EmbeddedDocument::new(
450 VectorDocument::new("doc-2", "text").with_metadata("tier", "advanced"),
451 vec![0.5, 0.0, 0.5, 0.0],
452 ),
453 ];
454
455 backend.add_documents(docs);
456
457 let params = VectorSearchParams::new("test").with_filter("tier", "beginner");
458 let results = backend.search(params).await.unwrap();
459
460 assert_eq!(results.items.len(), 1);
461 assert_eq!(results.items[0].id, "doc-1");
462 }
463
464 #[tokio::test]
465 async fn test_simple_backend_search_limit() {
466 let provider = Arc::new(MockEmbeddingProvider::new(4));
467 let mut backend = SimpleVectorBackend::new(provider.clone());
468
469 let docs: Vec<EmbeddedDocument> = (0..20)
470 .map(|i| {
471 EmbeddedDocument::new(
472 VectorDocument::new(format!("doc-{i}"), format!("text {i}")),
473 vec![0.5, 0.5, 0.0, 0.0],
474 )
475 })
476 .collect();
477
478 backend.add_documents(docs);
479
480 let params = VectorSearchParams::new("test").with_limit(5);
481 let results = backend.search(params).await.unwrap();
482
483 assert_eq!(results.items.len(), 5);
484 }
485
486 #[test]
487 fn test_cosine_similarity_identical() {
488 let v = vec![1.0, 0.0, 0.0];
489 let sim = SimpleVectorBackend::cosine_similarity(&v, &v);
490 assert!((sim - 1.0).abs() < 1e-5);
491 }
492
493 #[test]
494 fn test_cosine_similarity_orthogonal() {
495 let a = vec![1.0, 0.0, 0.0];
496 let b = vec![0.0, 1.0, 0.0];
497 let sim = SimpleVectorBackend::cosine_similarity(&a, &b);
498 assert!(sim.abs() < 1e-5);
499 }
500
501 #[test]
502 fn test_cosine_similarity_opposite() {
503 let a = vec![1.0, 0.0];
504 let b = vec![-1.0, 0.0];
505 let sim = SimpleVectorBackend::cosine_similarity(&a, &b);
506 assert!((sim + 1.0).abs() < 1e-5);
507 }
508
509 #[test]
510 fn test_cosine_similarity_empty() {
511 let sim = SimpleVectorBackend::cosine_similarity(&[], &[]);
512 assert_eq!(sim, 0.0);
513 }
514
515 #[test]
516 fn test_cosine_similarity_different_lengths() {
517 let a = vec![1.0, 0.0];
518 let b = vec![1.0, 0.0, 0.0];
519 let sim = SimpleVectorBackend::cosine_similarity(&a, &b);
520 assert_eq!(sim, 0.0);
521 }
522
523 #[test]
524 fn test_create_vector_backend_simple() {
525 let config = VectorConfig {
526 backend: "simple".to_string(),
527 ..Default::default()
528 };
529 let provider = mock_provider();
530
531 let backend = create_vector_backend(&config, provider).unwrap();
532 assert_eq!(backend.name(), "simple");
533 }
534
535 #[test]
536 fn test_create_vector_backend_disabled() {
537 let config = VectorConfig {
538 enabled: false,
539 ..Default::default()
540 };
541 let provider = mock_provider();
542
543 let backend = create_vector_backend(&config, provider).unwrap();
544 assert_eq!(backend.name(), "simple");
545 }
546
547 #[test]
548 fn test_trait_object_safety() {
549 fn _assert_object_safe(_: &dyn VectorBackend) {}
551 }
552
553 #[test]
554 fn test_simple_backend_debug() {
555 let backend = SimpleVectorBackend::new(mock_provider());
556 let debug = format!("{:?}", backend);
557 assert!(debug.contains("SimpleVectorBackend"));
558 assert!(debug.contains("documents"));
559 }
560
561 #[test]
566 fn test_save_and_load_cache() {
567 let dir = tempfile::tempdir().unwrap();
568 let cache_path = dir.path().join("test-cache.json");
569 let provider = mock_provider();
570
571 let mut backend = SimpleVectorBackend::new(provider.clone());
572 backend.add_documents(vec![
573 EmbeddedDocument::new(
574 VectorDocument::new("doc-1", "hello"),
575 vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
576 ),
577 EmbeddedDocument::new(
578 VectorDocument::new("doc-2", "world"),
579 vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
580 ),
581 ]);
582
583 backend.save_cache(&cache_path, "hash123").unwrap();
584 assert!(cache_path.exists());
585
586 let loaded = SimpleVectorBackend::load_cache(&cache_path, provider)
587 .unwrap()
588 .unwrap();
589 assert_eq!(loaded.document_count().unwrap(), 2);
590 }
591
592 #[test]
593 fn test_cache_freshness() {
594 let dir = tempfile::tempdir().unwrap();
595 let cache_path = dir.path().join("test-cache.json");
596 let provider = mock_provider();
597
598 let mut backend = SimpleVectorBackend::new(provider);
599 backend.add_documents(vec![EmbeddedDocument::new(
600 VectorDocument::new("doc-1", "hello"),
601 vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
602 )]);
603
604 backend.save_cache(&cache_path, "hash123").unwrap();
605
606 assert!(SimpleVectorBackend::is_cache_fresh(&cache_path, "hash123"));
607 assert!(!SimpleVectorBackend::is_cache_fresh(
608 &cache_path,
609 "different_hash"
610 ));
611 assert!(!SimpleVectorBackend::is_cache_fresh(
612 &dir.path().join("missing.json"),
613 "hash123"
614 ));
615 }
616
617 #[test]
618 fn test_load_cache_nonexistent() {
619 let result = SimpleVectorBackend::load_cache(
620 std::path::Path::new("/nonexistent/path.json"),
621 mock_provider(),
622 )
623 .unwrap();
624 assert!(result.is_none());
625 }
626}