use super::constants::SQUARED_MAGNITUDE_THRESHOLD;
use crate::core::error::{Error, Result, VectorError};
use crate::core::property::MAX_VECTOR_DIMENSIONS;
#[derive(Debug, Clone, PartialEq)]
pub struct SparseVec {
indices: Vec<u32>,
values: Vec<f32>,
dimension: u32,
}
impl SparseVec {
pub fn new(mut indices: Vec<u32>, mut values: Vec<f32>, dimension: u32) -> Result<Self> {
if dimension as usize > MAX_VECTOR_DIMENSIONS {
return Err(VectorError::DimensionTooLarge {
dimension: dimension as usize,
max_allowed: MAX_VECTOR_DIMENSIONS,
}
.into());
}
if indices.len() != values.len() {
return Err(VectorError::DimensionMismatch {
expected: indices.len(),
actual: values.len(),
}
.into());
}
if !indices.is_empty() {
let mut is_sorted = true;
if indices[0] >= dimension {
return Err(Error::Vector(VectorError::InvalidSparseVector {
reason: format!(
"Index {} is out of bounds for dimension {}",
indices[0], dimension
),
}));
}
Self::validate_value(values[0])?;
for i in 1..indices.len() {
let prev = indices[i - 1];
let curr = indices[i];
if curr <= prev {
if curr == prev {
return Err(Error::Vector(VectorError::InvalidSparseVector {
reason: format!("Duplicate index {} found", curr),
}));
}
is_sorted = false;
break;
}
if curr >= dimension {
return Err(Error::Vector(VectorError::InvalidSparseVector {
reason: format!(
"Index {} is out of bounds for dimension {}",
curr, dimension
),
}));
}
Self::validate_value(values[i])?;
}
if is_sorted {
return Ok(Self {
indices,
values,
dimension,
});
}
let mut index_value_pairs: Vec<(u32, f32)> = indices.into_iter().zip(values).collect();
index_value_pairs.sort_by_key(|(idx, _)| *idx);
let mut prev_idx = None;
for (idx, val) in &index_value_pairs {
Self::validate_value(*val)?;
if *idx >= dimension {
return Err(Error::Vector(VectorError::InvalidSparseVector {
reason: format!(
"Index {} is out of bounds for dimension {}",
idx, dimension
),
}));
}
if let Some(prev) = prev_idx
&& prev == *idx
{
return Err(Error::Vector(VectorError::InvalidSparseVector {
reason: format!("Duplicate index {} found", idx),
}));
}
prev_idx = Some(*idx);
}
let (sorted_indices, sorted_values): (Vec<u32>, Vec<f32>) =
index_value_pairs.into_iter().unzip();
indices = sorted_indices;
values = sorted_values;
}
Ok(Self {
indices,
values,
dimension,
})
}
#[inline(always)]
fn validate_value(val: f32) -> Result<()> {
if val.is_nan() {
return Err(VectorError::ContainsNaN { count: 1 }.into());
}
if val.is_infinite() {
return Err(VectorError::ContainsInfinity { count: 1 }.into());
}
if val == 0.0 {
return Err(Error::Vector(VectorError::InvalidSparseVector {
reason: "Sparse vector contains zero value".to_string(),
}));
}
Ok(())
}
#[inline]
pub fn nnz(&self) -> usize {
self.indices.len()
}
#[inline]
pub fn dimension(&self) -> usize {
self.dimension as usize
}
#[inline]
pub fn indices(&self) -> &[u32] {
&self.indices
}
#[inline]
pub fn values(&self) -> &[f32] {
&self.values
}
pub fn to_dense(&self) -> Vec<f32> {
let mut dense = vec![0.0; self.dimension as usize];
for (&idx, &val) in self.indices.iter().zip(self.values.iter()) {
dense[idx as usize] = val;
}
dense
}
pub fn squared_magnitude(&self) -> f32 {
self.values.iter().map(|v| v * v).sum()
}
#[inline]
pub fn magnitude(&self) -> f32 {
self.squared_magnitude().sqrt()
}
pub fn approx_eq(&self, other: &SparseVec, epsilon: f32) -> bool {
self.dimension == other.dimension
&& self.indices == other.indices
&& self.values.len() == other.values.len()
&& self
.values
.iter()
.zip(other.values.iter())
.all(|(a, b)| (a - b).abs() < epsilon)
}
}
pub fn sparse_dot_product(a: &SparseVec, b: &SparseVec) -> Result<f32> {
if a.dimension() != b.dimension() {
return Err(VectorError::DimensionMismatch {
expected: a.dimension(),
actual: b.dimension(),
}
.into());
}
let mut sum = 0.0f32;
let mut i = 0;
let mut j = 0;
let a_indices = a.indices();
let a_values = a.values();
let b_indices = b.indices();
let b_values = b.values();
while i < a_indices.len() && j < b_indices.len() {
if a_indices[i] == b_indices[j] {
sum += a_values[i] * b_values[j];
i += 1;
j += 1;
} else if a_indices[i] < b_indices[j] {
i += 1;
} else {
j += 1;
}
}
Ok(sum)
}
pub fn sparse_cosine_similarity(a: &SparseVec, b: &SparseVec) -> Result<f32> {
let dot = sparse_dot_product(a, b)?;
let sq_mag_a = a.squared_magnitude();
let sq_mag_b = b.squared_magnitude();
if sq_mag_a < SQUARED_MAGNITUDE_THRESHOLD || sq_mag_b < SQUARED_MAGNITUDE_THRESHOLD {
return Ok(0.0);
}
let similarity = dot / (sq_mag_a.sqrt() * sq_mag_b.sqrt());
Ok(similarity.clamp(-1.0, 1.0))
}
pub fn sparse_squared_euclidean_distance(a: &SparseVec, b: &SparseVec) -> Result<f32> {
if a.dimension() != b.dimension() {
return Err(VectorError::DimensionMismatch {
expected: a.dimension(),
actual: b.dimension(),
}
.into());
}
let mut sum_sq_diff = 0.0f32;
let mut i = 0;
let mut j = 0;
let a_indices = a.indices();
let a_values = a.values();
let b_indices = b.indices();
let b_values = b.values();
while i < a_indices.len() && j < b_indices.len() {
if a_indices[i] == b_indices[j] {
let diff = a_values[i] - b_values[j];
sum_sq_diff += diff * diff;
i += 1;
j += 1;
} else if a_indices[i] < b_indices[j] {
sum_sq_diff += a_values[i] * a_values[i];
i += 1;
} else {
sum_sq_diff += b_values[j] * b_values[j];
j += 1;
}
}
while i < a_indices.len() {
sum_sq_diff += a_values[i] * a_values[i];
i += 1;
}
while j < b_indices.len() {
sum_sq_diff += b_values[j] * b_values[j];
j += 1;
}
Ok(sum_sq_diff)
}
#[inline]
pub fn sparse_euclidean_distance(a: &SparseVec, b: &SparseVec) -> Result<f32> {
sparse_squared_euclidean_distance(a, b).map(|sq| sq.sqrt())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_vec_new_invalid_inputs() {
struct TestCase {
name: &'static str,
indices: Vec<u32>,
values: Vec<f32>,
dimension: u32,
expected_error_contains: &'static str,
}
let cases = vec![
TestCase {
name: "Dimension too large",
indices: vec![0],
values: vec![1.0],
dimension: (MAX_VECTOR_DIMENSIONS + 1) as u32,
expected_error_contains: "exceeds maximum allowed",
},
TestCase {
name: "Mismatched lengths",
indices: vec![0, 1],
values: vec![1.0],
dimension: 10,
expected_error_contains: "Vector dimension mismatch",
},
TestCase {
name: "Index out of bounds (first element)",
indices: vec![10],
values: vec![1.0],
dimension: 10,
expected_error_contains: "out of bounds",
},
TestCase {
name: "Index out of bounds (subsequent element)",
indices: vec![0, 10],
values: vec![1.0, 2.0],
dimension: 10,
expected_error_contains: "out of bounds",
},
TestCase {
name: "Duplicate index",
indices: vec![1, 1],
values: vec![1.0, 2.0],
dimension: 10,
expected_error_contains: "Duplicate index",
},
TestCase {
name: "Zero value",
indices: vec![1],
values: vec![0.0],
dimension: 10,
expected_error_contains: "zero value",
},
TestCase {
name: "NaN value",
indices: vec![1],
values: vec![f32::NAN],
dimension: 10,
expected_error_contains: "NaN",
},
TestCase {
name: "Infinity value",
indices: vec![1],
values: vec![f32::INFINITY],
dimension: 10,
expected_error_contains: "infinity",
},
TestCase {
name: "Negative Infinity value",
indices: vec![1],
values: vec![f32::NEG_INFINITY],
dimension: 10,
expected_error_contains: "infinity",
},
];
for case in cases {
let result = SparseVec::new(case.indices.clone(), case.values.clone(), case.dimension);
assert!(result.is_err(), "Test '{}' should have failed", case.name);
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains(case.expected_error_contains),
"Test '{}' failed with wrong message: '{}', expected to contain '{}'",
case.name,
err_msg,
case.expected_error_contains
);
}
}
#[test]
fn test_sparse_vec_new_sorts_unsorted_input() {
let indices = vec![5, 1, 3];
let values = vec![5.0, 1.0, 3.0];
let sv = SparseVec::new(indices, values, 10).expect("Should construct successfully");
assert_eq!(sv.indices(), &[1, 3, 5]);
assert_eq!(sv.values(), &[1.0, 3.0, 5.0]);
}
#[test]
fn test_sparse_vec_subnormal_value() {
let indices = vec![1];
let values = vec![f32::from_bits(0x0000_0001)]; let result = SparseVec::new(indices, values, 10);
assert!(
result.is_ok(),
"Subnormal value should be accepted as non-zero"
);
}
#[test]
fn test_sparse_vec_operation_dimension_mismatch() {
let a = SparseVec::new(vec![0], vec![1.0], 5).unwrap();
let b = SparseVec::new(vec![0], vec![1.0], 10).unwrap();
assert!(
sparse_dot_product(&a, &b).is_err(),
"dot_product should fail on mismatched dimensions"
);
assert!(
sparse_cosine_similarity(&a, &b).is_err(),
"cosine_similarity should fail on mismatched dimensions"
);
assert!(
sparse_euclidean_distance(&a, &b).is_err(),
"euclidean_distance should fail on mismatched dimensions"
);
assert!(
sparse_squared_euclidean_distance(&a, &b).is_err(),
"squared_euclidean_distance should fail on mismatched dimensions"
);
}
}