use super::{SearchOptions, SearchResult, SearchSource};
use crate::db::Database;
use crate::error::Result;
use crate::llm::{Embedder, QueryExpander, RerankDocument, Reranker};
use std::collections::HashMap;
const RRF_K: f64 = 60.0;
const MAX_RERANK_DOCS: usize = 40;
const STRONG_SIGNAL_SCORE: f64 = 0.85;
const STRONG_SIGNAL_GAP: f64 = 0.15;
pub fn has_strong_signal(results: &[SearchResult]) -> bool {
if results.len() < 2 {
return results
.first()
.map(|r| r.score >= STRONG_SIGNAL_SCORE)
.unwrap_or(false);
}
let top_score = results[0].score;
let second_score = results[1].score;
let gap = top_score - second_score;
top_score >= STRONG_SIGNAL_SCORE && gap >= STRONG_SIGNAL_GAP
}
pub fn cap_for_reranking(results: Vec<SearchResult>) -> Vec<SearchResult> {
results.into_iter().take(MAX_RERANK_DOCS).collect()
}
pub fn blend_scores(rrf_rank: usize, rrf_score: f64, rerank_score: f64) -> f64 {
let rrf_weight = if rrf_rank <= 3 {
0.75 } else if rrf_rank <= 10 {
0.60
} else {
0.40 };
rrf_weight * rrf_score + (1.0 - rrf_weight) * rerank_score
}
pub fn rrf_fusion(
bm25_results: &[SearchResult],
vec_results: &[SearchResult],
) -> Vec<SearchResult> {
let mut scores: HashMap<String, (f64, SearchResult)> = HashMap::new();
for (rank, result) in bm25_results.iter().enumerate() {
let rrf_score = 2.0 / (RRF_K + (rank + 1) as f64);
let bonus = if rank < 3 {
0.05
} else if rank < 10 {
0.02
} else {
0.0
};
let entry = scores
.entry(result.hash.clone())
.or_insert((0.0, result.clone()));
entry.0 += rrf_score + bonus;
}
for (rank, result) in vec_results.iter().enumerate() {
let rrf_score = 1.0 / (RRF_K + (rank + 1) as f64);
let bonus = if rank < 3 {
0.05
} else if rank < 10 {
0.02
} else {
0.0
};
let entry = scores
.entry(result.hash.clone())
.or_insert((0.0, result.clone()));
entry.0 += rrf_score + bonus;
}
let mut results: Vec<(f64, SearchResult)> = scores.into_values().collect();
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
results
.into_iter()
.map(|(score, mut r)| {
r.score = score;
r.source = SearchSource::Hybrid;
r
})
.collect()
}
pub async fn hybrid_search(
db: &Database,
query: &str,
options: &SearchOptions,
embedder: &dyn Embedder,
expander: Option<&dyn QueryExpander>,
reranker: Option<&dyn Reranker>,
) -> Result<Vec<SearchResult>> {
let bm25_results = db.search_fts(query, options)?;
if has_strong_signal(&bm25_results) {
return Ok(bm25_results);
}
let vec_results = db.search_vec(query, embedder, options).await?;
let mut all_bm25 = bm25_results.clone();
let mut all_vec = vec_results.clone();
if let Some(exp) = expander {
let expanded = exp.expand(query, None).await?;
for lex_query in &expanded.lexical {
let results = db.search_fts(lex_query, options)?;
all_bm25.extend(results);
}
for vec_query in &expanded.semantic {
let results = db.search_vec(vec_query, embedder, options).await?;
all_vec.extend(results);
}
if let Some(ref hyde) = expanded.hyde {
let results = db.search_vec(hyde, embedder, options).await?;
all_vec.extend(results);
}
}
let mut fused = rrf_fusion(&all_bm25, &all_vec);
fused = cap_for_reranking(fused);
if let Some(rr) = reranker {
let docs: Vec<RerankDocument> = fused
.iter()
.map(|r| RerankDocument {
id: r.hash.clone(),
text: r.body.clone().unwrap_or_default(),
})
.collect();
let reranked = rr.rerank(query, &docs).await?;
let rerank_scores: HashMap<String, f64> =
reranked.iter().map(|r| (r.id.clone(), r.score)).collect();
for (rrf_rank, result) in fused.iter_mut().enumerate() {
if let Some(&rerank_score) = rerank_scores.get(&result.hash) {
let rrf_score = result.score;
result.score = blend_scores(rrf_rank + 1, rrf_score, rerank_score);
}
}
fused.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
let final_results: Vec<SearchResult> = fused
.into_iter()
.filter(|r| r.score >= options.min_score)
.take(options.limit)
.collect();
Ok(final_results)
}
impl Database {
pub fn search_vec_sync(
&self,
_query: &str,
options: &SearchOptions,
) -> Result<Vec<SearchResult>> {
eprintln!("Warning: Vector search requires embeddings, falling back to BM25");
self.search_fts(_query, options)
}
pub fn search_hybrid_sync(
&self,
query: &str,
options: &SearchOptions,
) -> Result<Vec<SearchResult>> {
self.search_fts(query, options)
}
}