frame_catalog/
retrieval.rs

1//! Document indexing and retrieval system
2//!
3//! Provides high-level RAG (Retrieval-Augmented Generation) functionality
4
5use crate::embeddings::{EmbeddingGenerator, SimpleEmbeddingGenerator};
6#[cfg(feature = "onnx")]
7use crate::embeddings::OnnxEmbeddingGenerator;
8use crate::vector_store::{DocumentChunk, SearchResult, VectorStore, VectorStoreConfig};
9use std::sync::Arc;
10
11/// Retrieval system error
12#[derive(Debug, thiserror::Error)]
13pub enum RetrievalError {
14    #[error("Vector store error: {0}")]
15    VectorStore(#[from] crate::vector_store::VectorStoreError),
16
17    #[error("Embedding error: {0}")]
18    Embedding(#[from] crate::embeddings::EmbeddingError),
19
20    #[error("Invalid chunk size: {0}")]
21    InvalidChunkSize(usize),
22}
23
24pub type Result<T> = std::result::Result<T, RetrievalError>;
25
26/// Configuration for document retrieval
27#[derive(Debug, Clone)]
28pub struct RetrievalConfig {
29    /// Maximum chunk size in characters
30    pub max_chunk_size: usize,
31
32    /// Overlap between chunks in characters
33    pub chunk_overlap: usize,
34
35    /// Vector store configuration
36    pub vector_config: VectorStoreConfig,
37}
38
39impl Default for RetrievalConfig {
40    fn default() -> Self {
41        Self {
42            max_chunk_size: 512,
43            chunk_overlap: 50,
44            vector_config: VectorStoreConfig::default(),
45        }
46    }
47}
48
49/// Document retrieval system for RAG
50pub struct RetrievalSystem {
51    /// Vector store for similarity search
52    vector_store: VectorStore,
53
54    /// Embedding generator
55    embedder: Arc<dyn EmbeddingGenerator>,
56
57    /// Configuration
58    config: RetrievalConfig,
59}
60
61impl RetrievalSystem {
62    /// Create a new retrieval system
63    ///
64    /// Attempts to use ONNX embeddings (semantic) with fallback to simple embeddings (hash-based)
65    pub fn new(config: RetrievalConfig) -> Result<Self> {
66        let vector_store = VectorStore::new(config.vector_config.clone())?;
67
68        // Try to use ONNX embeddings first (real semantic embeddings)
69        #[cfg(feature = "onnx")]
70        let embedder: Arc<dyn EmbeddingGenerator> = match OnnxEmbeddingGenerator::new() {
71            Ok(onnx_gen) => {
72                tracing::info!("Using ONNX semantic embeddings (MiniLM-L6-v2, 384d)");
73                Arc::new(onnx_gen)
74            }
75            Err(e) => {
76                tracing::warn!(
77                    "ONNX embeddings unavailable ({}), falling back to simple hash-based embeddings",
78                    e
79                );
80                Arc::new(SimpleEmbeddingGenerator::new())
81            }
82        };
83
84        #[cfg(not(feature = "onnx"))]
85        let embedder: Arc<dyn EmbeddingGenerator> = {
86            tracing::info!("Using simple hash-based embeddings (ONNX feature not enabled)");
87            Arc::new(SimpleEmbeddingGenerator::new())
88        };
89
90        Ok(Self {
91            vector_store,
92            embedder,
93            config,
94        })
95    }
96
97    /// Create a new retrieval system with a custom embedding generator
98    pub fn with_embedder(
99        config: RetrievalConfig,
100        embedder: Arc<dyn EmbeddingGenerator>,
101    ) -> Result<Self> {
102        let vector_store = VectorStore::new(config.vector_config.clone())?;
103
104        Ok(Self {
105            vector_store,
106            embedder,
107            config,
108        })
109    }
110
111    /// Index a document by splitting it into chunks
112    pub fn index_document(
113        &self,
114        document_id: &str,
115        content: &str,
116        source: &str,
117        metadata: Option<String>,
118    ) -> Result<usize> {
119        let chunks = self.split_into_chunks(content);
120        let mut indexed_count = 0;
121
122        for (i, chunk_text) in chunks.iter().enumerate() {
123            let chunk = DocumentChunk {
124                id: format!("{}:{}", document_id, i),
125                content: chunk_text.to_string(),
126                source: source.to_string(),
127                metadata: metadata.clone(),
128            };
129
130            // Generate embedding
131            let embedding = self.embedder.generate(chunk_text)?;
132
133            // Add to vector store
134            self.vector_store.add_chunk(chunk, &embedding)?;
135            indexed_count += 1;
136        }
137
138        Ok(indexed_count)
139    }
140
141    /// Retrieve relevant documents for a query
142    pub fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
143        // Generate query embedding
144        let query_embedding = self.embedder.generate(query)?;
145
146        // Search vector store
147        let results = self.vector_store.search(&query_embedding, top_k)?;
148
149        Ok(results)
150    }
151
152    /// Split text into overlapping chunks
153    fn split_into_chunks(&self, text: &str) -> Vec<String> {
154        let max_size = self.config.max_chunk_size;
155        let overlap = self.config.chunk_overlap;
156
157        if text.len() <= max_size {
158            return vec![text.to_string()];
159        }
160
161        let mut chunks = Vec::new();
162        let mut start = 0;
163
164        while start < text.len() {
165            let end = (start + max_size).min(text.len());
166            let chunk = &text[start..end];
167            chunks.push(chunk.to_string());
168
169            if end >= text.len() {
170                break;
171            }
172
173            // Move forward with overlap
174            start += max_size - overlap;
175        }
176
177        chunks
178    }
179
180    /// Get the number of indexed chunks
181    pub fn chunk_count(&self) -> usize {
182        self.vector_store.len()
183    }
184
185    /// Clear all indexed documents
186    pub fn clear(&self) {
187        self.vector_store.clear();
188    }
189
190    /// Get retrieval statistics
191    pub fn stats(&self) -> RetrievalStats {
192        let vector_stats = self.vector_store.stats();
193        RetrievalStats {
194            num_chunks: vector_stats.num_chunks,
195            embedding_dim: vector_stats.embedding_dim,
196            max_chunk_size: self.config.max_chunk_size,
197            chunk_overlap: self.config.chunk_overlap,
198        }
199    }
200}
201
202/// Retrieval system statistics
203#[derive(Debug, Clone)]
204pub struct RetrievalStats {
205    pub num_chunks: usize,
206    pub embedding_dim: usize,
207    pub max_chunk_size: usize,
208    pub chunk_overlap: usize,
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn test_retrieval_system_creation() {
217        let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
218        assert_eq!(system.chunk_count(), 0);
219    }
220
221    #[test]
222    fn test_split_into_chunks() {
223        let config = RetrievalConfig {
224            max_chunk_size: 20,
225            chunk_overlap: 5,
226            ..Default::default()
227        };
228        let system = RetrievalSystem::new(config).unwrap();
229
230        let text = "This is a test document that should be split into multiple chunks.";
231        let chunks = system.split_into_chunks(text);
232
233        assert!(chunks.len() > 1);
234        // Verify each chunk is within size limit
235        for chunk in &chunks {
236            assert!(chunk.len() <= 20);
237        }
238    }
239
240    #[test]
241    fn test_index_short_document() {
242        let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
243
244        let count = system
245            .index_document("doc1", "This is a short test document.", "test.txt", None)
246            .unwrap();
247
248        assert_eq!(count, 1);
249        assert_eq!(system.chunk_count(), 1);
250    }
251
252    #[test]
253    fn test_index_long_document() {
254        let config = RetrievalConfig {
255            max_chunk_size: 50,
256            chunk_overlap: 10,
257            ..Default::default()
258        };
259        let system = RetrievalSystem::new(config).unwrap();
260
261        let long_text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. \
262                        Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. \
263                        Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.";
264
265        let count = system
266            .index_document("doc1", long_text, "test.txt", None)
267            .unwrap();
268
269        assert!(count > 1);
270        assert_eq!(system.chunk_count(), count);
271    }
272
273    #[test]
274    fn test_retrieve() {
275        let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
276
277        // Index multiple documents
278        system
279            .index_document(
280                "doc1",
281                "The quick brown fox jumps over the lazy dog.",
282                "animals.txt",
283                None,
284            )
285            .unwrap();
286
287        system
288            .index_document(
289                "doc2",
290                "Rust is a systems programming language.",
291                "programming.txt",
292                None,
293            )
294            .unwrap();
295
296        system
297            .index_document(
298                "doc3",
299                "Machine learning and artificial intelligence.",
300                "ai.txt",
301                None,
302            )
303            .unwrap();
304
305        // Retrieve relevant documents
306        let results = system.retrieve("programming language", 2).unwrap();
307
308        assert!(results.len() > 0);
309        assert!(results.len() <= 2);
310    }
311
312    #[test]
313    fn test_retrieve_relevance_order() {
314        let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
315
316        // Index documents
317        system
318            .index_document("doc1", "The fox is brown and quick.", "test1.txt", None)
319            .unwrap();
320
321        system
322            .index_document("doc2", "The fox jumps over the dog.", "test2.txt", None)
323            .unwrap();
324
325        system
326            .index_document(
327                "doc3",
328                "Completely unrelated content about programming.",
329                "test3.txt",
330                None,
331            )
332            .unwrap();
333
334        // Query should find fox-related documents
335        let results = system.retrieve("fox", 3).unwrap();
336
337        // Results should be ordered by relevance (score)
338        for i in 0..results.len().saturating_sub(1) {
339            assert!(results[i].score >= results[i + 1].score);
340        }
341    }
342
343    #[test]
344    fn test_clear() {
345        let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
346
347        system
348            .index_document("doc1", "Test document", "test.txt", None)
349            .unwrap();
350
351        assert_eq!(system.chunk_count(), 1);
352
353        system.clear();
354
355        assert_eq!(system.chunk_count(), 0);
356    }
357
358    #[test]
359    fn test_stats() {
360        let config = RetrievalConfig {
361            max_chunk_size: 100,
362            chunk_overlap: 20,
363            ..Default::default()
364        };
365        let system = RetrievalSystem::new(config).unwrap();
366
367        system
368            .index_document("doc1", "Test document", "test.txt", None)
369            .unwrap();
370
371        let stats = system.stats();
372        assert_eq!(stats.num_chunks, 1);
373        assert_eq!(stats.max_chunk_size, 100);
374        assert_eq!(stats.chunk_overlap, 20);
375        assert!(stats.embedding_dim > 0);
376    }
377
378    #[test]
379    fn test_chunk_with_metadata() {
380        let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
381
382        let metadata = serde_json::json!({
383            "author": "Test Author",
384            "date": "2024-01-01"
385        })
386        .to_string();
387
388        let count = system
389            .index_document("doc1", "Test content", "test.txt", Some(metadata))
390            .unwrap();
391
392        assert_eq!(count, 1);
393    }
394}