agentroot_core/search/
hybrid.rs1use super::{SearchOptions, SearchResult, SearchSource};
4use crate::db::Database;
5use crate::error::Result;
6use crate::llm::{Embedder, QueryExpander, RerankDocument, Reranker};
7use std::collections::HashMap;
8
9const RRF_K: f64 = 60.0;
11
12const MAX_RERANK_DOCS: usize = 40;
14
15const STRONG_SIGNAL_SCORE: f64 = 0.85;
17const STRONG_SIGNAL_GAP: f64 = 0.15;
18
19pub fn has_strong_signal(results: &[SearchResult]) -> bool {
21 if results.len() < 2 {
22 return results
23 .first()
24 .map(|r| r.score >= STRONG_SIGNAL_SCORE)
25 .unwrap_or(false);
26 }
27
28 let top_score = results[0].score;
29 let second_score = results[1].score;
30 let gap = top_score - second_score;
31
32 top_score >= STRONG_SIGNAL_SCORE && gap >= STRONG_SIGNAL_GAP
33}
34
35pub fn cap_for_reranking(results: Vec<SearchResult>) -> Vec<SearchResult> {
37 results.into_iter().take(MAX_RERANK_DOCS).collect()
38}
39
40pub fn blend_scores(rrf_rank: usize, rrf_score: f64, rerank_score: f64) -> f64 {
42 let rrf_weight = if rrf_rank <= 3 {
43 0.75 } else if rrf_rank <= 10 {
45 0.60
46 } else {
47 0.40 };
49
50 rrf_weight * rrf_score + (1.0 - rrf_weight) * rerank_score
51}
52
53pub fn rrf_fusion(
55 bm25_results: &[SearchResult],
56 vec_results: &[SearchResult],
57) -> Vec<SearchResult> {
58 let mut scores: HashMap<String, (f64, SearchResult)> = HashMap::new();
59
60 for (rank, result) in bm25_results.iter().enumerate() {
62 let rrf_score = 2.0 / (RRF_K + (rank + 1) as f64);
63 let bonus = if rank < 3 {
65 0.05
66 } else if rank < 10 {
67 0.02
68 } else {
69 0.0
70 };
71
72 let entry = scores
73 .entry(result.hash.clone())
74 .or_insert((0.0, result.clone()));
75 entry.0 += rrf_score + bonus;
76 }
77
78 for (rank, result) in vec_results.iter().enumerate() {
80 let rrf_score = 1.0 / (RRF_K + (rank + 1) as f64);
81 let bonus = if rank < 3 {
82 0.05
83 } else if rank < 10 {
84 0.02
85 } else {
86 0.0
87 };
88
89 let entry = scores
90 .entry(result.hash.clone())
91 .or_insert((0.0, result.clone()));
92 entry.0 += rrf_score + bonus;
93 }
94
95 let mut results: Vec<(f64, SearchResult)> = scores.into_values().collect();
97 results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
98
99 results
100 .into_iter()
101 .map(|(score, mut r)| {
102 r.score = score;
103 r.source = SearchSource::Hybrid;
104 r
105 })
106 .collect()
107}
108
109pub async fn hybrid_search(
111 db: &Database,
112 query: &str,
113 options: &SearchOptions,
114 embedder: &dyn Embedder,
115 expander: Option<&dyn QueryExpander>,
116 reranker: Option<&dyn Reranker>,
117) -> Result<Vec<SearchResult>> {
118 let bm25_results = db.search_fts(query, options)?;
120
121 if has_strong_signal(&bm25_results) {
123 return Ok(bm25_results);
124 }
125
126 let vec_results = db.search_vec(query, embedder, options).await?;
128
129 let mut all_bm25 = bm25_results.clone();
131 let mut all_vec = vec_results.clone();
132
133 if let Some(exp) = expander {
134 let expanded = exp.expand(query, None).await?;
135
136 for lex_query in &expanded.lexical {
138 let results = db.search_fts(lex_query, options)?;
139 all_bm25.extend(results);
140 }
141
142 for vec_query in &expanded.semantic {
144 let results = db.search_vec(vec_query, embedder, options).await?;
145 all_vec.extend(results);
146 }
147
148 if let Some(ref hyde) = expanded.hyde {
150 let results = db.search_vec(hyde, embedder, options).await?;
151 all_vec.extend(results);
152 }
153 }
154
155 let mut fused = rrf_fusion(&all_bm25, &all_vec);
157
158 fused = cap_for_reranking(fused);
160
161 if let Some(rr) = reranker {
163 let docs: Vec<RerankDocument> = fused
164 .iter()
165 .map(|r| RerankDocument {
166 id: r.hash.clone(),
167 text: r.body.clone().unwrap_or_default(),
168 })
169 .collect();
170
171 let reranked = rr.rerank(query, &docs).await?;
172
173 let rerank_scores: HashMap<String, f64> =
175 reranked.iter().map(|r| (r.id.clone(), r.score)).collect();
176
177 for (rrf_rank, result) in fused.iter_mut().enumerate() {
179 if let Some(&rerank_score) = rerank_scores.get(&result.hash) {
180 let rrf_score = result.score;
181 result.score = blend_scores(rrf_rank + 1, rrf_score, rerank_score);
182 }
183 }
184
185 fused.sort_by(|a, b| {
187 b.score
188 .partial_cmp(&a.score)
189 .unwrap_or(std::cmp::Ordering::Equal)
190 });
191 }
192
193 let final_results: Vec<SearchResult> = fused
195 .into_iter()
196 .filter(|r| r.score >= options.min_score)
197 .take(options.limit)
198 .collect();
199
200 Ok(final_results)
201}
202
203impl Database {
204 pub fn search_vec_sync(
206 &self,
207 _query: &str,
208 options: &SearchOptions,
209 ) -> Result<Vec<SearchResult>> {
210 eprintln!("Warning: Vector search requires embeddings, falling back to BM25");
213 self.search_fts(_query, options)
214 }
215
216 pub fn search_hybrid_sync(
218 &self,
219 query: &str,
220 options: &SearchOptions,
221 ) -> Result<Vec<SearchResult>> {
222 self.search_fts(query, options)
225 }
226}