use anyhow::Result;
use std::fmt;
#[derive(Debug, Clone, Copy)]
pub struct HybridSearchConfig {
pub vector_weight: f32,
pub fts_weight: f32,
}
impl HybridSearchConfig {
#[must_use]
pub fn new(vector_weight: f32, fts_weight: f32) -> Self {
let total = vector_weight + fts_weight;
Self {
vector_weight: vector_weight / total,
fts_weight: fts_weight / total,
}
}
#[must_use]
pub fn default_config() -> Self {
Self::new(0.7, 0.3)
}
#[must_use]
pub fn vector_only() -> Self {
Self {
vector_weight: 1.0,
fts_weight: 0.0,
}
}
#[must_use]
pub fn keyword_only() -> Self {
Self {
vector_weight: 0.0,
fts_weight: 1.0,
}
}
pub fn validate(&self) -> Result<()> {
if !(0.0..=1.0).contains(&self.vector_weight) {
anyhow::bail!("Vector weight must be between 0.0 and 1.0");
}
if !(0.0..=1.0).contains(&self.fts_weight) {
anyhow::bail!("FTS weight must be between 0.0 and 1.0");
}
if (self.vector_weight + self.fts_weight - 1.0).abs() > 0.0001 {
anyhow::bail!(
"Weights must sum to 1.0 (got {} + {} = {})",
self.vector_weight,
self.fts_weight,
self.vector_weight + self.fts_weight
);
}
Ok(())
}
}
impl Default for HybridSearchConfig {
fn default() -> Self {
Self::default_config()
}
}
#[derive(Debug, Clone)]
pub struct HybridSearchResult<T> {
pub item: T,
pub hybrid_score: f32,
pub vector_score: f32,
pub fts_score: f32,
}
impl<T> HybridSearchResult<T> {
#[must_use]
pub fn new(item: T, vector_score: f32, fts_score: f32, config: &HybridSearchConfig) -> Self {
let hybrid_score = config.vector_weight * vector_score + config.fts_weight * fts_score;
Self {
item,
hybrid_score,
vector_score,
fts_score,
}
}
}
pub struct HybridSearch {
config: HybridSearchConfig,
}
impl HybridSearch {
#[must_use]
pub fn new() -> Self {
Self {
config: HybridSearchConfig::default(),
}
}
pub fn with_config(config: HybridSearchConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
#[must_use]
pub fn search_episodes<T>(
&self,
vector_results: Vec<(T, f32)>,
fts_results: Vec<(T, f32)>,
limit: usize,
) -> Vec<HybridSearchResult<T>>
where
T: Clone + PartialEq + Eq + std::hash::Hash,
{
let mut vector_map = std::collections::HashMap::new();
for (item, score) in vector_results {
vector_map.insert(item, score);
}
let mut fts_map = std::collections::HashMap::new();
for (item, score) in fts_results {
fts_map.insert(item, score);
}
let mut combined = Vec::new();
for (item, vector_score) in &vector_map {
let fts_score = fts_map.get(item).copied().unwrap_or(0.0);
let result =
HybridSearchResult::new((*item).clone(), *vector_score, fts_score, &self.config);
combined.push(result);
}
for (item, fts_score) in &fts_map {
if !vector_map.contains_key(item) {
let result = HybridSearchResult::new(
(*item).clone(),
0.0, *fts_score,
&self.config,
);
combined.push(result);
}
}
combined.sort_by(|a, b| {
b.hybrid_score
.partial_cmp(&a.hybrid_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
combined.truncate(limit);
combined
}
#[must_use]
pub fn config(&self) -> &HybridSearchConfig {
&self.config
}
pub fn update_config(&mut self, config: HybridSearchConfig) -> Result<()> {
config.validate()?;
self.config = config;
Ok(())
}
}
impl Default for HybridSearch {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for HybridSearch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HybridSearch")
.field("config", &self.config)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_search_config() {
let config = HybridSearchConfig::new(0.7, 0.3);
assert!((config.vector_weight - 0.7).abs() < 0.001);
assert!((config.fts_weight - 0.3).abs() < 0.001);
assert!(config.validate().is_ok());
}
#[test]
fn test_hybrid_search_config_normalization() {
let config = HybridSearchConfig::new(2.0, 1.0);
assert!((config.vector_weight - 0.666_666_7).abs() < 0.001);
assert!((config.fts_weight - 0.333_333_3).abs() < 0.001);
assert!(config.validate().is_ok());
}
#[test]
fn test_hybrid_search_config_validation() {
let invalid = HybridSearchConfig {
vector_weight: 1.5,
fts_weight: -0.5,
};
assert!(invalid.validate().is_err());
}
#[test]
fn test_hybrid_search_result() {
let config = HybridSearchConfig::new(0.7, 0.3);
let item = "test item".to_string();
let result = HybridSearchResult::new(item.clone(), 0.8, 0.6, &config);
let expected_score = 0.7 * 0.8 + 0.3 * 0.6; assert!((result.hybrid_score - expected_score).abs() < 0.001);
assert!((result.vector_score - 0.8).abs() < 0.001);
assert!((result.fts_score - 0.6).abs() < 0.001);
assert_eq!(result.item, item);
}
#[test]
fn test_hybrid_search_engine() {
let engine = HybridSearch::new();
let config = engine.config();
assert!(config.validate().is_ok());
assert!((config.vector_weight + config.fts_weight - 1.0).abs() < 0.001);
}
#[test]
fn test_search_episodes() {
let config = HybridSearchConfig::new(0.5, 0.5);
let engine = HybridSearch::with_config(config).unwrap();
let vector_results = vec![("item1".to_string(), 0.9), ("item2".to_string(), 0.7)];
let fts_results = vec![("item2".to_string(), 0.8), ("item3".to_string(), 0.6)];
let results = engine.search_episodes(vector_results, fts_results, 5);
assert_eq!(results.len(), 3);
assert_eq!(results[0].item, "item2");
assert!((results[0].hybrid_score - 0.75).abs() < 0.001);
assert_eq!(results[1].item, "item1");
assert!((results[1].hybrid_score - 0.45).abs() < 0.001);
assert_eq!(results[2].item, "item3");
assert!((results[2].hybrid_score - 0.30).abs() < 0.001);
}
#[test]
fn test_search_episodes_with_limit() {
let config = HybridSearchConfig::new(0.5, 0.5);
let engine = HybridSearch::with_config(config).unwrap();
let vector_results = vec![
("item1".to_string(), 0.9),
("item2".to_string(), 0.7),
("item3".to_string(), 0.5),
];
let fts_results = vec![
("item1".to_string(), 0.1),
("item2".to_string(), 0.8),
("item3".to_string(), 0.6),
];
let results = engine.search_episodes(vector_results, fts_results, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].item, "item2");
assert_eq!(results[1].item, "item3");
}
}