use super::types::SparseVec;
use std::cmp::Ordering;
#[inline]
pub fn sparse_dot(a: &SparseVec, b: &SparseVec) -> f32 {
let mut result = 0.0;
let mut i = 0;
let mut j = 0;
let a_indices = a.indices();
let b_indices = b.indices();
let a_values = a.values();
let b_values = b.values();
while i < a_indices.len() && j < b_indices.len() {
match a_indices[i].cmp(&b_indices[j]) {
Ordering::Less => i += 1,
Ordering::Greater => j += 1,
Ordering::Equal => {
result += a_values[i] * b_values[j];
i += 1;
j += 1;
}
}
}
result
}
#[inline]
pub fn sparse_cosine(a: &SparseVec, b: &SparseVec) -> f32 {
let dot = sparse_dot(a, b);
let norm_a = a.norm();
let norm_b = b.norm();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[inline]
pub fn sparse_euclidean(a: &SparseVec, b: &SparseVec) -> f32 {
let mut result = 0.0;
let mut i = 0;
let mut j = 0;
let a_indices = a.indices();
let b_indices = b.indices();
let a_values = a.values();
let b_values = b.values();
while i < a_indices.len() || j < b_indices.len() {
let idx_a = a_indices.get(i).copied().unwrap_or(u32::MAX);
let idx_b = b_indices.get(j).copied().unwrap_or(u32::MAX);
match idx_a.cmp(&idx_b) {
Ordering::Less => {
result += a_values[i] * a_values[i];
i += 1;
}
Ordering::Greater => {
result += b_values[j] * b_values[j];
j += 1;
}
Ordering::Equal => {
let diff = a_values[i] - b_values[j];
result += diff * diff;
i += 1;
j += 1;
}
}
}
result.sqrt()
}
#[inline]
pub fn sparse_manhattan(a: &SparseVec, b: &SparseVec) -> f32 {
let mut result = 0.0;
let mut i = 0;
let mut j = 0;
let a_indices = a.indices();
let b_indices = b.indices();
let a_values = a.values();
let b_values = b.values();
while i < a_indices.len() || j < b_indices.len() {
let idx_a = a_indices.get(i).copied().unwrap_or(u32::MAX);
let idx_b = b_indices.get(j).copied().unwrap_or(u32::MAX);
match idx_a.cmp(&idx_b) {
Ordering::Less => {
result += a_values[i].abs();
i += 1;
}
Ordering::Greater => {
result += b_values[j].abs();
j += 1;
}
Ordering::Equal => {
result += (a_values[i] - b_values[j]).abs();
i += 1;
j += 1;
}
}
}
result
}
#[inline]
pub fn sparse_bm25(
query: &SparseVec,
doc: &SparseVec,
doc_len: f32,
avg_doc_len: f32,
k1: f32,
b: f32,
) -> f32 {
let mut score = 0.0;
let mut i = 0;
let mut j = 0;
let q_indices = query.indices();
let d_indices = doc.indices();
let q_values = query.values();
let d_values = doc.values();
while i < q_indices.len() && j < d_indices.len() {
match q_indices[i].cmp(&d_indices[j]) {
Ordering::Less => i += 1,
Ordering::Greater => j += 1,
Ordering::Equal => {
let idf = q_values[i]; let tf = d_values[j];
let numerator = tf * (k1 + 1.0);
let denominator = tf + k1 * (1.0 - b + b * doc_len / avg_doc_len);
score += idf * numerator / denominator;
i += 1;
j += 1;
}
}
}
score
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_dot() {
let a = SparseVec::new(vec![0, 2, 5], vec![1.0, 2.0, 3.0], 10).unwrap();
let b = SparseVec::new(vec![2, 3, 5], vec![4.0, 5.0, 6.0], 10).unwrap();
let dot = sparse_dot(&a, &b);
assert!((dot - 26.0).abs() < 1e-5);
}
#[test]
fn test_sparse_dot_no_overlap() {
let a = SparseVec::new(vec![0, 1], vec![1.0, 2.0], 10).unwrap();
let b = SparseVec::new(vec![3, 4], vec![3.0, 4.0], 10).unwrap();
let dot = sparse_dot(&a, &b);
assert_eq!(dot, 0.0);
}
#[test]
fn test_sparse_cosine() {
let a = SparseVec::new(vec![0, 1], vec![3.0, 4.0], 10).unwrap();
let b = SparseVec::new(vec![0, 1], vec![3.0, 4.0], 10).unwrap();
let cos = sparse_cosine(&a, &b);
assert!((cos - 1.0).abs() < 1e-5);
}
#[test]
fn test_sparse_cosine_orthogonal() {
let a = SparseVec::new(vec![0], vec![1.0], 10).unwrap();
let b = SparseVec::new(vec![1], vec![1.0], 10).unwrap();
let cos = sparse_cosine(&a, &b);
assert_eq!(cos, 0.0);
}
#[test]
fn test_sparse_euclidean() {
let a = SparseVec::new(vec![0, 2], vec![0.0, 3.0], 10).unwrap();
let b = SparseVec::new(vec![0, 2], vec![4.0, 0.0], 10).unwrap();
let dist = sparse_euclidean(&a, &b);
assert!((dist - 5.0).abs() < 1e-5);
}
#[test]
fn test_sparse_euclidean_different_indices() {
let a = SparseVec::new(vec![0], vec![3.0], 10).unwrap();
let b = SparseVec::new(vec![1], vec![4.0], 10).unwrap();
let dist = sparse_euclidean(&a, &b);
assert!((dist - 5.0).abs() < 1e-5);
}
#[test]
fn test_sparse_manhattan() {
let a = SparseVec::new(vec![0, 2], vec![1.0, 3.0], 10).unwrap();
let b = SparseVec::new(vec![0, 2], vec![4.0, 1.0], 10).unwrap();
let dist = sparse_manhattan(&a, &b);
assert_eq!(dist, 5.0);
}
#[test]
fn test_sparse_bm25() {
let query = SparseVec::new(vec![0, 2], vec![2.0, 3.0], 10).unwrap();
let doc = SparseVec::new(vec![0, 2], vec![1.0, 2.0], 10).unwrap();
let score = sparse_bm25(&query, &doc, 10.0, 10.0, 1.2, 0.75);
assert!(score > 0.0);
}
}