#![allow(clippy::cast_precision_loss)]
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use tokio::sync::RwLock;
use crate::error::VectorStoreError;
use crate::types::DocumentId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseVector {
pub indices: Vec<usize>,
pub values: Vec<f32>,
pub dimension: usize,
}
impl SparseVector {
#[must_use]
pub fn new(indices: Vec<usize>, values: Vec<f32>, dimension: usize) -> Self {
assert_eq!(
indices.len(),
values.len(),
"Indices and values must have the same length"
);
Self {
indices,
values,
dimension,
}
}
#[must_use]
pub fn empty(dimension: usize) -> Self {
Self {
indices: Vec::new(),
values: Vec::new(),
dimension,
}
}
#[must_use]
pub fn nnz(&self) -> usize {
self.indices.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.indices.is_empty()
}
#[must_use]
pub fn dot(&self, other: &Self) -> f32 {
let mut result = 0.0;
let mut i = 0;
let mut j = 0;
while i < self.indices.len() && j < other.indices.len() {
match self.indices[i].cmp(&other.indices[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result += self.values[i] * other.values[j];
i += 1;
j += 1;
}
}
}
result
}
#[must_use]
pub fn norm(&self) -> f32 {
self.values.iter().map(|v| v * v).sum::<f32>().sqrt()
}
#[must_use]
pub fn cosine_similarity(&self, other: &Self) -> f32 {
let dot = self.dot(other);
let norm_a = self.norm();
let norm_b = other.norm();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[must_use]
pub fn to_dense(&self) -> Vec<f32> {
let mut dense = vec![0.0; self.dimension];
for (idx, val) in self.indices.iter().zip(self.values.iter()) {
if *idx < self.dimension {
dense[*idx] = *val;
}
}
dense
}
#[must_use]
pub fn from_dense(dense: &[f32]) -> Self {
let mut indices = Vec::new();
let mut values = Vec::new();
for (i, &v) in dense.iter().enumerate() {
if v.abs() > f32::EPSILON {
indices.push(i);
values.push(v);
}
}
Self {
indices,
values,
dimension: dense.len(),
}
}
}
impl Default for SparseVector {
fn default() -> Self {
Self::empty(0)
}
}
#[async_trait]
pub trait SparseVectorStore: Send + Sync {
async fn insert(
&mut self,
id: DocumentId,
vector: SparseVector,
) -> Result<(), VectorStoreError>;
async fn search(
&self,
query: &SparseVector,
top_k: usize,
) -> Result<Vec<(DocumentId, f32)>, VectorStoreError>;
async fn get(&self, id: &DocumentId) -> Result<Option<SparseVector>, VectorStoreError>;
async fn delete(&mut self, id: &DocumentId) -> Result<bool, VectorStoreError>;
async fn count(&self) -> usize;
async fn clear(&mut self) -> Result<(), VectorStoreError>;
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BM25Params {
pub k1: f32,
pub b: f32,
pub delta: f32,
}
impl Default for BM25Params {
fn default() -> Self {
Self {
k1: 1.5,
b: 0.75,
delta: 0.5,
}
}
}
pub struct BM25Encoder {
vocabulary: HashMap<String, usize>,
idf: HashMap<String, f32>,
document_frequencies: HashMap<String, usize>,
total_documents: usize,
average_doc_length: f32,
total_terms: usize,
params: BM25Params,
vocab_size: usize,
}
impl BM25Encoder {
#[must_use]
pub fn new() -> Self {
Self {
vocabulary: HashMap::new(),
idf: HashMap::new(),
document_frequencies: HashMap::new(),
total_documents: 0,
average_doc_length: 0.0,
total_terms: 0,
params: BM25Params::default(),
vocab_size: 0,
}
}
#[must_use]
pub fn with_params(params: BM25Params) -> Self {
Self {
vocabulary: HashMap::new(),
idf: HashMap::new(),
document_frequencies: HashMap::new(),
total_documents: 0,
average_doc_length: 0.0,
total_terms: 0,
params,
vocab_size: 0,
}
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
#[must_use]
pub fn total_documents(&self) -> usize {
self.total_documents
}
#[must_use]
pub fn params(&self) -> &BM25Params {
&self.params
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1)
.map(String::from)
.collect()
}
pub fn fit(&mut self, documents: &[&str]) {
self.vocabulary.clear();
self.document_frequencies.clear();
self.idf.clear();
self.total_documents = documents.len();
self.total_terms = 0;
for doc in documents {
let terms = Self::tokenize(doc);
self.total_terms += terms.len();
let unique_terms: HashSet<_> = terms.into_iter().collect();
for term in unique_terms {
if !self.vocabulary.contains_key(&term) {
let idx = self.vocabulary.len();
self.vocabulary.insert(term.clone(), idx);
}
*self.document_frequencies.entry(term).or_insert(0) += 1;
}
}
self.vocab_size = self.vocabulary.len();
self.average_doc_length = if self.total_documents > 0 {
self.total_terms as f32 / self.total_documents as f32
} else {
0.0
};
let n = self.total_documents as f32;
for (term, df) in &self.document_frequencies {
let df_f = *df as f32;
let idf = ((n + 1.0) / (df_f + self.params.delta)).ln();
self.idf.insert(term.clone(), idf.max(f32::EPSILON));
}
}
pub fn fit_owned(&mut self, documents: &[String]) {
let refs: Vec<&str> = documents.iter().map(String::as_str).collect();
self.fit(&refs);
}
#[must_use]
pub fn encode(&self, text: &str) -> SparseVector {
if self.vocab_size == 0 {
return SparseVector::empty(0);
}
let terms = Self::tokenize(text);
let doc_length = terms.len() as f32;
let mut term_freqs: HashMap<String, usize> = HashMap::new();
for term in terms {
*term_freqs.entry(term).or_insert(0) += 1;
}
let mut indices = Vec::new();
let mut values = Vec::new();
for (term, tf) in term_freqs {
if let Some(&idx) = self.vocabulary.get(&term) {
let idf = self.idf.get(&term).copied().unwrap_or(0.0);
let tf_f = tf as f32;
let length_norm = 1.0 - self.params.b
+ self.params.b * doc_length / self.average_doc_length.max(1.0);
let tf_score =
(tf_f * (self.params.k1 + 1.0)) / (tf_f + self.params.k1 * length_norm);
let score = idf * tf_score;
if score > 0.0 {
indices.push(idx);
values.push(score);
}
}
}
let mut pairs: Vec<_> = indices.into_iter().zip(values).collect();
pairs.sort_by_key(|(idx, _)| *idx);
let (indices, values): (Vec<_>, Vec<_>) = pairs.into_iter().unzip();
SparseVector::new(indices, values, self.vocab_size)
}
#[must_use]
pub fn encode_batch(&self, texts: &[&str]) -> Vec<SparseVector> {
texts.iter().map(|text| self.encode(text)).collect()
}
#[must_use]
pub fn encode_batch_owned(&self, texts: &[String]) -> Vec<SparseVector> {
texts.iter().map(|text| self.encode(text)).collect()
}
}
impl Default for BM25Encoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridResult {
pub document_id: DocumentId,
pub dense_score: f32,
pub sparse_score: f32,
pub combined_score: f32,
pub rank: usize,
}
impl HybridResult {
#[must_use]
pub fn new(
document_id: DocumentId,
dense_score: f32,
sparse_score: f32,
combined_score: f32,
rank: usize,
) -> Self {
Self {
document_id,
dense_score,
sparse_score,
combined_score,
rank,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum FusionStrategy {
WeightedSum {
dense_weight: f32,
sparse_weight: f32,
},
ReciprocalRankFusion {
k: usize,
},
DistributionBased {
dense_weight: f32,
sparse_weight: f32,
},
}
impl Default for FusionStrategy {
fn default() -> Self {
Self::WeightedSum {
dense_weight: 0.5,
sparse_weight: 0.5,
}
}
}
impl FusionStrategy {
#[must_use]
pub fn weighted_sum(dense_weight: f32, sparse_weight: f32) -> Self {
Self::WeightedSum {
dense_weight,
sparse_weight,
}
}
#[must_use]
pub fn rrf(k: usize) -> Self {
Self::ReciprocalRankFusion { k }
}
#[must_use]
pub fn distribution_based(dense_weight: f32, sparse_weight: f32) -> Self {
Self::DistributionBased {
dense_weight,
sparse_weight,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HybridConfig {
pub fusion_strategy: FusionStrategy,
pub dense_weight: f32,
pub sparse_weight: f32,
pub normalize_scores: bool,
pub min_score: Option<f32>,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
fusion_strategy: FusionStrategy::default(),
dense_weight: 0.5,
sparse_weight: 0.5,
normalize_scores: true,
min_score: None,
}
}
}
impl HybridConfig {
#[must_use]
pub fn weighted_sum(dense_weight: f32, sparse_weight: f32) -> Self {
Self {
fusion_strategy: FusionStrategy::weighted_sum(dense_weight, sparse_weight),
dense_weight,
sparse_weight,
normalize_scores: true,
min_score: None,
}
}
#[must_use]
pub fn rrf(k: usize) -> Self {
Self {
fusion_strategy: FusionStrategy::rrf(k),
dense_weight: 0.5,
sparse_weight: 0.5,
normalize_scores: false,
min_score: None,
}
}
#[must_use]
pub fn with_min_score(mut self, min_score: f32) -> Self {
self.min_score = Some(min_score);
self
}
#[must_use]
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize_scores = normalize;
self
}
}
pub struct HybridSearcher<S: SparseVectorStore> {
sparse_store: S,
encoder: BM25Encoder,
config: HybridConfig,
}
impl<S: SparseVectorStore> HybridSearcher<S> {
#[must_use]
pub fn new(sparse_store: S, encoder: BM25Encoder, config: HybridConfig) -> Self {
Self {
sparse_store,
encoder,
config,
}
}
#[must_use]
pub fn sparse_store(&self) -> &S {
&self.sparse_store
}
pub fn sparse_store_mut(&mut self) -> &mut S {
&mut self.sparse_store
}
#[must_use]
pub fn encoder(&self) -> &BM25Encoder {
&self.encoder
}
#[must_use]
pub fn config(&self) -> &HybridConfig {
&self.config
}
pub async fn search(
&self,
query: &str,
dense_results: &[(DocumentId, f32)],
top_k: usize,
) -> Result<Vec<HybridResult>, VectorStoreError> {
let sparse_query = self.encoder.encode(query);
let sparse_results = self.sparse_store.search(&sparse_query, top_k * 2).await?;
let results = self.fuse_results(dense_results, &sparse_results, top_k);
Ok(results)
}
pub async fn search_with_sparse(
&self,
sparse_query: &SparseVector,
dense_results: &[(DocumentId, f32)],
top_k: usize,
) -> Result<Vec<HybridResult>, VectorStoreError> {
let sparse_results = self.sparse_store.search(sparse_query, top_k * 2).await?;
let results = self.fuse_results(dense_results, &sparse_results, top_k);
Ok(results)
}
fn fuse_results(
&self,
dense_results: &[(DocumentId, f32)],
sparse_results: &[(DocumentId, f32)],
top_k: usize,
) -> Vec<HybridResult> {
match &self.config.fusion_strategy {
FusionStrategy::WeightedSum {
dense_weight,
sparse_weight,
} => self.fuse_weighted_sum(
dense_results,
sparse_results,
*dense_weight,
*sparse_weight,
top_k,
),
FusionStrategy::ReciprocalRankFusion { k } => {
self.fuse_rrf(dense_results, sparse_results, *k, top_k)
}
FusionStrategy::DistributionBased {
dense_weight,
sparse_weight,
} => self.fuse_distribution_based(
dense_results,
sparse_results,
*dense_weight,
*sparse_weight,
top_k,
),
}
}
fn fuse_weighted_sum(
&self,
dense_results: &[(DocumentId, f32)],
sparse_results: &[(DocumentId, f32)],
dense_weight: f32,
sparse_weight: f32,
top_k: usize,
) -> Vec<HybridResult> {
let mut scores: HashMap<DocumentId, (f32, f32)> = HashMap::new();
let (dense_min, dense_max) = if self.config.normalize_scores {
Self::score_range(dense_results)
} else {
(0.0, 1.0)
};
let (sparse_min, sparse_max) = if self.config.normalize_scores {
Self::score_range(sparse_results)
} else {
(0.0, 1.0)
};
for (id, score) in dense_results {
let normalized = Self::normalize_score(*score, dense_min, dense_max);
scores.entry(id.clone()).or_insert((0.0, 0.0)).0 = normalized;
}
for (id, score) in sparse_results {
let normalized = Self::normalize_score(*score, sparse_min, sparse_max);
scores.entry(id.clone()).or_insert((0.0, 0.0)).1 = normalized;
}
let mut results: Vec<HybridResult> = scores
.into_iter()
.map(|(id, (dense, sparse))| {
let combined = dense_weight * dense + sparse_weight * sparse;
HybridResult::new(id, dense, sparse, combined, 0)
})
.filter(|r| {
self.config
.min_score
.is_none_or(|min| r.combined_score >= min)
})
.collect();
results.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
results.truncate(top_k);
results
}
fn fuse_rrf(
&self,
dense_results: &[(DocumentId, f32)],
sparse_results: &[(DocumentId, f32)],
k: usize,
top_k: usize,
) -> Vec<HybridResult> {
let mut rrf_scores: HashMap<DocumentId, (f32, f32, f32)> = HashMap::new();
let mut dense_sorted: Vec<_> = dense_results.to_vec();
dense_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut sparse_sorted: Vec<_> = sparse_results.to_vec();
sparse_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for (rank, (id, score)) in dense_sorted.iter().enumerate() {
let rrf = 1.0 / (k as f32 + rank as f32 + 1.0);
let entry = rrf_scores.entry(id.clone()).or_insert((0.0, 0.0, 0.0));
entry.0 = *score; entry.2 += rrf; }
for (rank, (id, score)) in sparse_sorted.iter().enumerate() {
let rrf = 1.0 / (k as f32 + rank as f32 + 1.0);
let entry = rrf_scores.entry(id.clone()).or_insert((0.0, 0.0, 0.0));
entry.1 = *score; entry.2 += rrf; }
let mut results: Vec<HybridResult> = rrf_scores
.into_iter()
.map(|(id, (dense, sparse, combined))| {
HybridResult::new(id, dense, sparse, combined, 0)
})
.filter(|r| {
self.config
.min_score
.is_none_or(|min| r.combined_score >= min)
})
.collect();
results.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
results.truncate(top_k);
results
}
fn fuse_distribution_based(
&self,
dense_results: &[(DocumentId, f32)],
sparse_results: &[(DocumentId, f32)],
dense_weight: f32,
sparse_weight: f32,
top_k: usize,
) -> Vec<HybridResult> {
let mut scores: HashMap<DocumentId, (f32, f32)> = HashMap::new();
let (dense_mean, dense_std) = Self::score_stats(dense_results);
let (sparse_mean, sparse_std) = Self::score_stats(sparse_results);
for (id, score) in dense_results {
let normalized = Self::z_normalize(*score, dense_mean, dense_std);
scores.entry(id.clone()).or_insert((0.0, 0.0)).0 = normalized;
}
for (id, score) in sparse_results {
let normalized = Self::z_normalize(*score, sparse_mean, sparse_std);
scores.entry(id.clone()).or_insert((0.0, 0.0)).1 = normalized;
}
let mut results: Vec<HybridResult> = scores
.into_iter()
.map(|(id, (dense, sparse))| {
let combined = dense_weight * dense + sparse_weight * sparse;
HybridResult::new(id, dense, sparse, combined, 0)
})
.filter(|r| {
self.config
.min_score
.is_none_or(|min| r.combined_score >= min)
})
.collect();
results.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
for (i, result) in results.iter_mut().enumerate() {
result.rank = i;
}
results.truncate(top_k);
results
}
fn score_range(results: &[(DocumentId, f32)]) -> (f32, f32) {
if results.is_empty() {
return (0.0, 1.0);
}
let min = results
.iter()
.map(|(_, s)| *s)
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let max = results
.iter()
.map(|(_, s)| *s)
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(1.0);
(min, max)
}
fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
if (max - min).abs() < f32::EPSILON {
return 0.5;
}
(score - min) / (max - min)
}
fn score_stats(results: &[(DocumentId, f32)]) -> (f32, f32) {
if results.is_empty() {
return (0.0, 1.0);
}
let scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
let n = scores.len() as f32;
let mean = scores.iter().sum::<f32>() / n;
let variance = scores.iter().map(|s| (s - mean).powi(2)).sum::<f32>() / n;
let std = variance.sqrt().max(f32::EPSILON);
(mean, std)
}
fn z_normalize(score: f32, mean: f32, std: f32) -> f32 {
(score - mean) / std
}
}
pub struct InMemorySparseStore {
vectors: RwLock<HashMap<DocumentId, SparseVector>>,
inverted_index: RwLock<HashMap<usize, HashSet<DocumentId>>>,
}
impl InMemorySparseStore {
#[must_use]
pub fn new() -> Self {
Self {
vectors: RwLock::new(HashMap::new()),
inverted_index: RwLock::new(HashMap::new()),
}
}
}
impl Default for InMemorySparseStore {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SparseVectorStore for InMemorySparseStore {
async fn insert(
&mut self,
id: DocumentId,
vector: SparseVector,
) -> Result<(), VectorStoreError> {
{
let mut index = self.inverted_index.write().await;
for &idx in &vector.indices {
index.entry(idx).or_default().insert(id.clone());
}
}
{
let mut vectors = self.vectors.write().await;
vectors.insert(id, vector);
}
Ok(())
}
async fn search(
&self,
query: &SparseVector,
top_k: usize,
) -> Result<Vec<(DocumentId, f32)>, VectorStoreError> {
let vectors = self.vectors.read().await;
let inverted_index = self.inverted_index.read().await;
let mut candidates: HashSet<DocumentId> = HashSet::new();
for &idx in &query.indices {
if let Some(doc_ids) = inverted_index.get(&idx) {
candidates.extend(doc_ids.iter().cloned());
}
}
let mut scored: Vec<(DocumentId, f32)> = candidates
.into_iter()
.filter_map(|id| {
vectors.get(&id).map(|vec| {
let score = query.dot(vec);
(id, score)
})
})
.filter(|(_, score)| *score > 0.0)
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored)
}
async fn get(&self, id: &DocumentId) -> Result<Option<SparseVector>, VectorStoreError> {
let vectors = self.vectors.read().await;
Ok(vectors.get(id).cloned())
}
async fn delete(&mut self, id: &DocumentId) -> Result<bool, VectorStoreError> {
let vector = {
let mut vectors = self.vectors.write().await;
vectors.remove(id)
};
if let Some(vec) = vector {
let mut index = self.inverted_index.write().await;
for &idx in &vec.indices {
if let Some(doc_ids) = index.get_mut(&idx) {
doc_ids.remove(id);
if doc_ids.is_empty() {
index.remove(&idx);
}
}
}
Ok(true)
} else {
Ok(false)
}
}
async fn count(&self) -> usize {
self.vectors.read().await.len()
}
async fn clear(&mut self) -> Result<(), VectorStoreError> {
self.vectors.write().await.clear();
self.inverted_index.write().await.clear();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_vector_new() {
let v = SparseVector::new(vec![0, 2, 5], vec![1.0, 2.0, 3.0], 10);
assert_eq!(v.indices, vec![0, 2, 5]);
assert_eq!(v.values, vec![1.0, 2.0, 3.0]);
assert_eq!(v.dimension, 10);
assert_eq!(v.nnz(), 3);
}
#[test]
fn test_sparse_vector_empty() {
let v = SparseVector::empty(100);
assert!(v.is_empty());
assert_eq!(v.nnz(), 0);
assert_eq!(v.dimension, 100);
}
#[test]
fn test_sparse_vector_dot() {
let a = SparseVector::new(vec![0, 2, 4], vec![1.0, 2.0, 3.0], 10);
let b = SparseVector::new(vec![1, 2, 4], vec![1.0, 2.0, 1.0], 10);
assert!((a.dot(&b) - 7.0).abs() < 1e-6);
}
#[test]
fn test_sparse_vector_norm() {
let v = SparseVector::new(vec![0, 1], vec![3.0, 4.0], 10);
assert!((v.norm() - 5.0).abs() < 1e-6);
}
#[test]
fn test_sparse_vector_cosine_similarity() {
let a = SparseVector::new(vec![0, 1], vec![1.0, 0.0], 10);
let b = SparseVector::new(vec![0, 1], vec![1.0, 0.0], 10);
assert!((a.cosine_similarity(&b) - 1.0).abs() < 1e-6);
let c = SparseVector::new(vec![0, 1], vec![0.0, 1.0], 10);
assert!(a.cosine_similarity(&c).abs() < 1e-6);
}
#[test]
fn test_sparse_vector_to_dense() {
let v = SparseVector::new(vec![0, 2, 4], vec![1.0, 2.0, 3.0], 5);
let dense = v.to_dense();
assert_eq!(dense, vec![1.0, 0.0, 2.0, 0.0, 3.0]);
}
#[test]
fn test_sparse_vector_from_dense() {
let dense = vec![1.0, 0.0, 2.0, 0.0, 3.0];
let sparse = SparseVector::from_dense(&dense);
assert_eq!(sparse.indices, vec![0, 2, 4]);
assert_eq!(sparse.values, vec![1.0, 2.0, 3.0]);
assert_eq!(sparse.dimension, 5);
}
#[test]
fn test_bm25_encoder_tokenize() {
let tokens = BM25Encoder::tokenize("Hello, World! This is a test.");
assert!(tokens.contains(&"hello".to_string()));
assert!(tokens.contains(&"world".to_string()));
assert!(tokens.contains(&"this".to_string()));
assert!(tokens.contains(&"test".to_string()));
assert!(tokens.contains(&"is".to_string())); assert!(!tokens.contains(&"a".to_string()));
}
#[test]
fn test_bm25_encoder_fit() {
let mut encoder = BM25Encoder::new();
let docs = vec!["the quick brown fox", "the lazy dog", "quick brown dog"];
encoder.fit(&docs);
assert_eq!(encoder.total_documents(), 3);
assert!(encoder.vocab_size() > 0);
}
#[test]
fn test_bm25_encoder_encode() {
let mut encoder = BM25Encoder::new();
let docs = vec![
"the quick brown fox jumps",
"the lazy dog sleeps",
"quick brown dog runs",
];
encoder.fit(&docs);
let query = "quick brown";
let sparse = encoder.encode(query);
assert!(!sparse.is_empty());
assert!(sparse.nnz() > 0);
}
#[test]
fn test_bm25_encoder_encode_batch() {
let mut encoder = BM25Encoder::new();
let docs = vec!["document one", "document two"];
encoder.fit(&docs);
let queries = vec!["query one", "query two"];
let vectors = encoder.encode_batch(&queries);
assert_eq!(vectors.len(), 2);
}
#[test]
fn test_bm25_params_default() {
let params = BM25Params::default();
assert!((params.k1 - 1.5).abs() < 1e-6);
assert!((params.b - 0.75).abs() < 1e-6);
}
#[test]
fn test_hybrid_result_new() {
let result = HybridResult::new(DocumentId::from("doc1"), 0.8, 0.6, 0.7, 0);
assert_eq!(result.document_id.as_str(), "doc1");
assert!((result.dense_score - 0.8).abs() < 1e-6);
assert!((result.sparse_score - 0.6).abs() < 1e-6);
assert!((result.combined_score - 0.7).abs() < 1e-6);
assert_eq!(result.rank, 0);
}
#[test]
fn test_fusion_strategy_default() {
let strategy = FusionStrategy::default();
match strategy {
FusionStrategy::WeightedSum {
dense_weight,
sparse_weight,
} => {
assert!((dense_weight - 0.5).abs() < 1e-6);
assert!((sparse_weight - 0.5).abs() < 1e-6);
}
_ => panic!("Expected WeightedSum as default"),
}
}
#[test]
fn test_fusion_strategy_rrf() {
let strategy = FusionStrategy::rrf(60);
match strategy {
FusionStrategy::ReciprocalRankFusion { k } => {
assert_eq!(k, 60);
}
_ => panic!("Expected ReciprocalRankFusion"),
}
}
#[test]
fn test_hybrid_config_default() {
let config = HybridConfig::default();
assert!((config.dense_weight - 0.5).abs() < 1e-6);
assert!((config.sparse_weight - 0.5).abs() < 1e-6);
assert!(config.normalize_scores);
assert!(config.min_score.is_none());
}
#[test]
fn test_hybrid_config_weighted_sum() {
let config = HybridConfig::weighted_sum(0.7, 0.3);
assert!((config.dense_weight - 0.7).abs() < 1e-6);
assert!((config.sparse_weight - 0.3).abs() < 1e-6);
}
#[test]
fn test_hybrid_config_rrf() {
let config = HybridConfig::rrf(60);
match config.fusion_strategy {
FusionStrategy::ReciprocalRankFusion { k } => {
assert_eq!(k, 60);
}
_ => panic!("Expected RRF fusion strategy"),
}
}
#[tokio::test]
async fn test_in_memory_sparse_store_insert_and_get() {
let mut store = InMemorySparseStore::new();
let id = DocumentId::from("doc1");
let vector = SparseVector::new(vec![0, 2, 4], vec![1.0, 2.0, 3.0], 10);
store.insert(id.clone(), vector.clone()).await.unwrap();
let retrieved = store.get(&id).await.unwrap();
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.indices, vector.indices);
assert_eq!(retrieved.values, vector.values);
}
#[tokio::test]
async fn test_in_memory_sparse_store_search() {
let mut store = InMemorySparseStore::new();
store
.insert(
DocumentId::from("doc1"),
SparseVector::new(vec![0, 1], vec![1.0, 0.5], 10),
)
.await
.unwrap();
store
.insert(
DocumentId::from("doc2"),
SparseVector::new(vec![0, 2], vec![0.5, 1.0], 10),
)
.await
.unwrap();
store
.insert(
DocumentId::from("doc3"),
SparseVector::new(vec![3, 4], vec![1.0, 1.0], 10),
)
.await
.unwrap();
let query = SparseVector::new(vec![0, 1], vec![1.0, 1.0], 10);
let results = store.search(&query, 10).await.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0.as_str(), "doc1");
assert!((results[0].1 - 1.5).abs() < 1e-6);
}
#[tokio::test]
async fn test_in_memory_sparse_store_delete() {
let mut store = InMemorySparseStore::new();
let id = DocumentId::from("doc1");
let vector = SparseVector::new(vec![0, 2], vec![1.0, 2.0], 10);
store.insert(id.clone(), vector).await.unwrap();
assert_eq!(store.count().await, 1);
let deleted = store.delete(&id).await.unwrap();
assert!(deleted);
assert_eq!(store.count().await, 0);
let deleted = store.delete(&id).await.unwrap();
assert!(!deleted);
}
#[tokio::test]
async fn test_in_memory_sparse_store_clear() {
let mut store = InMemorySparseStore::new();
store
.insert(
DocumentId::from("doc1"),
SparseVector::new(vec![0], vec![1.0], 10),
)
.await
.unwrap();
store
.insert(
DocumentId::from("doc2"),
SparseVector::new(vec![1], vec![1.0], 10),
)
.await
.unwrap();
assert_eq!(store.count().await, 2);
store.clear().await.unwrap();
assert_eq!(store.count().await, 0);
}
#[tokio::test]
async fn test_hybrid_searcher_weighted_sum() {
let mut encoder = BM25Encoder::new();
let docs = vec!["quick brown fox", "lazy dog", "brown dog"];
encoder.fit(&docs);
let mut store = InMemorySparseStore::new();
for (i, doc) in docs.iter().enumerate() {
let id = DocumentId::from(format!("doc{}", i + 1));
let vector = encoder.encode(doc);
store.insert(id, vector).await.unwrap();
}
let config = HybridConfig::weighted_sum(0.5, 0.5);
let searcher = HybridSearcher::new(store, encoder, config);
let dense_results = vec![
(DocumentId::from("doc1"), 0.9),
(DocumentId::from("doc2"), 0.5),
(DocumentId::from("doc3"), 0.7),
];
let results = searcher
.search("brown fox", &dense_results, 3)
.await
.unwrap();
assert!(!results.is_empty());
for i in 1..results.len() {
assert!(results[i - 1].combined_score >= results[i].combined_score);
}
}
#[tokio::test]
async fn test_hybrid_searcher_rrf() {
let mut encoder = BM25Encoder::new();
let docs = vec!["quick brown fox", "lazy dog", "brown dog"];
encoder.fit(&docs);
let mut store = InMemorySparseStore::new();
for (i, doc) in docs.iter().enumerate() {
let id = DocumentId::from(format!("doc{}", i + 1));
let vector = encoder.encode(doc);
store.insert(id, vector).await.unwrap();
}
let config = HybridConfig::rrf(60);
let searcher = HybridSearcher::new(store, encoder, config);
let dense_results = vec![
(DocumentId::from("doc1"), 0.9),
(DocumentId::from("doc3"), 0.7),
(DocumentId::from("doc2"), 0.5),
];
let results = searcher.search("brown", &dense_results, 3).await.unwrap();
assert!(!results.is_empty());
for i in 1..results.len() {
assert!(results[i - 1].combined_score >= results[i].combined_score);
}
}
#[tokio::test]
async fn test_hybrid_searcher_distribution_based() {
let mut encoder = BM25Encoder::new();
let docs = vec!["quick brown fox", "lazy dog", "brown dog"];
encoder.fit(&docs);
let mut store = InMemorySparseStore::new();
for (i, doc) in docs.iter().enumerate() {
let id = DocumentId::from(format!("doc{}", i + 1));
let vector = encoder.encode(doc);
store.insert(id, vector).await.unwrap();
}
let config = HybridConfig {
fusion_strategy: FusionStrategy::distribution_based(0.5, 0.5),
dense_weight: 0.5,
sparse_weight: 0.5,
normalize_scores: true,
min_score: None,
};
let searcher = HybridSearcher::new(store, encoder, config);
let dense_results = vec![
(DocumentId::from("doc1"), 0.9),
(DocumentId::from("doc2"), 0.5),
(DocumentId::from("doc3"), 0.7),
];
let results = searcher
.search("brown fox", &dense_results, 3)
.await
.unwrap();
assert!(!results.is_empty());
}
#[tokio::test]
async fn test_hybrid_searcher_with_min_score() {
let mut encoder = BM25Encoder::new();
let docs = vec!["quick brown fox", "lazy dog"];
encoder.fit(&docs);
let mut store = InMemorySparseStore::new();
for (i, doc) in docs.iter().enumerate() {
let id = DocumentId::from(format!("doc{}", i + 1));
let vector = encoder.encode(doc);
store.insert(id, vector).await.unwrap();
}
let config = HybridConfig::weighted_sum(0.5, 0.5).with_min_score(0.8);
let searcher = HybridSearcher::new(store, encoder, config);
let dense_results = vec![
(DocumentId::from("doc1"), 0.9),
(DocumentId::from("doc2"), 0.3),
];
let results = searcher
.search("quick fox", &dense_results, 10)
.await
.unwrap();
for result in &results {
assert!(result.combined_score >= 0.8);
}
}
#[test]
fn test_normalize_score() {
let score = HybridSearcher::<InMemorySparseStore>::normalize_score(0.5, 0.0, 1.0);
assert!((score - 0.5).abs() < 1e-6);
let score = HybridSearcher::<InMemorySparseStore>::normalize_score(5.0, 0.0, 10.0);
assert!((score - 0.5).abs() < 1e-6);
let score = HybridSearcher::<InMemorySparseStore>::normalize_score(5.0, 5.0, 5.0);
assert!((score - 0.5).abs() < 1e-6);
}
#[test]
fn test_score_stats() {
let results = vec![
(DocumentId::from("a"), 1.0),
(DocumentId::from("b"), 2.0),
(DocumentId::from("c"), 3.0),
];
let (mean, std) = HybridSearcher::<InMemorySparseStore>::score_stats(&results);
assert!((mean - 2.0).abs() < 1e-6);
let expected_std = (2.0_f32 / 3.0).sqrt();
assert!((std - expected_std).abs() < 1e-6);
}
#[test]
fn test_z_normalize() {
let normalized = HybridSearcher::<InMemorySparseStore>::z_normalize(5.0, 3.0, 2.0);
assert!((normalized - 1.0).abs() < 1e-6);
}
#[test]
fn test_bm25_encoder_with_custom_params() {
let params = BM25Params {
k1: 2.0,
b: 0.5,
delta: 1.0,
};
let encoder = BM25Encoder::with_params(params);
assert!((encoder.params().k1 - 2.0).abs() < 1e-6);
assert!((encoder.params().b - 0.5).abs() < 1e-6);
}
#[test]
fn test_sparse_vector_default() {
let v = SparseVector::default();
assert!(v.is_empty());
assert_eq!(v.dimension, 0);
}
#[tokio::test]
async fn test_hybrid_searcher_empty_results() {
let encoder = BM25Encoder::new();
let store = InMemorySparseStore::new();
let config = HybridConfig::default();
let searcher = HybridSearcher::new(store, encoder, config);
let results = searcher.search("query", &[], 10).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_hybrid_searcher_search_with_sparse() {
let mut encoder = BM25Encoder::new();
let docs = vec!["hello world", "world peace"];
encoder.fit(&docs);
let mut store = InMemorySparseStore::new();
for (i, doc) in docs.iter().enumerate() {
let id = DocumentId::from(format!("doc{}", i + 1));
let vector = encoder.encode(doc);
store.insert(id, vector).await.unwrap();
}
let config = HybridConfig::default();
let searcher = HybridSearcher::new(store, encoder, config);
let sparse_query = searcher.encoder().encode("world");
let dense_results = vec![
(DocumentId::from("doc1"), 0.8),
(DocumentId::from("doc2"), 0.6),
];
let results = searcher
.search_with_sparse(&sparse_query, &dense_results, 10)
.await
.unwrap();
assert!(!results.is_empty());
}
}