use crate::sparse::SparseVector;
#[inline]
#[must_use]
pub fn sparse_dot_product(a: &SparseVector, b: &SparseVector) -> f32 {
let a_indices = a.indices();
let a_values = a.values();
let b_indices = b.indices();
let b_values = b.values();
let mut result = 0.0f32;
let mut i = 0usize;
let mut j = 0usize;
while i < a_indices.len() && j < b_indices.len() {
match a_indices[i].cmp(&b_indices[j]) {
std::cmp::Ordering::Less => {
i += 1;
}
std::cmp::Ordering::Greater => {
j += 1;
}
std::cmp::Ordering::Equal => {
result += a_values[i] * b_values[j];
i += 1;
j += 1;
}
}
}
result
}
#[inline]
#[must_use]
pub fn sparse_norm(v: &SparseVector) -> f32 {
v.values().iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[inline]
#[must_use]
pub fn sparse_cosine(a: &SparseVector, b: &SparseVector) -> f32 {
let dot = sparse_dot_product(a, b);
let norm_a = sparse_norm(a);
let norm_b = sparse_norm(b);
let denom = norm_a * norm_b;
if denom == 0.0 {
return 0.0;
}
dot / denom
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dot_product_overlap() {
let a = SparseVector::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0], 100).unwrap();
let b = SparseVector::new(vec![5, 10, 20], vec![0.5, 0.5, 1.0], 100).unwrap();
let dot = sparse_dot_product(&a, &b);
assert!((dot - 2.5).abs() < 1e-6);
}
#[test]
fn test_dot_product_no_overlap() {
let a = SparseVector::new(vec![0, 1, 2], vec![1.0, 1.0, 1.0], 100).unwrap();
let b = SparseVector::new(vec![10, 11, 12], vec![1.0, 1.0, 1.0], 100).unwrap();
let dot = sparse_dot_product(&a, &b);
assert!((dot - 0.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_self() {
let a = SparseVector::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0], 100).unwrap();
let dot = sparse_dot_product(&a, &a);
assert!((dot - 14.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_exact_match() {
let a = SparseVector::new(vec![0, 1, 2], vec![1.0, 2.0, 3.0], 100).unwrap();
let b = SparseVector::new(vec![0, 1, 2], vec![4.0, 5.0, 6.0], 100).unwrap();
let dot = sparse_dot_product(&a, &b);
assert!((dot - 32.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_singleton() {
let a = SparseVector::singleton(5, 3.0, 100).unwrap();
let b = SparseVector::singleton(5, 4.0, 100).unwrap();
let dot = sparse_dot_product(&a, &b);
assert!((dot - 12.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_commutative() {
let a = SparseVector::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0], 100).unwrap();
let b = SparseVector::new(vec![5, 10, 20], vec![0.5, 0.5, 1.0], 100).unwrap();
let ab = sparse_dot_product(&a, &b);
let ba = sparse_dot_product(&b, &a);
assert!((ab - ba).abs() < 1e-6);
}
#[test]
fn test_norm_345() {
let v = SparseVector::new(vec![0, 1], vec![3.0, 4.0], 100).unwrap();
let norm = sparse_norm(&v);
assert!((norm - 5.0).abs() < 1e-6);
}
#[test]
fn test_norm_unit() {
let v = SparseVector::singleton(0, 1.0, 100).unwrap();
let norm = sparse_norm(&v);
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_norm_multiple() {
let v = SparseVector::new(vec![0, 1, 2], vec![1.0, 1.0, 1.0], 100).unwrap();
let norm = sparse_norm(&v);
assert!((norm - 3.0f32.sqrt()).abs() < 1e-6);
}
#[test]
fn test_norm_equals_sqrt_dot_self() {
let v = SparseVector::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0], 100).unwrap();
let norm = sparse_norm(&v);
let dot_self = sparse_dot_product(&v, &v);
assert!((norm - dot_self.sqrt()).abs() < 1e-6);
}
#[test]
fn test_cosine_self_is_one() {
let v = SparseVector::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0], 100).unwrap();
let cos = sparse_cosine(&v, &v);
assert!((cos - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_parallel() {
let a = SparseVector::new(vec![0, 1], vec![1.0, 0.0], 100).unwrap();
let b = SparseVector::new(vec![0, 1], vec![5.0, 0.0], 100).unwrap();
let cos = sparse_cosine(&a, &b);
assert!((cos - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_orthogonal() {
let a = SparseVector::new(vec![0], vec![1.0], 100).unwrap();
let b = SparseVector::new(vec![1], vec![1.0], 100).unwrap();
let cos = sparse_cosine(&a, &b);
assert!((cos - 0.0).abs() < 1e-6);
}
#[test]
fn test_cosine_antiparallel() {
let a = SparseVector::new(vec![0], vec![1.0], 100).unwrap();
let b = SparseVector::new(vec![0], vec![-1.0], 100).unwrap();
let cos = sparse_cosine(&a, &b);
assert!((cos - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_cosine_commutative() {
let a = SparseVector::new(vec![0, 5, 10], vec![1.0, 2.0, 3.0], 100).unwrap();
let b = SparseVector::new(vec![5, 10, 20], vec![0.5, 0.5, 1.0], 100).unwrap();
assert!((sparse_cosine(&a, &b) - sparse_cosine(&b, &a)).abs() < 1e-6);
}
#[test]
fn test_cosine_in_range() {
let a = SparseVector::new(vec![0, 5, 10], vec![1.0, -2.0, 3.0], 100).unwrap();
let b = SparseVector::new(vec![5, 10, 20], vec![0.5, -0.5, 1.0], 100).unwrap();
let cos = sparse_cosine(&a, &b);
assert!((-1.0..=1.0).contains(&cos));
}
}