Skip to main content

fabryk_vector/
backend.rs

1//! Vector backend trait and simple fallback implementation.
2//!
3//! This module defines the `VectorBackend` trait that all vector search
4//! implementations must satisfy. It follows the same pattern as
5//! `fabryk_fts::SearchBackend`.
6//!
7//! # Backends
8//!
9//! - `LancedbBackend`: Vector search with LanceDB (requires `vector-lancedb` feature)
10//! - `SimpleVectorBackend`: In-memory brute-force fallback for small collections
11
12use async_trait::async_trait;
13use fabryk_core::{Error, Result};
14use serde::{Deserialize, Serialize};
15
16use crate::embedding::EmbeddingProvider;
17use crate::types::{
18    EmbeddedDocument, VectorConfig, VectorSearchParams, VectorSearchResult, VectorSearchResults,
19};
20use std::path::Path;
21use std::sync::Arc;
22
23/// Abstract vector search backend trait.
24///
25/// Implementations provide different vector search strategies:
26/// - `LancedbBackend`: Approximate nearest neighbor with LanceDB
27/// - `SimpleVectorBackend`: Brute-force cosine similarity (fallback)
28///
29/// # Async
30///
31/// The `search` method is async to support I/O-bound operations (embedding
32/// generation, index access) without blocking.
33#[async_trait]
34pub trait VectorBackend: Send + Sync {
35    /// Execute a vector similarity search.
36    ///
37    /// The query string is embedded using the backend's embedding provider,
38    /// then compared against indexed vectors. Results are ordered by
39    /// similarity score (highest first).
40    async fn search(&self, params: VectorSearchParams) -> Result<VectorSearchResults>;
41
42    /// Get the backend name for diagnostics.
43    fn name(&self) -> &str;
44
45    /// Check if the backend is ready to handle queries.
46    fn is_ready(&self) -> bool {
47        true
48    }
49
50    /// Get the number of indexed documents.
51    fn document_count(&self) -> Result<usize>;
52}
53
54/// Create a vector backend based on configuration.
55///
56/// Selection logic:
57/// 1. If `vector-lancedb` feature enabled and config says "lancedb" → `LancedbBackend`
58/// 2. Otherwise → `SimpleVectorBackend` (brute-force fallback)
59///
60/// Note: This creates an empty backend. Use `VectorIndexBuilder` to populate it.
61pub fn create_vector_backend(
62    config: &VectorConfig,
63    provider: Arc<dyn EmbeddingProvider>,
64) -> Result<Box<dyn VectorBackend>> {
65    if !config.enabled {
66        return Ok(Box::new(SimpleVectorBackend::new(provider)));
67    }
68
69    match config.backend.as_str() {
70        #[cfg(feature = "vector-lancedb")]
71        "lancedb" => {
72            // LanceDB backend requires async initialization; return simple as default.
73            // Use LancedbBackend::build() for full initialization.
74            log::info!("LanceDB requested but requires async build(); returning simple backend");
75            Ok(Box::new(SimpleVectorBackend::new(provider)))
76        }
77        _ => Ok(Box::new(SimpleVectorBackend::new(provider))),
78    }
79}
80
81// ============================================================================
82// SimpleVectorBackend
83// ============================================================================
84
85/// Serializable vector cache for persistence.
86#[derive(Serialize, Deserialize)]
87struct VectorCache {
88    content_hash: String,
89    documents: Vec<EmbeddedDocument>,
90}
91
92/// Lightweight header for checking freshness without loading all documents.
93#[derive(Deserialize)]
94struct VectorCacheHeader {
95    content_hash: String,
96}
97
98/// Brute-force vector search backend.
99///
100/// Stores documents in memory and computes cosine similarity for each query.
101/// Used as a fallback when LanceDB is not available or for small collections.
102///
103/// # Caching
104///
105/// Supports cache persistence via [`save_cache`](Self::save_cache) and
106/// [`load_cache`](Self::load_cache). Use [`is_cache_fresh`](Self::is_cache_fresh)
107/// to check if a cached index is still valid.
108///
109/// # Limitations
110///
111/// - O(n) search time
112/// - All documents must fit in memory
113pub struct SimpleVectorBackend {
114    provider: Arc<dyn EmbeddingProvider>,
115    documents: Vec<EmbeddedDocument>,
116}
117
118impl SimpleVectorBackend {
119    /// Create a new empty simple vector backend.
120    pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
121        Self {
122            provider,
123            documents: Vec::new(),
124        }
125    }
126
127    /// Add documents to the backend.
128    pub fn add_documents(&mut self, documents: Vec<EmbeddedDocument>) {
129        self.documents.extend(documents);
130    }
131
132    /// Save the backend's documents to a cache file.
133    ///
134    /// Stores documents and a content hash for freshness checking.
135    /// Uses JSON format for simplicity and debuggability.
136    pub fn save_cache(&self, path: &Path, content_hash: &str) -> Result<()> {
137        let cache = VectorCache {
138            content_hash: content_hash.to_string(),
139            documents: self.documents.clone(),
140        };
141
142        // Ensure parent directory exists
143        if let Some(parent) = path.parent() {
144            if !parent.exists() {
145                std::fs::create_dir_all(parent).map_err(|e| Error::io_with_path(e, parent))?;
146            }
147        }
148
149        let json = serde_json::to_string(&cache)
150            .map_err(|e| Error::operation(format!("Failed to serialize vector cache: {e}")))?;
151
152        std::fs::write(path, json).map_err(|e| Error::io_with_path(e, path))?;
153
154        log::info!(
155            "Saved vector cache: {} documents to {}",
156            self.documents.len(),
157            path.display()
158        );
159
160        Ok(())
161    }
162
163    /// Load a cached backend from disk.
164    ///
165    /// Returns `Ok(Some(backend))` if the cache exists and loaded successfully,
166    /// `Ok(None)` if the cache doesn't exist, or `Err` on read/parse errors.
167    pub fn load_cache(path: &Path, provider: Arc<dyn EmbeddingProvider>) -> Result<Option<Self>> {
168        if !path.exists() {
169            return Ok(None);
170        }
171
172        let json = std::fs::read_to_string(path).map_err(|e| Error::io_with_path(e, path))?;
173
174        let cache: VectorCache = serde_json::from_str(&json)
175            .map_err(|e| Error::parse(format!("Failed to parse vector cache: {e}")))?;
176
177        let mut backend = Self::new(provider);
178        backend.documents = cache.documents;
179
180        log::info!(
181            "Loaded vector cache: {} documents from {}",
182            backend.documents.len(),
183            path.display()
184        );
185
186        Ok(Some(backend))
187    }
188
189    /// Check if the cache is fresh (content hasn't changed).
190    pub fn is_cache_fresh(path: &Path, content_hash: &str) -> bool {
191        if !path.exists() {
192            return false;
193        }
194
195        // Read only the content_hash field without deserializing the full document array
196        if let Ok(json) = std::fs::read_to_string(path) {
197            if let Ok(cache) = serde_json::from_str::<VectorCacheHeader>(&json) {
198                return cache.content_hash == content_hash;
199            }
200        }
201
202        false
203    }
204
205    /// Compute cosine similarity between two vectors.
206    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
207        if a.len() != b.len() || a.is_empty() {
208            return 0.0;
209        }
210
211        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
212        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
213        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
214
215        if norm_a == 0.0 || norm_b == 0.0 {
216            return 0.0;
217        }
218
219        dot / (norm_a * norm_b)
220    }
221}
222
223#[async_trait]
224impl VectorBackend for SimpleVectorBackend {
225    async fn search(&self, params: VectorSearchParams) -> Result<VectorSearchResults> {
226        if self.documents.is_empty() {
227            return Ok(VectorSearchResults::empty(self.name()));
228        }
229
230        let query_embedding = self.provider.embed(&params.query).await?;
231        let limit = params.limit.unwrap_or(10);
232        let threshold = params.similarity_threshold.unwrap_or(0.0);
233
234        let mut scored: Vec<(usize, f32)> = self
235            .documents
236            .iter()
237            .enumerate()
238            .map(|(i, doc)| {
239                let sim = Self::cosine_similarity(&query_embedding, &doc.embedding);
240                (i, sim)
241            })
242            .filter(|(_, sim)| *sim >= threshold)
243            .collect();
244
245        // Filter by category if specified
246        if let Some(ref category) = params.category {
247            scored.retain(|(i, _)| {
248                self.documents[*i].document.category.as_deref() == Some(category.as_str())
249            });
250        }
251
252        // Filter by metadata
253        for (key, value) in &params.metadata_filters {
254            scored.retain(|(i, _)| {
255                self.documents[*i]
256                    .document
257                    .metadata
258                    .get(key)
259                    .map(|v| v == value)
260                    .unwrap_or(false)
261            });
262        }
263
264        // Sort by similarity (highest first)
265        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
266        scored.truncate(limit);
267
268        let total = scored.len();
269        let items: Vec<VectorSearchResult> = scored
270            .into_iter()
271            .map(|(i, score)| {
272                let doc = &self.documents[i];
273                let distance = 1.0 - score; // cosine distance
274                VectorSearchResult {
275                    id: doc.document.id.clone(),
276                    score,
277                    distance,
278                    metadata: doc.document.metadata.clone(),
279                }
280            })
281            .collect();
282
283        Ok(VectorSearchResults {
284            items,
285            total,
286            backend: self.name().to_string(),
287        })
288    }
289
290    fn name(&self) -> &str {
291        "simple"
292    }
293
294    fn document_count(&self) -> Result<usize> {
295        Ok(self.documents.len())
296    }
297}
298
299impl std::fmt::Debug for SimpleVectorBackend {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        f.debug_struct("SimpleVectorBackend")
302            .field("documents", &self.documents.len())
303            .finish()
304    }
305}
306
307// ============================================================================
308// Tests
309// ============================================================================
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use crate::embedding::MockEmbeddingProvider;
315    use crate::types::VectorDocument;
316
317    fn mock_provider() -> Arc<dyn EmbeddingProvider> {
318        Arc::new(MockEmbeddingProvider::new(8))
319    }
320
321    #[test]
322    fn test_simple_backend_creation() {
323        let backend = SimpleVectorBackend::new(mock_provider());
324        assert_eq!(backend.name(), "simple");
325        assert!(backend.is_ready());
326        assert_eq!(backend.document_count().unwrap(), 0);
327    }
328
329    #[test]
330    fn test_simple_backend_add_documents() {
331        let provider = mock_provider();
332        let mut backend = SimpleVectorBackend::new(provider);
333
334        let docs = vec![
335            EmbeddedDocument::new(
336                VectorDocument::new("doc-1", "hello"),
337                vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
338            ),
339            EmbeddedDocument::new(
340                VectorDocument::new("doc-2", "world"),
341                vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
342            ),
343        ];
344
345        backend.add_documents(docs);
346        assert_eq!(backend.document_count().unwrap(), 2);
347    }
348
349    #[tokio::test]
350    async fn test_simple_backend_search_empty() {
351        let backend = SimpleVectorBackend::new(mock_provider());
352
353        let params = VectorSearchParams::new("test query");
354        let results = backend.search(params).await.unwrap();
355
356        assert!(results.items.is_empty());
357        assert_eq!(results.total, 0);
358        assert_eq!(results.backend, "simple");
359    }
360
361    #[tokio::test]
362    async fn test_simple_backend_search_with_results() {
363        let provider = Arc::new(MockEmbeddingProvider::new(4));
364        let mut backend = SimpleVectorBackend::new(provider.clone());
365
366        // Add documents with known embeddings
367        let docs = vec![
368            EmbeddedDocument::new(
369                VectorDocument::new("doc-close", "close match"),
370                vec![0.9, 0.1, 0.0, 0.0],
371            ),
372            EmbeddedDocument::new(
373                VectorDocument::new("doc-far", "far away"),
374                vec![0.0, 0.0, 0.1, 0.9],
375            ),
376        ];
377
378        backend.add_documents(docs);
379
380        let params = VectorSearchParams::new("test").with_limit(10);
381        let results = backend.search(params).await.unwrap();
382
383        assert_eq!(results.items.len(), 2);
384        // Results should be ordered by score (highest first)
385        assert!(results.items[0].score >= results.items[1].score);
386    }
387
388    #[tokio::test]
389    async fn test_simple_backend_search_with_threshold() {
390        let provider = Arc::new(MockEmbeddingProvider::new(4));
391        let mut backend = SimpleVectorBackend::new(provider.clone());
392
393        let docs = vec![
394            EmbeddedDocument::new(
395                VectorDocument::new("doc-1", "text"),
396                vec![1.0, 0.0, 0.0, 0.0],
397            ),
398            EmbeddedDocument::new(
399                VectorDocument::new("doc-2", "text"),
400                vec![0.0, 0.0, 0.0, 1.0],
401            ),
402        ];
403
404        backend.add_documents(docs);
405
406        // Very high threshold should filter most results
407        let params = VectorSearchParams::new("test").with_threshold(0.99);
408        let results = backend.search(params).await.unwrap();
409
410        // At threshold 0.99, likely 0 or 1 results
411        assert!(results.items.len() <= 2);
412    }
413
414    #[tokio::test]
415    async fn test_simple_backend_search_with_category() {
416        let provider = Arc::new(MockEmbeddingProvider::new(4));
417        let mut backend = SimpleVectorBackend::new(provider.clone());
418
419        let docs = vec![
420            EmbeddedDocument::new(
421                VectorDocument::new("doc-1", "harmony text").with_category("harmony"),
422                vec![0.5, 0.5, 0.0, 0.0],
423            ),
424            EmbeddedDocument::new(
425                VectorDocument::new("doc-2", "rhythm text").with_category("rhythm"),
426                vec![0.5, 0.0, 0.5, 0.0],
427            ),
428        ];
429
430        backend.add_documents(docs);
431
432        let params = VectorSearchParams::new("test").with_category("harmony");
433        let results = backend.search(params).await.unwrap();
434
435        assert_eq!(results.items.len(), 1);
436        assert_eq!(results.items[0].id, "doc-1");
437    }
438
439    #[tokio::test]
440    async fn test_simple_backend_search_with_metadata_filter() {
441        let provider = Arc::new(MockEmbeddingProvider::new(4));
442        let mut backend = SimpleVectorBackend::new(provider.clone());
443
444        let docs = vec![
445            EmbeddedDocument::new(
446                VectorDocument::new("doc-1", "text").with_metadata("tier", "beginner"),
447                vec![0.5, 0.5, 0.0, 0.0],
448            ),
449            EmbeddedDocument::new(
450                VectorDocument::new("doc-2", "text").with_metadata("tier", "advanced"),
451                vec![0.5, 0.0, 0.5, 0.0],
452            ),
453        ];
454
455        backend.add_documents(docs);
456
457        let params = VectorSearchParams::new("test").with_filter("tier", "beginner");
458        let results = backend.search(params).await.unwrap();
459
460        assert_eq!(results.items.len(), 1);
461        assert_eq!(results.items[0].id, "doc-1");
462    }
463
464    #[tokio::test]
465    async fn test_simple_backend_search_limit() {
466        let provider = Arc::new(MockEmbeddingProvider::new(4));
467        let mut backend = SimpleVectorBackend::new(provider.clone());
468
469        let docs: Vec<EmbeddedDocument> = (0..20)
470            .map(|i| {
471                EmbeddedDocument::new(
472                    VectorDocument::new(format!("doc-{i}"), format!("text {i}")),
473                    vec![0.5, 0.5, 0.0, 0.0],
474                )
475            })
476            .collect();
477
478        backend.add_documents(docs);
479
480        let params = VectorSearchParams::new("test").with_limit(5);
481        let results = backend.search(params).await.unwrap();
482
483        assert_eq!(results.items.len(), 5);
484    }
485
486    #[test]
487    fn test_cosine_similarity_identical() {
488        let v = vec![1.0, 0.0, 0.0];
489        let sim = SimpleVectorBackend::cosine_similarity(&v, &v);
490        assert!((sim - 1.0).abs() < 1e-5);
491    }
492
493    #[test]
494    fn test_cosine_similarity_orthogonal() {
495        let a = vec![1.0, 0.0, 0.0];
496        let b = vec![0.0, 1.0, 0.0];
497        let sim = SimpleVectorBackend::cosine_similarity(&a, &b);
498        assert!(sim.abs() < 1e-5);
499    }
500
501    #[test]
502    fn test_cosine_similarity_opposite() {
503        let a = vec![1.0, 0.0];
504        let b = vec![-1.0, 0.0];
505        let sim = SimpleVectorBackend::cosine_similarity(&a, &b);
506        assert!((sim + 1.0).abs() < 1e-5);
507    }
508
509    #[test]
510    fn test_cosine_similarity_empty() {
511        let sim = SimpleVectorBackend::cosine_similarity(&[], &[]);
512        assert_eq!(sim, 0.0);
513    }
514
515    #[test]
516    fn test_cosine_similarity_different_lengths() {
517        let a = vec![1.0, 0.0];
518        let b = vec![1.0, 0.0, 0.0];
519        let sim = SimpleVectorBackend::cosine_similarity(&a, &b);
520        assert_eq!(sim, 0.0);
521    }
522
523    #[test]
524    fn test_create_vector_backend_simple() {
525        let config = VectorConfig {
526            backend: "simple".to_string(),
527            ..Default::default()
528        };
529        let provider = mock_provider();
530
531        let backend = create_vector_backend(&config, provider).unwrap();
532        assert_eq!(backend.name(), "simple");
533    }
534
535    #[test]
536    fn test_create_vector_backend_disabled() {
537        let config = VectorConfig {
538            enabled: false,
539            ..Default::default()
540        };
541        let provider = mock_provider();
542
543        let backend = create_vector_backend(&config, provider).unwrap();
544        assert_eq!(backend.name(), "simple");
545    }
546
547    #[test]
548    fn test_trait_object_safety() {
549        // Verify VectorBackend can be used as a trait object
550        fn _assert_object_safe(_: &dyn VectorBackend) {}
551    }
552
553    #[test]
554    fn test_simple_backend_debug() {
555        let backend = SimpleVectorBackend::new(mock_provider());
556        let debug = format!("{:?}", backend);
557        assert!(debug.contains("SimpleVectorBackend"));
558        assert!(debug.contains("documents"));
559    }
560
561    // ================================================================
562    // Cache tests
563    // ================================================================
564
565    #[test]
566    fn test_save_and_load_cache() {
567        let dir = tempfile::tempdir().unwrap();
568        let cache_path = dir.path().join("test-cache.json");
569        let provider = mock_provider();
570
571        let mut backend = SimpleVectorBackend::new(provider.clone());
572        backend.add_documents(vec![
573            EmbeddedDocument::new(
574                VectorDocument::new("doc-1", "hello"),
575                vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
576            ),
577            EmbeddedDocument::new(
578                VectorDocument::new("doc-2", "world"),
579                vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
580            ),
581        ]);
582
583        backend.save_cache(&cache_path, "hash123").unwrap();
584        assert!(cache_path.exists());
585
586        let loaded = SimpleVectorBackend::load_cache(&cache_path, provider)
587            .unwrap()
588            .unwrap();
589        assert_eq!(loaded.document_count().unwrap(), 2);
590    }
591
592    #[test]
593    fn test_cache_freshness() {
594        let dir = tempfile::tempdir().unwrap();
595        let cache_path = dir.path().join("test-cache.json");
596        let provider = mock_provider();
597
598        let mut backend = SimpleVectorBackend::new(provider);
599        backend.add_documents(vec![EmbeddedDocument::new(
600            VectorDocument::new("doc-1", "hello"),
601            vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
602        )]);
603
604        backend.save_cache(&cache_path, "hash123").unwrap();
605
606        assert!(SimpleVectorBackend::is_cache_fresh(&cache_path, "hash123"));
607        assert!(!SimpleVectorBackend::is_cache_fresh(
608            &cache_path,
609            "different_hash"
610        ));
611        assert!(!SimpleVectorBackend::is_cache_fresh(
612            &dir.path().join("missing.json"),
613            "hash123"
614        ));
615    }
616
617    #[test]
618    fn test_load_cache_nonexistent() {
619        let result = SimpleVectorBackend::load_cache(
620            std::path::Path::new("/nonexistent/path.json"),
621            mock_provider(),
622        )
623        .unwrap();
624        assert!(result.is_none());
625    }
626}