diskann_quantization/scalar/vectors.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann_vector::{DistanceFunction, PureDistanceFunction};
7
8use super::inverse_bit_scale;
9use crate::{
10 bits::{BitSlice, Dense, Representation, Unsigned},
11 distances::{self, check_lengths, InnerProduct, SquaredL2, MV},
12 meta,
13};
14
15/// A per-vector precomputed coefficient to help compute inner products.
16///
17/// To understand the use of the compensation coefficient, assume that we wish to compute
18/// the inner product between two scalar compressed vectors where the quantization has
19/// scale parameter `a` and centroid `B` (note: capital letters represent vectors, lower
20/// case letters represent scalars).
21///
22/// The inner product between a `X = a * (X' + B)` and `Y = a * (Y' + B)` where
23/// `X'` and `Y'` are the scalar encodings for `X` and `Y` respectively is:
24/// ```math
25/// P = <a * X' + B, a * Y' + B>
26/// = a^2 * <X', Y'> + a * <X', B> + a * <Y', B> + <B, B>
27/// ------ ----------- ----------- ------
28/// | | | |
29/// Integer Dot | Compensation |
30/// Product | for Y |
31/// | Constant for
32/// Compensation all vectors
33/// for X
34///
35/// ```
36/// In other words, the inner product can be decomposed into an integer dot-product plus
37/// a bunch of other terms that compensate for the compression.
38///
39/// These compensation terms can be computed as the vectors are compressed. At run time,
40/// we can the return vectors consisting of the quantized encodings (e.g. `X'`) and the
41/// compensation `<X', B>`.
42///
43/// Computation of squared Euclidean distance is more straight forward:
44/// ```math
45/// P = sum( ((a * X' + B) - (a * Y' + B))^2 )
46/// = sum( a^2 * (X' - Y')^2 )
47/// = a^2 * sum( (X' - Y')^2 )
48/// ```
49/// This means the squared Euclidean distance is computed by scaling the squared Euclidean
50/// distance computed directly on the integer codes.
51///
52/// # Distance Implementations
53///
54/// The following distance function types are implemented:
55///
56/// * [`CompensatedSquaredL2`]: For computing squared euclidean distances.
57/// * [`CompensatedIP`]: For computing inner products.
58///
59/// # Examples
60///
61/// The `CompensatedVector` has several named variants that are commonly used:
62/// * [`CompensatedVector`]: An owning, indepndently allocated `CompensatedVector`.
63/// * [`MutCompensatedVectorRef`]: A mutable, reference-like type to a `CompensatedVector`.
64/// * [`CompensatedVectorRef`]: A const, reference-like type to a `CompensatedVector`.
65///
66/// ```
67/// use diskann_quantization::{
68/// scalar::{
69/// self,
70/// CompensatedVector,
71/// MutCompensatedVectorRef,
72/// CompensatedVectorRef
73/// },
74/// };
75///
76/// use diskann_utils::{Reborrow, ReborrowMut};
77///
78/// // Create a new heap-allocated CompensatedVector for 4-bit compressions capable of
79/// // holding 3 elements.
80/// let mut v = CompensatedVector::<4>::new_boxed(3);
81///
82/// // We can inspect the underlying bitslice.
83/// let bitslice = v.vector();
84/// assert_eq!(bitslice.get(0).unwrap(), 0);
85/// assert_eq!(bitslice.get(1).unwrap(), 0);
86/// assert_eq!(v.meta().0, 0.0, "expected default compensation value");
87///
88/// // If we want, we can mutably borrow the bitslice and mutate its components.
89/// let mut bitslice = v.vector_mut();
90/// bitslice.set(0, 1).unwrap();
91/// bitslice.set(1, 2).unwrap();
92/// bitslice.set(2, 3).unwrap();
93///
94/// assert!(bitslice.set(3, 4).is_err(), "out-of-bounds access");
95///
96/// // Get the underlying pointer for comparision.
97/// let ptr = bitslice.as_ptr();
98///
99/// // Vectors can be converted to a generalized reference.
100/// let mut v_ref = v.reborrow_mut();
101///
102/// // The generalized reference preserves the underlying pointer.
103/// assert_eq!(v_ref.vector().as_ptr(), ptr);
104/// let mut bitslice = v_ref.vector_mut();
105/// bitslice.set(0, 10).unwrap();
106///
107/// // Setting the underlying compensation will be visible in the original allocation.
108/// v_ref.set_meta(scalar::Compensation(1.0));
109///
110/// // Check that the changes are visible.
111/// assert_eq!(v.meta().0, 1.0);
112/// assert_eq!(v.vector().get(0).unwrap(), 10);
113///
114/// // Finally, the immutable ref also maintains pointer compatibility.
115/// let v_ref = v.reborrow();
116/// assert_eq!(v_ref.vector().as_ptr(), ptr);
117/// ```
118///
119/// ## Constructing a `MutCompensatedVectorRef` From Components
120///
121/// The following example shows how to assemble a `MutCompensatedVectorRef` from raw memory.
122/// ```
123/// use diskann_quantization::{
124/// bits::{Unsigned, MutBitSlice},
125/// scalar::{self, MutCompensatedVectorRef}
126/// };
127///
128/// // Start with 2 bytes of memory. We will impose a 4-bit scalar quantization on top of
129/// // these 4 bytes.
130/// let mut data = vec![0u8; 2];
131/// let mut compensation = scalar::Compensation(0.0);
132/// {
133/// // First, we need to construct a bit-slice over the data.
134/// // This will check that it is sized properly for 4, 4-bit values.
135/// let mut slice = MutBitSlice::<4, Unsigned>::new(data.as_mut_slice(), 4).unwrap();
136///
137/// // Next, we construct the `MutCompensatedVectorRef`.
138/// let mut v = MutCompensatedVectorRef::new(slice, &mut compensation);
139///
140/// // Through `v`, we can set all the components in `slice` and the compensation.
141/// v.set_meta(scalar::Compensation(1.0));
142/// let mut from_v = v.vector_mut();
143/// from_v.set(0, 1).unwrap();
144/// from_v.set(1, 2).unwrap();
145/// from_v.set(2, 3).unwrap();
146/// from_v.set(3, 4).unwrap();
147/// }
148///
149/// // Now we can check that the changes made internally are visible.
150/// assert_eq!(&data, &[0x21, 0x43]);
151/// assert_eq!(compensation.0, 1.0);
152/// ```
153#[derive(Default, Debug, Clone, Copy, bytemuck::Zeroable, bytemuck::Pod)]
154#[repr(transparent)]
155pub struct Compensation(pub f32);
156
157/// A borrowed `ComptensatedVector`.
158///
159/// See: [`meta::Vector`].
160pub type CompensatedVectorRef<'a, const NBITS: usize, Perm = Dense> =
161 meta::VectorRef<'a, NBITS, Unsigned, Compensation, Perm>;
162
163/// A mutably borrowed `ComptensatedVector`.
164///
165/// See: [`meta::Vector`].
166pub type MutCompensatedVectorRef<'a, const NBITS: usize, Perm = Dense> =
167 meta::VectorMut<'a, NBITS, Unsigned, Compensation, Perm>;
168
169/// An owning `CompensatedVector`.
170///
171/// See: [`meta::Vector`].
172pub type CompensatedVector<const NBITS: usize, Perm = Dense> =
173 meta::Vector<NBITS, Unsigned, Compensation, Perm>;
174
175////////////////////////////
176// Compensated Squared L2 //
177////////////////////////////
178
179/// A `DistanceFunction` containing scaling parameters to enable distance the SquaredL2
180/// distance function over `CompensatedVectors` belonging to the same quantization space.
181#[derive(Debug, Clone, Copy)]
182pub struct CompensatedSquaredL2 {
183 pub(super) scale_squared: f32,
184}
185
186impl CompensatedSquaredL2 {
187 /// Construct a new `CompensatedSquaredL2` with the given scaling factor.
188 pub fn new(scale_squared: f32) -> Self {
189 Self { scale_squared }
190 }
191}
192
193/// Compute the squared euclidean distance between the two compensated vectors.
194///
195/// The value returned by this function is scaled properly, meaning that distances returned
196/// by this method are compatible with full-precision distances.
197///
198/// # Validity
199///
200/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
201/// the same quantizer.
202///
203/// # Panics
204///
205/// Panics if `x.len() != y.len()`.
206impl<const NBITS: usize>
207 DistanceFunction<
208 CompensatedVectorRef<'_, NBITS>,
209 CompensatedVectorRef<'_, NBITS>,
210 distances::MathematicalResult<f32>,
211 > for CompensatedSquaredL2
212where
213 Unsigned: Representation<NBITS>,
214 SquaredL2: for<'a, 'b> PureDistanceFunction<
215 BitSlice<'a, NBITS, Unsigned>,
216 BitSlice<'b, NBITS, Unsigned>,
217 distances::MathematicalResult<u32>,
218 >,
219{
220 fn evaluate_similarity(
221 &self,
222 x: CompensatedVectorRef<'_, NBITS>,
223 y: CompensatedVectorRef<'_, NBITS>,
224 ) -> distances::MathematicalResult<f32> {
225 check_lengths!(x, y)?;
226 let squared_l2: distances::MathematicalResult<u32> =
227 SquaredL2::evaluate(x.vector(), y.vector());
228 let squared_l2 = squared_l2?.into_inner() as f32;
229
230 // This should constant-propagate.
231 let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
232
233 let result = bit_scale * self.scale_squared * squared_l2;
234 Ok(MV::new(result))
235 }
236}
237
238/// Compute the squared euclidean distance between the two compensated vectors.
239///
240/// The value returned by this function is scaled properly, meaning that distances returned
241/// by this method are compatible with full-precision distances.
242///
243/// # Validity
244///
245/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
246/// the same quantizer.
247///
248/// # Panics
249///
250/// Panics if `x.len() != y.len()`.
251impl<const NBITS: usize>
252 DistanceFunction<
253 CompensatedVectorRef<'_, NBITS>,
254 CompensatedVectorRef<'_, NBITS>,
255 distances::Result<f32>,
256 > for CompensatedSquaredL2
257where
258 Unsigned: Representation<NBITS>,
259 Self: for<'a, 'b> DistanceFunction<
260 CompensatedVectorRef<'a, NBITS>,
261 CompensatedVectorRef<'b, NBITS>,
262 distances::MathematicalResult<f32>,
263 >,
264{
265 fn evaluate_similarity(
266 &self,
267 x: CompensatedVectorRef<'_, NBITS>,
268 y: CompensatedVectorRef<'_, NBITS>,
269 ) -> distances::Result<f32> {
270 let v: MV<f32> = self.evaluate_similarity(x, y)?;
271 Ok(v.into_inner())
272 }
273}
274
275////////////////////
276// Compensated IP //
277////////////////////
278
279/// A `DistanceFunction` containing scaling parameters to enable distance the SquaredL2
280/// distance function over `CompensatedVectors` belonging to the same quantization space.
281#[derive(Debug, Clone, Copy)]
282pub struct CompensatedIP {
283 pub(super) scale_squared: f32,
284 pub(super) shift_square_norm: f32,
285}
286
287impl CompensatedIP {
288 /// Construct a new `CompensatedIP` with the given scaling factor and shift norm.
289 pub fn new(scale_squared: f32, shift_square_norm: f32) -> Self {
290 Self {
291 scale_squared,
292 shift_square_norm,
293 }
294 }
295}
296
297/// Compute the inner product between the two compensated vectors.
298///
299/// The value returned by this function is scaled properly, meaning that distances returned
300/// by this method are compatible with full-precision computations.
301///
302/// # Validity
303///
304/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
305/// the same quantizer.
306///
307/// # Panics
308///
309/// Panics if `x.len() != y.len()`.
310impl<const NBITS: usize>
311 DistanceFunction<
312 CompensatedVectorRef<'_, NBITS>,
313 CompensatedVectorRef<'_, NBITS>,
314 distances::MathematicalResult<f32>,
315 > for CompensatedIP
316where
317 Unsigned: Representation<NBITS>,
318 InnerProduct: for<'a, 'b> PureDistanceFunction<
319 BitSlice<'a, NBITS, Unsigned>,
320 BitSlice<'b, NBITS, Unsigned>,
321 distances::MathematicalResult<u32>,
322 >,
323{
324 fn evaluate_similarity(
325 &self,
326 x: CompensatedVectorRef<'_, NBITS>,
327 y: CompensatedVectorRef<'_, NBITS>,
328 ) -> distances::MathematicalResult<f32> {
329 let product: MV<u32> = InnerProduct::evaluate(x.vector(), y.vector())?;
330
331 // This should constant-propagate.
332 let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
333
334 let result = (bit_scale * self.scale_squared)
335 .mul_add(product.into_inner() as f32, self.shift_square_norm)
336 + (y.meta().0 + x.meta().0);
337 Ok(MV::new(result))
338 }
339}
340
341/// Compute the inner product between the two compensated vectors.
342///
343/// The value returned by this function is scaled properly, meaning that distances returned
344/// by this method are compatible with full-precision computations.
345///
346/// # Validity
347///
348/// The results of this function are only meaningful if both `x`, `y`, and `Self` belong to
349/// the same quantizer.
350///
351/// # Panics
352///
353/// Panics if `x.len() != y.len()`.
354impl<const NBITS: usize>
355 DistanceFunction<
356 CompensatedVectorRef<'_, NBITS>,
357 CompensatedVectorRef<'_, NBITS>,
358 distances::Result<f32>,
359 > for CompensatedIP
360where
361 Unsigned: Representation<NBITS>,
362 Self: for<'a, 'b> DistanceFunction<
363 CompensatedVectorRef<'a, NBITS>,
364 CompensatedVectorRef<'b, NBITS>,
365 distances::MathematicalResult<f32>,
366 >,
367{
368 fn evaluate_similarity(
369 &self,
370 x: CompensatedVectorRef<'_, NBITS>,
371 y: CompensatedVectorRef<'_, NBITS>,
372 ) -> distances::Result<f32> {
373 let v: MV<f32> = self.evaluate_similarity(x, y)?;
374 Ok(-v.into_inner())
375 }
376}
377
378/// Compensated CosineNormalized distance function.
379#[derive(Debug, Clone, Copy)]
380pub struct CompensatedCosineNormalized {
381 pub(super) scale_squared: f32,
382}
383
384impl CompensatedCosineNormalized {
385 pub fn new(scale_squared: f32) -> Self {
386 Self { scale_squared }
387 }
388}
389
390/// CosineNormalized
391///
392/// This implementation calculates the <x, y> = 1 - L2 / 2 value, which will be further used
393/// to compute the CosineNormalised distance function
394///
395/// # Notes
396///
397/// s = 1 - cosine(X, Y) = 1- <X, Y> / (||X|| * ||Y||)
398///
399/// We can make simply assumption that ||X|| = 1 and ||Y|| = 1.
400/// Then s = 1 - <X, Y>
401///
402/// The squared L2 distance can be computed as follows:
403/// p = ||x||^2 + ||y||^2 - 2<x, y>
404/// When vectors are normalized, this becomes
405/// p = 2 - 2<x, y> = 2 * (1 - <x, y>)
406///
407/// In other words, the similarity score for the squared L2 distance in an ideal world is
408/// 2 times that for cosine similarity. Therefore, squared L2 may serves as a stand-in for
409/// cosine normalized as ordering is preserved.
410impl<const NBITS: usize>
411 DistanceFunction<
412 CompensatedVectorRef<'_, NBITS>,
413 CompensatedVectorRef<'_, NBITS>,
414 distances::MathematicalResult<f32>,
415 > for CompensatedCosineNormalized
416where
417 Unsigned: Representation<NBITS>,
418 SquaredL2: for<'a, 'b> PureDistanceFunction<
419 BitSlice<'a, NBITS, Unsigned>,
420 BitSlice<'b, NBITS, Unsigned>,
421 distances::MathematicalResult<u32>,
422 >,
423{
424 fn evaluate_similarity(
425 &self,
426 x: CompensatedVectorRef<'_, NBITS>,
427 y: CompensatedVectorRef<'_, NBITS>,
428 ) -> distances::MathematicalResult<f32> {
429 let squared_l2: MV<u32> = SquaredL2::evaluate(x.vector(), y.vector())?;
430
431 // This should constant-propagate.
432 let bit_scale = inverse_bit_scale::<NBITS>() * inverse_bit_scale::<NBITS>();
433
434 let l2 = bit_scale * self.scale_squared * squared_l2.into_inner() as f32;
435
436 let result = 1.0 - l2 / 2.0;
437 Ok(MV::new(result))
438 }
439}
440
441impl<const NBITS: usize>
442 DistanceFunction<
443 CompensatedVectorRef<'_, NBITS>,
444 CompensatedVectorRef<'_, NBITS>,
445 distances::Result<f32>,
446 > for CompensatedCosineNormalized
447where
448 Unsigned: Representation<NBITS>,
449 Self: for<'a, 'b> DistanceFunction<
450 CompensatedVectorRef<'a, NBITS>,
451 CompensatedVectorRef<'b, NBITS>,
452 distances::MathematicalResult<f32>,
453 >,
454{
455 fn evaluate_similarity(
456 &self,
457 x: CompensatedVectorRef<'_, NBITS>,
458 y: CompensatedVectorRef<'_, NBITS>,
459 ) -> distances::Result<f32> {
460 let v: MV<f32> = self.evaluate_similarity(x, y)?;
461 Ok(1.0 - v.into_inner())
462 }
463}
464
465///////////
466// Tests //
467///////////
468
469#[cfg(test)]
470mod tests {
471 use diskann_utils::{Reborrow, ReborrowMut};
472 use rand::{
473 distr::{Distribution, Uniform},
474 rngs::StdRng,
475 Rng, SeedableRng,
476 };
477
478 use super::*;
479 use crate::{
480 bits::{Representation, Unsigned},
481 scalar::bit_scale,
482 test_util,
483 };
484
485 ///////////////
486 // Distances //
487 ///////////////
488
489 /// This test works as follows:
490 ///
491 /// First, generate a random value for `a`, `X'` and `B` where:
492 ///
493 /// * `a`: Is the scaling parameters.
494 /// * `X'`: Is the integer compressed codes for a vector.
495 /// * `B`: The floating point vector representing the dataset center.
496 ///
497 /// Next, compute the reconstructed vector using `X = a * X' + B`.
498 /// Repeat this process for another vector `Y` using the same `a` and `B`.
499 ///
500 /// Then, the result of a distance computation can be done on the compressed
501 /// representation and on the reconstructed representation. The results should match
502 /// (modulo floating-point rounding).
503 ///
504 /// To get a handle on floating point issues, we pick "nice" numbers for the values of
505 /// `a` and each component of `B` that are either small integers, or nice binary fractions
506 /// like 1/2 or 3/4.
507 ///
508 /// Even with nice numbers, there is still a small amount of rounding instability.
509 fn test_compensated_distance<const NBITS: usize, R>(
510 dim: usize,
511 ntrials: usize,
512 max_relative_err_l2: f32,
513 max_relative_err_ip: f32,
514 max_relative_err_cos: f32,
515 max_absolute_error: f32,
516 rng: &mut R,
517 ) where
518 Unsigned: Representation<NBITS>,
519 R: Rng,
520 CompensatedSquaredL2: for<'a, 'b> DistanceFunction<
521 CompensatedVectorRef<'a, NBITS>,
522 CompensatedVectorRef<'b, NBITS>,
523 distances::MathematicalResult<f32>,
524 >,
525 CompensatedSquaredL2: for<'a, 'b> DistanceFunction<
526 CompensatedVectorRef<'a, NBITS>,
527 CompensatedVectorRef<'b, NBITS>,
528 distances::Result<f32>,
529 >,
530 CompensatedIP: for<'a, 'b> DistanceFunction<
531 CompensatedVectorRef<'a, NBITS>,
532 CompensatedVectorRef<'b, NBITS>,
533 distances::MathematicalResult<f32>,
534 >,
535 CompensatedIP: for<'a, 'b> DistanceFunction<
536 CompensatedVectorRef<'a, NBITS>,
537 CompensatedVectorRef<'b, NBITS>,
538 distances::Result<f32>,
539 >,
540 CompensatedCosineNormalized: for<'a, 'b> DistanceFunction<
541 CompensatedVectorRef<'a, NBITS>,
542 CompensatedVectorRef<'b, NBITS>,
543 distances::MathematicalResult<f32>,
544 >,
545 CompensatedCosineNormalized: for<'a, 'b> DistanceFunction<
546 CompensatedVectorRef<'a, NBITS>,
547 CompensatedVectorRef<'b, NBITS>,
548 distances::Result<f32>,
549 >,
550 {
551 // The distributions we use for `a` and `B` are taken from integer distributions,
552 // which we then convert to `f32` and divide by a power of 2.
553 //
554 // This helps keep computations exact so we don't also have to worry about tracking
555 // floating rounding issues.
556 //
557 // Here, `alpha` refers to `a` in the function docstring and `beta` refers to `B`.
558 let alpha_distribution = Uniform::new_inclusive(-16, 16).unwrap();
559 let beta_distribution = Uniform::new_inclusive(-32, 32).unwrap();
560
561 // What we divide the results generated by the alpha and beta distributions.
562 let alpha_divisor: f32 = 64.0;
563 let beta_divisor: f32 = 128.0;
564
565 let domain = Unsigned::domain_const::<NBITS>();
566 let code_distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
567
568 // Preallocate buffers.
569 let mut beta: Vec<f32> = vec![0.0; dim];
570 let mut x_prime: Vec<u8> = vec![0; dim];
571 let mut y_prime: Vec<u8> = vec![0; dim];
572 let mut x_reconstructed: Vec<f32> = vec![0.0; dim];
573 let mut y_reconstructed: Vec<f32> = vec![0.0; dim];
574
575 let mut x_compensated = CompensatedVector::<NBITS>::new_boxed(dim);
576 let mut y_compensated = CompensatedVector::<NBITS>::new_boxed(dim);
577
578 // Populate a compensated vector from the codes and `beta`.
579 let populate_compensation = |mut dst: MutCompensatedVectorRef<'_, NBITS>,
580 codes: &[u8],
581 alpha: f32,
582 beta: &[f32]| {
583 assert_eq!(dst.len(), codes.len());
584 assert_eq!(dst.len(), beta.len());
585
586 let mut compensation: f32 = 0.0;
587 let mut vector = dst.vector_mut();
588 for (i, (&c, &b)) in std::iter::zip(codes.iter(), beta.iter()).enumerate() {
589 vector.set(i, c.into()).unwrap();
590
591 let c: f32 = c.into();
592 compensation += c * b;
593 }
594 dst.set_meta(Compensation(alpha * compensation / bit_scale::<NBITS>()));
595 };
596
597 for trial in 0..ntrials {
598 // Generate the problem.
599 let alpha = (alpha_distribution.sample(rng) as f32) / alpha_divisor;
600 beta.iter_mut().for_each(|b| {
601 *b = (beta_distribution.sample(rng) as f32) / beta_divisor;
602 });
603 x_prime
604 .iter_mut()
605 .for_each(|x| *x = code_distribution.sample(rng).try_into().unwrap());
606 y_prime
607 .iter_mut()
608 .for_each(|y| *y = code_distribution.sample(rng).try_into().unwrap());
609
610 // Generate the reconstructed vectors.
611 let bit_scale = inverse_bit_scale::<NBITS>();
612 x_reconstructed
613 .iter_mut()
614 .zip(x_prime.iter())
615 .zip(beta.iter())
616 .for_each(|((x, xp), b)| {
617 *x = (alpha * *xp as f32) * bit_scale + *b;
618 });
619
620 y_reconstructed
621 .iter_mut()
622 .zip(y_prime.iter())
623 .zip(beta.iter())
624 .for_each(|((y, yp), b)| {
625 *y = (alpha * *yp as f32) * bit_scale + *b;
626 });
627
628 populate_compensation(x_compensated.reborrow_mut(), &x_prime, alpha, &beta);
629 populate_compensation(y_compensated.reborrow_mut(), &y_prime, alpha, &beta);
630
631 // Squared L2
632 let expected: MV<f32> =
633 diskann_vector::distance::SquaredL2::evaluate(&*x_reconstructed, &*y_reconstructed);
634
635 let distance = CompensatedSquaredL2::new(alpha * alpha);
636 let got: distances::MathematicalResult<f32> =
637 distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
638 let got = got.unwrap();
639
640 let relative_err =
641 test_util::compute_relative_error(got.into_inner(), expected.into_inner());
642 let absolute_err =
643 test_util::compute_absolute_error(got.into_inner(), expected.into_inner());
644
645 assert!(
646 relative_err <= max_relative_err_l2 || absolute_err <= max_absolute_error,
647 "failed SquaredL2 for NBITS = {}, dim = {}, trial = {}. \
648 Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
649 Expected {}, got {}",
650 NBITS,
651 dim,
652 trial,
653 relative_err,
654 absolute_err,
655 max_relative_err_l2,
656 max_absolute_error,
657 expected.into_inner(),
658 got.into_inner(),
659 );
660
661 // f32 should match Mathematicalvalue.
662 let got_f32: distances::Result<f32> =
663 distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
664 let got_f32 = got_f32.unwrap();
665 assert_eq!(got.into_inner(), got_f32);
666
667 // Inner Product
668 let expected: MV<f32> = diskann_vector::distance::InnerProduct::evaluate(
669 &*x_reconstructed,
670 &*y_reconstructed,
671 );
672
673 let distance =
674 CompensatedIP::new(alpha * alpha, beta.iter().map(|&i| i * i).sum::<f32>());
675 let got: distances::MathematicalResult<f32> =
676 distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
677 let got = got.unwrap();
678
679 let relative_err =
680 test_util::compute_relative_error(got.into_inner(), expected.into_inner());
681 let absolute_err =
682 test_util::compute_absolute_error(got.into_inner(), expected.into_inner());
683
684 assert!(
685 relative_err <= max_relative_err_ip || absolute_err < max_absolute_error,
686 "failed InnerProduct for NBITS = {}, dim = {}, trial = {}. \
687 Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
688 Expected {}, got {}",
689 NBITS,
690 dim,
691 trial,
692 relative_err,
693 absolute_err,
694 max_relative_err_ip,
695 max_absolute_error,
696 expected.into_inner(),
697 got.into_inner(),
698 );
699
700 // f32 should be the negative Mathematicalvalue.
701 let got_f32: distances::Result<f32> =
702 distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
703 let got_f32 = got_f32.unwrap();
704
705 assert_eq!(-got.into_inner(), got_f32);
706
707 // CosineNormalized:
708 // expected value is cosine similarity of reconstructed vectors (no scale/shift)
709 let expected: MV<f32> =
710 diskann_vector::distance::SquaredL2::evaluate(&*x_reconstructed, &*y_reconstructed);
711 let expected = 1.0 - expected.into_inner() / 2.0;
712
713 let distance = CompensatedCosineNormalized::new(alpha * alpha);
714 let got: distances::MathematicalResult<f32> =
715 distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
716 let got = got.unwrap();
717
718 if expected != 0.0 {
719 let relative_err = test_util::compute_relative_error(got.into_inner(), expected);
720 let absolute_err = test_util::compute_absolute_error(got.into_inner(), expected);
721 assert!(
722 relative_err < max_relative_err_cos || absolute_err < max_absolute_error,
723 "failed CosineNormalized for NBITS = {}, dim = {}, trial = {}. \
724 Got an error {} (rel) / {} (abs) with tolerance {}/{}. \
725 Expected {}, got {}",
726 NBITS,
727 dim,
728 trial,
729 relative_err,
730 absolute_err,
731 max_relative_err_cos,
732 max_absolute_error,
733 expected,
734 got.into_inner(),
735 );
736 } else {
737 let absolute_err = test_util::compute_absolute_error(got.into_inner(), expected);
738 assert!(
739 absolute_err < max_absolute_error,
740 "failed CosineNormalized for NBITS = {}, dim = {}, trial = {}. \
741 Got an absolute error {} with tolerance {}. \
742 Expected {}, got {}",
743 NBITS,
744 dim,
745 trial,
746 absolute_err,
747 max_absolute_error,
748 expected,
749 got.into_inner(),
750 );
751 }
752
753 let got_f32: distances::Result<f32> =
754 distance.evaluate_similarity(x_compensated.reborrow(), y_compensated.reborrow());
755 let got_f32 = got_f32.unwrap();
756 assert_eq!(1.0 - got.into_inner(), got_f32);
757 }
758 }
759
760 cfg_if::cfg_if! {
761 if #[cfg(miri)] {
762 // The max dim does not need to be as high for `CompensatedVectors` because they
763 // defer their distance function implementation to `BitSlice`, which is more
764 // heavily tested.
765 const MAX_DIM: usize = 37;
766 const TRIALS_PER_DIM: usize = 1;
767 } else {
768 const MAX_DIM: usize = 256;
769 const TRIALS_PER_DIM: usize = 20;
770 }
771 }
772
773 macro_rules! test_unsigned_compensated {
774 (
775 $name:ident,
776 $nbits:literal,
777 $relative_err_l2:literal,
778 $relative_err_ip:literal,
779 $relative_err_cos:literal,
780 $seed:literal
781 ) => {
782 #[test]
783 fn $name() {
784 let mut rng = StdRng::seed_from_u64($seed);
785 let absolute_error: f32 = 2.0e-7;
786 for dim in 0..MAX_DIM {
787 test_compensated_distance::<$nbits, _>(
788 dim,
789 TRIALS_PER_DIM,
790 $relative_err_l2,
791 $relative_err_ip,
792 $relative_err_cos,
793 absolute_error,
794 &mut rng,
795 );
796 }
797 }
798 };
799 }
800
801 test_unsigned_compensated!(
802 unsigned_compensated_distances_8bit,
803 8,
804 4.0e-4,
805 3.0e-6,
806 1.0e-3,
807 0xa32d5658097a1c35
808 );
809 test_unsigned_compensated!(
810 unsigned_compensated_distances_7bit,
811 7,
812 5.0e-6,
813 3.0e-6,
814 1.0e-3,
815 0x0b65ca44ec7b47d8
816 );
817 test_unsigned_compensated!(
818 unsigned_compensated_distances_6bit,
819 6,
820 5.0e-6,
821 3.0e-6,
822 1.0e-3,
823 0x471b640fba5c520b
824 );
825 test_unsigned_compensated!(
826 unsigned_compensated_distances_5bit,
827 5,
828 5.0e-6,
829 3.0e-6,
830 1.0e-3,
831 0xf60c0c8d1aadc126
832 );
833 test_unsigned_compensated!(
834 unsigned_compensated_distances_4bit,
835 4,
836 3.0e-6,
837 3.0e-6,
838 1.0e-3,
839 0xcc2b897373a143f3
840 );
841 test_unsigned_compensated!(
842 unsigned_compensated_distances_3bit,
843 3,
844 3.0e-6,
845 3.0e-6,
846 1.0e-3,
847 0xaedf3d2a223b7b77
848 );
849 test_unsigned_compensated!(
850 unsigned_compensated_distances_2bit,
851 2,
852 3.0e-6,
853 3.0e-6,
854 1.0e-3,
855 0x2b34015910b34083
856 );
857 test_unsigned_compensated!(
858 unsigned_compensated_distances_1bit,
859 1,
860 0.0,
861 0.0,
862 0.0,
863 0x09fa14c42a9d7d98
864 );
865}