Skip to main content

qdrant_edge/sparse/common/
sparse_vector.rs

1use std::borrow::Cow;
2use std::hash::Hash;
3
4use crate::common::types::ScoreType;
5use crate::gridstore::Blob;
6use itertools::Itertools;
7use ordered_float::OrderedFloat;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use validator::{Validate, ValidationError, ValidationErrors};
11
12use crate::sparse::common::types::{DimId, DimOffset, DimWeight};
13
14/// Sparse vector structure
15#[derive(Debug, PartialEq, Clone, Default, Serialize, Deserialize, JsonSchema)]
16#[serde(rename_all = "snake_case")]
17pub struct SparseVector {
18    /// Indices must be unique
19    pub indices: Vec<DimId>,
20    /// Values and indices must be the same length
21    pub values: Vec<DimWeight>,
22}
23
24impl Hash for SparseVector {
25    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
26        let Self { indices, values } = self;
27        indices.hash(state);
28        for &value in values {
29            OrderedFloat(value).hash(state);
30        }
31    }
32}
33
34/// Same as `SparseVector` but with `DimOffset` indices.
35/// Meaning that is uses internal segment-specific indices.
36#[derive(Debug, PartialEq, Clone, Default, Serialize, Deserialize)]
37pub struct RemappedSparseVector {
38    /// indices must be unique
39    pub indices: Vec<DimOffset>,
40    /// values and indices must be the same length
41    pub values: Vec<DimWeight>,
42}
43
44/// Sort two arrays by the first array.
45pub fn double_sort<T: Ord + Copy, V: Copy>(indices: &mut [T], values: &mut [V]) {
46    // Check if the indices are already sorted
47    if indices.array_windows().all(|[a, b]| a < b) {
48        return;
49    }
50
51    let mut indexed_values: Vec<(T, V)> = indices
52        .iter()
53        .zip(values.iter())
54        .map(|(&i, &v)| (i, v))
55        .collect();
56
57    // Sort the vector of tuples by indices
58    indexed_values.sort_unstable_by_key(|&(i, _)| i);
59
60    for (i, (index, value)) in indexed_values.into_iter().enumerate() {
61        indices[i] = index;
62        values[i] = value;
63    }
64}
65
66pub fn score_vectors<T: Ord + Eq>(
67    self_indices: &[T],
68    self_values: &[DimWeight],
69    other_indices: &[T],
70    other_values: &[DimWeight],
71) -> Option<ScoreType> {
72    let mut score = 0.0;
73    // track whether there is any overlap
74    let mut overlap = false;
75    let mut i = 0;
76    let mut j = 0;
77    while i < self_indices.len() && j < other_indices.len() {
78        match self_indices[i].cmp(&other_indices[j]) {
79            std::cmp::Ordering::Less => i += 1,
80            std::cmp::Ordering::Greater => j += 1,
81            std::cmp::Ordering::Equal => {
82                overlap = true;
83                score += self_values[i] * other_values[j];
84                i += 1;
85                j += 1;
86            }
87        }
88    }
89    if overlap { Some(score) } else { None }
90}
91
92impl RemappedSparseVector {
93    pub fn new(indices: Vec<DimId>, values: Vec<DimWeight>) -> Result<Self, ValidationErrors> {
94        let vector = Self { indices, values };
95        vector.validate()?;
96        Ok(vector)
97    }
98
99    pub fn sort_by_indices(&mut self) {
100        double_sort(&mut self.indices, &mut self.values);
101    }
102
103    /// Check if this vector is sorted by indices.
104    pub fn is_sorted(&self) -> bool {
105        self.indices.array_windows().all(|[a, b]| a < b)
106    }
107
108    /// Score this vector against another vector using dot product.
109    /// Warning: Expects both vectors to be sorted by indices.
110    ///
111    /// Return None if the vectors do not overlap.
112    pub fn score(&self, other: &RemappedSparseVector) -> Option<ScoreType> {
113        debug_assert!(self.is_sorted());
114        debug_assert!(other.is_sorted());
115        score_vectors(&self.indices, &self.values, &other.indices, &other.values)
116    }
117
118    /// Returns the number of elements in the vector.
119    pub fn len(&self) -> usize {
120        self.indices.len()
121    }
122
123    pub fn is_empty(&self) -> bool {
124        self.len() == 0
125    }
126}
127
128impl SparseVector {
129    pub fn new(indices: Vec<DimId>, values: Vec<DimWeight>) -> Result<Self, ValidationErrors> {
130        let vector = SparseVector { indices, values };
131        vector.validate()?;
132        Ok(vector)
133    }
134
135    /// Sort this vector by indices.
136    ///
137    /// Sorting is required for scoring and overlap checks.
138    pub fn sort_by_indices(&mut self) {
139        double_sort(&mut self.indices, &mut self.values);
140    }
141
142    /// Check if this vector is sorted by indices.
143    pub fn is_sorted(&self) -> bool {
144        self.indices.windows(2).all(|w| w[0] < w[1])
145    }
146
147    /// Check if this vector is empty.
148    pub fn is_empty(&self) -> bool {
149        self.indices.is_empty() && self.values.is_empty()
150    }
151
152    /// Returns the number of elements in the vector.
153    pub fn len(&self) -> usize {
154        self.indices.len()
155    }
156
157    /// Score this vector against another vector using dot product.
158    /// Warning: Expects both vectors to be sorted by indices.
159    ///
160    /// Return None if the vectors do not overlap.
161    pub fn score(&self, other: &SparseVector) -> Option<ScoreType> {
162        debug_assert!(self.is_sorted());
163        debug_assert!(other.is_sorted());
164        score_vectors(&self.indices, &self.values, &other.indices, &other.values)
165    }
166
167    /// Construct a new vector that is the result of performing all indices-wise operations.
168    /// Automatically sort input vectors if necessary.
169    pub fn combine_aggregate(
170        &self,
171        other: &SparseVector,
172        op: impl Fn(DimWeight, DimWeight) -> DimWeight,
173    ) -> Self {
174        // Copy and sort `self` vector if not already sorted
175        let this: Cow<SparseVector> = if !self.is_sorted() {
176            let mut this = self.clone();
177            this.sort_by_indices();
178            Cow::Owned(this)
179        } else {
180            Cow::Borrowed(self)
181        };
182        assert!(this.is_sorted());
183
184        // Copy and sort `other` vector if not already sorted
185        let cow_other: Cow<SparseVector> = if !other.is_sorted() {
186            let mut other = other.clone();
187            other.sort_by_indices();
188            Cow::Owned(other)
189        } else {
190            Cow::Borrowed(other)
191        };
192        let other = &cow_other;
193        assert!(other.is_sorted());
194
195        let mut result = SparseVector::default();
196        let mut i = 0;
197        let mut j = 0;
198        while i < this.indices.len() && j < other.indices.len() {
199            match this.indices[i].cmp(&other.indices[j]) {
200                std::cmp::Ordering::Less => {
201                    result.indices.push(this.indices[i]);
202                    result.values.push(op(this.values[i], 0.0));
203                    i += 1;
204                }
205                std::cmp::Ordering::Greater => {
206                    result.indices.push(other.indices[j]);
207                    result.values.push(op(0.0, other.values[j]));
208                    j += 1;
209                }
210                std::cmp::Ordering::Equal => {
211                    result.indices.push(this.indices[i]);
212                    result.values.push(op(this.values[i], other.values[j]));
213                    i += 1;
214                    j += 1;
215                }
216            }
217        }
218        while i < this.indices.len() {
219            result.indices.push(this.indices[i]);
220            result.values.push(op(this.values[i], 0.0));
221            i += 1;
222        }
223        while j < other.indices.len() {
224            result.indices.push(other.indices[j]);
225            result.values.push(op(0.0, other.values[j]));
226            j += 1;
227        }
228        debug_assert!(result.is_sorted());
229        debug_assert!(result.validate().is_ok());
230        result
231    }
232
233    /// Create [RemappedSparseVector] from this vector in a naive way. Only suitable for testing.
234    #[cfg(feature = "testing")]
235    pub fn into_remapped(self) -> RemappedSparseVector {
236        RemappedSparseVector {
237            indices: self.indices,
238            values: self.values,
239        }
240    }
241}
242
243impl TryFrom<Vec<(u32, f32)>> for RemappedSparseVector {
244    type Error = ValidationErrors;
245
246    fn try_from(tuples: Vec<(u32, f32)>) -> Result<Self, Self::Error> {
247        let (indices, values): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
248        RemappedSparseVector::new(indices, values)
249    }
250}
251
252impl TryFrom<Vec<(u32, f32)>> for SparseVector {
253    type Error = ValidationErrors;
254
255    fn try_from(tuples: Vec<(u32, f32)>) -> Result<Self, Self::Error> {
256        let (indices, values): (Vec<_>, Vec<_>) = tuples.into_iter().unzip();
257        SparseVector::new(indices, values)
258    }
259}
260
261impl Blob for SparseVector {
262    fn to_bytes(&self) -> Vec<u8> {
263        bincode::serialize(&self).expect("Sparse vector serialization should not fail")
264    }
265
266    fn from_bytes(data: &[u8]) -> Self {
267        bincode::deserialize(data).expect("Sparse vector deserialization should not fail")
268    }
269}
270
271#[cfg(test)]
272impl<const N: usize> From<[(u32, f32); N]> for SparseVector {
273    fn from(value: [(u32, f32); N]) -> Self {
274        value.to_vec().try_into().unwrap()
275    }
276}
277
278#[cfg(test)]
279impl<const N: usize> From<[(u32, f32); N]> for RemappedSparseVector {
280    fn from(value: [(u32, f32); N]) -> Self {
281        value.to_vec().try_into().unwrap()
282    }
283}
284
285impl Validate for SparseVector {
286    fn validate(&self) -> Result<(), ValidationErrors> {
287        validate_sparse_vector_impl(&self.indices, &self.values)
288    }
289}
290
291impl Validate for RemappedSparseVector {
292    fn validate(&self) -> Result<(), ValidationErrors> {
293        validate_sparse_vector_impl(&self.indices, &self.values)
294    }
295}
296
297pub fn validate_sparse_vector_impl<T: Clone + Eq + Hash>(
298    indices: &[T],
299    values: &[DimWeight],
300) -> Result<(), ValidationErrors> {
301    let mut errors = ValidationErrors::default();
302
303    if indices.len() != values.len() {
304        errors.add(
305            "values",
306            ValidationError::new("must be the same length as indices"),
307        );
308    }
309    if indices.iter().unique().count() != indices.len() {
310        errors.add("indices", ValidationError::new("must be unique"));
311    }
312
313    if errors.is_empty() {
314        Ok(())
315    } else {
316        Err(errors)
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_score_aligned_same_size() {
326        let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
327        let v2 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
328        assert_eq!(v1.score(&v2), Some(14.0));
329    }
330
331    #[test]
332    fn test_score_not_aligned_same_size() {
333        let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
334        let v2 = RemappedSparseVector::new(vec![2, 3, 4], vec![2.0, 3.0, 4.0]).unwrap();
335        assert_eq!(v1.score(&v2), Some(13.0));
336    }
337
338    #[test]
339    fn test_score_aligned_different_size() {
340        let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
341        let v2 = RemappedSparseVector::new(vec![1, 2, 3, 4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
342        assert_eq!(v1.score(&v2), Some(14.0));
343    }
344
345    #[test]
346    fn test_score_not_aligned_different_size() {
347        let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
348        let v2 = RemappedSparseVector::new(vec![2, 3, 4, 5], vec![2.0, 3.0, 4.0, 5.0]).unwrap();
349        assert_eq!(v1.score(&v2), Some(13.0));
350    }
351
352    #[test]
353    fn test_score_no_overlap() {
354        let v1 = RemappedSparseVector::new(vec![1, 2, 3], vec![1.0, 2.0, 3.0]).unwrap();
355        let v2 = RemappedSparseVector::new(vec![4, 5, 6], vec![2.0, 3.0, 4.0]).unwrap();
356        assert!(v1.score(&v2).is_none());
357    }
358
359    #[test]
360    fn validation_test() {
361        let fully_empty = SparseVector::new(vec![], vec![]);
362        assert!(fully_empty.is_ok());
363        assert!(fully_empty.unwrap().is_empty());
364
365        let different_length = SparseVector::new(vec![1, 2, 3], vec![1.0, 2.0]);
366        assert!(different_length.is_err());
367
368        let not_sorted = SparseVector::new(vec![1, 3, 2], vec![1.0, 2.0, 3.0]);
369        assert!(not_sorted.is_ok());
370
371        let not_unique = SparseVector::new(vec![1, 2, 3, 2], vec![1.0, 2.0, 3.0, 4.0]);
372        assert!(not_unique.is_err());
373    }
374
375    #[test]
376    fn sorting_test() {
377        let mut not_sorted = SparseVector::new(vec![1, 3, 2], vec![1.0, 2.0, 3.0]).unwrap();
378        assert!(!not_sorted.is_sorted());
379        not_sorted.sort_by_indices();
380        assert!(not_sorted.is_sorted());
381    }
382
383    #[test]
384    fn combine_aggregate_test() {
385        // Test with missing index
386        let a = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
387        let b = SparseVector::new(vec![2, 3, 4], vec![2.0, 3.0, 4.0]).unwrap();
388        let sum = a.combine_aggregate(&b, |x, y| x + 2.0 * y);
389        assert_eq!(sum.indices, vec![1, 2, 3, 4]);
390        assert_eq!(sum.values, vec![0.1, 4.2, 6.3, 8.0]);
391
392        // reverse arguments
393        let sum = b.combine_aggregate(&a, |x, y| x + 2.0 * y);
394        assert_eq!(sum.indices, vec![1, 2, 3, 4]);
395        assert_eq!(sum.values, vec![0.2, 2.4, 3.6, 4.0]);
396
397        // Test with non-sorted input
398        let a = SparseVector::new(vec![1, 2, 3], vec![0.1, 0.2, 0.3]).unwrap();
399        let b = SparseVector::new(vec![4, 2, 3], vec![4.0, 2.0, 3.0]).unwrap();
400        let sum = a.combine_aggregate(&b, |x, y| x + 2.0 * y);
401        assert_eq!(sum.indices, vec![1, 2, 3, 4]);
402        assert_eq!(sum.values, vec![0.1, 4.2, 6.3, 8.0]);
403    }
404}