#[inline]
#[must_use]
#[allow(unsafe_code)]
pub fn sparse_dot(a_indices: &[u32], a_values: &[f32], b_indices: &[u32], b_values: &[f32]) -> f32 {
assert_eq!(
a_indices.len(),
a_values.len(),
"sparse_dot: a indices/values length mismatch"
);
assert_eq!(
b_indices.len(),
b_values.len(),
"sparse_dot: b indices/values length mismatch"
);
sparse_dot_portable(a_indices, a_values, b_indices, b_values)
}
#[inline]
#[must_use]
pub fn sparse_dot_portable(
a_indices: &[u32],
a_values: &[f32],
b_indices: &[u32],
b_values: &[f32],
) -> f32 {
let mut i = 0;
let mut j = 0;
let mut result = 0.0;
while i < a_indices.len() && j < b_indices.len() {
match a_indices[i].cmp(&b_indices[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result += a_values[i] * b_values[j];
i += 1;
j += 1;
}
}
}
result
}
#[must_use]
pub fn sparse_maxsim(query_tokens: &[(&[u32], &[f32])], doc_tokens: &[(&[u32], &[f32])]) -> f32 {
if query_tokens.is_empty() || doc_tokens.is_empty() {
return 0.0;
}
query_tokens
.iter()
.map(|(q_idx, q_val)| {
doc_tokens
.iter()
.map(|(d_idx, d_val)| sparse_dot(q_idx, q_val, d_idx, d_val))
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_dot_no_overlap() {
let a_idx = [0u32, 2, 4];
let a_val = [1.0f32, 2.0, 3.0];
let b_idx = [1u32, 3, 5];
let b_val = [4.0f32, 5.0, 6.0];
let result = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
assert_eq!(result, 0.0);
}
#[test]
fn test_sparse_dot_full_overlap() {
let a_idx = [0u32, 1, 2];
let a_val = [1.0f32, 2.0, 3.0];
let b_idx = [0u32, 1, 2];
let b_val = [4.0f32, 5.0, 6.0];
let result = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
assert!((result - 32.0).abs() < 1e-6);
}
#[test]
fn test_sparse_dot_partial_overlap() {
let a_idx = [0u32, 2, 4];
let a_val = [1.0f32, 2.0, 3.0];
let b_idx = [1u32, 2, 3];
let b_val = [4.0f32, 5.0, 6.0];
let result = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
assert!((result - 10.0).abs() < 1e-6);
}
#[test]
fn test_sparse_dot_empty() {
let empty_idx: [u32; 0] = [];
let empty_val: [f32; 0] = [];
let a_idx = [0u32, 1];
let a_val = [1.0f32, 2.0];
assert_eq!(sparse_dot(&empty_idx, &empty_val, &a_idx, &a_val), 0.0);
assert_eq!(sparse_dot(&a_idx, &a_val, &empty_idx, &empty_val), 0.0);
}
#[test]
fn test_sparse_dot_different_lengths() {
let a_idx = [0u32, 1, 2, 3, 4];
let a_val = [1.0f32, 2.0, 3.0, 4.0, 5.0];
let b_idx = [2u32];
let b_val = [10.0f32];
let result = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
assert!((result - 30.0).abs() < 1e-6);
}
#[test]
fn test_sparse_maxsim_basic() {
let q1_idx = [0u32, 1];
let q1_val = [1.0f32, 2.0];
let q2_idx = [2u32, 3];
let q2_val = [3.0f32, 4.0];
let d1_idx = [0u32, 2];
let d1_val = [0.5f32, 1.5];
let d2_idx = [1u32, 3];
let d2_val = [2.5f32, 3.5];
let query = vec![(&q1_idx[..], &q1_val[..]), (&q2_idx[..], &q2_val[..])];
let doc = vec![(&d1_idx[..], &d1_val[..]), (&d2_idx[..], &d2_val[..])];
let result = sparse_maxsim(&query, &doc);
assert!((result - 19.0).abs() < 1e-6);
}
#[test]
fn test_sparse_maxsim_empty_query() {
let doc: Vec<(&[u32], &[f32])> = vec![(&[0u32][..], &[1.0f32][..])];
let query: Vec<(&[u32], &[f32])> = vec![];
assert_eq!(sparse_maxsim(&query, &doc), 0.0);
}
#[test]
fn test_sparse_maxsim_empty_doc() {
let query: Vec<(&[u32], &[f32])> = vec![(&[0u32][..], &[1.0f32][..])];
let doc: Vec<(&[u32], &[f32])> = vec![];
assert_eq!(sparse_maxsim(&query, &doc), 0.0);
}
#[test]
fn test_sparse_dot_single_element_overlap() {
let a_idx = [42u32];
let a_val = [7.0f32];
let b_idx = [42u32];
let b_val = [3.0f32];
let result = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
assert!((result - 21.0).abs() < 1e-6);
}
#[test]
fn test_sparse_dot_single_element_no_overlap() {
let a_idx = [10u32];
let a_val = [5.0f32];
let b_idx = [20u32];
let b_val = [5.0f32];
assert_eq!(sparse_dot(&a_idx, &a_val, &b_idx, &b_val), 0.0);
}
#[test]
fn test_sparse_dot_large_index_values() {
let a_idx = [0u32, 1_000_000, u32::MAX - 1];
let a_val = [1.0f32, 2.0, 3.0];
let b_idx = [500_000u32, 1_000_000, u32::MAX - 1];
let b_val = [9.0f32, 4.0, 5.0];
let result = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
assert!((result - 23.0).abs() < 1e-6);
}
#[test]
fn test_sparse_dot_both_empty() {
let empty_idx: [u32; 0] = [];
let empty_val: [f32; 0] = [];
assert_eq!(
sparse_dot(&empty_idx, &empty_val, &empty_idx, &empty_val),
0.0
);
}
#[test]
fn test_sparse_dot_negative_values() {
let a_idx = [0u32, 1, 2];
let a_val = [-1.0f32, 2.0, -3.0];
let b_idx = [0u32, 1, 2];
let b_val = [4.0f32, -5.0, 6.0];
let result = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
assert!((result - (-32.0)).abs() < 1e-6);
}
#[test]
fn test_sparse_dot_portable_matches_dispatch() {
let a_idx = [0u32, 3, 7, 15];
let a_val = [1.5f32, 2.5, 3.5, 4.5];
let b_idx = [3u32, 7, 20];
let b_val = [0.5f32, 1.0, 2.0];
let dispatched = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
let portable = sparse_dot_portable(&a_idx, &a_val, &b_idx, &b_val);
assert!((dispatched - portable).abs() < 1e-6);
assert!((portable - 4.75).abs() < 1e-6);
}
#[test]
fn test_sparse_maxsim_single_token_each() {
let q_idx = [0u32, 5];
let q_val = [1.0f32, 2.0];
let d_idx = [5u32, 10];
let d_val = [3.0f32, 4.0];
let query = vec![(&q_idx[..], &q_val[..])];
let doc = vec![(&d_idx[..], &d_val[..])];
let result = sparse_maxsim(&query, &doc);
assert!((result - 6.0).abs() < 1e-6);
}
#[test]
fn test_sparse_maxsim_no_overlap_returns_zero_per_query() {
let q_idx = [0u32, 1];
let q_val = [1.0f32, 1.0];
let d_idx = [100u32, 200];
let d_val = [1.0f32, 1.0];
let query = vec![(&q_idx[..], &q_val[..])];
let doc = vec![(&d_idx[..], &d_val[..])];
let result = sparse_maxsim(&query, &doc);
assert_eq!(result, 0.0);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn arb_sparse_vec_bounded(
max_len: usize,
max_idx: u32,
) -> impl Strategy<Value = (Vec<u32>, Vec<f32>)> {
prop::collection::vec(0..max_idx, 0..=max_len).prop_flat_map(move |mut indices| {
indices.sort_unstable();
indices.dedup();
let n = indices.len();
prop::collection::vec(-1000.0f32..1000.0f32, n)
.prop_map(move |values| (indices.clone(), values))
})
}
proptest! {
#[test]
fn sparse_dot_commutative(
(a_idx, a_val) in arb_sparse_vec_bounded(20, 1000),
(b_idx, b_val) in arb_sparse_vec_bounded(20, 1000)
) {
let ab = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
let ba = sparse_dot(&b_idx, &b_val, &a_idx, &a_val);
let is_equal = (ab - ba).abs() < 1e-3 * ab.abs().max(ba.abs()).max(1.0);
let both_inf = ab.is_infinite() && ba.is_infinite() && ab.signum() == ba.signum();
let both_nan = ab.is_nan() && ba.is_nan();
prop_assert!(is_equal || both_inf || both_nan,
"ab={}, ba={}", ab, ba);
}
#[test]
fn sparse_dot_self_is_norm_squared(
(idx, val) in arb_sparse_vec_bounded(50, 10000)
) {
let result = sparse_dot(&idx, &val, &idx, &val);
let expected: f32 = val.iter().map(|v| v * v).sum();
let tolerance = 1e-4 * expected.abs().max(1.0);
prop_assert!(
(result - expected).abs() < tolerance,
"result={}, expected={}, tolerance={}",
result,
expected,
tolerance
);
}
#[test]
fn sparse_dot_finite_result(
(a_idx, a_val) in arb_sparse_vec_bounded(20, 1000),
(b_idx, b_val) in arb_sparse_vec_bounded(20, 1000)
) {
let result = sparse_dot(&a_idx, &a_val, &b_idx, &b_val);
prop_assert!(result.is_finite(), "result was not finite: {}", result);
}
#[test]
fn sparse_dot_disjoint_is_zero(
(idx, val) in arb_sparse_vec_bounded(20, 500)
) {
let shift = idx.iter().max().copied().unwrap_or(0) + 1;
let b_idx: Vec<u32> = idx.iter().map(|i| i + shift).collect();
let result = sparse_dot(&idx, &val, &b_idx, &val);
prop_assert_eq!(result, 0.0);
}
#[test]
fn sparse_maxsim_nonnegative_for_positive_values(
n_query in 1usize..5,
n_doc in 1usize..5
) {
let mut query_tokens = Vec::new();
let mut doc_tokens = Vec::new();
for i in 0..n_query {
query_tokens.push((vec![i as u32], vec![1.0f32]));
}
for i in 0..n_doc {
doc_tokens.push((vec![i as u32], vec![1.0f32]));
}
let query: Vec<(&[u32], &[f32])> = query_tokens
.iter()
.map(|(idx, val)| (idx.as_slice(), val.as_slice()))
.collect();
let doc: Vec<(&[u32], &[f32])> = doc_tokens
.iter()
.map(|(idx, val)| (idx.as_slice(), val.as_slice()))
.collect();
let result = sparse_maxsim(&query, &doc);
prop_assert!(result >= 0.0, "sparse_maxsim should be non-negative for positive values");
}
}
}