Skip to main content

nodedb_types/
sparse_vector.rs

1//! Sparse vector type for learned sparse retrieval (SPLADE, SPLADE++).
2//!
3//! Internal representation: sorted `Vec<(u32, f32)>` — (dimension_index, weight).
4//! Only nonzero entries stored — storage proportional to nnz, not total dimensions.
5
6use serde::{Deserialize, Serialize};
7
8/// A sparse vector: a set of (dimension_index, weight) pairs.
9///
10/// Entries are sorted by dimension index for efficient intersection during
11/// dot-product scoring. Dimension indices are non-negative, weights are finite.
12#[derive(
13    Debug,
14    Clone,
15    PartialEq,
16    Serialize,
17    Deserialize,
18    zerompk::ToMessagePack,
19    zerompk::FromMessagePack,
20)]
21pub struct SparseVector {
22    /// Sorted by dimension index (ascending). No duplicate dimensions.
23    entries: Vec<(u32, f32)>,
24}
25
26impl SparseVector {
27    /// Create from unsorted entries. Sorts and deduplicates by dimension.
28    /// Last-writer-wins for duplicate dimensions. Validates all weights are finite.
29    pub fn from_entries(mut entries: Vec<(u32, f32)>) -> Result<Self, SparseVectorError> {
30        for &(_, w) in &entries {
31            if !w.is_finite() {
32                return Err(SparseVectorError::NonFiniteWeight(w));
33            }
34        }
35        // Sort by dimension, deduplicate (last wins).
36        entries.sort_by_key(|&(dim, _)| dim);
37        entries.dedup_by_key(|e| e.0);
38        // Remove zero-weight entries (they contribute nothing to scoring).
39        entries.retain(|&(_, w)| w != 0.0);
40        Ok(Self { entries })
41    }
42
43    /// Create an empty sparse vector.
44    pub fn empty() -> Self {
45        Self {
46            entries: Vec::new(),
47        }
48    }
49
50    /// Parse from literal syntax: `'{103: 0.85, 2941: 0.42, 15003: 0.91}'`.
51    pub fn parse_literal(s: &str) -> Result<Self, SparseVectorError> {
52        let trimmed = s.trim().trim_matches('\'').trim_matches('"');
53        let inner = trimmed
54            .strip_prefix('{')
55            .and_then(|s| s.strip_suffix('}'))
56            .ok_or(SparseVectorError::InvalidLiteral(
57                "expected '{dim: weight, ...}'".into(),
58            ))?;
59
60        if inner.trim().is_empty() {
61            return Ok(Self::empty());
62        }
63
64        let mut entries = Vec::new();
65        for pair in inner.split(',') {
66            let pair = pair.trim();
67            if pair.is_empty() {
68                continue;
69            }
70            let (dim_str, weight_str) = pair.split_once(':').ok_or_else(|| {
71                SparseVectorError::InvalidLiteral(format!("expected 'dim: weight', got '{pair}'"))
72            })?;
73            let dim: u32 = dim_str.trim().parse().map_err(|_| {
74                SparseVectorError::InvalidLiteral(format!("invalid dimension '{}'", dim_str.trim()))
75            })?;
76            let weight: f32 = weight_str.trim().parse().map_err(|_| {
77                SparseVectorError::InvalidLiteral(format!("invalid weight '{}'", weight_str.trim()))
78            })?;
79            entries.push((dim, weight));
80        }
81
82        Self::from_entries(entries)
83    }
84
85    /// Access the sorted entries.
86    pub fn entries(&self) -> &[(u32, f32)] {
87        &self.entries
88    }
89
90    /// Number of nonzero entries.
91    pub fn nnz(&self) -> usize {
92        self.entries.len()
93    }
94
95    /// Whether the vector has no entries.
96    pub fn is_empty(&self) -> bool {
97        self.entries.is_empty()
98    }
99
100    /// Dot product with another sparse vector.
101    ///
102    /// `score = Σ self[d] * other[d]` for dimensions present in both vectors.
103    /// Runs in O(nnz_self + nnz_other) via sorted merge.
104    pub fn dot_product(&self, other: &SparseVector) -> f32 {
105        let mut score = 0.0f32;
106        let (mut i, mut j) = (0, 0);
107        let (a, b) = (&self.entries, &other.entries);
108
109        while i < a.len() && j < b.len() {
110            match a[i].0.cmp(&b[j].0) {
111                std::cmp::Ordering::Equal => {
112                    score += a[i].1 * b[j].1;
113                    i += 1;
114                    j += 1;
115                }
116                std::cmp::Ordering::Less => i += 1,
117                std::cmp::Ordering::Greater => j += 1,
118            }
119        }
120
121        score
122    }
123}
124
125/// Errors from sparse vector construction or parsing.
126#[derive(Debug, Clone)]
127pub enum SparseVectorError {
128    /// A weight value is NaN or infinite.
129    NonFiniteWeight(f32),
130    /// Literal syntax is malformed.
131    InvalidLiteral(String),
132}
133
134impl std::fmt::Display for SparseVectorError {
135    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136        match self {
137            Self::NonFiniteWeight(w) => write!(f, "sparse vector weight must be finite, got {w}"),
138            Self::InvalidLiteral(msg) => write!(f, "invalid sparse vector literal: {msg}"),
139        }
140    }
141}
142
143impl std::error::Error for SparseVectorError {}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn from_entries_sorts_and_deduplicates() {
151        let sv = SparseVector::from_entries(vec![(5, 0.5), (2, 0.3), (5, 0.9)]).unwrap();
152        // Sorted, deduped (last-wins for dim 5 after sort → first 0.5, then dedup keeps first).
153        assert_eq!(sv.nnz(), 2);
154        assert_eq!(sv.entries()[0].0, 2);
155        assert_eq!(sv.entries()[1].0, 5);
156    }
157
158    #[test]
159    fn zero_weights_removed() {
160        let sv = SparseVector::from_entries(vec![(1, 0.5), (2, 0.0), (3, 0.3)]).unwrap();
161        assert_eq!(sv.nnz(), 2);
162        assert!(sv.entries().iter().all(|&(_, w)| w != 0.0));
163    }
164
165    #[test]
166    fn non_finite_rejected() {
167        assert!(SparseVector::from_entries(vec![(1, f32::NAN)]).is_err());
168        assert!(SparseVector::from_entries(vec![(1, f32::INFINITY)]).is_err());
169    }
170
171    #[test]
172    fn parse_literal() {
173        let sv = SparseVector::parse_literal("'{103: 0.85, 2941: 0.42, 15003: 0.91}'").unwrap();
174        assert_eq!(sv.nnz(), 3);
175        assert_eq!(sv.entries()[0], (103, 0.85));
176        assert_eq!(sv.entries()[1], (2941, 0.42));
177        assert_eq!(sv.entries()[2], (15003, 0.91));
178    }
179
180    #[test]
181    fn parse_empty_literal() {
182        let sv = SparseVector::parse_literal("'{}'").unwrap();
183        assert!(sv.is_empty());
184    }
185
186    #[test]
187    fn dot_product_basic() {
188        let a = SparseVector::from_entries(vec![(1, 2.0), (3, 4.0), (5, 6.0)]).unwrap();
189        let b = SparseVector::from_entries(vec![(1, 0.5), (5, 0.5), (7, 1.0)]).unwrap();
190        // Shared dims: 1 (2.0*0.5=1.0) and 5 (6.0*0.5=3.0) = 4.0
191        let score = a.dot_product(&b);
192        assert!((score - 4.0).abs() < 1e-6);
193    }
194
195    #[test]
196    fn dot_product_no_overlap() {
197        let a = SparseVector::from_entries(vec![(1, 1.0)]).unwrap();
198        let b = SparseVector::from_entries(vec![(2, 1.0)]).unwrap();
199        assert_eq!(a.dot_product(&b), 0.0);
200    }
201
202    #[test]
203    fn serde_roundtrip() {
204        let sv = SparseVector::from_entries(vec![(10, 0.5), (20, 0.8)]).unwrap();
205        let bytes = zerompk::to_msgpack_vec(&sv).unwrap();
206        let restored: SparseVector = zerompk::from_msgpack(&bytes).unwrap();
207        assert_eq!(sv, restored);
208    }
209}