1use super::profiling::{record_query_latency, span};
15use super::types::{Bm25Config, RetrievalResult, RrfConfig, ScoreBreakdown};
16use super::DocumentIndex;
17use serde::{Deserialize, Serialize};
18use std::collections::{HashMap, HashSet};
19use std::time::Instant;
20
21#[derive(Debug)]
23pub struct HybridRetriever {
24 bm25_config: Bm25Config,
26 rrf_config: RrfConfig,
28 inverted_index: InvertedIndex,
30 avg_doc_length: f64,
32}
33
34impl HybridRetriever {
35 pub fn new() -> Self {
37 Self {
38 bm25_config: Bm25Config::default(),
39 rrf_config: RrfConfig::default(),
40 inverted_index: InvertedIndex::new(),
41 avg_doc_length: 0.0,
42 }
43 }
44
45 pub fn with_config(bm25_config: Bm25Config, rrf_config: RrfConfig) -> Self {
47 Self { bm25_config, rrf_config, inverted_index: InvertedIndex::new(), avg_doc_length: 0.0 }
48 }
49
50 pub fn index_document(&mut self, doc_id: &str, content: &str) {
52 self.inverted_index.add_document(doc_id, content);
53 self.update_avg_doc_length();
54 }
55
56 pub fn remove_document(&mut self, doc_id: &str) {
58 self.inverted_index.remove_document(doc_id);
59 self.update_avg_doc_length();
60 }
61
62 fn update_avg_doc_length(&mut self) {
64 let total_length: usize = self.inverted_index.doc_lengths.values().sum();
65 let doc_count = self.inverted_index.doc_lengths.len();
66 self.avg_doc_length =
67 if doc_count > 0 { total_length as f64 / doc_count as f64 } else { 0.0 };
68 }
69
70 pub fn retrieve(
75 &self,
76 query: &str,
77 _index: &DocumentIndex,
78 top_k: usize,
79 ) -> Vec<RetrievalResult> {
80 let start = Instant::now();
81 let _retrieve_span = span("retrieve");
82
83 let bm25_results = {
85 let _bm25_span = span("bm25_search");
86 self.bm25_search(query, top_k * 2)
87 };
88
89 let dense_results = {
91 let _dense_span = span("dense_search");
92 self.dense_search(query, top_k * 2)
93 };
94
95 let mut results = {
97 let _fuse_span = span("rrf_fuse");
98 self.rrf_fuse(&bm25_results, &dense_results, top_k)
99 };
100
101 {
103 let _boost_span = span("component_boost");
104 self.apply_component_boost(&mut results, query);
105 }
106
107 record_query_latency(start.elapsed());
109
110 results
111 }
112
113 fn bm25_search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
115 let query_terms = tokenize(query);
116 let mut scores: HashMap<String, f64> = HashMap::new();
117
118 let n = self.inverted_index.doc_lengths.len() as f64;
119
120 for term in &query_terms {
121 if let Some(postings) = self.inverted_index.index.get(term) {
122 let df = postings.len() as f64;
124 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
125
126 for (doc_id, tf) in postings {
127 let doc_len =
128 self.inverted_index.doc_lengths.get(doc_id).copied().unwrap_or(1) as f64;
129
130 let k1 = self.bm25_config.k1 as f64;
132 let b = self.bm25_config.b as f64;
133 let tf_norm = (*tf as f64 * (k1 + 1.0))
134 / (*tf as f64
135 + k1 * (1.0 - b + b * doc_len / self.avg_doc_length.max(1.0)));
136
137 *scores.entry(doc_id.clone()).or_insert(0.0) += idf * tf_norm;
138 }
139 }
140 }
141
142 let mut results: Vec<_> = scores.into_iter().collect();
143 sort_and_truncate(&mut results, top_k);
144
145 results
146 }
147
148 fn dense_search(&self, query: &str, top_k: usize) -> Vec<(String, f64)> {
153 let query_terms = tokenize(query);
154 if query_terms.is_empty() {
155 return vec![];
156 }
157
158 let n = self.inverted_index.doc_lengths.len() as f64;
159 if n == 0.0 {
160 return vec![];
161 }
162
163 let mut query_vec: HashMap<&str, f64> = HashMap::new();
165 let mut candidates: HashSet<String> = HashSet::new();
166
167 for term in &query_terms {
168 if let Some(postings) = self.inverted_index.index.get(term.as_str()) {
169 let df = postings.len() as f64;
170 let idf = (n / df).ln() + 1.0; *query_vec.entry(term.as_str()).or_insert(0.0) += idf;
172 candidates.extend(postings.keys().cloned());
173 }
174 }
175
176 let query_norm: f64 = query_vec.values().map(|v| v * v).sum::<f64>().sqrt();
178 if query_norm == 0.0 {
179 return vec![];
180 }
181
182 let mut scores: Vec<(String, f64)> = candidates
183 .into_iter()
184 .filter_map(|doc_id| {
185 let doc_len = *self.inverted_index.doc_lengths.get(&doc_id)? as f64;
186 let mut dot = 0.0;
187 let mut doc_norm_sq = 0.0;
188
189 for term in &query_terms {
190 if let Some(postings) = self.inverted_index.index.get(term.as_str()) {
191 if let Some(&tf) = postings.get(&doc_id) {
192 let df = postings.len() as f64;
193 let idf = (n / df).ln() + 1.0;
194 let tfidf = (tf as f64 / doc_len.max(1.0)) * idf;
195 dot += tfidf * query_vec.get(term.as_str()).unwrap_or(&0.0);
196 doc_norm_sq += tfidf * tfidf;
197 }
198 }
199 }
200
201 let doc_norm = doc_norm_sq.sqrt();
202 if doc_norm == 0.0 {
203 return None;
204 }
205 let cosine = dot / (query_norm * doc_norm);
206 Some((doc_id, cosine))
207 })
208 .collect();
209
210 sort_and_truncate(&mut scores, top_k);
211 scores
212 }
213
214 fn rrf_fuse(
219 &self,
220 sparse_results: &[(String, f64)],
221 dense_results: &[(String, f64)],
222 top_k: usize,
223 ) -> Vec<RetrievalResult> {
224 let k = self.rrf_config.k as f64;
225 let mut rrf_scores: HashMap<String, (f64, f64, f64)> = HashMap::new(); let mut accumulate =
230 |results: &[(String, f64)], set_field: fn(&mut (f64, f64, f64), f64)| {
231 for (rank, (doc_id, raw_score)) in results.iter().enumerate() {
232 let entry = rrf_scores.entry(doc_id.clone()).or_insert((0.0, 0.0, 0.0));
233 entry.0 += 1.0 / (k + rank as f64 + 1.0);
234 set_field(entry, *raw_score);
235 }
236 };
237
238 accumulate(sparse_results, |e, s| e.1 = s); accumulate(dense_results, |e, s| e.2 = s); let mut results: Vec<_> = rrf_scores
243 .into_iter()
244 .map(|(doc_id, (rrf_score, bm25_score, dense_score))| {
245 let max_rrf = 2.0 / (k + 1.0); let normalized_score = (rrf_score / max_rrf).min(1.0);
248
249 let component = extract_component(&doc_id);
250 let id = doc_id.clone();
251 RetrievalResult {
252 id,
253 component,
254 source: doc_id,
255 content: String::new(), score: normalized_score,
257 start_line: 1,
258 end_line: 1,
259 score_breakdown: ScoreBreakdown {
260 bm25_score,
261 dense_score,
262 rrf_score,
263 rerank_score: None,
264 },
265 }
266 })
267 .collect();
268
269 results.sort();
271 results.truncate(top_k);
272
273 results
274 }
275
276 pub fn stats(&self) -> RetrieverStats {
278 RetrieverStats {
279 total_documents: self.inverted_index.doc_lengths.len(),
280 total_terms: self.inverted_index.index.len(),
281 avg_doc_length: self.avg_doc_length,
282 }
283 }
284
285 fn apply_component_boost(&self, results: &mut [RetrievalResult], query: &str) {
291 let query_lower = query.to_lowercase();
292
293 let mut components: Vec<String> = self
295 .inverted_index
296 .doc_lengths
297 .keys()
298 .filter_map(|k| k.split('/').next())
299 .collect::<HashSet<_>>()
300 .into_iter()
301 .map(|s| s.to_string())
302 .collect();
303 components.sort_by_key(|c| std::cmp::Reverse(c.len()));
304
305 let mentioned: Vec<String> =
307 components.into_iter().filter(|c| query_lower.contains(&c.to_lowercase())).collect();
308
309 if mentioned.is_empty() {
310 return;
311 }
312
313 for result in results.iter_mut() {
315 if mentioned.iter().any(|m| result.component.eq_ignore_ascii_case(m)) {
316 result.score = (result.score * 1.5).min(1.0);
317 }
318 }
319
320 results.sort();
321 }
322
323 pub fn to_persisted(&self) -> super::persistence::PersistedIndex {
325 super::persistence::PersistedIndex {
326 inverted_index: self.inverted_index.index.clone(),
327 doc_lengths: self.inverted_index.doc_lengths.clone(),
328 bm25_config: self.bm25_config,
329 rrf_config: self.rrf_config,
330 avg_doc_length: self.avg_doc_length,
331 }
332 }
333
334 pub fn from_persisted(persisted: super::persistence::PersistedIndex) -> Self {
336 Self {
337 bm25_config: persisted.bm25_config,
338 rrf_config: persisted.rrf_config,
339 inverted_index: InvertedIndex {
340 index: persisted.inverted_index,
341 doc_lengths: persisted.doc_lengths,
342 },
343 avg_doc_length: persisted.avg_doc_length,
344 }
345 }
346}
347
348impl Default for HybridRetriever {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354#[derive(Debug, Clone)]
356pub struct RetrieverStats {
357 pub total_documents: usize,
359 pub total_terms: usize,
361 pub avg_doc_length: f64,
363}
364
365#[derive(Debug, Default, Clone, Serialize, Deserialize)]
367pub struct InvertedIndex {
368 pub index: HashMap<String, HashMap<String, usize>>,
370 pub doc_lengths: HashMap<String, usize>,
372}
373
374impl InvertedIndex {
375 fn new() -> Self {
376 Self::default()
377 }
378
379 fn add_document(&mut self, doc_id: &str, content: &str) {
380 let tokens = tokenize(content);
381 self.doc_lengths.insert(doc_id.to_string(), tokens.len());
382
383 let mut term_freqs: HashMap<String, usize> = HashMap::new();
385 for token in tokens {
386 *term_freqs.entry(token).or_insert(0) += 1;
387 }
388
389 for (term, freq) in term_freqs {
391 self.index.entry(term).or_default().insert(doc_id.to_string(), freq);
392 }
393 }
394
395 fn remove_document(&mut self, doc_id: &str) {
396 self.doc_lengths.remove(doc_id);
397
398 for postings in self.index.values_mut() {
400 postings.remove(doc_id);
401 }
402
403 self.index.retain(|_, postings| !postings.is_empty());
405 }
406}
407
408#[cfg(feature = "ml")]
411fn stem(word: &str) -> String {
412 use aprender::text::stem::{PorterStemmer, Stemmer};
413 PorterStemmer::new().stem(word).unwrap_or_else(|_| word.to_string())
414}
415
416#[cfg(not(feature = "ml"))]
420fn stem(word: &str) -> String {
421 if word.len() <= 3 {
422 return word.to_string();
423 }
424 for suffix in &[
425 "ization", "isation", "ation", "tion", "sion", "ment", "ness", "ible", "able", "ence",
426 "ance", "zing", "ying", "ming", "ning", "ting", "ring", "ling", "sing", "ious", "eous",
427 "mming", "ful", "ive", "ize", "ise", "ity", "ist", "ism", "ied", "ies", "ing", "ous",
428 "ers", "est", "ely", "ory", "ant", "ent", "ial", "ual", "ly", "ed", "er", "al", "ic",
429 ] {
430 if let Some(s) = word.strip_suffix(suffix) {
431 if s.len() >= 3 {
432 return s.to_string();
433 }
434 }
435 }
436 word.to_string()
437}
438
439#[cfg(feature = "ml")]
441fn is_stop_word(word: &str) -> bool {
442 use aprender::text::stopwords::StopWordsFilter;
443 use std::sync::LazyLock;
444 static FILTER: LazyLock<StopWordsFilter> = LazyLock::new(StopWordsFilter::english);
445 FILTER.is_stop_word(word)
446}
447
448#[cfg(not(feature = "ml"))]
450fn is_stop_word(word: &str) -> bool {
451 const STOP_WORDS: &[&str] = &[
452 "the", "is", "at", "which", "on", "in", "to", "for", "of", "and", "or", "an", "be", "by",
453 "as", "do", "if", "it", "no", "so", "up", "how", "can", "its", "has", "had", "was", "are",
454 "were", "been", "have", "from", "this", "that", "with", "what", "when", "where", "will",
455 "not", "but", "all", "each", "than",
456 ];
457 STOP_WORDS.contains(&word)
458}
459
460fn tokenize(text: &str) -> Vec<String> {
467 text.to_lowercase()
468 .split(|c: char| !c.is_alphanumeric() && c != '_')
469 .filter(|s| !s.is_empty() && s.len() > 1)
470 .filter(|s| !is_stop_word(s))
471 .map(stem)
472 .collect()
473}
474
475fn sort_and_truncate(results: &mut Vec<(String, f64)>, k: usize) {
477 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
478 results.truncate(k);
479}
480
481fn extract_component(doc_id: &str) -> String {
483 doc_id.split('/').next().unwrap_or("unknown").to_string()
484}
485
486#[cfg(test)]
487#[path = "retriever_tests.rs"]
488mod tests;