use crate::sparse::error::SparseError;
use crate::sparse::metrics::{sparse_cosine, sparse_dot_product, sparse_norm};
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(try_from = "SparseVectorRaw")]
pub struct SparseVector {
indices: Vec<u32>,
values: Vec<f32>,
dim: u32,
}
#[derive(Deserialize)]
struct SparseVectorRaw {
indices: Vec<u32>,
values: Vec<f32>,
dim: u32,
}
impl TryFrom<SparseVectorRaw> for SparseVector {
type Error = SparseError;
fn try_from(raw: SparseVectorRaw) -> Result<Self, Self::Error> {
SparseVector::new(raw.indices, raw.values, raw.dim)
}
}
impl SparseVector {
fn validate(indices: &[u32], values: &[f32], dim: u32) -> Result<(), SparseError> {
if indices.len() != values.len() {
return Err(SparseError::LengthMismatch {
indices: indices.len(),
values: values.len(),
});
}
if indices.is_empty() {
return Err(SparseError::EmptyVector);
}
for i in 1..indices.len() {
match indices[i - 1].cmp(&indices[i]) {
std::cmp::Ordering::Greater => return Err(SparseError::UnsortedIndices),
std::cmp::Ordering::Equal => return Err(SparseError::DuplicateIndex(i)),
std::cmp::Ordering::Less => {}
}
}
for (i, &idx) in indices.iter().enumerate() {
if idx >= dim {
return Err(SparseError::IndexOutOfBounds { index: idx, dim });
}
if !values[i].is_finite() {
return Err(SparseError::InvalidValue(i));
}
}
Ok(())
}
pub fn new(indices: Vec<u32>, values: Vec<f32>, dim: u32) -> Result<Self, SparseError> {
Self::validate(&indices, &values, dim)?;
Ok(Self {
indices,
values,
dim,
})
}
#[doc(hidden)]
#[must_use]
pub(crate) fn new_unchecked(indices: Vec<u32>, values: Vec<f32>, dim: u32) -> Self {
Self {
indices,
values,
dim,
}
}
pub fn from_pairs(pairs: &[(u32, f32)], dim: u32) -> Result<Self, SparseError> {
if pairs.is_empty() {
return Err(SparseError::EmptyVector);
}
let mut sorted: Vec<(u32, f32)> = pairs.to_vec();
sorted.sort_by_key(|(idx, _)| *idx);
let (indices, values): (Vec<u32>, Vec<f32>) = sorted.into_iter().unzip();
Self::new(indices, values, dim)
}
pub fn singleton(index: u32, value: f32, dim: u32) -> Result<Self, SparseError> {
Self::new(vec![index], vec![value], dim)
}
#[must_use]
#[inline]
pub fn indices(&self) -> &[u32] {
&self.indices
}
#[must_use]
#[inline]
pub fn values(&self) -> &[f32] {
&self.values
}
#[must_use]
#[inline]
pub fn dim(&self) -> u32 {
self.dim
}
#[must_use]
#[inline]
pub fn nnz(&self) -> usize {
self.indices.len()
}
#[must_use]
pub fn to_pairs(&self) -> Vec<(u32, f32)> {
self.indices
.iter()
.copied()
.zip(self.values.iter().copied())
.collect()
}
#[must_use]
pub fn get(&self, index: u32) -> Option<f32> {
self.indices
.binary_search(&index)
.ok()
.map(|pos| self.values[pos])
}
#[must_use]
#[inline]
pub fn dot(&self, other: &SparseVector) -> f32 {
sparse_dot_product(self, other)
}
#[must_use]
#[inline]
pub fn norm(&self) -> f32 {
sparse_norm(self)
}
#[must_use]
#[inline]
pub fn cosine(&self, other: &SparseVector) -> f32 {
sparse_cosine(self, other)
}
pub fn normalize(&self) -> Result<Self, SparseError> {
let norm = self.norm();
if norm == 0.0 {
return Err(SparseError::ZeroNorm);
}
let normalized_values: Vec<f32> = self.values.iter().map(|v| v / norm).collect();
Ok(Self {
indices: self.indices.clone(),
values: normalized_values,
dim: self.dim,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_valid() {
assert!(SparseVector::validate(&[0, 5, 10], &[0.1, 0.2, 0.3], 100).is_ok());
}
#[test]
fn test_validate_length_mismatch() {
let result = SparseVector::validate(&[0, 5], &[0.1], 100);
assert!(matches!(result, Err(SparseError::LengthMismatch { .. })));
}
#[test]
fn test_validate_empty() {
let result = SparseVector::validate(&[], &[], 100);
assert!(matches!(result, Err(SparseError::EmptyVector)));
}
#[test]
fn test_validate_unsorted() {
let result = SparseVector::validate(&[5, 0, 10], &[0.1, 0.2, 0.3], 100);
assert!(matches!(result, Err(SparseError::UnsortedIndices)));
}
#[test]
fn test_validate_duplicate() {
let result = SparseVector::validate(&[0, 5, 5], &[0.1, 0.2, 0.3], 100);
assert!(matches!(result, Err(SparseError::DuplicateIndex(2))));
}
#[test]
fn test_validate_out_of_bounds() {
let result = SparseVector::validate(&[0, 100], &[0.1, 0.2], 100);
assert!(matches!(
result,
Err(SparseError::IndexOutOfBounds {
index: 100,
dim: 100
})
));
}
#[test]
fn test_validate_nan() {
let result = SparseVector::validate(&[0, 5], &[0.1, f32::NAN], 100);
assert!(matches!(result, Err(SparseError::InvalidValue(1))));
}
#[test]
fn test_validate_infinity() {
let result = SparseVector::validate(&[0, 5], &[f32::INFINITY, 0.2], 100);
assert!(matches!(result, Err(SparseError::InvalidValue(0))));
}
#[test]
fn test_validate_neg_infinity() {
let result = SparseVector::validate(&[0], &[f32::NEG_INFINITY], 100);
assert!(matches!(result, Err(SparseError::InvalidValue(0))));
}
#[test]
fn test_new_valid() {
let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.2, 0.3], 100);
assert!(sparse.is_ok());
let sparse = sparse.unwrap();
assert_eq!(sparse.nnz(), 3);
assert_eq!(sparse.dim(), 100);
}
#[test]
fn test_new_invalid() {
let result = SparseVector::new(vec![10, 5, 0], vec![0.1, 0.2, 0.3], 100);
assert!(result.is_err());
}
#[test]
fn test_from_pairs_sorts() {
let sparse = SparseVector::from_pairs(&[(10, 0.3), (0, 0.1), (5, 0.2)], 100);
assert!(sparse.is_ok());
let sparse = sparse.unwrap();
assert_eq!(sparse.indices(), &[0, 5, 10]);
assert_eq!(sparse.values(), &[0.1, 0.2, 0.3]);
}
#[test]
fn test_from_pairs_duplicate_fails() {
let result = SparseVector::from_pairs(&[(5, 0.1), (5, 0.2)], 100);
assert!(matches!(result, Err(SparseError::DuplicateIndex(_))));
}
#[test]
fn test_from_pairs_empty_fails() {
let result = SparseVector::from_pairs(&[], 100);
assert!(matches!(result, Err(SparseError::EmptyVector)));
}
#[test]
fn test_singleton() {
let sparse = SparseVector::singleton(42, 1.0, 100);
assert!(sparse.is_ok());
let sparse = sparse.unwrap();
assert_eq!(sparse.nnz(), 1);
assert_eq!(sparse.indices(), &[42]);
assert_eq!(sparse.values(), &[1.0]);
}
#[test]
fn test_singleton_out_of_bounds() {
let result = SparseVector::singleton(100, 1.0, 100);
assert!(matches!(result, Err(SparseError::IndexOutOfBounds { .. })));
}
#[test]
fn test_singleton_nan() {
let result = SparseVector::singleton(0, f32::NAN, 100);
assert!(matches!(result, Err(SparseError::InvalidValue(_))));
}
#[test]
fn test_accessors() {
let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.2, 0.3], 100).unwrap();
assert_eq!(sparse.indices(), &[0, 5, 10]);
assert_eq!(sparse.values(), &[0.1, 0.2, 0.3]);
assert_eq!(sparse.dim(), 100);
assert_eq!(sparse.nnz(), 3);
}
#[test]
fn test_to_pairs() {
let sparse = SparseVector::new(vec![0, 5], vec![0.1, 0.2], 100).unwrap();
assert_eq!(sparse.to_pairs(), vec![(0, 0.1), (5, 0.2)]);
}
#[test]
fn test_get_present() {
let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.2, 0.3], 100).unwrap();
assert_eq!(sparse.get(0), Some(0.1));
assert_eq!(sparse.get(5), Some(0.2));
assert_eq!(sparse.get(10), Some(0.3));
}
#[test]
fn test_get_absent() {
let sparse = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.2, 0.3], 100).unwrap();
assert_eq!(sparse.get(1), None);
assert_eq!(sparse.get(99), None);
}
#[test]
fn test_serde_roundtrip() {
let original = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.2, 0.3], 100).unwrap();
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn test_dot_method() {
let a = SparseVector::new(vec![0, 5], vec![1.0, 2.0], 100).unwrap();
let b = SparseVector::new(vec![5, 10], vec![3.0, 1.0], 100).unwrap();
assert!((a.dot(&b) - 6.0).abs() < 1e-6);
}
#[test]
fn test_norm_method() {
let v = SparseVector::new(vec![0, 1], vec![3.0, 4.0], 100).unwrap();
assert!((v.norm() - 5.0).abs() < 1e-6);
}
#[test]
fn test_cosine_method() {
let a = SparseVector::new(vec![0], vec![1.0], 100).unwrap();
let b = SparseVector::new(vec![0], vec![2.0], 100).unwrap();
assert!((a.cosine(&b) - 1.0).abs() < 1e-6);
}
#[test]
fn test_normalize() {
let v = SparseVector::new(vec![0, 1], vec![3.0, 4.0], 100).unwrap();
let normalized = v.normalize().unwrap();
assert!((normalized.norm() - 1.0).abs() < 1e-6);
assert!((normalized.values()[0] - 0.6).abs() < 1e-6);
assert!((normalized.values()[1] - 0.8).abs() < 1e-6);
}
#[test]
fn test_normalize_preserves_indices() {
let v = SparseVector::new(vec![5, 10, 20], vec![1.0, 2.0, 3.0], 100).unwrap();
let normalized = v.normalize().unwrap();
assert_eq!(normalized.indices(), &[5, 10, 20]);
assert_eq!(normalized.dim(), 100);
}
#[test]
fn test_normalize_zero_values() {
let v = SparseVector::new(vec![0], vec![0.0], 100).unwrap();
assert_eq!(v.norm(), 0.0);
let result = v.normalize();
assert!(
matches!(result, Err(SparseError::ZeroNorm)),
"normalize() on all-zero-values vector must return ZeroNorm, got: {:?}",
result
);
let v2 = SparseVector::new(vec![0, 5, 10], vec![0.0, 0.0, 0.0], 100).unwrap();
assert!(matches!(v2.normalize(), Err(SparseError::ZeroNorm)));
}
#[test]
fn test_deserialize_rejects_unsorted_indices() {
let json = r#"{"indices":[5,0,10],"values":[0.1,0.2,0.3],"dim":100}"#;
let result: Result<SparseVector, _> = serde_json::from_str(json);
assert!(
result.is_err(),
"Deserialization must reject unsorted indices"
);
}
#[test]
fn test_deserialize_rejects_duplicate_indices() {
let json = r#"{"indices":[0,5,5],"values":[0.1,0.2,0.3],"dim":100}"#;
let result: Result<SparseVector, _> = serde_json::from_str(json);
assert!(
result.is_err(),
"Deserialization must reject duplicate indices"
);
}
#[test]
fn test_deserialize_rejects_nan_values() {
let json = r#"{"indices":[0,5],"values":[0.1,null],"dim":100}"#;
let result: Result<SparseVector, _> = serde_json::from_str(json);
assert!(
result.is_err(),
"Deserialization must reject null/invalid values"
);
}
#[test]
fn test_deserialize_rejects_empty_vector() {
let json = r#"{"indices":[],"values":[],"dim":100}"#;
let result: Result<SparseVector, _> = serde_json::from_str(json);
assert!(result.is_err(), "Deserialization must reject empty vectors");
}
#[test]
fn test_deserialize_rejects_out_of_bounds_indices() {
let json = r#"{"indices":[0,100],"values":[0.1,0.2],"dim":100}"#;
let result: Result<SparseVector, _> = serde_json::from_str(json);
assert!(
result.is_err(),
"Deserialization must reject out-of-bounds indices"
);
}
#[test]
fn test_deserialize_rejects_length_mismatch() {
let json = r#"{"indices":[0,5],"values":[0.1],"dim":100}"#;
let result: Result<SparseVector, _> = serde_json::from_str(json);
assert!(
result.is_err(),
"Deserialization must reject length mismatch"
);
}
#[test]
fn test_deserialize_valid_roundtrip() {
let original = SparseVector::new(vec![0, 5, 10], vec![0.1, 0.2, 0.3], 100).unwrap();
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: SparseVector = serde_json::from_str(&serialized).unwrap();
assert_eq!(original, deserialized);
}
}