diskann_quantization/multi_vector/distance/
max_sim.rs1use thiserror::Error;
7
8#[derive(Clone, Debug, Copy, Error)]
10pub enum MaxSimError {
11 #[error("Trying to access score in index {0} for output of size {1}")]
12 IndexOutOfBounds(usize, usize),
13 #[error("Scores buffer length cannot be 0")]
14 BufferLengthIsZero,
15 #[error("Invalid buffer length {0} for query size {0}")]
16 InvalidBufferLength(usize, usize),
17}
18
19#[derive(Debug)]
40pub struct MaxSim<'a> {
41 pub(crate) scores: &'a mut [f32],
42}
43
44impl<'a> MaxSim<'a> {
45 pub fn new(scores: &'a mut [f32]) -> Result<Self, MaxSimError> {
50 if scores.is_empty() {
51 return Err(MaxSimError::BufferLengthIsZero);
52 }
53 Ok(Self { scores })
54 }
55
56 #[inline]
58 pub fn size(&self) -> usize {
59 self.scores.len()
60 }
61
62 #[inline(always)]
64 pub fn get(&self, i: usize) -> Result<f32, MaxSimError> {
65 self.scores
66 .get(i)
67 .copied()
68 .ok_or_else(|| MaxSimError::IndexOutOfBounds(i, self.size()))
69 }
70
71 #[inline(always)]
73 pub fn set(&mut self, i: usize, x: f32) -> Result<(), MaxSimError> {
74 let size = self.size();
75
76 let s = self
77 .scores
78 .get_mut(i)
79 .ok_or(MaxSimError::IndexOutOfBounds(i, size))?;
80
81 *s = x;
82 Ok(())
83 }
84
85 pub fn scores_mut(&mut self) -> &mut [f32] {
90 self.scores
91 }
92}
93
94#[derive(Debug, Clone, Copy)]
109pub struct Chamfer;
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 struct TestFixture {
117 buffer: Vec<f32>,
118 }
119
120 impl TestFixture {
121 fn new(size: usize) -> Self {
122 Self {
123 buffer: vec![0.0; size],
124 }
125 }
126
127 fn with_values(values: &[f32]) -> Self {
128 Self {
129 buffer: values.to_vec(),
130 }
131 }
132
133 fn max_sim(&mut self) -> Result<MaxSim<'_>, MaxSimError> {
134 MaxSim::new(&mut self.buffer)
135 }
136 }
137
138 mod max_sim_new {
139 use super::*;
140
141 #[test]
142 fn fails_with_empty_buffer() {
143 let mut buffer: Vec<f32> = vec![];
144 let result = MaxSim::new(&mut buffer);
145 assert!(matches!(result, Err(MaxSimError::BufferLengthIsZero)));
146 }
147
148 #[test]
149 fn returns_correct_size() {
150 let sizes = [1, 2, 5, 100, 1000];
151 for size in sizes {
152 let mut fixture = TestFixture::new(size);
153 let mut max_sim = fixture.max_sim().unwrap();
154 assert_eq!(max_sim.size(), size, "size mismatch for buffer of {}", size);
155
156 let scores = max_sim.scores_mut();
157 assert_eq!(scores.len(), max_sim.size());
158 }
159 }
160 }
161
162 mod max_sim_get {
163 use super::*;
164
165 #[test]
166 fn returns_value_at_valid_index() {
167 let mut fixture = TestFixture::with_values(&[1.0, 2.0, 3.0]);
168 let max_sim = fixture.max_sim().unwrap();
169
170 assert_eq!(max_sim.get(0).unwrap(), 1.0);
171 assert_eq!(max_sim.get(1).unwrap(), 2.0);
172 assert_eq!(max_sim.get(2).unwrap(), 3.0);
173 }
174
175 #[test]
176 fn fails_at_out_of_bounds_index() {
177 let mut fixture = TestFixture::new(3);
178 let max_sim = fixture.max_sim().unwrap();
179
180 let result = max_sim.get(3);
181 assert!(matches!(result, Err(MaxSimError::IndexOutOfBounds(3, 3))));
182
183 let result = max_sim.get(100);
184 assert!(matches!(result, Err(MaxSimError::IndexOutOfBounds(100, 3))));
185 }
186 }
187
188 mod max_sim_set {
189 use super::*;
190
191 #[test]
192 fn sets_value_at_valid_index() {
193 let mut fixture = TestFixture::new(3);
194 let mut max_sim = fixture.max_sim().unwrap();
195
196 max_sim.set(0, 10.0).unwrap();
197 max_sim.set(1, 20.0).unwrap();
198 max_sim.set(2, 30.0).unwrap();
199
200 assert_eq!(max_sim.get(0).unwrap(), 10.0);
201 assert_eq!(max_sim.get(1).unwrap(), 20.0);
202 assert_eq!(max_sim.get(2).unwrap(), 30.0);
203 }
204
205 #[test]
206 fn fails_at_out_of_bounds_index() {
207 let mut fixture = TestFixture::new(3);
208 let mut max_sim = fixture.max_sim().unwrap();
209
210 let result = max_sim.set(3, 999.0);
211 assert!(matches!(result, Err(MaxSimError::IndexOutOfBounds(3, 3))));
212 }
213
214 #[test]
215 fn overwrites_existing_value() {
216 let mut fixture = TestFixture::with_values(&[1.0, 2.0, 3.0]);
217 let mut max_sim = fixture.max_sim().unwrap();
218
219 max_sim.set(1, 99.0).unwrap();
220
221 assert_eq!(max_sim.get(0).unwrap(), 1.0); assert_eq!(max_sim.get(1).unwrap(), 99.0); assert_eq!(max_sim.get(2).unwrap(), 3.0); }
225
226 #[test]
227 fn handles_special_float_values() {
228 let mut fixture = TestFixture::new(4);
229 let mut max_sim = fixture.max_sim().unwrap();
230
231 max_sim.set(0, f32::INFINITY).unwrap();
232 max_sim.set(1, f32::NEG_INFINITY).unwrap();
233 max_sim.set(2, f32::NAN).unwrap();
234 max_sim.set(3, -0.0).unwrap();
235
236 assert_eq!(max_sim.get(0).unwrap(), f32::INFINITY);
237 assert_eq!(max_sim.get(1).unwrap(), f32::NEG_INFINITY);
238 assert!(max_sim.get(2).unwrap().is_nan());
239 assert!(max_sim.get(3).unwrap().is_sign_negative());
240 }
241
242 #[test]
243 fn writes_through_to_underlying_buffer() {
244 let mut buffer = vec![0.0f32; 3];
245 {
246 let mut max_sim = MaxSim::new(&mut buffer).unwrap();
247 max_sim.set(0, 1.0).unwrap();
248 max_sim.set(1, 2.0).unwrap();
249 }
250 assert_eq!(buffer, vec![1.0, 2.0, 0.0]);
252 }
253 }
254}