agentroot_core/search/
hybrid.rs

1//! Hybrid search with Reciprocal Rank Fusion
2
3use super::{SearchOptions, SearchResult, SearchSource};
4use crate::db::Database;
5use crate::error::Result;
6use crate::llm::{Embedder, QueryExpander, RerankDocument, Reranker};
7use std::collections::HashMap;
8
9/// RRF constant (standard value)
10const RRF_K: f64 = 60.0;
11
12/// Maximum documents to send to reranker
13const MAX_RERANK_DOCS: usize = 40;
14
15/// Strong signal threshold
16const STRONG_SIGNAL_SCORE: f64 = 0.85;
17const STRONG_SIGNAL_GAP: f64 = 0.15;
18
19/// Check if top BM25 result is a strong signal (skip expansion)
20pub fn has_strong_signal(results: &[SearchResult]) -> bool {
21    if results.len() < 2 {
22        return results
23            .first()
24            .map(|r| r.score >= STRONG_SIGNAL_SCORE)
25            .unwrap_or(false);
26    }
27
28    let top_score = results[0].score;
29    let second_score = results[1].score;
30    let gap = top_score - second_score;
31
32    top_score >= STRONG_SIGNAL_SCORE && gap >= STRONG_SIGNAL_GAP
33}
34
35/// Cap results for reranking
36pub fn cap_for_reranking(results: Vec<SearchResult>) -> Vec<SearchResult> {
37    results.into_iter().take(MAX_RERANK_DOCS).collect()
38}
39
40/// Position-aware score blending
41pub fn blend_scores(rrf_rank: usize, rrf_score: f64, rerank_score: f64) -> f64 {
42    let rrf_weight = if rrf_rank <= 3 {
43        0.75 // Trust retrieval for top results
44    } else if rrf_rank <= 10 {
45        0.60
46    } else {
47        0.40 // Trust reranker for lower-ranked
48    };
49
50    rrf_weight * rrf_score + (1.0 - rrf_weight) * rerank_score
51}
52
53/// Reciprocal Rank Fusion
54pub fn rrf_fusion(
55    bm25_results: &[SearchResult],
56    vec_results: &[SearchResult],
57) -> Vec<SearchResult> {
58    let mut scores: HashMap<String, (f64, SearchResult)> = HashMap::new();
59
60    // Process BM25 results (weight 2x)
61    for (rank, result) in bm25_results.iter().enumerate() {
62        let rrf_score = 2.0 / (RRF_K + (rank + 1) as f64);
63        // Bonus for appearing in top 3
64        let bonus = if rank < 3 {
65            0.05
66        } else if rank < 10 {
67            0.02
68        } else {
69            0.0
70        };
71
72        let entry = scores
73            .entry(result.hash.clone())
74            .or_insert((0.0, result.clone()));
75        entry.0 += rrf_score + bonus;
76    }
77
78    // Process vector results
79    for (rank, result) in vec_results.iter().enumerate() {
80        let rrf_score = 1.0 / (RRF_K + (rank + 1) as f64);
81        let bonus = if rank < 3 {
82            0.05
83        } else if rank < 10 {
84            0.02
85        } else {
86            0.0
87        };
88
89        let entry = scores
90            .entry(result.hash.clone())
91            .or_insert((0.0, result.clone()));
92        entry.0 += rrf_score + bonus;
93    }
94
95    // Sort by score
96    let mut results: Vec<(f64, SearchResult)> = scores.into_values().collect();
97    results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
98
99    results
100        .into_iter()
101        .map(|(score, mut r)| {
102            r.score = score;
103            r.source = SearchSource::Hybrid;
104            r
105        })
106        .collect()
107}
108
109/// Full hybrid search pipeline
110pub async fn hybrid_search(
111    db: &Database,
112    query: &str,
113    options: &SearchOptions,
114    embedder: &dyn Embedder,
115    expander: Option<&dyn QueryExpander>,
116    reranker: Option<&dyn Reranker>,
117) -> Result<Vec<SearchResult>> {
118    // 1. Initial BM25 search
119    let bm25_results = db.search_fts(query, options)?;
120
121    // 2. Check for strong signal
122    if has_strong_signal(&bm25_results) {
123        return Ok(bm25_results);
124    }
125
126    // 3. Vector search
127    let vec_results = db.search_vec(query, embedder, options).await?;
128
129    // 4. Query expansion (if available and not skipped)
130    let mut all_bm25 = bm25_results.clone();
131    let mut all_vec = vec_results.clone();
132
133    if let Some(exp) = expander {
134        let expanded = exp.expand(query, None).await?;
135
136        // Run lexical variations
137        for lex_query in &expanded.lexical {
138            let results = db.search_fts(lex_query, options)?;
139            all_bm25.extend(results);
140        }
141
142        // Run semantic variations
143        for vec_query in &expanded.semantic {
144            let results = db.search_vec(vec_query, embedder, options).await?;
145            all_vec.extend(results);
146        }
147
148        // Run HyDE if present
149        if let Some(ref hyde) = expanded.hyde {
150            let results = db.search_vec(hyde, embedder, options).await?;
151            all_vec.extend(results);
152        }
153    }
154
155    // 5. RRF fusion
156    let mut fused = rrf_fusion(&all_bm25, &all_vec);
157
158    // 6. Cap for reranking
159    fused = cap_for_reranking(fused);
160
161    // 7. Rerank (if available)
162    if let Some(rr) = reranker {
163        let docs: Vec<RerankDocument> = fused
164            .iter()
165            .map(|r| RerankDocument {
166                id: r.hash.clone(),
167                text: r.body.clone().unwrap_or_default(),
168            })
169            .collect();
170
171        let reranked = rr.rerank(query, &docs).await?;
172
173        // Build hash -> rerank score map
174        let rerank_scores: HashMap<String, f64> =
175            reranked.iter().map(|r| (r.id.clone(), r.score)).collect();
176
177        // Blend scores
178        for (rrf_rank, result) in fused.iter_mut().enumerate() {
179            if let Some(&rerank_score) = rerank_scores.get(&result.hash) {
180                let rrf_score = result.score;
181                result.score = blend_scores(rrf_rank + 1, rrf_score, rerank_score);
182            }
183        }
184
185        // Re-sort by blended score
186        fused.sort_by(|a, b| {
187            b.score
188                .partial_cmp(&a.score)
189                .unwrap_or(std::cmp::Ordering::Equal)
190        });
191    }
192
193    // 8. Apply final limit and min_score
194    let final_results: Vec<SearchResult> = fused
195        .into_iter()
196        .filter(|r| r.score >= options.min_score)
197        .take(options.limit)
198        .collect();
199
200    Ok(final_results)
201}
202
203impl Database {
204    /// Synchronous vector search (placeholder for CLI - needs runtime)
205    pub fn search_vec_sync(
206        &self,
207        _query: &str,
208        options: &SearchOptions,
209    ) -> Result<Vec<SearchResult>> {
210        // Placeholder: In production, this would use a runtime
211        // For now, fall back to BM25
212        eprintln!("Warning: Vector search requires embeddings, falling back to BM25");
213        self.search_fts(_query, options)
214    }
215
216    /// Synchronous hybrid search (placeholder for CLI - needs runtime)
217    pub fn search_hybrid_sync(
218        &self,
219        query: &str,
220        options: &SearchOptions,
221    ) -> Result<Vec<SearchResult>> {
222        // Placeholder: In production, this would use full hybrid pipeline
223        // For now, just run BM25
224        self.search_fts(query, options)
225    }
226}