use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridQuery {
pub query_text: String,
pub query_vector: Option<Vec<f32>>,
pub top_k: usize,
pub weights: SearchWeights,
pub filters: HashMap<String, String>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct SearchWeights {
pub keyword_weight: f32,
pub semantic_weight: f32,
pub recency_weight: f32,
}
impl Default for SearchWeights {
fn default() -> Self {
Self {
keyword_weight: 0.3,
semantic_weight: 0.7,
recency_weight: 0.0,
}
}
}
impl SearchWeights {
pub fn validate(&self) -> anyhow::Result<()> {
let sum = self.keyword_weight + self.semantic_weight + self.recency_weight;
if (sum - 1.0).abs() > 0.1 {
anyhow::bail!(
"Search weights should sum to approximately 1.0, got {}",
sum
);
}
if self.keyword_weight < 0.0 || self.semantic_weight < 0.0 || self.recency_weight < 0.0 {
anyhow::bail!("Search weights must be non-negative");
}
Ok(())
}
pub fn normalize(&mut self) {
let sum = self.keyword_weight + self.semantic_weight + self.recency_weight;
if sum > 0.0 {
self.keyword_weight /= sum;
self.semantic_weight /= sum;
self.recency_weight /= sum;
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridResult {
pub doc_id: String,
pub score: f32,
pub score_breakdown: ScoreBreakdown,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoreBreakdown {
pub keyword_score: f32,
pub semantic_score: f32,
pub recency_score: f32,
pub keyword_rank: Option<usize>,
pub semantic_rank: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct DocumentScore {
pub doc_id: String,
pub score: f32,
pub rank: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeywordMatch {
pub doc_id: String,
pub score: f32,
pub matched_terms: Vec<String>,
pub term_frequencies: HashMap<String, usize>,
}
impl HybridResult {
pub fn new(
doc_id: String,
keyword_score: f32,
semantic_score: f32,
recency_score: f32,
weights: &SearchWeights,
) -> Self {
let score = keyword_score * weights.keyword_weight
+ semantic_score * weights.semantic_weight
+ recency_score * weights.recency_weight;
Self {
doc_id,
score,
score_breakdown: ScoreBreakdown {
keyword_score,
semantic_score,
recency_score,
keyword_rank: None,
semantic_rank: None,
},
metadata: HashMap::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_search_weights_validation() {
let weights = SearchWeights {
keyword_weight: 0.3,
semantic_weight: 0.7,
recency_weight: 0.0,
};
assert!(weights.validate().is_ok());
let bad_weights = SearchWeights {
keyword_weight: 0.5,
semantic_weight: 0.8,
recency_weight: 0.0,
};
assert!(bad_weights.validate().is_err());
}
#[test]
fn test_weights_normalization() {
let mut weights = SearchWeights {
keyword_weight: 1.0,
semantic_weight: 2.0,
recency_weight: 1.0,
};
weights.normalize();
assert!((weights.keyword_weight - 0.25).abs() < 0.001);
assert!((weights.semantic_weight - 0.5).abs() < 0.001);
assert!((weights.recency_weight - 0.25).abs() < 0.001);
}
#[test]
fn test_hybrid_result_scoring() {
let weights = SearchWeights {
keyword_weight: 0.4,
semantic_weight: 0.6,
recency_weight: 0.0,
};
let result = HybridResult::new("doc1".to_string(), 0.8, 0.9, 0.0, &weights);
let expected_score = 0.8 * 0.4 + 0.9 * 0.6;
assert!((result.score - expected_score).abs() < 0.001);
}
}