nodedb_types/
sparse_vector.rs1use serde::{Deserialize, Serialize};
9
10#[derive(
15 Debug,
16 Clone,
17 PartialEq,
18 Serialize,
19 Deserialize,
20 zerompk::ToMessagePack,
21 zerompk::FromMessagePack,
22)]
23pub struct SparseVector {
24 entries: Vec<(u32, f32)>,
26}
27
28impl SparseVector {
29 pub fn from_entries(mut entries: Vec<(u32, f32)>) -> Result<Self, SparseVectorError> {
32 for &(_, w) in &entries {
33 if !w.is_finite() {
34 return Err(SparseVectorError::NonFiniteWeight(w));
35 }
36 }
37 entries.sort_by_key(|&(dim, _)| dim);
39 entries.dedup_by_key(|e| e.0);
40 entries.retain(|&(_, w)| w != 0.0);
42 Ok(Self { entries })
43 }
44
45 pub fn empty() -> Self {
47 Self {
48 entries: Vec::new(),
49 }
50 }
51
52 pub fn parse_literal(s: &str) -> Result<Self, SparseVectorError> {
54 let trimmed = s.trim().trim_matches('\'').trim_matches('"');
55 let inner = trimmed
56 .strip_prefix('{')
57 .and_then(|s| s.strip_suffix('}'))
58 .ok_or(SparseVectorError::InvalidLiteral(
59 "expected '{dim: weight, ...}'".into(),
60 ))?;
61
62 if inner.trim().is_empty() {
63 return Ok(Self::empty());
64 }
65
66 let mut entries = Vec::new();
67 for pair in inner.split(',') {
68 let pair = pair.trim();
69 if pair.is_empty() {
70 continue;
71 }
72 let (dim_str, weight_str) = pair.split_once(':').ok_or_else(|| {
73 SparseVectorError::InvalidLiteral(format!("expected 'dim: weight', got '{pair}'"))
74 })?;
75 let dim: u32 = dim_str.trim().parse().map_err(|_| {
76 SparseVectorError::InvalidLiteral(format!("invalid dimension '{}'", dim_str.trim()))
77 })?;
78 let weight: f32 = weight_str.trim().parse().map_err(|_| {
79 SparseVectorError::InvalidLiteral(format!("invalid weight '{}'", weight_str.trim()))
80 })?;
81 entries.push((dim, weight));
82 }
83
84 Self::from_entries(entries)
85 }
86
87 pub fn entries(&self) -> &[(u32, f32)] {
89 &self.entries
90 }
91
92 pub fn nnz(&self) -> usize {
94 self.entries.len()
95 }
96
97 pub fn is_empty(&self) -> bool {
99 self.entries.is_empty()
100 }
101
102 pub fn dot_product(&self, other: &SparseVector) -> f32 {
107 let mut score = 0.0f32;
108 let (mut i, mut j) = (0, 0);
109 let (a, b) = (&self.entries, &other.entries);
110
111 while i < a.len() && j < b.len() {
112 match a[i].0.cmp(&b[j].0) {
113 std::cmp::Ordering::Equal => {
114 score += a[i].1 * b[j].1;
115 i += 1;
116 j += 1;
117 }
118 std::cmp::Ordering::Less => i += 1,
119 std::cmp::Ordering::Greater => j += 1,
120 }
121 }
122
123 score
124 }
125}
126
127#[derive(Debug, Clone)]
129#[non_exhaustive]
130pub enum SparseVectorError {
131 NonFiniteWeight(f32),
133 InvalidLiteral(String),
135}
136
137impl std::fmt::Display for SparseVectorError {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 match self {
140 Self::NonFiniteWeight(w) => write!(f, "sparse vector weight must be finite, got {w}"),
141 Self::InvalidLiteral(msg) => write!(f, "invalid sparse vector literal: {msg}"),
142 }
143 }
144}
145
146impl std::error::Error for SparseVectorError {}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn from_entries_sorts_and_deduplicates() {
154 let sv = SparseVector::from_entries(vec![(5, 0.5), (2, 0.3), (5, 0.9)]).unwrap();
155 assert_eq!(sv.nnz(), 2);
157 assert_eq!(sv.entries()[0].0, 2);
158 assert_eq!(sv.entries()[1].0, 5);
159 }
160
161 #[test]
162 fn zero_weights_removed() {
163 let sv = SparseVector::from_entries(vec![(1, 0.5), (2, 0.0), (3, 0.3)]).unwrap();
164 assert_eq!(sv.nnz(), 2);
165 assert!(sv.entries().iter().all(|&(_, w)| w != 0.0));
166 }
167
168 #[test]
169 fn non_finite_rejected() {
170 assert!(SparseVector::from_entries(vec![(1, f32::NAN)]).is_err());
171 assert!(SparseVector::from_entries(vec![(1, f32::INFINITY)]).is_err());
172 }
173
174 #[test]
175 fn parse_literal() {
176 let sv = SparseVector::parse_literal("'{103: 0.85, 2941: 0.42, 15003: 0.91}'").unwrap();
177 assert_eq!(sv.nnz(), 3);
178 assert_eq!(sv.entries()[0], (103, 0.85));
179 assert_eq!(sv.entries()[1], (2941, 0.42));
180 assert_eq!(sv.entries()[2], (15003, 0.91));
181 }
182
183 #[test]
184 fn parse_empty_literal() {
185 let sv = SparseVector::parse_literal("'{}'").unwrap();
186 assert!(sv.is_empty());
187 }
188
189 #[test]
190 fn dot_product_basic() {
191 let a = SparseVector::from_entries(vec![(1, 2.0), (3, 4.0), (5, 6.0)]).unwrap();
192 let b = SparseVector::from_entries(vec![(1, 0.5), (5, 0.5), (7, 1.0)]).unwrap();
193 let score = a.dot_product(&b);
195 assert!((score - 4.0).abs() < 1e-6);
196 }
197
198 #[test]
199 fn dot_product_no_overlap() {
200 let a = SparseVector::from_entries(vec![(1, 1.0)]).unwrap();
201 let b = SparseVector::from_entries(vec![(2, 1.0)]).unwrap();
202 assert_eq!(a.dot_product(&b), 0.0);
203 }
204
205 #[test]
206 fn serde_roundtrip() {
207 let sv = SparseVector::from_entries(vec![(10, 0.5), (20, 0.8)]).unwrap();
208 let bytes = zerompk::to_msgpack_vec(&sv).unwrap();
209 let restored: SparseVector = zerompk::from_msgpack(&bytes).unwrap();
210 assert_eq!(sv, restored);
211 }
212}