1use crate::embeddings::{EmbeddingGenerator, SimpleEmbeddingGenerator};
6#[cfg(feature = "onnx")]
7use crate::embeddings::OnnxEmbeddingGenerator;
8use crate::vector_store::{DocumentChunk, SearchResult, VectorStore, VectorStoreConfig};
9use std::sync::Arc;
10
11#[derive(Debug, thiserror::Error)]
13pub enum RetrievalError {
14 #[error("Vector store error: {0}")]
15 VectorStore(#[from] crate::vector_store::VectorStoreError),
16
17 #[error("Embedding error: {0}")]
18 Embedding(#[from] crate::embeddings::EmbeddingError),
19
20 #[error("Invalid chunk size: {0}")]
21 InvalidChunkSize(usize),
22}
23
24pub type Result<T> = std::result::Result<T, RetrievalError>;
25
26#[derive(Debug, Clone)]
28pub struct RetrievalConfig {
29 pub max_chunk_size: usize,
31
32 pub chunk_overlap: usize,
34
35 pub vector_config: VectorStoreConfig,
37}
38
39impl Default for RetrievalConfig {
40 fn default() -> Self {
41 Self {
42 max_chunk_size: 512,
43 chunk_overlap: 50,
44 vector_config: VectorStoreConfig::default(),
45 }
46 }
47}
48
49pub struct RetrievalSystem {
51 vector_store: VectorStore,
53
54 embedder: Arc<dyn EmbeddingGenerator>,
56
57 config: RetrievalConfig,
59}
60
61impl RetrievalSystem {
62 pub fn new(config: RetrievalConfig) -> Result<Self> {
66 let vector_store = VectorStore::new(config.vector_config.clone())?;
67
68 #[cfg(feature = "onnx")]
70 let embedder: Arc<dyn EmbeddingGenerator> = match OnnxEmbeddingGenerator::new() {
71 Ok(onnx_gen) => {
72 tracing::info!("Using ONNX semantic embeddings (MiniLM-L6-v2, 384d)");
73 Arc::new(onnx_gen)
74 }
75 Err(e) => {
76 tracing::warn!(
77 "ONNX embeddings unavailable ({}), falling back to simple hash-based embeddings",
78 e
79 );
80 Arc::new(SimpleEmbeddingGenerator::new())
81 }
82 };
83
84 #[cfg(not(feature = "onnx"))]
85 let embedder: Arc<dyn EmbeddingGenerator> = {
86 tracing::info!("Using simple hash-based embeddings (ONNX feature not enabled)");
87 Arc::new(SimpleEmbeddingGenerator::new())
88 };
89
90 Ok(Self {
91 vector_store,
92 embedder,
93 config,
94 })
95 }
96
97 pub fn with_embedder(
99 config: RetrievalConfig,
100 embedder: Arc<dyn EmbeddingGenerator>,
101 ) -> Result<Self> {
102 let vector_store = VectorStore::new(config.vector_config.clone())?;
103
104 Ok(Self {
105 vector_store,
106 embedder,
107 config,
108 })
109 }
110
111 pub fn index_document(
113 &self,
114 document_id: &str,
115 content: &str,
116 source: &str,
117 metadata: Option<String>,
118 ) -> Result<usize> {
119 let chunks = self.split_into_chunks(content);
120 let mut indexed_count = 0;
121
122 for (i, chunk_text) in chunks.iter().enumerate() {
123 let chunk = DocumentChunk {
124 id: format!("{}:{}", document_id, i),
125 content: chunk_text.to_string(),
126 source: source.to_string(),
127 metadata: metadata.clone(),
128 };
129
130 let embedding = self.embedder.generate(chunk_text)?;
132
133 self.vector_store.add_chunk(chunk, &embedding)?;
135 indexed_count += 1;
136 }
137
138 Ok(indexed_count)
139 }
140
141 pub fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<SearchResult>> {
143 let query_embedding = self.embedder.generate(query)?;
145
146 let results = self.vector_store.search(&query_embedding, top_k)?;
148
149 Ok(results)
150 }
151
152 fn split_into_chunks(&self, text: &str) -> Vec<String> {
154 let max_size = self.config.max_chunk_size;
155 let overlap = self.config.chunk_overlap;
156
157 if text.len() <= max_size {
158 return vec![text.to_string()];
159 }
160
161 let mut chunks = Vec::new();
162 let mut start = 0;
163
164 while start < text.len() {
165 let end = (start + max_size).min(text.len());
166 let chunk = &text[start..end];
167 chunks.push(chunk.to_string());
168
169 if end >= text.len() {
170 break;
171 }
172
173 start += max_size - overlap;
175 }
176
177 chunks
178 }
179
180 pub fn chunk_count(&self) -> usize {
182 self.vector_store.len()
183 }
184
185 pub fn clear(&self) {
187 self.vector_store.clear();
188 }
189
190 pub fn stats(&self) -> RetrievalStats {
192 let vector_stats = self.vector_store.stats();
193 RetrievalStats {
194 num_chunks: vector_stats.num_chunks,
195 embedding_dim: vector_stats.embedding_dim,
196 max_chunk_size: self.config.max_chunk_size,
197 chunk_overlap: self.config.chunk_overlap,
198 }
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct RetrievalStats {
205 pub num_chunks: usize,
206 pub embedding_dim: usize,
207 pub max_chunk_size: usize,
208 pub chunk_overlap: usize,
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
216 fn test_retrieval_system_creation() {
217 let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
218 assert_eq!(system.chunk_count(), 0);
219 }
220
221 #[test]
222 fn test_split_into_chunks() {
223 let config = RetrievalConfig {
224 max_chunk_size: 20,
225 chunk_overlap: 5,
226 ..Default::default()
227 };
228 let system = RetrievalSystem::new(config).unwrap();
229
230 let text = "This is a test document that should be split into multiple chunks.";
231 let chunks = system.split_into_chunks(text);
232
233 assert!(chunks.len() > 1);
234 for chunk in &chunks {
236 assert!(chunk.len() <= 20);
237 }
238 }
239
240 #[test]
241 fn test_index_short_document() {
242 let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
243
244 let count = system
245 .index_document("doc1", "This is a short test document.", "test.txt", None)
246 .unwrap();
247
248 assert_eq!(count, 1);
249 assert_eq!(system.chunk_count(), 1);
250 }
251
252 #[test]
253 fn test_index_long_document() {
254 let config = RetrievalConfig {
255 max_chunk_size: 50,
256 chunk_overlap: 10,
257 ..Default::default()
258 };
259 let system = RetrievalSystem::new(config).unwrap();
260
261 let long_text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. \
262 Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. \
263 Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.";
264
265 let count = system
266 .index_document("doc1", long_text, "test.txt", None)
267 .unwrap();
268
269 assert!(count > 1);
270 assert_eq!(system.chunk_count(), count);
271 }
272
273 #[test]
274 fn test_retrieve() {
275 let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
276
277 system
279 .index_document(
280 "doc1",
281 "The quick brown fox jumps over the lazy dog.",
282 "animals.txt",
283 None,
284 )
285 .unwrap();
286
287 system
288 .index_document(
289 "doc2",
290 "Rust is a systems programming language.",
291 "programming.txt",
292 None,
293 )
294 .unwrap();
295
296 system
297 .index_document(
298 "doc3",
299 "Machine learning and artificial intelligence.",
300 "ai.txt",
301 None,
302 )
303 .unwrap();
304
305 let results = system.retrieve("programming language", 2).unwrap();
307
308 assert!(results.len() > 0);
309 assert!(results.len() <= 2);
310 }
311
312 #[test]
313 fn test_retrieve_relevance_order() {
314 let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
315
316 system
318 .index_document("doc1", "The fox is brown and quick.", "test1.txt", None)
319 .unwrap();
320
321 system
322 .index_document("doc2", "The fox jumps over the dog.", "test2.txt", None)
323 .unwrap();
324
325 system
326 .index_document(
327 "doc3",
328 "Completely unrelated content about programming.",
329 "test3.txt",
330 None,
331 )
332 .unwrap();
333
334 let results = system.retrieve("fox", 3).unwrap();
336
337 for i in 0..results.len().saturating_sub(1) {
339 assert!(results[i].score >= results[i + 1].score);
340 }
341 }
342
343 #[test]
344 fn test_clear() {
345 let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
346
347 system
348 .index_document("doc1", "Test document", "test.txt", None)
349 .unwrap();
350
351 assert_eq!(system.chunk_count(), 1);
352
353 system.clear();
354
355 assert_eq!(system.chunk_count(), 0);
356 }
357
358 #[test]
359 fn test_stats() {
360 let config = RetrievalConfig {
361 max_chunk_size: 100,
362 chunk_overlap: 20,
363 ..Default::default()
364 };
365 let system = RetrievalSystem::new(config).unwrap();
366
367 system
368 .index_document("doc1", "Test document", "test.txt", None)
369 .unwrap();
370
371 let stats = system.stats();
372 assert_eq!(stats.num_chunks, 1);
373 assert_eq!(stats.max_chunk_size, 100);
374 assert_eq!(stats.chunk_overlap, 20);
375 assert!(stats.embedding_dim > 0);
376 }
377
378 #[test]
379 fn test_chunk_with_metadata() {
380 let system = RetrievalSystem::new(RetrievalConfig::default()).unwrap();
381
382 let metadata = serde_json::json!({
383 "author": "Test Author",
384 "date": "2024-01-01"
385 })
386 .to_string();
387
388 let count = system
389 .index_document("doc1", "Test content", "test.txt", Some(metadata))
390 .unwrap();
391
392 assert_eq!(count, 1);
393 }
394}