use std::collections::{HashMap, HashSet};
use serde::{Deserialize, Serialize};
use crate::embeddings::EmbeddingProvider;
use crate::storage::Storage;
use crate::types::MemoryRecord;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridSearchResult {
pub id: String,
pub score: f64,
pub vector_score: f64,
pub fts_score: f64,
pub record: Option<MemoryRecord>,
}
#[derive(Debug, Clone)]
pub struct HybridSearchOpts {
pub vector_weight: f64,
pub fts_weight: f64,
pub limit: usize,
pub namespace: Option<String>,
pub include_records: bool,
}
impl Default for HybridSearchOpts {
fn default() -> Self {
Self {
vector_weight: 0.7,
fts_weight: 0.3,
limit: 10,
namespace: None,
include_records: true,
}
}
}
pub fn hybrid_search(
storage: &Storage,
query_vector: Option<&[f32]>,
query_text: &str,
opts: HybridSearchOpts,
model: &str,
) -> Result<Vec<HybridSearchResult>, Box<dyn std::error::Error>> {
let ns = opts.namespace.as_deref();
let fetch_limit = opts.limit * 3;
let fts_results = storage.search_fts_ns(query_text, fetch_limit, ns)?;
let fts_count = fts_results.len();
let fts_scores: HashMap<String, f64> = fts_results
.iter()
.enumerate()
.map(|(rank, record)| {
let score = 1.0 - (rank as f64 / fetch_limit.max(1) as f64);
(record.id.clone(), score)
})
.collect();
let vector_scores: HashMap<String, f64> = if let Some(qvec) = query_vector {
let embeddings = storage.get_embeddings_in_namespace(ns, model)?;
embeddings
.iter()
.map(|(id, emb)| {
let sim = EmbeddingProvider::cosine_similarity(qvec, emb);
let score = (sim + 1.0) / 2.0;
(id.clone(), score as f64)
})
.collect()
} else {
HashMap::new()
};
let all_ids: HashSet<String> = fts_scores.keys()
.chain(vector_scores.keys())
.cloned()
.collect();
let mut results: Vec<HybridSearchResult> = all_ids
.into_iter()
.map(|id| {
let vs = vector_scores.get(&id).copied().unwrap_or(0.0);
let fs = fts_scores.get(&id).copied().unwrap_or(0.0);
let score = opts.vector_weight * vs + opts.fts_weight * fs;
HybridSearchResult {
id,
score,
vector_score: vs,
fts_score: fs,
record: None,
}
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(opts.limit);
if opts.include_records {
for result in &mut results {
result.record = storage.get(&result.id)?;
}
}
log::debug!(
"Hybrid search: {} FTS results, {} vector results, {} combined",
fts_count,
vector_scores.len(),
results.len()
);
Ok(results)
}
pub fn adaptive_hybrid_search(
storage: &Storage,
query_vector: Option<&[f32]>,
query_text: &str,
limit: usize,
model: &str,
) -> Result<Vec<HybridSearchResult>, Box<dyn std::error::Error>> {
let fetch_limit = limit * 3;
let fts_results = storage.search_fts_ns(query_text, fetch_limit, None)?;
let fts_ids: HashSet<String> = fts_results.iter().map(|r| r.id.clone()).collect();
let vector_scores: HashMap<String, f64> = if let Some(qvec) = query_vector {
let embeddings = storage.get_embeddings_in_namespace(None, model)?;
let mut scores: Vec<(String, f64)> = embeddings
.iter()
.map(|(id, emb)| {
let sim = EmbeddingProvider::cosine_similarity(qvec, emb);
let score = (sim + 1.0) / 2.0;
(id.clone(), score as f64)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(fetch_limit);
scores.into_iter().collect()
} else {
HashMap::new()
};
let vector_ids: HashSet<String> = vector_scores.keys().cloned().collect();
let (vector_weight, fts_weight) = if vector_ids.is_empty() {
(0.0, 1.0)
} else {
let intersection = fts_ids.intersection(&vector_ids).count();
let union = fts_ids.union(&vector_ids).count();
let jaccard = if union > 0 {
intersection as f64 / union as f64
} else {
0.0
};
if jaccard >= 0.6 {
(0.5, 0.5)
} else if jaccard >= 0.3 {
(0.6, 0.4)
} else {
(0.7, 0.3)
}
};
log::debug!(
"Adaptive weights: vector={:.2}, fts={:.2} (overlap={})",
vector_weight,
fts_weight,
fts_ids.intersection(&vector_ids).count()
);
let fts_scores: HashMap<String, f64> = fts_results
.iter()
.enumerate()
.map(|(rank, record)| {
let score = 1.0 - (rank as f64 / fetch_limit.max(1) as f64);
(record.id.clone(), score)
})
.collect();
let all_ids: HashSet<String> = fts_scores.keys()
.chain(vector_scores.keys())
.cloned()
.collect();
let mut results: Vec<HybridSearchResult> = all_ids
.into_iter()
.filter_map(|id| {
let vs = vector_scores.get(&id).copied().unwrap_or(0.0);
let fs = fts_scores.get(&id).copied().unwrap_or(0.0);
if vs == 0.0 && fs == 0.0 {
return None;
}
let score = vector_weight * vs + fts_weight * fs;
Some(HybridSearchResult {
id,
score,
vector_score: vs,
fts_score: fs,
record: None,
})
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
for result in &mut results {
result.record = storage.get(&result.id)?;
}
Ok(results)
}
pub fn reciprocal_rank_fusion(
storage: &Storage,
query_vector: Option<&[f32]>,
query_text: &str,
limit: usize,
k: f64, model: &str,
) -> Result<Vec<HybridSearchResult>, Box<dyn std::error::Error>> {
let fetch_limit = limit * 3;
let fts_results = storage.search_fts_ns(query_text, fetch_limit, None)?;
let fts_ranks: HashMap<String, usize> = fts_results
.iter()
.enumerate()
.map(|(rank, r)| (r.id.clone(), rank + 1)) .collect();
let vector_ranks: HashMap<String, usize> = if let Some(qvec) = query_vector {
let embeddings = storage.get_embeddings_in_namespace(None, model)?;
let mut scored: Vec<(String, f64)> = embeddings
.iter()
.map(|(id, emb)| {
let sim = EmbeddingProvider::cosine_similarity(qvec, emb);
(id.clone(), sim as f64)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.enumerate()
.map(|(rank, (id, _))| (id, rank + 1)) .collect()
} else {
HashMap::new()
};
let all_ids: HashSet<String> = fts_ranks.keys()
.chain(vector_ranks.keys())
.cloned()
.collect();
let mut results: Vec<HybridSearchResult> = all_ids
.into_iter()
.map(|id| {
let fts_contribution = fts_ranks.get(&id)
.map(|&rank| 1.0 / (k + rank as f64))
.unwrap_or(0.0);
let vector_contribution = vector_ranks.get(&id)
.map(|&rank| 1.0 / (k + rank as f64))
.unwrap_or(0.0);
let rrf_score = fts_contribution + vector_contribution;
let fts_score = fts_ranks.get(&id)
.map(|&rank| 1.0 - (rank as f64 / fetch_limit as f64))
.unwrap_or(0.0);
let vector_score = vector_ranks.get(&id)
.map(|&rank| 1.0 - (rank as f64 / fetch_limit as f64))
.unwrap_or(0.0);
HybridSearchResult {
id,
score: rrf_score,
vector_score,
fts_score,
record: None,
}
})
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
for result in &mut results {
result.record = storage.get(&result.id)?;
}
Ok(results)
}
pub fn jaccard_similarity(set_a: &HashSet<String>, set_b: &HashSet<String>) -> f64 {
if set_a.is_empty() && set_b.is_empty() {
return 1.0; }
let intersection = set_a.intersection(set_b).count();
let union = set_a.union(set_b).count();
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jaccard_similarity() {
let a: HashSet<String> = ["1", "2", "3"].iter().map(|s| s.to_string()).collect();
let b: HashSet<String> = ["2", "3", "4"].iter().map(|s| s.to_string()).collect();
let sim = jaccard_similarity(&a, &b);
assert!((sim - 0.5).abs() < 0.01);
}
#[test]
fn test_jaccard_identical() {
let a: HashSet<String> = ["1", "2", "3"].iter().map(|s| s.to_string()).collect();
let b = a.clone();
let sim = jaccard_similarity(&a, &b);
assert!((sim - 1.0).abs() < 0.01);
}
#[test]
fn test_jaccard_disjoint() {
let a: HashSet<String> = ["1", "2"].iter().map(|s| s.to_string()).collect();
let b: HashSet<String> = ["3", "4"].iter().map(|s| s.to_string()).collect();
let sim = jaccard_similarity(&a, &b);
assert!(sim.abs() < 0.01);
}
#[test]
fn test_jaccard_empty() {
let a: HashSet<String> = HashSet::new();
let b: HashSet<String> = HashSet::new();
let sim = jaccard_similarity(&a, &b);
assert!((sim - 1.0).abs() < 0.01);
}
#[test]
fn test_hybrid_search_opts_default() {
let opts = HybridSearchOpts::default();
assert!((opts.vector_weight - 0.7).abs() < 0.01);
assert!((opts.fts_weight - 0.3).abs() < 0.01);
assert_eq!(opts.limit, 10);
assert!(opts.include_records);
}
}