1use std::borrow::Cow;
2use std::hash::Hash;
3
4use crate::common::types::ScoreType;
5use crate::gridstore::Blob;
6use itertools::Itertools;
7use ordered_float::OrderedFloat;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use validator::{Validate, ValidationError, ValidationErrors};
11
12use crate::sparse::common::types::{DimId, DimOffset, DimWeight};
13
14#[derive(Debug, PartialEq, Clone, Default, Serialize, Deserialize, JsonSchema)]
16#[serde(rename_all = "snake_case")]
17pub struct SparseVector {
18 pub indices: Vec<DimId>,
20 pub values: Vec<DimWeight>,
22}
23
24impl Hash for SparseVector {
25 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
26 let Self { indices, values } = self;
27 indices.hash(state);
28 for &value in values {
29 OrderedFloat(value).hash(state);
30 }
31 }
32}
33
34#[derive(Debug, PartialEq, Clone, Default, Serialize, Deserialize)]
37pub struct RemappedSparseVector {
38 pub indices: Vec<DimOffset>,
40 pub values: Vec<DimWeight>,
42}
43
44pub fn double_sort<T: Ord + Copy, V: Copy>(indices: &mut [T], values: &mut [V]) {
46 if indices.array_windows().all(|[a, b]| a < b) {
48 return;
49 }
50
51 let mut indexed_values: Vec<(T, V)> = indices
52 .iter()
53 .zip(values.iter())
54 .map(|(&i, &v)| (i, v))
55 .collect();
56
57 indexed_values.sort_unstable_by_key(|&(i, _)| i);
59
60 for (i, (index, value)) in indexed_values.into_iter().enumerate() {
61 indices[i] = index;
62 values[i] = value;
63 }
64}
65
66pub fn score_vectors<T: Ord + Eq>(
67 self_indices: &[T],
68 self_values: &[DimWeight],
69 other_indices: &[T],
70 other_values: &[DimWeight],
71) -> Option<ScoreType> {
72 let mut score = 0.0;
73 let mut overlap = false;
75 let mut i = 0;
76 let mut j = 0;
77 while i < self_indices.len() && j < other_indices.len() {
78 match self_indices[i].cmp(&other_indices[j]) {
79 std::cmp::Ordering::Less => i += 1,
80 std::cmp::Ordering::Greater => j += 1,
81 std::cmp::Ordering::Equal => {
82 overlap = true;
83 score += self_values[i] * other_values[j];
84 i += 1;
85 j += 1;
86 }
87 }
88 }
89 if overlap { Some(score) } else { None }
90}
91
92impl RemappedSparseVector {
93 pub fn new(indices: Vec<DimId>, values: Vec<DimWeight>) -> Result<Self, ValidationErrors> {
94 let vector = Self { indices, values };
95 vector.validate()?;
96 Ok(vector)
97 }
98
99 pub fn sort_by_indices(&mut self) {
100 double_sort(&mut self.indices, &mut self.values);
101 }
102
103 pub fn is_sorted(&self) -> bool {
105 self.indices.array_windows().all(|[a, b]| a < b)
106 }
107
108 pub fn score(&self, other: &RemappedSparseVector) -> Option<ScoreType> {
113 debug_assert!(self.is_sorted());
114 debug_assert!(other.is_sorted());
115 score_vectors(&self.indices, &self.values, &other.indices, &other.values)
116 }
117
118 pub fn len(&self) -> usize {
120 self.indices.len()
121 }
122
123 pub fn is_empty(&self) -> bool {
124 self.len() == 0
125 }
126}
127
128impl SparseVector {
129 pub fn new(indices: Vec<DimId>, values: Vec<DimWeight>) -> Result<Self, ValidationErrors> {
130 let vector = SparseVector { indices, values };
131 vector.validate()?;
132 Ok(vector)
133 }
134
135 pub fn sort_by_indices(&mut self) {
139 double_sort(&mut self.indices, &mut self.values);
140 }
141
142 pub fn is_sorted(&self) -> bool {
144 self.indices.windows(2).all(|w| w[0] < w[1])
145 }
146
147 pub fn is_empty(&self) -> bool {
149 self.indices.is_empty() && self.values.is_empty()
150 }
151
152 pub fn len(&self) -> usize {
154 self.indices.len()
155 }
156
157 pub fn score(&self, other: &SparseVector) -> Option<ScoreType> {
162 debug_assert!(self.is_sorted());
163 debug_assert!(other.is_sorted());
164 score_vectors(&self.indices, &self.values, &other.indices, &other.values)
165 }
166
167 pub fn combine_aggregate(
170 &self,
171 other: &SparseVector,
172 op: impl Fn(DimWeight, DimWeight) -> DimWeight,
173 ) -> Self {
174 let this: Cow<SparseVector> = if !self.is_sorted() {
176 let mut this = self.clone();
177 this.sort_by_indices();
178 Cow::Owned(this)
179 } else {
180 Cow::Borrowed(self)
181 };
182 assert!(this.is_sorted());
183
184 let cow_other: Cow<SparseVector> = if !other.is_sorted() {
186 let mut other = other.clone();
187 other.sort_by_indices();
188 Cow::Owned(other)
189 } else {
190 Cow::Borrowed(other)
191 };
192 let other = &cow_other;
193 assert!(other.is_sorted());
194
195 let mut result = SparseVector::default();
196 let mut i = 0;
197 let mut j = 0;
198 while i < this.indices.len() && j < other.indices.len() {
199 match this.indices[i].cmp(&other.indices[j]) {
200 std::cmp::Ordering::Less => {
201 result.indices.push(this.indices[i]);
202 result.values.push(op(this.values[i], 0.0));
203 i += 1;
204 }
205 std::cmp::Ordering::Greater => {
206 result.indices.push(other.indices[j]);
207 result.values.push(op(0.0, other.values[j]));
208 j += 1;
209 }
210 std::cmp::Ordering::Equal => {
211 result.indices.push(this.indices[i]);
212 result.values.push(op(this.values[i], other.values[j]));
213 i += 1;
214 j += 1;
215 }
216 }
217 }
218 while i < this.indices.len() {
219 result.indices.push(this.indices[i]);
220 result.values.push(op(this.values[i], 0.0));
221 i += 1;
222 }
223 while j < other.indices.len() {
224 result.indices.push(other.indices[j]);
225 result.values.push(op(0.0, other.values[j]));
226 j += 1;
227 }
228 debug_assert!(result.is_sorted());
229 debug_assert!(result.validate().is_ok());
230 result
231 }
232
233 #[cfg(feature = "testing")]
235 pub fn into_remapped(self) -> RemappedSparseVector {
236 RemappedSparseVector {
237 indices: self.indices,
238 values: self.values,
239 }
240 }
241}
242
243impl TryFrom<Vec<(u32, f32)>> for RemappedSparseVector {
244 type Error = ValidationErrors;
245
246 fn try_from(tuples: Vec<(u32, f32)>) -> Result<Self, Self::Error> {
247 let (indices, values): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
248 RemappedSparseVector::new(indices, values)
249 }
250}
251
252impl TryFrom<Vec<(u32, f32)>> for SparseVector {
253 type Error = ValidationErrors;
254
255 fn try_from(tuples: Vec<(u32, f32)>) -> Result<Self, Self::Error> {
256 let (indices, values): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
257 SparseVector::new(indices, values)
258 }
259}
260
261impl Blob for SparseVector {
262 fn to_bytes(&self) -> Vec<u8> {
263 bincode::serialize(&self).expect("Sparse vector serialization should not fail")
264 }
265
266 fn from_bytes(data: &[u8]) -> Self {
267 bincode::deserialize(data).expect("Sparse vector deserialization should not fail")
268 }
269}
270
271#[cfg(test)]
272impl<const N: usize> From<[(u32, f32); N]> for SparseVector {
273 fn from(value: [(u32, f32); N]) -> Self {
274 value.to_vec().try_into().unwrap()
275 }
276}
277
278#[cfg(test)]
279impl<const N: usize> From<[(u32, f32); N]> for RemappedSparseVector {
280 fn from(value: [(u32, f32); N]) -> Self {
281 value.to_vec().try_into().unwrap()
282 }
283}
284
285impl Validate for SparseVector {
286 fn validate(&self) -> Result<(), ValidationErrors> {
287 validate_sparse_vector_impl(&self.indices, &self.values)
288 }
289}
290
291impl Validate for RemappedSparseVector {
292 fn validate(&self) -> Result<(), ValidationErrors> {
293 validate_sparse_vector_impl(&self.indices, &self.values)
294 }
295}
296
297pub fn validate_sparse_vector_impl<T: Clone + Eq + Hash>(
298 indices: &[T],
299 values: &[DimWeight],
300) -> Result<(), ValidationErrors> {
301 let mut errors = ValidationErrors::default();
302
303 if indices.len() != values.len() {
304 errors.add(
305 "values",
306 ValidationError::new("must be the same length as indices"),
307 );
308 }
309 if indices.iter().unique().count() != indices.len() {
310 errors.add("indices", ValidationError::new("must be unique"));
311 }
312
313 if errors.is_empty() {
314 Ok(())
315 } else {
316 Err(errors)
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_score_aligned_same_size() {
326 let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
327 let v2 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
328 assert_eq!(v1.score(&v2), Some(14.0));
329 }
330
331 #[test]
332 fn test_score_not_aligned_same_size() {
333 let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
334 let v2 = RemappedSparseVector::new(vec![2, 3, 4], vec![2.0, 3.0, 4.0]).unwrap();
335 assert_eq!(v1.score(&v2), Some(13.0));
336 }
337
338 #[test]
339 fn test_score_aligned_different_size() {
340 let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
341 let v2 = RemappedSparseVector::new(vec![1, 2, 3, 4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
342 assert_eq!(v1.score(&v2), Some(14.0));
343 }
344
345 #[test]
346 fn test_score_not_aligned_different_size() {
347 let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
348 let v2 = RemappedSparseVector::new(vec![2, 3, 4, 5], vec![2.0, 3.0, 4.0, 5.0]).unwrap();
349 assert_eq!(v1.score(&v2), Some(13.0));
350 }
351
352 #[test]
353 fn test_score_no_overlap() {
354 let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
355 let v2 = RemappedSparseVector::new(vec![4, 5, 6], vec![2.0, 3.0, 4.0]).unwrap();
356 assert!(v1.score(&v2).is_none());
357 }
358
359 #[test]
360 fn validation_test() {
361 let fully_empty = SparseVector::new(vec![], vec![]);
362 assert!(fully_empty.is_ok());
363 assert!(fully_empty.unwrap().is_empty());
364
365 let different_length = SparseVector::new(vec![1, 2, 3], vec![1.0, 2.0]);
366 assert!(different_length.is_err());
367
368 let not_sorted = SparseVector::new(vec![1, 3, 2], vec![1.0, 2.0, 3.0]);
369 assert!(not_sorted.is_ok());
370
371 let not_unique = SparseVector::new(vec![1, 2, 3, 2], vec![1.0, 2.0, 3.0, 4.0]);
372 assert!(not_unique.is_err());
373 }
374
375 #[test]
376 fn sorting_test() {
377 let mut not_sorted = SparseVector::new(vec![1, 3, 2], vec![1.0, 2.0, 3.0]).unwrap();
378 assert!(!not_sorted.is_sorted());
379 not_sorted.sort_by_indices();
380 assert!(not_sorted.is_sorted());
381 }
382
383 #[test]
384 fn combine_aggregate_test() {
385 let a = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
387 let b = SparseVector::new(vec![2, 3, 4], vec![2.0, 3.0, 4.0]).unwrap();
388 let sum = a.combine_aggregate(&b, |x, y| x + 2.0 * y);
389 assert_eq!(sum.indices, vec![1, 2, 3, 4]);
390 assert_eq!(sum.values, vec![0.1, 4.2, 6.3, 8.0]);
391
392 let sum = b.combine_aggregate(&a, |x, y| x + 2.0 * y);
394 assert_eq!(sum.indices, vec![1, 2, 3, 4]);
395 assert_eq!(sum.values, vec![0.2, 2.4, 3.6, 4.0]);
396
397 let a = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
399 let b = SparseVector::new(vec![4, 2, 3], vec![4.0, 2.0, 3.0]).unwrap();
400 let sum = a.combine_aggregate(&b, |x, y| x + 2.0 * y);
401 assert_eq!(sum.indices, vec![1, 2, 3, 4]);
402 assert_eq!(sum.values, vec![0.1, 4.2, 6.3, 8.0]);
403 }
404}