rag_plusplus_core/retrieval/
engine.rs

1//! Query Engine
2//!
3//! End-to-end query execution with validation, search, and result building.
4
5use crate::error::{Error, Result};
6use crate::filter::{CompiledFilter, FilterExpr};
7use crate::index::{
8    IndexRegistry, MultiIndexResults, ParallelSearcher, SearchResult,
9    rrf_fuse,
10};
11use crate::retrieval::rerank::{Reranker, RerankerConfig};
12use crate::stats::OutcomeStats;
13use crate::store::RecordStore;
14use crate::types::{MemoryRecord, PriorBundle, RecordId};
15use std::time::{Duration, Instant};
16
17/// Query engine configuration.
18#[derive(Debug, Clone)]
19pub struct QueryEngineConfig {
20    /// Default number of results to return
21    pub default_k: usize,
22    /// Maximum allowed k
23    pub max_k: usize,
24    /// Query timeout in milliseconds
25    pub timeout_ms: u64,
26    /// Whether to use parallel search for multi-index queries
27    pub parallel_search: bool,
28    /// Reranker configuration
29    pub reranker: Option<RerankerConfig>,
30    /// Whether to build priors from results
31    pub build_priors: bool,
32}
33
34impl Default for QueryEngineConfig {
35    fn default() -> Self {
36        Self {
37            default_k: 10,
38            max_k: 1000,
39            timeout_ms: 5000,
40            parallel_search: true,
41            reranker: None,
42            build_priors: true,
43        }
44    }
45}
46
47impl QueryEngineConfig {
48    /// Create new config with defaults.
49    #[must_use]
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    /// Set default k.
55    #[must_use]
56    pub const fn with_default_k(mut self, k: usize) -> Self {
57        self.default_k = k;
58        self
59    }
60
61    /// Set timeout.
62    #[must_use]
63    pub const fn with_timeout_ms(mut self, ms: u64) -> Self {
64        self.timeout_ms = ms;
65        self
66    }
67
68    /// Set reranker.
69    #[must_use]
70    pub fn with_reranker(mut self, config: RerankerConfig) -> Self {
71        self.reranker = Some(config);
72        self
73    }
74}
75
76/// Query request.
77#[derive(Debug, Clone)]
78pub struct QueryRequest {
79    /// Query embedding
80    pub embedding: Vec<f32>,
81    /// Number of results (uses default if None)
82    pub k: Option<usize>,
83    /// Metadata filter (optional)
84    pub filter: Option<FilterExpr>,
85    /// Specific index names to search (None = all)
86    pub indexes: Option<Vec<String>>,
87    /// Timeout override (milliseconds)
88    pub timeout_ms: Option<u64>,
89}
90
91impl QueryRequest {
92    /// Create a new query request.
93    #[must_use]
94    pub fn new(embedding: Vec<f32>) -> Self {
95        Self {
96            embedding,
97            k: None,
98            filter: None,
99            indexes: None,
100            timeout_ms: None,
101        }
102    }
103
104    /// Set k.
105    #[must_use]
106    pub const fn with_k(mut self, k: usize) -> Self {
107        self.k = Some(k);
108        self
109    }
110
111    /// Set filter.
112    #[must_use]
113    pub fn with_filter(mut self, filter: FilterExpr) -> Self {
114        self.filter = Some(filter);
115        self
116    }
117
118    /// Set specific indexes to search.
119    #[must_use]
120    pub fn with_indexes(mut self, indexes: Vec<String>) -> Self {
121        self.indexes = Some(indexes);
122        self
123    }
124}
125
126/// Single result in query response.
127#[derive(Debug, Clone)]
128pub struct RetrievedRecord {
129    /// The full record
130    pub record: MemoryRecord,
131    /// Similarity score (0-1, higher is better)
132    pub score: f32,
133    /// Rank in results (1-indexed)
134    pub rank: usize,
135    /// Source index name
136    pub source_index: String,
137}
138
139/// Query response.
140#[derive(Debug, Clone)]
141pub struct QueryResponse {
142    /// Retrieved records
143    pub results: Vec<RetrievedRecord>,
144    /// Prior bundle built from results
145    pub priors: Option<PriorBundle>,
146    /// Query execution time
147    pub latency: Duration,
148    /// Number of indexes searched
149    pub indexes_searched: usize,
150    /// Total candidates considered
151    pub candidates_considered: usize,
152}
153
154impl QueryResponse {
155    /// Get top result (if any).
156    #[must_use]
157    pub fn top(&self) -> Option<&RetrievedRecord> {
158        self.results.first()
159    }
160
161    /// Check if any results were found.
162    #[must_use]
163    pub fn is_empty(&self) -> bool {
164        self.results.is_empty()
165    }
166
167    /// Number of results.
168    #[must_use]
169    pub fn len(&self) -> usize {
170        self.results.len()
171    }
172}
173
174/// Query engine for executing retrieval queries.
175///
176/// Provides end-to-end query execution including:
177/// - Query validation
178/// - Vector search (single or multi-index)
179/// - Metadata filtering
180/// - Result reranking
181/// - Prior building
182pub struct QueryEngine<'a, S: RecordStore> {
183    /// Configuration
184    config: QueryEngineConfig,
185    /// Index registry
186    registry: &'a IndexRegistry,
187    /// Record store
188    store: &'a S,
189    /// Reranker (if configured)
190    reranker: Option<Reranker>,
191}
192
193impl<'a, S: RecordStore> QueryEngine<'a, S> {
194    /// Create a new query engine.
195    #[must_use]
196    pub fn new(
197        config: QueryEngineConfig,
198        registry: &'a IndexRegistry,
199        store: &'a S,
200    ) -> Self {
201        let reranker = config.reranker.clone().map(Reranker::new);
202        Self {
203            config,
204            registry,
205            store,
206            reranker,
207        }
208    }
209
210    /// Execute a query.
211    ///
212    /// # Errors
213    ///
214    /// Returns error if query is invalid, timeout occurs, or search fails.
215    pub fn query(&self, request: QueryRequest) -> Result<QueryResponse> {
216        let start = Instant::now();
217        let timeout = Duration::from_millis(
218            request.timeout_ms.unwrap_or(self.config.timeout_ms),
219        );
220
221        // Validate query
222        self.validate_query(&request)?;
223
224        // Determine k
225        let k = request.k.unwrap_or(self.config.default_k).min(self.config.max_k);
226
227        // Execute search
228        let (search_results, indexes_searched) = self.execute_search(&request, k)?;
229
230        // Check timeout
231        if start.elapsed() > timeout {
232            return Err(Error::QueryTimeout {
233                elapsed_ms: start.elapsed().as_millis() as u64,
234                budget_ms: timeout.as_millis() as u64,
235            });
236        }
237
238        // Fetch records and build results
239        let mut results = self.build_results(search_results, &request)?;
240        let candidates_considered = results.len();
241
242        // Apply filter if specified
243        if let Some(ref filter_expr) = request.filter {
244            let filter = CompiledFilter::compile(filter_expr.clone());
245            results.retain(|r| filter.evaluate(&r.record.metadata));
246        }
247
248        // Rerank if configured
249        if let Some(ref reranker) = self.reranker {
250            results = reranker.rerank(results);
251        }
252
253        // Truncate to k
254        results.truncate(k);
255
256        // Update ranks
257        for (i, result) in results.iter_mut().enumerate() {
258            result.rank = i + 1;
259        }
260
261        // Build priors
262        let priors = if self.config.build_priors && !results.is_empty() {
263            Some(self.build_priors(&results))
264        } else {
265            None
266        };
267
268        Ok(QueryResponse {
269            results,
270            priors,
271            latency: start.elapsed(),
272            indexes_searched,
273            candidates_considered,
274        })
275    }
276
277    /// Validate query request.
278    fn validate_query(&self, request: &QueryRequest) -> Result<()> {
279        if request.embedding.is_empty() {
280            return Err(Error::InvalidQuery {
281                reason: "Empty embedding".into(),
282            });
283        }
284
285        if let Some(k) = request.k {
286            if k == 0 {
287                return Err(Error::InvalidQuery {
288                    reason: "k must be > 0".into(),
289                });
290            }
291            if k > self.config.max_k {
292                return Err(Error::InvalidQuery {
293                    reason: format!("k exceeds maximum ({})", self.config.max_k),
294                });
295            }
296        }
297
298        // Check that at least one index has matching dimension
299        let dim = request.embedding.len();
300        let has_compatible = self.registry.info().iter().any(|i| i.dimension == dim);
301
302        if !has_compatible {
303            return Err(Error::InvalidQuery {
304                reason: format!("No index with dimension {dim}"),
305            });
306        }
307
308        Ok(())
309    }
310
311    /// Execute vector search.
312    fn execute_search(
313        &self,
314        request: &QueryRequest,
315        k: usize,
316    ) -> Result<(Vec<(String, SearchResult)>, usize)> {
317        let query = &request.embedding;
318
319        // Multi-index or specific indexes?
320        let multi_results: MultiIndexResults = if let Some(ref index_names) = request.indexes {
321            // Search specific indexes
322            let names: Vec<&str> = index_names.iter().map(String::as_str).collect();
323            if self.config.parallel_search && names.len() > 1 {
324                let searcher = ParallelSearcher::new(self.registry);
325                searcher.search_indexes_parallel(&names, query, k)?
326            } else {
327                self.registry.search_indexes(&names, query, k)?
328            }
329        } else {
330            // Search all compatible indexes
331            if self.config.parallel_search {
332                let searcher = ParallelSearcher::new(self.registry);
333                searcher.search_parallel(query, k)?
334            } else {
335                self.registry.search_all(query, k)?
336            }
337        };
338
339        let indexes_searched = multi_results.by_index.len();
340
341        // Fuse results if multiple indexes
342        let results: Vec<(String, SearchResult)> = if indexes_searched > 1 {
343            let fused = rrf_fuse(&multi_results);
344            fused
345                .into_iter()
346                .map(|f| {
347                    let source = f.sources.first().cloned().unwrap_or_default();
348                    (
349                        source,
350                        SearchResult {
351                            id: f.id,
352                            distance: 0.0, // Not meaningful after fusion
353                            score: f.fused_score,
354                        },
355                    )
356                })
357                .collect()
358        } else {
359            multi_results.flatten()
360        };
361
362        Ok((results, indexes_searched))
363    }
364
365    /// Build result records from search results.
366    fn build_results(
367        &self,
368        search_results: Vec<(String, SearchResult)>,
369        _request: &QueryRequest,
370    ) -> Result<Vec<RetrievedRecord>> {
371        let mut results = Vec::with_capacity(search_results.len());
372
373        for (index_name, sr) in search_results {
374            let id: RecordId = sr.id.into();
375
376            if let Some(record) = self.store.get(&id) {
377                results.push(RetrievedRecord {
378                    record,
379                    score: sr.score,
380                    rank: 0, // Set later
381                    source_index: index_name,
382                });
383            }
384        }
385
386        Ok(results)
387    }
388
389    /// Build priors from results.
390    fn build_priors(&self, results: &[RetrievedRecord]) -> PriorBundle {
391        let mut stats = OutcomeStats::new(1);
392
393        for result in results {
394            stats.update_scalar(result.record.outcome);
395            // Merge record's stats if compatible (same dimension)
396            if result.record.stats.dim() == 1 {
397                stats = stats.merge(&result.record.stats);
398            }
399        }
400
401        let mean = stats.mean_scalar().unwrap_or(0.0);
402        let std_dev = stats.std_scalar().unwrap_or(0.0);
403        let ci = stats.confidence_interval(0.95)
404            .map(|(l, u)| (l[0] as f64, u[0] as f64))
405            .unwrap_or((mean, mean));
406
407        PriorBundle {
408            mean_outcome: mean,
409            std_outcome: std_dev,
410            confidence_interval: ci,
411            sample_count: stats.count(),
412            prototype_ids: results.iter().take(3).map(|r| r.record.id.clone()).collect(),
413        }
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use crate::index::{FlatIndex, IndexConfig, VectorIndex};
421    use crate::store::InMemoryStore;
422    use crate::types::RecordStatus;
423    use crate::OutcomeStats;
424
425    fn create_test_record(id: &str, embedding: Vec<f32>) -> MemoryRecord {
426        MemoryRecord {
427            id: id.into(),
428            embedding,
429            context: format!("Context for {id}"),
430            outcome: 0.8,
431            metadata: Default::default(),
432            created_at: 1234567890,
433            status: RecordStatus::Active,
434            stats: OutcomeStats::new(1),
435        }
436    }
437
438    fn setup_test_env() -> (IndexRegistry, InMemoryStore) {
439        let mut registry = IndexRegistry::new();
440        let mut store = InMemoryStore::new();
441
442        // Create index
443        let mut index = FlatIndex::new(IndexConfig::new(4));
444
445        // Add records
446        for i in 0..10 {
447            let embedding = vec![i as f32, 0.0, 0.0, 0.0];
448            let record = create_test_record(&format!("rec-{i}"), embedding.clone());
449
450            index.add(record.id.to_string(), &embedding).unwrap();
451            store.insert(record).unwrap();
452        }
453
454        registry.register("test", index).unwrap();
455        (registry, store)
456    }
457
458    #[test]
459    fn test_basic_query() {
460        let (registry, store) = setup_test_env();
461        let engine = QueryEngine::new(
462            QueryEngineConfig::new(),
463            &registry,
464            &store,
465        );
466
467        let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(3);
468        let response = engine.query(request).unwrap();
469
470        assert_eq!(response.len(), 3);
471        assert!(!response.is_empty());
472        assert!(response.priors.is_some());
473    }
474
475    #[test]
476    fn test_query_validation_empty_embedding() {
477        let (registry, store) = setup_test_env();
478        let engine = QueryEngine::new(
479            QueryEngineConfig::new(),
480            &registry,
481            &store,
482        );
483
484        let request = QueryRequest::new(vec![]);
485        let result = engine.query(request);
486
487        assert!(result.is_err());
488    }
489
490    #[test]
491    fn test_query_validation_k_zero() {
492        let (registry, store) = setup_test_env();
493        let engine = QueryEngine::new(
494            QueryEngineConfig::new(),
495            &registry,
496            &store,
497        );
498
499        let request = QueryRequest::new(vec![1.0, 0.0, 0.0, 0.0]).with_k(0);
500        let result = engine.query(request);
501
502        assert!(result.is_err());
503    }
504
505    #[test]
506    fn test_query_with_priors() {
507        let (registry, store) = setup_test_env();
508        let config = QueryEngineConfig::new();
509        let engine = QueryEngine::new(config, &registry, &store);
510
511        let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(5);
512        let response = engine.query(request).unwrap();
513
514        let priors = response.priors.unwrap();
515        assert!(priors.sample_count > 0);
516        assert!(!priors.prototype_ids.is_empty());
517    }
518
519    #[test]
520    fn test_multi_index_query() {
521        let mut registry = IndexRegistry::new();
522        let mut store = InMemoryStore::new();
523
524        // Create two indexes
525        let mut index1 = FlatIndex::new(IndexConfig::new(4));
526        let mut index2 = FlatIndex::new(IndexConfig::new(4));
527
528        // Add to first index
529        let rec1 = create_test_record("rec-a", vec![1.0, 0.0, 0.0, 0.0]);
530        index1.add(rec1.id.to_string(), &rec1.embedding).unwrap();
531        store.insert(rec1).unwrap();
532
533        // Add to second index
534        let rec2 = create_test_record("rec-b", vec![0.0, 1.0, 0.0, 0.0]);
535        index2.add(rec2.id.to_string(), &rec2.embedding).unwrap();
536        store.insert(rec2).unwrap();
537
538        registry.register("idx1", index1).unwrap();
539        registry.register("idx2", index2).unwrap();
540
541        let engine = QueryEngine::new(
542            QueryEngineConfig::new(),
543            &registry,
544            &store,
545        );
546
547        let request = QueryRequest::new(vec![0.5, 0.5, 0.0, 0.0]).with_k(5);
548        let response = engine.query(request).unwrap();
549
550        assert_eq!(response.indexes_searched, 2);
551        assert_eq!(response.len(), 2);
552    }
553
554    #[test]
555    fn test_response_latency() {
556        let (registry, store) = setup_test_env();
557        let engine = QueryEngine::new(
558            QueryEngineConfig::new(),
559            &registry,
560            &store,
561        );
562
563        let request = QueryRequest::new(vec![5.0, 0.0, 0.0, 0.0]).with_k(3);
564        let response = engine.query(request).unwrap();
565
566        assert!(response.latency.as_micros() > 0);
567    }
568}