use std::collections::HashMap;
use crate::fulltext::FullTextResult;
type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
#[derive(Debug, Clone)]
pub struct HybridConfig {
pub vector_weight: f32,
pub require_both: bool,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
vector_weight: 0.5,
require_both: false,
}
}
}
#[derive(Debug, Clone)]
struct RawScore {
score: f32,
metadata: Option<serde_json::Value>,
vector: Option<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct HybridResult {
pub id: String,
pub combined_score: f32,
pub vector_score: f32,
pub text_score: f32,
pub metadata: Option<serde_json::Value>,
pub vector: Option<Vec<f32>>,
}
pub struct HybridSearcher {
config: HybridConfig,
}
impl HybridSearcher {
pub fn new(config: HybridConfig) -> Self {
Self { config }
}
pub fn with_vector_weight(mut self, weight: f32) -> Self {
self.config.vector_weight = weight.clamp(0.0, 1.0);
self
}
pub fn search(
&self,
vector_results: Vec<VectorResultRow>,
text_results: Vec<FullTextResult>,
top_k: usize,
) -> Vec<HybridResult> {
let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
let mut text_scores: HashMap<String, f32> = HashMap::new();
let mut vector_min = f32::MAX;
let mut vector_max = f32::MIN;
let mut text_min = f32::MAX;
let mut text_max = f32::MIN;
for (id, score, metadata, vector) in vector_results {
vector_min = vector_min.min(score);
vector_max = vector_max.max(score);
vector_scores.insert(
id,
RawScore {
score,
metadata,
vector,
},
);
}
for result in text_results {
text_min = text_min.min(result.score);
text_max = text_max.max(result.score);
text_scores.insert(result.doc_id, result.score);
}
let mut all_ids: Vec<String> = vector_scores
.keys()
.chain(text_scores.keys())
.cloned()
.collect();
all_ids.sort();
all_ids.dedup();
let mut results: Vec<HybridResult> = Vec::new();
for id in all_ids {
let vector_raw = vector_scores.get(&id);
let text_raw = text_scores.get(&id);
if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
continue;
}
let vector_normalized = if let Some(raw) = vector_raw {
normalize_score(raw.score, vector_min, vector_max)
} else {
0.0
};
let text_normalized = if let Some(&score) = text_raw {
normalize_score(score, text_min, text_max)
} else {
0.0
};
let combined = self.config.vector_weight * vector_normalized
+ (1.0 - self.config.vector_weight) * text_normalized;
let (metadata, vector) = if let Some(raw) = vector_raw {
(raw.metadata.clone(), raw.vector.clone())
} else {
(None, None)
};
results.push(HybridResult {
id,
combined_score: combined,
vector_score: vector_normalized,
text_score: text_normalized,
metadata,
vector,
});
}
results.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
}
impl Default for HybridSearcher {
fn default() -> Self {
Self::new(HybridConfig::default())
}
}
pub fn adaptive_vector_weight(kind: crate::routing::QueryKind) -> f32 {
match kind {
crate::routing::QueryKind::Keyword => 0.25,
crate::routing::QueryKind::Hybrid => 0.50,
crate::routing::QueryKind::Semantic => 0.75,
}
}
fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
if (max - min).abs() < f32::EPSILON {
1.0
} else {
(score - min) / (max - min)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_search_basic() {
let searcher = HybridSearcher::default();
let vector_results = vec![
("doc1".to_string(), 0.9, None, None),
("doc2".to_string(), 0.7, None, None),
("doc3".to_string(), 0.5, None, None),
];
let text_results = vec![
FullTextResult {
doc_id: "doc1".to_string(),
score: 3.0,
},
FullTextResult {
doc_id: "doc2".to_string(),
score: 4.0,
},
FullTextResult {
doc_id: "doc4".to_string(),
score: 2.0,
},
];
let results = searcher.search(vector_results, text_results, 10);
assert_eq!(results.len(), 4);
let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
assert!(doc1.vector_score > 0.0);
assert!(doc1.text_score >= 0.0);
assert!(doc1.combined_score > 0.0);
let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
assert!(doc2.vector_score > 0.0);
assert!(doc2.text_score > 0.0); assert!(doc2.combined_score > 0.0);
assert_eq!(doc2.text_score, 1.0);
}
#[test]
fn test_hybrid_search_vector_only() {
let searcher = HybridSearcher::new(HybridConfig {
vector_weight: 1.0,
require_both: false,
});
let vector_results = vec![
("doc1".to_string(), 0.9, None, None),
("doc2".to_string(), 0.5, None, None),
];
let text_results = vec![FullTextResult {
doc_id: "doc1".to_string(),
score: 1.0,
}];
let results = searcher.search(vector_results, text_results, 10);
assert_eq!(results[0].id, "doc1");
assert_eq!(results[0].combined_score, results[0].vector_score);
}
#[test]
fn test_hybrid_search_text_only() {
let searcher = HybridSearcher::new(HybridConfig {
vector_weight: 0.0,
require_both: false,
});
let vector_results = vec![
("doc1".to_string(), 0.9, None, None),
("doc2".to_string(), 0.5, None, None),
];
let text_results = vec![
FullTextResult {
doc_id: "doc1".to_string(),
score: 1.0,
},
FullTextResult {
doc_id: "doc2".to_string(),
score: 3.0,
},
];
let results = searcher.search(vector_results, text_results, 10);
assert_eq!(results[0].id, "doc2");
assert_eq!(results[0].combined_score, results[0].text_score);
}
#[test]
fn test_hybrid_search_require_both() {
let searcher = HybridSearcher::new(HybridConfig {
vector_weight: 0.5,
require_both: true,
});
let vector_results = vec![
("doc1".to_string(), 0.9, None, None),
("doc2".to_string(), 0.7, None, None),
];
let text_results = vec![FullTextResult {
doc_id: "doc1".to_string(),
score: 2.0,
}];
let results = searcher.search(vector_results, text_results, 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, "doc1");
}
#[test]
fn test_hybrid_search_top_k() {
let searcher = HybridSearcher::default();
let vector_results = vec![
("doc1".to_string(), 0.9, None, None),
("doc2".to_string(), 0.8, None, None),
("doc3".to_string(), 0.7, None, None),
("doc4".to_string(), 0.6, None, None),
("doc5".to_string(), 0.5, None, None),
];
let text_results = vec![];
let results = searcher.search(vector_results, text_results, 3);
assert_eq!(results.len(), 3);
}
#[test]
fn test_hybrid_search_with_metadata() {
let searcher = HybridSearcher::default();
let metadata = serde_json::json!({"title": "Test Document"});
let vector = vec![1.0, 0.0, 0.0];
let vector_results = vec![(
"doc1".to_string(),
0.9,
Some(metadata.clone()),
Some(vector.clone()),
)];
let text_results = vec![FullTextResult {
doc_id: "doc1".to_string(),
score: 2.0,
}];
let results = searcher.search(vector_results, text_results, 10);
assert_eq!(results.len(), 1);
assert_eq!(results[0].metadata, Some(metadata));
assert_eq!(results[0].vector, Some(vector));
}
#[test]
fn test_normalize_score() {
assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
}
#[test]
fn test_hybrid_searcher_builder() {
let searcher = HybridSearcher::default().with_vector_weight(0.7);
assert_eq!(searcher.config.vector_weight, 0.7);
}
#[test]
fn test_vector_weight_clamping() {
let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
assert_eq!(searcher1.config.vector_weight, 1.0);
let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
assert_eq!(searcher2.config.vector_weight, 0.0);
}
}