nodedb_types/
sparse_vector.rs1use serde::{Deserialize, Serialize};
7
8#[derive(
13 Debug,
14 Clone,
15 PartialEq,
16 Serialize,
17 Deserialize,
18 zerompk::ToMessagePack,
19 zerompk::FromMessagePack,
20)]
21pub struct SparseVector {
22 entries: Vec<(u32, f32)>,
24}
25
26impl SparseVector {
27 pub fn from_entries(mut entries: Vec<(u32, f32)>) -> Result<Self, SparseVectorError> {
30 for &(_, w) in &entries {
31 if !w.is_finite() {
32 return Err(SparseVectorError::NonFiniteWeight(w));
33 }
34 }
35 entries.sort_by_key(|&(dim, _)| dim);
37 entries.dedup_by_key(|e| e.0);
38 entries.retain(|&(_, w)| w != 0.0);
40 Ok(Self { entries })
41 }
42
43 pub fn empty() -> Self {
45 Self {
46 entries: Vec::new(),
47 }
48 }
49
50 pub fn parse_literal(s: &str) -> Result<Self, SparseVectorError> {
52 let trimmed = s.trim().trim_matches('\'').trim_matches('"');
53 let inner = trimmed
54 .strip_prefix('{')
55 .and_then(|s| s.strip_suffix('}'))
56 .ok_or(SparseVectorError::InvalidLiteral(
57 "expected '{dim: weight, ...}'".into(),
58 ))?;
59
60 if inner.trim().is_empty() {
61 return Ok(Self::empty());
62 }
63
64 let mut entries = Vec::new();
65 for pair in inner.split(',') {
66 let pair = pair.trim();
67 if pair.is_empty() {
68 continue;
69 }
70 let (dim_str, weight_str) = pair.split_once(':').ok_or_else(|| {
71 SparseVectorError::InvalidLiteral(format!("expected 'dim: weight', got '{pair}'"))
72 })?;
73 let dim: u32 = dim_str.trim().parse().map_err(|_| {
74 SparseVectorError::InvalidLiteral(format!("invalid dimension '{}'", dim_str.trim()))
75 })?;
76 let weight: f32 = weight_str.trim().parse().map_err(|_| {
77 SparseVectorError::InvalidLiteral(format!("invalid weight '{}'", weight_str.trim()))
78 })?;
79 entries.push((dim, weight));
80 }
81
82 Self::from_entries(entries)
83 }
84
85 pub fn entries(&self) -> &[(u32, f32)] {
87 &self.entries
88 }
89
90 pub fn nnz(&self) -> usize {
92 self.entries.len()
93 }
94
95 pub fn is_empty(&self) -> bool {
97 self.entries.is_empty()
98 }
99
100 pub fn dot_product(&self, other: &SparseVector) -> f32 {
105 let mut score = 0.0f32;
106 let (mut i, mut j) = (0, 0);
107 let (a, b) = (&self.entries, &other.entries);
108
109 while i < a.len() && j < b.len() {
110 match a[i].0.cmp(&b[j].0) {
111 std::cmp::Ordering::Equal => {
112 score += a[i].1 * b[j].1;
113 i += 1;
114 j += 1;
115 }
116 std::cmp::Ordering::Less => i += 1,
117 std::cmp::Ordering::Greater => j += 1,
118 }
119 }
120
121 score
122 }
123}
124
125#[derive(Debug, Clone)]
127pub enum SparseVectorError {
128 NonFiniteWeight(f32),
130 InvalidLiteral(String),
132}
133
134impl std::fmt::Display for SparseVectorError {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 match self {
137 Self::NonFiniteWeight(w) => write!(f, "sparse vector weight must be finite, got {w}"),
138 Self::InvalidLiteral(msg) => write!(f, "invalid sparse vector literal: {msg}"),
139 }
140 }
141}
142
143impl std::error::Error for SparseVectorError {}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn from_entries_sorts_and_deduplicates() {
151 let sv = SparseVector::from_entries(vec![(5, 0.5), (2, 0.3), (5, 0.9)]).unwrap();
152 assert_eq!(sv.nnz(), 2);
154 assert_eq!(sv.entries()[0].0, 2);
155 assert_eq!(sv.entries()[1].0, 5);
156 }
157
158 #[test]
159 fn zero_weights_removed() {
160 let sv = SparseVector::from_entries(vec![(1, 0.5), (2, 0.0), (3, 0.3)]).unwrap();
161 assert_eq!(sv.nnz(), 2);
162 assert!(sv.entries().iter().all(|&(_, w)| w != 0.0));
163 }
164
165 #[test]
166 fn non_finite_rejected() {
167 assert!(SparseVector::from_entries(vec![(1, f32::NAN)]).is_err());
168 assert!(SparseVector::from_entries(vec![(1, f32::INFINITY)]).is_err());
169 }
170
171 #[test]
172 fn parse_literal() {
173 let sv = SparseVector::parse_literal("'{103: 0.85, 2941: 0.42, 15003: 0.91}'").unwrap();
174 assert_eq!(sv.nnz(), 3);
175 assert_eq!(sv.entries()[0], (103, 0.85));
176 assert_eq!(sv.entries()[1], (2941, 0.42));
177 assert_eq!(sv.entries()[2], (15003, 0.91));
178 }
179
180 #[test]
181 fn parse_empty_literal() {
182 let sv = SparseVector::parse_literal("'{}'").unwrap();
183 assert!(sv.is_empty());
184 }
185
186 #[test]
187 fn dot_product_basic() {
188 let a = SparseVector::from_entries(vec![(1, 2.0), (3, 4.0), (5, 6.0)]).unwrap();
189 let b = SparseVector::from_entries(vec![(1, 0.5), (5, 0.5), (7, 1.0)]).unwrap();
190 let score = a.dot_product(&b);
192 assert!((score - 4.0).abs() < 1e-6);
193 }
194
195 #[test]
196 fn dot_product_no_overlap() {
197 let a = SparseVector::from_entries(vec![(1, 1.0)]).unwrap();
198 let b = SparseVector::from_entries(vec![(2, 1.0)]).unwrap();
199 assert_eq!(a.dot_product(&b), 0.0);
200 }
201
202 #[test]
203 fn serde_roundtrip() {
204 let sv = SparseVector::from_entries(vec![(10, 0.5), (20, 0.8)]).unwrap();
205 let bytes = zerompk::to_msgpack_vec(&sv).unwrap();
206 let restored: SparseVector = zerompk::from_msgpack(&bytes).unwrap();
207 assert_eq!(sv, restored);
208 }
209}