use serde::{Deserialize, Serialize};
#[derive(
Debug,
Clone,
PartialEq,
Serialize,
Deserialize,
zerompk::ToMessagePack,
zerompk::FromMessagePack,
)]
pub struct SparseVector {
entries: Vec<(u32, f32)>,
}
impl SparseVector {
pub fn from_entries(mut entries: Vec<(u32, f32)>) -> Result<Self, SparseVectorError> {
for &(_, w) in &entries {
if !w.is_finite() {
return Err(SparseVectorError::NonFiniteWeight(w));
}
}
entries.sort_by_key(|&(dim, _)| dim);
entries.dedup_by_key(|e| e.0);
entries.retain(|&(_, w)| w != 0.0);
Ok(Self { entries })
}
pub fn empty() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn parse_literal(s: &str) -> Result<Self, SparseVectorError> {
let trimmed = s.trim().trim_matches('\'').trim_matches('"');
let inner = trimmed
.strip_prefix('{')
.and_then(|s| s.strip_suffix('}'))
.ok_or(SparseVectorError::InvalidLiteral(
"expected '{dim: weight, ...}'".into(),
))?;
if inner.trim().is_empty() {
return Ok(Self::empty());
}
let mut entries = Vec::new();
for pair in inner.split(',') {
let pair = pair.trim();
if pair.is_empty() {
continue;
}
let (dim_str, weight_str) = pair.split_once(':').ok_or_else(|| {
SparseVectorError::InvalidLiteral(format!("expected 'dim: weight', got '{pair}'"))
})?;
let dim: u32 = dim_str.trim().parse().map_err(|_| {
SparseVectorError::InvalidLiteral(format!("invalid dimension '{}'", dim_str.trim()))
})?;
let weight: f32 = weight_str.trim().parse().map_err(|_| {
SparseVectorError::InvalidLiteral(format!("invalid weight '{}'", weight_str.trim()))
})?;
entries.push((dim, weight));
}
Self::from_entries(entries)
}
pub fn entries(&self) -> &[(u32, f32)] {
&self.entries
}
pub fn nnz(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn dot_product(&self, other: &SparseVector) -> f32 {
let mut score = 0.0f32;
let (mut i, mut j) = (0, 0);
let (a, b) = (&self.entries, &other.entries);
while i < a.len() && j < b.len() {
match a[i].0.cmp(&b[j].0) {
std::cmp::Ordering::Equal => {
score += a[i].1 * b[j].1;
i += 1;
j += 1;
}
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
}
}
score
}
}
#[derive(Debug, Clone)]
pub enum SparseVectorError {
NonFiniteWeight(f32),
InvalidLiteral(String),
}
impl std::fmt::Display for SparseVectorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NonFiniteWeight(w) => write!(f, "sparse vector weight must be finite, got {w}"),
Self::InvalidLiteral(msg) => write!(f, "invalid sparse vector literal: {msg}"),
}
}
}
impl std::error::Error for SparseVectorError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_entries_sorts_and_deduplicates() {
let sv = SparseVector::from_entries(vec![(5, 0.5), (2, 0.3), (5, 0.9)]).unwrap();
assert_eq!(sv.nnz(), 2);
assert_eq!(sv.entries()[0].0, 2);
assert_eq!(sv.entries()[1].0, 5);
}
#[test]
fn zero_weights_removed() {
let sv = SparseVector::from_entries(vec![(1, 0.5), (2, 0.0), (3, 0.3)]).unwrap();
assert_eq!(sv.nnz(), 2);
assert!(sv.entries().iter().all(|&(_, w)| w != 0.0));
}
#[test]
fn non_finite_rejected() {
assert!(SparseVector::from_entries(vec![(1, f32::NAN)]).is_err());
assert!(SparseVector::from_entries(vec![(1, f32::INFINITY)]).is_err());
}
#[test]
fn parse_literal() {
let sv = SparseVector::parse_literal("'{103: 0.85, 2941: 0.42, 15003: 0.91}'").unwrap();
assert_eq!(sv.nnz(), 3);
assert_eq!(sv.entries()[0], (103, 0.85));
assert_eq!(sv.entries()[1], (2941, 0.42));
assert_eq!(sv.entries()[2], (15003, 0.91));
}
#[test]
fn parse_empty_literal() {
let sv = SparseVector::parse_literal("'{}'").unwrap();
assert!(sv.is_empty());
}
#[test]
fn dot_product_basic() {
let a = SparseVector::from_entries(vec![(1, 2.0), (3, 4.0), (5, 6.0)]).unwrap();
let b = SparseVector::from_entries(vec![(1, 0.5), (5, 0.5), (7, 1.0)]).unwrap();
let score = a.dot_product(&b);
assert!((score - 4.0).abs() < 1e-6);
}
#[test]
fn dot_product_no_overlap() {
let a = SparseVector::from_entries(vec![(1, 1.0)]).unwrap();
let b = SparseVector::from_entries(vec![(2, 1.0)]).unwrap();
assert_eq!(a.dot_product(&b), 0.0);
}
#[test]
fn serde_roundtrip() {
let sv = SparseVector::from_entries(vec![(10, 0.5), (20, 0.8)]).unwrap();
let bytes = zerompk::to_msgpack_vec(&sv).unwrap();
let restored: SparseVector = zerompk::from_msgpack(&bytes).unwrap();
assert_eq!(sv, restored);
}
}