use std::collections::HashMap;
pub fn compute_weights(token_count: usize) -> (f32, f32) {
match token_count {
1..=2 => (0.9, 0.1),
3..=4 => (0.7, 0.3),
5..=8 => (0.4, 0.6),
_ => (0.2, 0.8),
}
}
pub fn normalize_scores(scores: &[(String, f32)]) -> Vec<(String, f32)> {
if scores.is_empty() {
return Vec::new();
}
let min = scores.iter().map(|(_, s)| *s).fold(f32::INFINITY, f32::min);
let max = scores
.iter()
.map(|(_, s)| *s)
.fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
let epsilon = 1e-10;
if range < epsilon {
return scores.iter().map(|(id, _)| (id.clone(), 1.0)).collect();
}
scores
.iter()
.map(|(id, score)| {
let normalized = (score - min) / range;
(id.clone(), normalized)
})
.collect()
}
pub fn merge_results(
bm25_results: &[(String, f32)],
hdc_results: &[(String, f32)],
weights: (f32, f32),
) -> Vec<(String, f32)> {
let (kw_weight, sem_weight) = weights;
let bm25_normalized = normalize_scores(bm25_results);
let hdc_normalized = normalize_scores(hdc_results);
let mut combined: HashMap<String, f32> = HashMap::new();
for (id, score) in &bm25_normalized {
let entry = combined.entry(id.clone()).or_insert(0.0);
*entry += kw_weight * score;
}
for (id, score) in &hdc_normalized {
let entry = combined.entry(id.clone()).or_insert(0.0);
*entry += sem_weight * score;
}
let mut results: Vec<(String, f32)> = combined.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum HybridMode {
#[default]
Auto,
SemanticOnly,
KeywordOnly,
Custom(f32),
}
#[derive(Debug, Clone)]
pub struct HybridConfig {
pub mode: HybridMode,
pub min_score: f32,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
mode: HybridMode::Auto,
min_score: 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compute_weights_short_query() {
let (kw, sem) = compute_weights(1);
assert_eq!(kw, 0.9);
assert_eq!(sem, 0.1);
let (kw, sem) = compute_weights(2);
assert_eq!(kw, 0.9);
assert_eq!(sem, 0.1);
}
#[test]
fn test_compute_weights_medium_query() {
let (kw, sem) = compute_weights(3);
assert_eq!(kw, 0.7);
assert_eq!(sem, 0.3);
let (kw, sem) = compute_weights(4);
assert_eq!(kw, 0.7);
assert_eq!(sem, 0.3);
}
#[test]
fn test_compute_weights_long_query() {
let (kw, sem) = compute_weights(5);
assert_eq!(kw, 0.4);
assert_eq!(sem, 0.6);
let (kw, sem) = compute_weights(8);
assert_eq!(kw, 0.4);
assert_eq!(sem, 0.6);
}
#[test]
fn test_compute_weights_very_long_query() {
let (kw, sem) = compute_weights(9);
assert_eq!(kw, 0.2);
assert_eq!(sem, 0.8);
let (kw, sem) = compute_weights(100);
assert_eq!(kw, 0.2);
assert_eq!(sem, 0.8);
}
#[test]
fn test_normalize_scores_basic() {
let scores = vec![
("a".to_string(), 0.0),
("b".to_string(), 0.5),
("c".to_string(), 1.0),
];
let normalized = normalize_scores(&scores);
assert!((normalized[0].1 - 0.0).abs() < 1e-6);
assert!((normalized[1].1 - 0.5).abs() < 1e-6);
assert!((normalized[2].1 - 1.0).abs() < 1e-6);
}
#[test]
fn test_normalize_scores_empty() {
let normalized = normalize_scores(&[]);
assert!(normalized.is_empty());
}
#[test]
fn test_normalize_scores_equal() {
let scores = vec![("a".to_string(), 5.0), ("b".to_string(), 5.0)];
let normalized = normalize_scores(&scores);
assert!((normalized[0].1 - 1.0).abs() < 1e-6);
assert!((normalized[1].1 - 1.0).abs() < 1e-6);
}
#[test]
fn test_merge_results_basic() {
let bm25 = vec![("doc1".to_string(), 1.0), ("doc2".to_string(), 0.5)];
let hdc = vec![("doc1".to_string(), 0.5), ("doc3".to_string(), 1.0)];
let merged = merge_results(&bm25, &hdc, (0.5, 0.5));
assert!(merged.iter().any(|(id, _)| id == "doc1"));
assert!(merged.iter().any(|(id, _)| id == "doc2"));
assert!(merged.iter().any(|(id, _)| id == "doc3"));
}
#[test]
fn test_merge_results_weighted() {
let bm25 = vec![("doc1".to_string(), 1.0)];
let hdc = vec![("doc1".to_string(), 1.0)];
let merged = merge_results(&bm25, &hdc, (0.9, 0.1));
assert!(merged.iter().any(|(id, s)| id == "doc1" && *s > 0.0));
}
#[test]
fn test_merge_results_empty() {
let merged = merge_results(&[], &[], (0.5, 0.5));
assert!(merged.is_empty());
let merged = merge_results(&[("a".to_string(), 1.0)], &[], (0.5, 0.5));
assert_eq!(merged.len(), 1);
let merged = merge_results(&[], &[("a".to_string(), 1.0)], (0.5, 0.5));
assert_eq!(merged.len(), 1);
}
}