use std::collections::HashMap;
use crate::bm25::{Bm25Error, Bm25Index};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct HybridSearchConfig {
pub alpha: f32,
pub top_k: usize,
}
impl Default for HybridSearchConfig {
fn default() -> Self {
Self {
alpha: 0.5,
top_k: 10,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct HybridResult {
pub id: String,
pub file: String,
pub symbol: Option<String>,
pub content: String,
pub score: f32,
}
pub fn hybrid_search(
bm25: &Bm25Index,
vector_results: &[(String, f32)],
query: &str,
config: &HybridSearchConfig,
) -> Result<Vec<HybridResult>, Bm25Error> {
let candidates = bm25.search(query, config.top_k.saturating_mul(2))?;
if candidates.is_empty() {
return Ok(Vec::new());
}
let max_bm25 = candidates
.iter()
.map(|r| r.score)
.fold(f32::NEG_INFINITY, f32::max);
let vector_map: HashMap<&str, f32> = vector_results
.iter()
.map(|(id, score)| (id.as_str(), *score))
.collect();
let alpha = config.alpha.clamp(0.0, 1.0);
let mut combined: Vec<(String, f32)> = candidates
.into_iter()
.map(|r| {
let norm = if max_bm25 > 0.0 {
r.score / max_bm25
} else {
0.0
};
let vec_score = vector_map.get(r.id.as_str()).copied().unwrap_or(0.0);
let score = alpha * norm + (1.0 - alpha) * vec_score;
(r.id, score)
})
.collect();
combined.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
combined.truncate(config.top_k);
let results = combined
.into_iter()
.map(|(id, score)| HybridResult {
id,
file: String::new(),
symbol: None,
content: String::new(),
score,
})
.collect();
Ok(results)
}