use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq)]
pub struct SparseVector {
weights: HashMap<u32, f32>,
}
impl Default for SparseVector {
fn default() -> Self {
Self::new()
}
}
impl SparseVector {
pub fn new() -> Self {
Self {
weights: HashMap::new(),
}
}
pub fn from_weights(weights: HashMap<u32, f32>) -> Self {
let filtered: HashMap<u32, f32> = weights.into_iter().filter(|(_, w)| *w > 0.0).collect();
Self { weights: filtered }
}
pub fn add(&mut self, term_id: u32, weight: f32) {
if weight > 0.0 {
self.weights.insert(term_id, weight);
}
}
pub fn dot(&self, other: &SparseVector) -> f32 {
let (smaller, larger) = if self.weights.len() <= other.weights.len() {
(&self.weights, &other.weights)
} else {
(&other.weights, &self.weights)
};
smaller
.iter()
.filter_map(|(term_id, w1)| larger.get(term_id).map(|w2| w1 * w2))
.sum()
}
pub fn norm(&self) -> f32 {
self.weights.values().map(|w| w * w).sum::<f32>().sqrt()
}
pub fn l1_norm(&self) -> f32 {
self.weights.values().sum()
}
pub fn nnz(&self) -> usize {
self.weights.len()
}
pub fn get(&self, term_id: u32) -> f32 {
*self.weights.get(&term_id).unwrap_or(&0.0)
}
pub fn iter(&self) -> impl Iterator<Item = (&u32, &f32)> {
self.weights.iter()
}
pub fn to_bytes(&self) -> Vec<u8> {
let count = self.weights.len() as u32;
let mut bytes = Vec::with_capacity(4 + self.weights.len() * 8);
bytes.extend_from_slice(&count.to_le_bytes());
let mut pairs: Vec<_> = self.weights.iter().collect();
pairs.sort_by_key(|(k, _)| *k);
for (term_id, weight) in pairs {
bytes.extend_from_slice(&term_id.to_le_bytes());
bytes.extend_from_slice(&weight.to_le_bytes());
}
bytes
}
pub fn from_bytes(data: &[u8]) -> Option<Self> {
if data.len() < 4 {
return None;
}
let count = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
let expected_len = 4 + count * 8;
if data.len() < expected_len {
return None;
}
let mut weights = HashMap::with_capacity(count);
let mut offset = 4;
for _ in 0..count {
let term_id = u32::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]);
let weight = f32::from_le_bytes([
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
]);
if weight > 0.0 {
weights.insert(term_id, weight);
}
offset += 8;
}
Some(Self { weights })
}
pub fn weights(&self) -> &HashMap<u32, f32> {
&self.weights
}
pub fn is_empty(&self) -> bool {
self.weights.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_empty() {
let vec = SparseVector::new();
assert_eq!(vec.nnz(), 0);
assert!(vec.is_empty());
}
#[test]
fn test_add_positive_weight() {
let mut vec = SparseVector::new();
vec.add(100, 1.5);
assert_eq!(vec.nnz(), 1);
assert_eq!(vec.get(100), 1.5);
}
#[test]
fn test_add_zero_weight_ignored() {
let mut vec = SparseVector::new();
vec.add(100, 0.0);
assert_eq!(vec.nnz(), 0);
}
#[test]
fn test_add_negative_weight_ignored() {
let mut vec = SparseVector::new();
vec.add(100, -1.0);
assert_eq!(vec.nnz(), 0);
}
#[test]
fn test_from_weights_filters() {
let mut weights = HashMap::new();
weights.insert(100, 1.5);
weights.insert(200, 0.0);
weights.insert(300, -0.5);
let vec = SparseVector::from_weights(weights);
assert_eq!(vec.nnz(), 1);
assert_eq!(vec.get(100), 1.5);
}
#[test]
fn test_dot_product() {
let mut a = SparseVector::new();
a.add(100, 1.0);
a.add(200, 2.0);
let mut b = SparseVector::new();
b.add(100, 3.0);
b.add(300, 4.0);
assert_eq!(a.dot(&b), 3.0);
}
#[test]
fn test_dot_product_no_overlap() {
let mut a = SparseVector::new();
a.add(100, 1.0);
let mut b = SparseVector::new();
b.add(200, 2.0);
assert_eq!(a.dot(&b), 0.0);
}
#[test]
fn test_dot_product_commutative() {
let mut a = SparseVector::new();
a.add(100, 1.5);
a.add(200, 2.5);
let mut b = SparseVector::new();
b.add(100, 3.0);
b.add(200, 1.0);
b.add(300, 5.0);
assert_eq!(a.dot(&b), b.dot(&a));
}
#[test]
fn test_norm() {
let mut vec = SparseVector::new();
vec.add(100, 3.0);
vec.add(200, 4.0);
assert_eq!(vec.norm(), 5.0);
}
#[test]
fn test_l1_norm() {
let mut vec = SparseVector::new();
vec.add(100, 1.5);
vec.add(200, 2.5);
assert_eq!(vec.l1_norm(), 4.0);
}
#[test]
fn test_serialization_roundtrip() {
let mut vec = SparseVector::new();
vec.add(100, 1.5);
vec.add(200, 2.5);
vec.add(300, 3.5);
let bytes = vec.to_bytes();
let restored = SparseVector::from_bytes(&bytes).unwrap();
assert_eq!(vec.nnz(), restored.nnz());
assert_eq!(vec.get(100), restored.get(100));
assert_eq!(vec.get(200), restored.get(200));
assert_eq!(vec.get(300), restored.get(300));
}
#[test]
fn test_serialization_empty() {
let vec = SparseVector::new();
let bytes = vec.to_bytes();
let restored = SparseVector::from_bytes(&bytes).unwrap();
assert_eq!(restored.nnz(), 0);
}
#[test]
fn test_from_bytes_invalid() {
assert!(SparseVector::from_bytes(&[]).is_none());
assert!(SparseVector::from_bytes(&[1, 0, 0, 0]).is_none()); }
}