Skip to main content

diskann_quantization/multi_vector/distance/
max_sim.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4//! MaxSim and Chamfer distance types for multi-vector representations.
5
6use thiserror::Error;
7
8/// Error type for [`MaxSim`] operations.
9#[derive(Clone, Debug, Copy, Error)]
10pub enum MaxSimError {
11    #[error("Trying to access score in index {0} for output of size {1}")]
12    IndexOutOfBounds(usize, usize),
13    #[error("Scores buffer length cannot be 0")]
14    BufferLengthIsZero,
15    #[error("Invalid buffer length {0} for query size {0}")]
16    InvalidBufferLength(usize, usize),
17}
18
19////////////
20// MaxSim //
21////////////
22
23/// Computes per-query-vector maximum similarities to document vectors.
24///
25/// For each query vector `qᵢ`, finds the maximum similarity (minimum negated
26/// inner product) to any document vector:
27///
28/// ```text
29/// scores[i] = minⱼ -IP(qᵢ, dⱼ)
30/// ```
31///
32/// Implements `DistanceFnMut` for various matrix types
33/// (e.g., [`MatRef<Standard<f32>>`](crate::multi_vector::MatRef)).
34///
35/// # Usage
36/// - Create with [`MaxSim::new`], providing a mutable scores buffer.
37/// - Call `DistanceFnMut::evaluate` with query and document matrices.
38/// - Read results from the scores buffer.
39#[derive(Debug)]
40pub struct MaxSim<'a> {
41    pub(crate) scores: &'a mut [f32],
42}
43
44impl<'a> MaxSim<'a> {
45    /// Creates a new [`MaxSim`] with the provided scores buffer.
46    ///
47    /// # Errors
48    /// Returns an error if `scores` is empty.
49    pub fn new(scores: &'a mut [f32]) -> Result<Self, MaxSimError> {
50        if scores.is_empty() {
51            return Err(MaxSimError::BufferLengthIsZero);
52        }
53        Ok(Self { scores })
54    }
55
56    /// Returns the number of score slots in the buffer.
57    #[inline]
58    pub fn size(&self) -> usize {
59        self.scores.len()
60    }
61
62    /// Returns the score at index `i`.
63    #[inline(always)]
64    pub fn get(&self, i: usize) -> Result<f32, MaxSimError> {
65        self.scores
66            .get(i)
67            .copied()
68            .ok_or_else(|| MaxSimError::IndexOutOfBounds(i, self.size()))
69    }
70
71    /// Sets the score at index `i`.
72    #[inline(always)]
73    pub fn set(&mut self, i: usize, x: f32) -> Result<(), MaxSimError> {
74        let size = self.size();
75
76        let s = self
77            .scores
78            .get_mut(i)
79            .ok_or(MaxSimError::IndexOutOfBounds(i, size))?;
80
81        *s = x;
82        Ok(())
83    }
84
85    /// Returns a mutable reference to the internal buffer of scores.
86    ///
87    /// This is useful for implementations external to crate as well as
88    /// optimized implementations to access the buffer if needed.
89    pub fn scores_mut(&mut self) -> &mut [f32] {
90        self.scores
91    }
92}
93
94/////////////
95// Chamfer //
96/////////////
97
98/// Asymmetric Chamfer distance for multi-vector similarity.
99///
100/// Computes the sum of per-query-vector maximum similarities:
101///
102/// ```text
103/// Chamfer(Q, D) = Σᵢ minⱼ -IP(qᵢ, dⱼ)
104/// ```
105///
106/// Implements [`PureDistanceFunction`](diskann_vector::PureDistanceFunction)
107/// for matrix view types.
108#[derive(Debug, Clone, Copy)]
109pub struct Chamfer;
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    /// Test fixture providing common buffer sizes for testing
116    struct TestFixture {
117        buffer: Vec<f32>,
118    }
119
120    impl TestFixture {
121        fn new(size: usize) -> Self {
122            Self {
123                buffer: vec![0.0; size],
124            }
125        }
126
127        fn with_values(values: &[f32]) -> Self {
128            Self {
129                buffer: values.to_vec(),
130            }
131        }
132
133        fn max_sim(&mut self) -> Result<MaxSim<'_>, MaxSimError> {
134            MaxSim::new(&mut self.buffer)
135        }
136    }
137
138    mod max_sim_new {
139        use super::*;
140
141        #[test]
142        fn fails_with_empty_buffer() {
143            let mut buffer: Vec<f32> = vec![];
144            let result = MaxSim::new(&mut buffer);
145            assert!(matches!(result, Err(MaxSimError::BufferLengthIsZero)));
146        }
147
148        #[test]
149        fn returns_correct_size() {
150            let sizes = [1, 2, 5, 100, 1000];
151            for size in sizes {
152                let mut fixture = TestFixture::new(size);
153                let mut max_sim = fixture.max_sim().unwrap();
154                assert_eq!(max_sim.size(), size, "size mismatch for buffer of {}", size);
155
156                let scores = max_sim.scores_mut();
157                assert_eq!(scores.len(), max_sim.size());
158            }
159        }
160    }
161
162    mod max_sim_get {
163        use super::*;
164
165        #[test]
166        fn returns_value_at_valid_index() {
167            let mut fixture = TestFixture::with_values(&[1.0, 2.0, 3.0]);
168            let max_sim = fixture.max_sim().unwrap();
169
170            assert_eq!(max_sim.get(0).unwrap(), 1.0);
171            assert_eq!(max_sim.get(1).unwrap(), 2.0);
172            assert_eq!(max_sim.get(2).unwrap(), 3.0);
173        }
174
175        #[test]
176        fn fails_at_out_of_bounds_index() {
177            let mut fixture = TestFixture::new(3);
178            let max_sim = fixture.max_sim().unwrap();
179
180            let result = max_sim.get(3);
181            assert!(matches!(result, Err(MaxSimError::IndexOutOfBounds(3, 3))));
182
183            let result = max_sim.get(100);
184            assert!(matches!(result, Err(MaxSimError::IndexOutOfBounds(100, 3))));
185        }
186    }
187
188    mod max_sim_set {
189        use super::*;
190
191        #[test]
192        fn sets_value_at_valid_index() {
193            let mut fixture = TestFixture::new(3);
194            let mut max_sim = fixture.max_sim().unwrap();
195
196            max_sim.set(0, 10.0).unwrap();
197            max_sim.set(1, 20.0).unwrap();
198            max_sim.set(2, 30.0).unwrap();
199
200            assert_eq!(max_sim.get(0).unwrap(), 10.0);
201            assert_eq!(max_sim.get(1).unwrap(), 20.0);
202            assert_eq!(max_sim.get(2).unwrap(), 30.0);
203        }
204
205        #[test]
206        fn fails_at_out_of_bounds_index() {
207            let mut fixture = TestFixture::new(3);
208            let mut max_sim = fixture.max_sim().unwrap();
209
210            let result = max_sim.set(3, 999.0);
211            assert!(matches!(result, Err(MaxSimError::IndexOutOfBounds(3, 3))));
212        }
213
214        #[test]
215        fn overwrites_existing_value() {
216            let mut fixture = TestFixture::with_values(&[1.0, 2.0, 3.0]);
217            let mut max_sim = fixture.max_sim().unwrap();
218
219            max_sim.set(1, 99.0).unwrap();
220
221            assert_eq!(max_sim.get(0).unwrap(), 1.0); // unchanged
222            assert_eq!(max_sim.get(1).unwrap(), 99.0); // changed
223            assert_eq!(max_sim.get(2).unwrap(), 3.0); // unchanged
224        }
225
226        #[test]
227        fn handles_special_float_values() {
228            let mut fixture = TestFixture::new(4);
229            let mut max_sim = fixture.max_sim().unwrap();
230
231            max_sim.set(0, f32::INFINITY).unwrap();
232            max_sim.set(1, f32::NEG_INFINITY).unwrap();
233            max_sim.set(2, f32::NAN).unwrap();
234            max_sim.set(3, -0.0).unwrap();
235
236            assert_eq!(max_sim.get(0).unwrap(), f32::INFINITY);
237            assert_eq!(max_sim.get(1).unwrap(), f32::NEG_INFINITY);
238            assert!(max_sim.get(2).unwrap().is_nan());
239            assert!(max_sim.get(3).unwrap().is_sign_negative());
240        }
241
242        #[test]
243        fn writes_through_to_underlying_buffer() {
244            let mut buffer = vec![0.0f32; 3];
245            {
246                let mut max_sim = MaxSim::new(&mut buffer).unwrap();
247                max_sim.set(0, 1.0).unwrap();
248                max_sim.set(1, 2.0).unwrap();
249            }
250            // After MaxSim is dropped, buffer reflects the changes
251            assert_eq!(buffer, vec![1.0, 2.0, 0.0]);
252        }
253    }
254}