use proptest::collection::vec;
use proptest::prelude::*;
use std::collections::HashMap;
use uni_sparse_vector::SparseVector;
use uni_sparse_vector::encode::{decode_slice, encode};
use uni_sparse_vector::ops::{l2_norm, prune_top_k, sparse_dot};
fn weight() -> impl Strategy<Value = f32> {
-1.0e6f32..1.0e6f32
}
fn pairs() -> impl Strategy<Value = Vec<(u32, f32)>> {
vec((0u32..2_000u32, weight()), 0..64)
}
fn sparse_vector() -> impl Strategy<Value = SparseVector> {
pairs().prop_map(|p| SparseVector::from_pairs(p).expect("finite weights -> valid"))
}
fn reference_dot(a: &SparseVector, b: &SparseVector) -> f32 {
let map: HashMap<u32, f32> = a.iter().collect();
b.iter()
.filter_map(|(t, w)| map.get(&t).map(|aw| aw * w))
.sum()
}
proptest! {
#[test]
fn from_pairs_yields_valid_sorted_unique(p in pairs()) {
let sv = SparseVector::from_pairs(p).unwrap();
for w in sv.indices().windows(2) {
prop_assert!(w[0] < w[1]);
}
prop_assert_eq!(sv.indices().len(), sv.values().len());
prop_assert!(SparseVector::new(sv.indices().to_vec(), sv.values().to_vec()).is_ok());
}
#[test]
fn encode_decode_roundtrip(sv in sparse_vector()) {
let bytes = encode(&sv);
let decoded = decode_slice(&bytes).unwrap();
prop_assert_eq!(sv, decoded);
}
#[test]
fn encoded_length_is_exact(sv in sparse_vector()) {
prop_assert_eq!(encode(&sv).len(), 4 + sv.len() * 8);
}
#[test]
fn truncating_any_byte_is_rejected(sv in sparse_vector()) {
let bytes = encode(&sv);
if !bytes.is_empty() {
let mut t = bytes.clone();
t.truncate(bytes.len() - 1);
prop_assert!(decode_slice(&t).is_err());
}
}
#[test]
fn dot_matches_reference(a in sparse_vector(), b in sparse_vector()) {
let merge = sparse_dot(&a, &b);
let reference = reference_dot(&a, &b);
let tol = 1e-2 + 1e-4 * reference.abs();
prop_assert!((merge - reference).abs() <= tol,
"merge={merge} reference={reference}");
}
#[test]
fn dot_is_commutative(a in sparse_vector(), b in sparse_vector()) {
prop_assert_eq!(sparse_dot(&a, &b), sparse_dot(&b, &a));
}
#[test]
fn self_dot_equals_l2_norm_squared(sv in sparse_vector()) {
let dot = sparse_dot(&sv, &sv);
let norm_sq = l2_norm(&sv).powi(2);
let tol = 1e-1 + 1e-4 * norm_sq.abs();
prop_assert!((dot - norm_sq).abs() <= tol, "dot={dot} norm_sq={norm_sq}");
}
#[test]
fn prune_is_valid_subset(sv in sparse_vector(), k in 0usize..80usize) {
let pruned = prune_top_k(&sv, k);
prop_assert_eq!(pruned.len(), k.min(sv.len()));
let original: HashMap<u32, f32> = sv.iter().collect();
for (t, w) in pruned.iter() {
prop_assert_eq!(original.get(&t).copied(), Some(w));
}
for win in pruned.indices().windows(2) {
prop_assert!(win[0] < win[1]);
}
}
}