Skip to main content

nodedb_types/
sparse_vector.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Sparse vector type for learned sparse retrieval (SPLADE, SPLADE++).
4//!
5//! Internal representation: sorted `Vec<(u32, f32)>` — (dimension_index, weight).
6//! Only nonzero entries stored — storage proportional to nnz, not total dimensions.
7
8use serde::{Deserialize, Serialize};
9
10/// A sparse vector: a set of (dimension_index, weight) pairs.
11///
12/// Entries are sorted by dimension index for efficient intersection during
13/// dot-product scoring. Dimension indices are non-negative, weights are finite.
14#[derive(
15    Debug,
16    Clone,
17    PartialEq,
18    Serialize,
19    Deserialize,
20    zerompk::ToMessagePack,
21    zerompk::FromMessagePack,
22)]
23pub struct SparseVector {
24    /// Sorted by dimension index (ascending). No duplicate dimensions.
25    entries: Vec<(u32, f32)>,
26}
27
28impl SparseVector {
29    /// Create from unsorted entries. Sorts and deduplicates by dimension.
30    /// Last-writer-wins for duplicate dimensions. Validates all weights are finite.
31    pub fn from_entries(mut entries: Vec<(u32, f32)>) -> Result<Self, SparseVectorError> {
32        for &(_, w) in &entries {
33            if !w.is_finite() {
34                return Err(SparseVectorError::NonFiniteWeight(w));
35            }
36        }
37        // Sort by dimension, deduplicate (last wins).
38        entries.sort_by_key(|&(dim, _)| dim);
39        entries.dedup_by_key(|e| e.0);
40        // Remove zero-weight entries (they contribute nothing to scoring).
41        entries.retain(|&(_, w)| w != 0.0);
42        Ok(Self { entries })
43    }
44
45    /// Create an empty sparse vector.
46    pub fn empty() -> Self {
47        Self {
48            entries: Vec::new(),
49        }
50    }
51
52    /// Parse from literal syntax: `'{103: 0.85, 2941: 0.42, 15003: 0.91}'`.
53    pub fn parse_literal(s: &str) -> Result<Self, SparseVectorError> {
54        let trimmed = s.trim().trim_matches('\'').trim_matches('"');
55        let inner = trimmed
56            .strip_prefix('{')
57            .and_then(|s| s.strip_suffix('}'))
58            .ok_or(SparseVectorError::InvalidLiteral(
59                "expected '{dim: weight, ...}'".into(),
60            ))?;
61
62        if inner.trim().is_empty() {
63            return Ok(Self::empty());
64        }
65
66        let mut entries = Vec::new();
67        for pair in inner.split(',') {
68            let pair = pair.trim();
69            if pair.is_empty() {
70                continue;
71            }
72            let (dim_str, weight_str) = pair.split_once(':').ok_or_else(|| {
73                SparseVectorError::InvalidLiteral(format!("expected 'dim: weight', got '{pair}'"))
74            })?;
75            let dim: u32 = dim_str.trim().parse().map_err(|_| {
76                SparseVectorError::InvalidLiteral(format!("invalid dimension '{}'", dim_str.trim()))
77            })?;
78            let weight: f32 = weight_str.trim().parse().map_err(|_| {
79                SparseVectorError::InvalidLiteral(format!("invalid weight '{}'", weight_str.trim()))
80            })?;
81            entries.push((dim, weight));
82        }
83
84        Self::from_entries(entries)
85    }
86
87    /// Access the sorted entries.
88    pub fn entries(&self) -> &[(u32, f32)] {
89        &self.entries
90    }
91
92    /// Number of nonzero entries.
93    pub fn nnz(&self) -> usize {
94        self.entries.len()
95    }
96
97    /// Whether the vector has no entries.
98    pub fn is_empty(&self) -> bool {
99        self.entries.is_empty()
100    }
101
102    /// Dot product with another sparse vector.
103    ///
104    /// `score = Σ self[d] * other[d]` for dimensions present in both vectors.
105    /// Runs in O(nnz_self + nnz_other) via sorted merge.
106    pub fn dot_product(&self, other: &SparseVector) -> f32 {
107        let mut score = 0.0f32;
108        let (mut i, mut j) = (0, 0);
109        let (a, b) = (&self.entries, &other.entries);
110
111        while i < a.len() && j < b.len() {
112            match a[i].0.cmp(&b[j].0) {
113                std::cmp::Ordering::Equal => {
114                    score += a[i].1 * b[j].1;
115                    i += 1;
116                    j += 1;
117                }
118                std::cmp::Ordering::Less => i += 1,
119                std::cmp::Ordering::Greater => j += 1,
120            }
121        }
122
123        score
124    }
125}
126
127/// Errors from sparse vector construction or parsing.
128#[derive(Debug, Clone)]
129#[non_exhaustive]
130pub enum SparseVectorError {
131    /// A weight value is NaN or infinite.
132    NonFiniteWeight(f32),
133    /// Literal syntax is malformed.
134    InvalidLiteral(String),
135}
136
137impl std::fmt::Display for SparseVectorError {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        match self {
140            Self::NonFiniteWeight(w) => write!(f, "sparse vector weight must be finite, got {w}"),
141            Self::InvalidLiteral(msg) => write!(f, "invalid sparse vector literal: {msg}"),
142        }
143    }
144}
145
146impl std::error::Error for SparseVectorError {}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn from_entries_sorts_and_deduplicates() {
154        let sv = SparseVector::from_entries(vec![(5, 0.5), (2, 0.3), (5, 0.9)]).unwrap();
155        // Sorted, deduped (last-wins for dim 5 after sort → first 0.5, then dedup keeps first).
156        assert_eq!(sv.nnz(), 2);
157        assert_eq!(sv.entries()[0].0, 2);
158        assert_eq!(sv.entries()[1].0, 5);
159    }
160
161    #[test]
162    fn zero_weights_removed() {
163        let sv = SparseVector::from_entries(vec![(1, 0.5), (2, 0.0), (3, 0.3)]).unwrap();
164        assert_eq!(sv.nnz(), 2);
165        assert!(sv.entries().iter().all(|&(_, w)| w != 0.0));
166    }
167
168    #[test]
169    fn non_finite_rejected() {
170        assert!(SparseVector::from_entries(vec![(1, f32::NAN)]).is_err());
171        assert!(SparseVector::from_entries(vec![(1, f32::INFINITY)]).is_err());
172    }
173
174    #[test]
175    fn parse_literal() {
176        let sv = SparseVector::parse_literal("'{103: 0.85, 2941: 0.42, 15003: 0.91}'").unwrap();
177        assert_eq!(sv.nnz(), 3);
178        assert_eq!(sv.entries()[0], (103, 0.85));
179        assert_eq!(sv.entries()[1], (2941, 0.42));
180        assert_eq!(sv.entries()[2], (15003, 0.91));
181    }
182
183    #[test]
184    fn parse_empty_literal() {
185        let sv = SparseVector::parse_literal("'{}'").unwrap();
186        assert!(sv.is_empty());
187    }
188
189    #[test]
190    fn dot_product_basic() {
191        let a = SparseVector::from_entries(vec![(1, 2.0), (3, 4.0), (5, 6.0)]).unwrap();
192        let b = SparseVector::from_entries(vec![(1, 0.5), (5, 0.5), (7, 1.0)]).unwrap();
193        // Shared dims: 1 (2.0*0.5=1.0) and 5 (6.0*0.5=3.0) = 4.0
194        let score = a.dot_product(&b);
195        assert!((score - 4.0).abs() < 1e-6);
196    }
197
198    #[test]
199    fn dot_product_no_overlap() {
200        let a = SparseVector::from_entries(vec![(1, 1.0)]).unwrap();
201        let b = SparseVector::from_entries(vec![(2, 1.0)]).unwrap();
202        assert_eq!(a.dot_product(&b), 0.0);
203    }
204
205    #[test]
206    fn serde_roundtrip() {
207        let sv = SparseVector::from_entries(vec![(10, 0.5), (20, 0.8)]).unwrap();
208        let bytes = zerompk::to_msgpack_vec(&sv).unwrap();
209        let restored: SparseVector = zerompk::from_msgpack(&bytes).unwrap();
210        assert_eq!(sv, restored);
211    }
212}