Skip to main content

batuta/oracle/rag/
retriever.rs

1//! Hybrid Retriever - BM25 + Dense with RRF Fusion
2//!
3//! Implements state-of-the-art hybrid retrieval following:
4//! - Karpukhin et al. (2020) for dense retrieval
5//! - Robertson & Zaragoza (2009) for BM25
6//! - Cormack et al. (2009) for Reciprocal Rank Fusion
7//!
8//! # Performance
9//!
10//! - Query latency tracked via `profiling::GLOBAL_METRICS`
11//! - Spans: `retrieve`, `bm25_search`, `dense_search`, `rrf_fuse`, `tokenize`
12//! - Targets: p50 <20ms, p99 <100ms
13
14use super::profiling::{record_query_latency, span};
15use super::types::{Bm25Config, RetrievalResult, RrfConfig, ScoreBreakdown};
16use super::DocumentIndex;
17use serde::{Deserialize, Serialize};
18use std::collections::{HashMap, HashSet};
19use std::time::Instant;
20
21/// Hybrid retriever combining sparse (BM25) and dense retrieval
22#[derive(Debug)]
23pub struct HybridRetriever {
24    /// BM25 configuration
25    bm25_config: Bm25Config,
26    /// RRF configuration
27    rrf_config: RrfConfig,
28    /// Inverted index for BM25
29    inverted_index: InvertedIndex,
30    /// Average document length for BM25
31    avg_doc_length: f64,
32}
33
34impl HybridRetriever {
35    /// Create a new hybrid retriever
36    pub fn new() -> Self {
37        Self {
38            bm25_config: Bm25Config::default(),
39            rrf_config: RrfConfig::default(),
40            inverted_index: InvertedIndex::new(),
41            avg_doc_length: 0.0,
42        }
43    }
44
45    /// Create with custom configuration
46    pub fn with_config(bm25_config: Bm25Config, rrf_config: RrfConfig) -> Self {
47        Self { bm25_config, rrf_config, inverted_index: InvertedIndex::new(), avg_doc_length: 0.0 }
48    }
49
50    /// Index a document for retrieval
51    pub fn index_document(&mut self, doc_id: &str, content: &str) {
52        self.inverted_index.add_document(doc_id, content);
53        self.update_avg_doc_length();
54    }
55
56    /// Remove a document from the index
57    pub fn remove_document(&mut self, doc_id: &str) {
58        self.inverted_index.remove_document(doc_id);
59        self.update_avg_doc_length();
60    }
61
62    /// Update average document length
63    fn update_avg_doc_length(&mut self) {
64        let total_length: usize = self.inverted_index.doc_lengths.values().sum();
65        let doc_count = self.inverted_index.doc_lengths.len();
66        self.avg_doc_length =
67            if doc_count > 0 { total_length as f64 / doc_count as f64 } else { 0.0 };
68    }
69
70    /// Retrieve documents matching query
71    ///
72    /// Records latency metrics to `GLOBAL_METRICS` for performance monitoring.
73    /// Target: p50 <20ms, p99 <100ms
74    pub fn retrieve(
75        &self,
76        query: &str,
77        _index: &DocumentIndex,
78        top_k: usize,
79    ) -> Vec<RetrievalResult> {
80        let start = Instant::now();
81        let _retrieve_span = span("retrieve");
82
83        // Get BM25 results
84        let bm25_results = {
85            let _bm25_span = span("bm25_search");
86            self.bm25_search(query, top_k * 2)
87        };
88
89        // Get dense results (TF-IDF cosine similarity)
90        let dense_results = {
91            let _dense_span = span("dense_search");
92            self.dense_search(query, top_k * 2)
93        };
94
95        // Fuse with RRF
96        let mut results = {
97            let _fuse_span = span("rrf_fuse");
98            self.rrf_fuse(&bm25_results, &dense_results, top_k)
99        };
100
101        // Apply component boosting
102        {
103            let _boost_span = span("component_boost");
104            self.apply_component_boost(&mut results, query);
105        }
106
107        // Record query latency for performance tracking
108        record_query_latency(start.elapsed());
109
110        results
111    }
112
113    /// BM25 sparse retrieval
114    fn bm25_search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
115        let query_terms = tokenize(query);
116        let mut scores: HashMap<String, f64> = HashMap::new();
117
118        let n = self.inverted_index.doc_lengths.len() as f64;
119
120        for term in &query_terms {
121            if let Some(postings) = self.inverted_index.index.get(term) {
122                // IDF calculation: log((N - n + 0.5) / (n + 0.5))
123                let df = postings.len() as f64;
124                let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
125
126                for (doc_id, tf) in postings {
127                    let doc_len =
128                        self.inverted_index.doc_lengths.get(doc_id).copied().unwrap_or(1) as f64;
129
130                    // BM25 score
131                    let k1 = self.bm25_config.k1 as f64;
132                    let b = self.bm25_config.b as f64;
133                    let tf_norm = (*tf as f64 * (k1 + 1.0))
134                        / (*tf as f64
135                            + k1 * (1.0 - b + b * doc_len / self.avg_doc_length.max(1.0)));
136
137                    *scores.entry(doc_id.clone()).or_insert(0.0) += idf * tf_norm;
138                }
139            }
140        }
141
142        let mut results: Vec<_> = scores.into_iter().collect();
143        sort_and_truncate(&mut results, top_k);
144
145        results
146    }
147
148    /// Dense retrieval using TF-IDF cosine similarity
149    ///
150    /// Only iterates candidate documents (those containing at least one query term),
151    /// not the full index. This makes it efficient even for large indices.
152    fn dense_search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
153        let query_terms = tokenize(query);
154        if query_terms.is_empty() {
155            return vec![];
156        }
157
158        let n = self.inverted_index.doc_lengths.len() as f64;
159        if n == 0.0 {
160            return vec![];
161        }
162
163        // Build query TF-IDF vector + collect candidate docs
164        let mut query_vec: HashMap<&str, f64> = HashMap::new();
165        let mut candidates: HashSet<String> = HashSet::new();
166
167        for term in &query_terms {
168            if let Some(postings) = self.inverted_index.index.get(term.as_str()) {
169                let df = postings.len() as f64;
170                let idf = (n / df).ln() + 1.0; // smoothed IDF
171                *query_vec.entry(term.as_str()).or_insert(0.0) += idf;
172                candidates.extend(postings.keys().cloned());
173            }
174        }
175
176        // Score each candidate doc by cosine similarity
177        let query_norm: f64 = query_vec.values().map(|v| v * v).sum::<f64>().sqrt();
178        if query_norm == 0.0 {
179            return vec![];
180        }
181
182        let mut scores: Vec<(String, f64)> = candidates
183            .into_iter()
184            .filter_map(|doc_id| {
185                let doc_len = *self.inverted_index.doc_lengths.get(&doc_id)? as f64;
186                let mut dot = 0.0;
187                let mut doc_norm_sq = 0.0;
188
189                for term in &query_terms {
190                    if let Some(postings) = self.inverted_index.index.get(term.as_str()) {
191                        if let Some(&tf) = postings.get(&doc_id) {
192                            let df = postings.len() as f64;
193                            let idf = (n / df).ln() + 1.0;
194                            let tfidf = (tf as f64 / doc_len.max(1.0)) * idf;
195                            dot += tfidf * query_vec.get(term.as_str()).unwrap_or(&0.0);
196                            doc_norm_sq += tfidf * tfidf;
197                        }
198                    }
199                }
200
201                let doc_norm = doc_norm_sq.sqrt();
202                if doc_norm == 0.0 {
203                    return None;
204                }
205                let cosine = dot / (query_norm * doc_norm);
206                Some((doc_id, cosine))
207            })
208            .collect();
209
210        sort_and_truncate(&mut scores, top_k);
211        scores
212    }
213
214    /// Reciprocal Rank Fusion
215    ///
216    /// RRF score = Σ 1/(k + rank) for each retriever
217    /// Following Cormack et al. (2009)
218    fn rrf_fuse(
219        &self,
220        sparse_results: &[(String, f64)],
221        dense_results: &[(String, f64)],
222        top_k: usize,
223    ) -> Vec<RetrievalResult> {
224        let k = self.rrf_config.k as f64;
225        let mut rrf_scores: HashMap<String, (f64, f64, f64)> = HashMap::new(); // (rrf, bm25, dense)
226
227        // Accumulate RRF contribution from a single ranked list.
228        // `set_field` stores the raw score into the appropriate tuple slot.
229        let mut accumulate =
230            |results: &[(String, f64)], set_field: fn(&mut (f64, f64, f64), f64)| {
231                for (rank, (doc_id, raw_score)) in results.iter().enumerate() {
232                    let entry = rrf_scores.entry(doc_id.clone()).or_insert((0.0, 0.0, 0.0));
233                    entry.0 += 1.0 / (k + rank as f64 + 1.0);
234                    set_field(entry, *raw_score);
235                }
236            };
237
238        accumulate(sparse_results, |e, s| e.1 = s); // BM25
239        accumulate(dense_results, |e, s| e.2 = s); // Dense
240
241        // Convert to results
242        let mut results: Vec<_> = rrf_scores
243            .into_iter()
244            .map(|(doc_id, (rrf_score, bm25_score, dense_score))| {
245                // Normalize score to 0-1 range
246                let max_rrf = 2.0 / (k + 1.0); // Max possible RRF score (rank 1 in both)
247                let normalized_score = (rrf_score / max_rrf).min(1.0);
248
249                let component = extract_component(&doc_id);
250                let id = doc_id.clone();
251                RetrievalResult {
252                    id,
253                    component,
254                    source: doc_id,
255                    content: String::new(), // Would be filled from index
256                    score: normalized_score,
257                    start_line: 1,
258                    end_line: 1,
259                    score_breakdown: ScoreBreakdown {
260                        bm25_score,
261                        dense_score,
262                        rrf_score,
263                        rerank_score: None,
264                    },
265                }
266            })
267            .collect();
268
269        // Sort by score descending
270        results.sort();
271        results.truncate(top_k);
272
273        results
274    }
275
276    /// Get index statistics
277    pub fn stats(&self) -> RetrieverStats {
278        RetrieverStats {
279            total_documents: self.inverted_index.doc_lengths.len(),
280            total_terms: self.inverted_index.index.len(),
281            avg_doc_length: self.avg_doc_length,
282        }
283    }
284
285    /// Boost results whose component matches a component name mentioned in the query.
286    ///
287    /// Extracts component names from `doc_lengths` keys (first path segment),
288    /// sorts longest-first to handle hyphenated names (e.g., "trueno-ublk" before "trueno"),
289    /// and applies a 1.5x multiplier to matching results, then re-sorts.
290    fn apply_component_boost(&self, results: &mut [RetrievalResult], query: &str) {
291        let query_lower = query.to_lowercase();
292
293        // Collect unique component names from index, sorted longest first
294        let mut components: Vec<String> = self
295            .inverted_index
296            .doc_lengths
297            .keys()
298            .filter_map(|k| k.split('/').next())
299            .collect::<HashSet<_>>()
300            .into_iter()
301            .map(|s| s.to_string())
302            .collect();
303        components.sort_by_key(|c| std::cmp::Reverse(c.len()));
304
305        // Find which components are mentioned in query
306        let mentioned: Vec<String> =
307            components.into_iter().filter(|c| query_lower.contains(&c.to_lowercase())).collect();
308
309        if mentioned.is_empty() {
310            return;
311        }
312
313        // Apply 1.5x boost to matching results
314        for result in results.iter_mut() {
315            if mentioned.iter().any(|m| result.component.eq_ignore_ascii_case(m)) {
316                result.score = (result.score * 1.5).min(1.0);
317            }
318        }
319
320        results.sort();
321    }
322
323    /// Convert to persisted format for serialization
324    pub fn to_persisted(&self) -> super::persistence::PersistedIndex {
325        super::persistence::PersistedIndex {
326            inverted_index: self.inverted_index.index.clone(),
327            doc_lengths: self.inverted_index.doc_lengths.clone(),
328            bm25_config: self.bm25_config,
329            rrf_config: self.rrf_config,
330            avg_doc_length: self.avg_doc_length,
331        }
332    }
333
334    /// Restore from persisted format
335    pub fn from_persisted(persisted: super::persistence::PersistedIndex) -> Self {
336        Self {
337            bm25_config: persisted.bm25_config,
338            rrf_config: persisted.rrf_config,
339            inverted_index: InvertedIndex {
340                index: persisted.inverted_index,
341                doc_lengths: persisted.doc_lengths,
342            },
343            avg_doc_length: persisted.avg_doc_length,
344        }
345    }
346}
347
348impl Default for HybridRetriever {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354/// Retriever statistics
355#[derive(Debug, Clone)]
356pub struct RetrieverStats {
357    /// Total documents indexed
358    pub total_documents: usize,
359    /// Total unique terms
360    pub total_terms: usize,
361    /// Average document length
362    pub avg_doc_length: f64,
363}
364
365/// Inverted index for BM25
366#[derive(Debug, Default, Clone, Serialize, Deserialize)]
367pub struct InvertedIndex {
368    /// Term -> (doc_id -> term_frequency)
369    pub index: HashMap<String, HashMap<String, usize>>,
370    /// Document lengths
371    pub doc_lengths: HashMap<String, usize>,
372}
373
374impl InvertedIndex {
375    fn new() -> Self {
376        Self::default()
377    }
378
379    fn add_document(&mut self, doc_id: &str, content: &str) {
380        let tokens = tokenize(content);
381        self.doc_lengths.insert(doc_id.to_string(), tokens.len());
382
383        // Count term frequencies
384        let mut term_freqs: HashMap<String, usize> = HashMap::new();
385        for token in tokens {
386            *term_freqs.entry(token).or_insert(0) += 1;
387        }
388
389        // Add to inverted index
390        for (term, freq) in term_freqs {
391            self.index.entry(term).or_default().insert(doc_id.to_string(), freq);
392        }
393    }
394
395    fn remove_document(&mut self, doc_id: &str) {
396        self.doc_lengths.remove(doc_id);
397
398        // Remove from all posting lists
399        for postings in self.index.values_mut() {
400            postings.remove(doc_id);
401        }
402
403        // Clean up empty posting lists
404        self.index.retain(|_, postings| !postings.is_empty());
405    }
406}
407
408/// Stem a word using aprender's PorterStemmer when the `ml` feature is enabled,
409/// falling back to simple suffix stripping otherwise.
410#[cfg(feature = "ml")]
411fn stem(word: &str) -> String {
412    use aprender::text::stem::{PorterStemmer, Stemmer};
413    PorterStemmer::new().stem(word).unwrap_or_else(|_| word.to_string())
414}
415
416/// Fallback suffix stripping when aprender is not available.
417///
418/// Strips the longest matching suffix while keeping the stem >= 3 characters.
419#[cfg(not(feature = "ml"))]
420fn stem(word: &str) -> String {
421    if word.len() <= 3 {
422        return word.to_string();
423    }
424    for suffix in &[
425        "ization", "isation", "ation", "tion", "sion", "ment", "ness", "ible", "able", "ence",
426        "ance", "zing", "ying", "ming", "ning", "ting", "ring", "ling", "sing", "ious", "eous",
427        "mming", "ful", "ive", "ize", "ise", "ity", "ist", "ism", "ied", "ies", "ing", "ous",
428        "ers", "est", "ely", "ory", "ant", "ent", "ial", "ual", "ly", "ed", "er", "al", "ic",
429    ] {
430        if let Some(s) = word.strip_suffix(suffix) {
431            if s.len() >= 3 {
432                return s.to_string();
433            }
434        }
435    }
436    word.to_string()
437}
438
439/// Check if a word is a stop word using aprender's StopWordsFilter when available.
440#[cfg(feature = "ml")]
441fn is_stop_word(word: &str) -> bool {
442    use aprender::text::stopwords::StopWordsFilter;
443    use std::sync::LazyLock;
444    static FILTER: LazyLock<StopWordsFilter> = LazyLock::new(StopWordsFilter::english);
445    FILTER.is_stop_word(word)
446}
447
448/// Fallback stop word check when aprender is not available.
449#[cfg(not(feature = "ml"))]
450fn is_stop_word(word: &str) -> bool {
451    const STOP_WORDS: &[&str] = &[
452        "the", "is", "at", "which", "on", "in", "to", "for", "of", "and", "or", "an", "be", "by",
453        "as", "do", "if", "it", "no", "so", "up", "how", "can", "its", "has", "had", "was", "are",
454        "were", "been", "have", "from", "this", "that", "with", "what", "when", "where", "will",
455        "not", "but", "all", "each", "than",
456    ];
457    STOP_WORDS.contains(&word)
458}
459
460/// Tokenizer with stop-word filtering and stemming.
461///
462/// Splits on non-alphanumeric characters (preserving underscores),
463/// removes single-character tokens and stop words, then applies stemming.
464/// When the `ml` feature is enabled, uses aprender's PorterStemmer and
465/// 171-word English stop words list. Otherwise falls back to simple suffix stripping.
466fn tokenize(text: &str) -> Vec<String> {
467    text.to_lowercase()
468        .split(|c: char| !c.is_alphanumeric() && c != '_')
469        .filter(|s| !s.is_empty() && s.len() > 1)
470        .filter(|s| !is_stop_word(s))
471        .map(stem)
472        .collect()
473}
474
475/// Sort `(id, score)` pairs by score descending and keep only the top `k`.
476fn sort_and_truncate(results: &mut Vec<(String, f64)>, k: usize) {
477    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
478    results.truncate(k);
479}
480
481/// Extract component name from doc_id
482fn extract_component(doc_id: &str) -> String {
483    doc_id.split('/').next().unwrap_or("unknown").to_string()
484}
485
486#[cfg(test)]
487#[path = "retriever_tests.rs"]
488mod tests;