use crate::error::Result;
use crate::storage::{SearchResult, VectorStore};
use std::collections::HashMap;
use uuid::Uuid;
pub struct HybridSearch {
semantic_weight: f32,
}
impl HybridSearch {
pub fn new(semantic_weight: f32) -> Self {
Self {
semantic_weight: semantic_weight.clamp(0.0, 1.0),
}
}
pub fn search(
&self,
store: &VectorStore,
query: &str,
embedding: &[f32],
k: usize,
) -> Result<Vec<SearchResult>> {
let candidate_k = (k * 3).max(20).min(store.len());
let semantic_results = store.search(embedding, candidate_k)?;
let mut scores: HashMap<Uuid, (f32, f32, serde_json::Value)> = HashMap::new();
for result in &semantic_results {
scores
.entry(result.id)
.or_insert((0.0, 0.0, result.metadata.clone()))
.0 = result.score;
}
let query_lower = query.to_lowercase();
let query_terms: Vec<&str> = query_lower.split_whitespace().collect();
for result in &semantic_results {
let text = result
.metadata
.get("text")
.and_then(|v| v.as_str())
.unwrap_or("");
let text_lower = text.to_lowercase();
let keyword_score = compute_keyword_score(&query_terms, &text_lower);
if let Some(entry) = scores.get_mut(&result.id) {
entry.1 = keyword_score;
}
}
let keyword_weight = 1.0 - self.semantic_weight;
let mut combined: Vec<SearchResult> = scores
.into_iter()
.map(|(id, (sem_score, kw_score, metadata))| {
let combined_score =
self.semantic_weight * sem_score + keyword_weight * kw_score;
SearchResult {
id,
score: combined_score,
metadata,
}
})
.collect();
combined.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
combined.truncate(k);
Ok(combined)
}
pub fn semantic_weight(&self) -> f32 {
self.semantic_weight
}
}
fn compute_keyword_score(query_terms: &[&str], text_lower: &str) -> f32 {
if query_terms.is_empty() {
return 0.0;
}
let matches = query_terms
.iter()
.filter(|term| text_lower.contains(*term))
.count();
matches as f32 / query_terms.len() as f32
}