Skip to main content

diskann_providers/model/pq/distance/
multi.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::ops::Deref;
7
8use diskann::ANNResult;
9use diskann_utils::Reborrow;
10use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
11use thiserror::Error;
12
13use super::{
14    QueryComputer,
15    dynamic::{DistanceComputerConstructionError, VTable},
16};
17use crate::model::FixedChunkPQTable;
18
19pub trait PQVersion: Eq + Copy {}
20impl<T> PQVersion for T where T: Eq + Copy {}
21
22/// A PQ vector with an associated version.
23#[derive(Debug, Clone, PartialEq)]
24pub struct VersionedPQVector<I: PQVersion> {
25    data: Vec<u8>,
26    version: I,
27}
28
29impl<I> VersionedPQVector<I>
30where
31    I: PQVersion,
32{
33    /// Construct a new `VersionedPQVector` taking ownership of the provided data and version.
34    pub fn new(data: Vec<u8>, version: I) -> Self {
35        Self { data, version }
36    }
37
38    /// Return a `VersionedPQVectorRef` over the data owned by this vector.
39    pub fn as_ref(&self) -> VersionedPQVectorRef<'_, I> {
40        VersionedPQVectorRef::new(&self.data, self.version)
41    }
42
43    /// Return the version associated with this vector.
44    pub fn version(&self) -> &I {
45        &self.version
46    }
47
48    /// Return the raw underlying data.
49    pub fn data(&self) -> &[u8] {
50        &self.data
51    }
52
53    /// Return the components of the vector. This is a low-level API.
54    pub fn raw_mut(&mut self) -> (&mut Vec<u8>, &mut I) {
55        (&mut self.data, &mut self.version)
56    }
57}
58
59impl<'a, I> Reborrow<'a> for VersionedPQVector<I>
60where
61    I: PQVersion,
62{
63    type Target = VersionedPQVectorRef<'a, I>;
64    fn reborrow(&'a self) -> Self::Target {
65        self.as_ref()
66    }
67}
68
69/// A reference version of `VersionedPQVector`.
70#[derive(Debug, Clone, Copy)]
71pub struct VersionedPQVectorRef<'a, I: PQVersion> {
72    data: &'a [u8],
73    version: I,
74}
75
76impl<'a, I: PQVersion> VersionedPQVectorRef<'a, I> {
77    /// Construct a new `VersionedPQVectorRef` around the provided data.
78    pub fn new(data: &'a [u8], version: I) -> Self {
79        Self { data, version }
80    }
81
82    /// Return the version associated with this vector.
83    pub fn version(&self) -> &I {
84        &self.version
85    }
86
87    /// Return the raw underlying data.
88    pub fn data(&self) -> &[u8] {
89        self.data
90    }
91}
92
93/// A wrapper for `FixedChunkPQTable` that contains either one or two inner
94/// `FixedChunkPQTables` with associated versions.
95#[derive(Debug, Clone)]
96pub enum MultiTable<T, I>
97where
98    T: Deref<Target = FixedChunkPQTable>,
99    I: PQVersion,
100{
101    /// Only one table is present with an associated version.
102    One { table: T, version: I },
103    /// Two tables are present, an incoming "new" table and an outgoing "old" table.
104    /// The versions of these tables are recorded respectively in `new_version` and
105    /// `old_version`.
106    Two {
107        new: T,
108        old: T,
109        new_version: I,
110        old_version: I,
111    },
112}
113
114#[derive(Debug, Error)]
115#[error("provided versions must not be equal")]
116pub struct EqualVersionsError;
117
118impl<T, I> MultiTable<T, I>
119where
120    T: Deref<Target = FixedChunkPQTable>,
121    I: PQVersion,
122{
123    /// Construct a new `MultiTable` containing a single `FixedChunkPQTable`.
124    pub fn one(table: T, version: I) -> Self {
125        Self::One { table, version }
126    }
127
128    /// Construct a new `MultiTable` with two `FixedChunkPQTable`s.
129    ///
130    /// Returns an `Err` if the two provided versions are equal.
131    pub fn two(new: T, old: T, new_version: I, old_version: I) -> Result<Self, EqualVersionsError> {
132        if new_version == old_version {
133            Err(EqualVersionsError)
134        } else {
135            Ok(Self::Two {
136                new,
137                old,
138                new_version,
139                old_version,
140            })
141        }
142    }
143
144    /// Return the versions associated with the tables in this schema.
145    ///
146    /// The returned tuple depends on whether this table has one or two registerd schemas.
147    ///
148    /// * If there is only one schema, return the `(version, None)` where `version` is the
149    ///   version of the only schema.
150    /// * If there are two schema, return `(new_version, Some(old_version))` where
151    ///   `new_version` is the version of the most recently registered schema while
152    ///   `old_version` is the old version.
153    pub fn versions(&self) -> (&I, Option<&I>) {
154        match &self {
155            Self::One { version, .. } => (version, None),
156            Self::Two {
157                new_version,
158                old_version,
159                ..
160            } => (new_version, Some(old_version)),
161        }
162    }
163}
164
165/// A distance computer implementing
166///
167/// * `DistanceFunction<&[f32], &VersionedPQVector, Option<f32>`
168/// * `DistanceFunction<&VersionedPQVector, &VersionedPQVector, Option<f32>`
169///
170/// That can contain either one or two PQ schemas, disambiguating which schema to use based
171/// on the version numbers contained in the `VersionedPQVectors`.
172///
173/// Since this struct stores at most two PQ tables, that means there is the possibility
174/// that a PQ vector is provided that does not match either of the tables.
175///
176/// Returns `None` for distance computations when the version of the PQ vector does not
177/// match with any version in the local table.
178#[derive(Debug, Clone)]
179pub struct MultiDistanceComputer<T, I>
180where
181    T: Deref<Target = FixedChunkPQTable>,
182    I: PQVersion,
183{
184    table: MultiTable<T, I>,
185    vtable: VTable,
186}
187
188impl<T, I> MultiDistanceComputer<T, I>
189where
190    T: Deref<Target = FixedChunkPQTable>,
191    I: PQVersion,
192{
193    /// Construct a `MultiDistanceComputer` from the provided table implementing the
194    /// requested metric.
195    pub fn new(
196        table: MultiTable<T, I>,
197        metric: Metric,
198    ) -> Result<Self, DistanceComputerConstructionError> {
199        // Check if OPQ is used. If so, we cannot correctly perform distance computations.
200        match &table {
201            MultiTable::One { table, .. } => {
202                if table.has_opq() {
203                    return Err(DistanceComputerConstructionError::OPQNotSupported);
204                }
205            }
206            MultiTable::Two { new, old, .. } => {
207                if new.has_opq() || old.has_opq() {
208                    return Err(DistanceComputerConstructionError::OPQNotSupported);
209                }
210            }
211        };
212        Ok(Self {
213            table,
214            vtable: VTable::new(metric),
215        })
216    }
217
218    /// Return the versions associated with the tables in this schema.
219    ///
220    /// The returned tuple depends on whether this table has one or two registerd schemas.
221    ///
222    /// * If there is only one schema, return the `(version, None)` where `version` is the
223    ///   version of the only schema.
224    /// * If there are two schema, return `(new_version, Some(old_version))` where
225    ///   `new_version` is the version of the most recently registered schema while
226    ///   `old_version` is the old version.
227    pub fn versions(&self) -> (&I, Option<&I>) {
228        self.table.versions()
229    }
230}
231
232impl<T, I> DistanceFunction<&[f32], &VersionedPQVector<I>, Option<f32>>
233    for MultiDistanceComputer<T, I>
234where
235    T: Deref<Target = FixedChunkPQTable>,
236    I: PQVersion,
237{
238    #[inline(always)]
239    fn evaluate_similarity(&self, x: &[f32], y: &VersionedPQVector<I>) -> Option<f32> {
240        self.evaluate_similarity(x, y.reborrow())
241    }
242}
243
244impl<T, I> DistanceFunction<&[f32], VersionedPQVectorRef<'_, I>, Option<f32>>
245    for MultiDistanceComputer<T, I>
246where
247    T: Deref<Target = FixedChunkPQTable>,
248    I: PQVersion,
249{
250    fn evaluate_similarity(&self, x: &[f32], y: VersionedPQVectorRef<'_, I>) -> Option<f32> {
251        match &self.table {
252            MultiTable::One { table, version } => {
253                if version != &y.version {
254                    None
255                } else {
256                    Some((self.vtable.distance_fn)(table, x, y.data))
257                }
258            }
259            MultiTable::Two {
260                old,
261                new,
262                old_version,
263                new_version,
264            } => {
265                if old_version == &y.version {
266                    Some((self.vtable.distance_fn)(old, x, y.data))
267                } else if new_version == &y.version {
268                    Some((self.vtable.distance_fn)(new, x, y.data))
269                } else {
270                    None
271                }
272            }
273        }
274    }
275}
276
277impl<T, I> DistanceFunction<&VersionedPQVector<I>, &VersionedPQVector<I>, Option<f32>>
278    for MultiDistanceComputer<T, I>
279where
280    T: Deref<Target = FixedChunkPQTable>,
281    I: PQVersion,
282{
283    #[inline(always)]
284    fn evaluate_similarity(
285        &self,
286        x: &VersionedPQVector<I>,
287        y: &VersionedPQVector<I>,
288    ) -> Option<f32> {
289        self.evaluate_similarity(x.reborrow(), y.reborrow())
290    }
291}
292
293/// Compute the distance between two versioned quantized vectors.
294///
295/// If one schema is currently being used and at least one of the versions of the argument
296/// vectors does not match, return `None`.
297///
298/// If two schemas are used and at least one of the versions of the argument vectors is not
299/// recognized, then return `None`.
300impl<T, I> DistanceFunction<VersionedPQVectorRef<'_, I>, VersionedPQVectorRef<'_, I>, Option<f32>>
301    for MultiDistanceComputer<T, I>
302where
303    T: Deref<Target = FixedChunkPQTable>,
304    I: PQVersion,
305{
306    fn evaluate_similarity(
307        &self,
308        x: VersionedPQVectorRef<'_, I>,
309        y: VersionedPQVectorRef<'_, I>,
310    ) -> Option<f32> {
311        match &self.table {
312            MultiTable::One { table, version } => {
313                if (&x.version != version) || (&y.version != version) {
314                    None
315                } else {
316                    Some((self.vtable.distance_fn_qq)(table, x.data, y.data))
317                }
318            }
319            MultiTable::Two {
320                new,
321                old,
322                new_version,
323                old_version,
324            } => {
325                let x_new = &x.version == new_version;
326                let x_old = &x.version == old_version;
327
328                let y_new = &y.version == new_version;
329                let y_old = &y.version == old_version;
330
331                if x_old {
332                    if y_old {
333                        // Both Old
334                        Some((self.vtable.distance_fn_qq)(old, x.data, y.data))
335                    } else if y_new {
336                        let x_full = old.inflate_vector(x.data);
337                        // X Old, Y New
338                        Some((self.vtable.distance_fn)(new, &x_full, y.data))
339                    } else {
340                        None
341                    }
342                } else if x_new {
343                    if y_old {
344                        let y_full = old.inflate_vector(y.data);
345                        // X New, Y Old
346                        Some((self.vtable.distance_fn)(new, &y_full, x.data))
347                    } else if y_new {
348                        // Both new
349                        Some((self.vtable.distance_fn_qq)(new, x.data, y.data))
350                    } else {
351                        None
352                    }
353                } else {
354                    None
355                }
356            }
357        }
358    }
359}
360
361////////////////////
362// Query Computer //
363////////////////////
364
365/// A `PreprocessedDistanceFunction` containing either one or two PQ schemas, capable of
366/// performing distance computations with either. Upon a version mismatch with a query,
367/// `None` is returned.
368#[derive(Debug)]
369pub enum MultiQueryComputer<T, I>
370where
371    T: Deref<Target = FixedChunkPQTable>,
372    I: PQVersion,
373{
374    One {
375        computer: QueryComputer<T>,
376        version: I,
377    },
378    Two {
379        new: QueryComputer<T>,
380        old: QueryComputer<T>,
381        new_version: I,
382        old_version: I,
383    },
384}
385
386impl<T, I> MultiQueryComputer<T, I>
387where
388    T: Deref<Target = FixedChunkPQTable>,
389    I: PQVersion,
390{
391    /// Construct a new `MultiQueryComputer` with the requested metric and query.
392    pub fn new<U>(table: MultiTable<T, I>, metric: Metric, query: &[U]) -> ANNResult<Self>
393    where
394        U: Into<f32> + Copy,
395    {
396        let s = match table {
397            MultiTable::One { table, version } => Self::One {
398                computer: { QueryComputer::new(table, metric, query, None)? },
399                version,
400            },
401            MultiTable::Two {
402                new,
403                old,
404                new_version,
405                old_version,
406            } => Self::Two {
407                new: { QueryComputer::new(new, metric, query, None)? },
408                old: { QueryComputer::new(old, metric, query, None)? },
409                new_version,
410                old_version,
411            },
412        };
413        Ok(s)
414    }
415
416    /// Return a tuple that is either:
417    /// 1. (The only table version, None)
418    /// 2. (New Table Version, Old Table Version)
419    pub fn versions(&self) -> (&I, Option<&I>) {
420        match &self {
421            Self::One { version, .. } => (version, None),
422            Self::Two {
423                new_version,
424                old_version,
425                ..
426            } => (new_version, Some(old_version)),
427        }
428    }
429}
430
431impl<T, I> PreprocessedDistanceFunction<&VersionedPQVector<I>, Option<f32>>
432    for MultiQueryComputer<T, I>
433where
434    T: Deref<Target = FixedChunkPQTable>,
435    I: PQVersion,
436{
437    #[inline(always)]
438    fn evaluate_similarity(&self, x: &VersionedPQVector<I>) -> Option<f32> {
439        self.evaluate_similarity(x.reborrow())
440    }
441}
442
443impl<T, I> PreprocessedDistanceFunction<VersionedPQVectorRef<'_, I>, Option<f32>>
444    for MultiQueryComputer<T, I>
445where
446    T: Deref<Target = FixedChunkPQTable>,
447    I: PQVersion,
448{
449    fn evaluate_similarity(&self, x: VersionedPQVectorRef<'_, I>) -> Option<f32> {
450        match &self {
451            Self::One { computer, version } => {
452                if version != &x.version {
453                    None
454                } else {
455                    Some(computer.evaluate_similarity(x.data))
456                }
457            }
458            Self::Two {
459                new,
460                old,
461                new_version,
462                old_version,
463            } => {
464                if old_version == &x.version {
465                    Some(old.evaluate_similarity(x.data))
466                } else if new_version == &x.version {
467                    Some(new.evaluate_similarity(x.data))
468                } else {
469                    None
470                }
471            }
472        }
473    }
474}
475
476/// # Testing Strategies.
477///
478/// ## Distance Computations
479///
480/// At this point, we assume that the lower level distance functions are more-or-less
481/// accurate. That is, a given `QueryComputer` or VTable based distance work correctly.
482///
483/// The testing functions at this level are more designed for testing that versioned vectors
484/// get sent to the right location and that the error handling is correct.
485#[cfg(test)]
486mod tests {
487    use std::marker::PhantomData;
488
489    use approx::assert_relative_eq;
490    use diskann::utils::{IntoUsize, VectorRepr};
491    use diskann_vector::{Half, PreprocessedDistanceFunction};
492    use rand::{Rng, SeedableRng, distr::Distribution};
493    use rstest::rstest;
494
495    use super::{
496        super::test_utils::{self, TestDistribution},
497        *,
498    };
499
500    fn to_f32<T>(x: &[T]) -> Vec<f32>
501    where
502        T: Into<f32> + Copy,
503    {
504        x.iter().map(|i| (*i).into()).collect()
505    }
506
507    /////////////////////////
508    // Versioned PQ Vector //
509    /////////////////////////
510
511    #[test]
512    fn test_versioned_pq_vector() {
513        let vec = vec![1, 2, 3];
514        let ptr = vec.as_ptr();
515        let pq = VersionedPQVector::<usize>::new(vec, 10);
516        assert_eq!(*pq.version(), 10);
517        assert_eq!(pq.data().len(), 3);
518
519        let data_ptr = pq.data().as_ptr();
520        let pq_ref = pq.as_ref();
521        assert_eq!(pq_ref.version(), pq.version());
522        assert_eq!(data_ptr, ptr);
523        assert_eq!(
524            pq_ref.data().as_ptr(),
525            data_ptr,
526            "expected VersionedPQVectorRef to have the same underlying data as the \
527             original VersionedPQVector"
528        );
529
530        let pq_ref = pq.reborrow();
531        assert_eq!(pq_ref.version(), pq.version());
532        assert_eq!(data_ptr, ptr);
533        assert_eq!(
534            pq_ref.data().as_ptr(),
535            data_ptr,
536            "expected VersionedPQVectorRef to have the same underlying data as the \
537             original VersionedPQVector"
538        );
539    }
540
541    ////////////////
542    // MultiTable //
543    ////////////////
544
545    #[test]
546    fn test_table_error() {
547        let config = test_utils::TableConfig {
548            dim: 17,
549            pq_chunks: 4,
550            num_pivots: 20,
551            start_value: 10.0,
552            use_opq: false,
553        };
554
555        let new = test_utils::seed_pivot_table(config);
556        let old = test_utils::seed_pivot_table(config);
557
558        let result = MultiTable::two(&new, &old, 0, 0);
559        assert!(
560            matches!(result, Err(EqualVersionsError)),
561            "MultiTable should now allow construction of the Two variant with equal versions"
562        );
563    }
564
565    ///////////////////////////////////
566    // Distance Computer - One Table //
567    ///////////////////////////////////
568
569    /// Test that the table works correctl where there is one inner PQ table.
570    fn test_distance_computer_multi_with_one<R>(
571        computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
572        table: &FixedChunkPQTable,
573        config: &test_utils::TableConfig,
574        reference: &<f32 as VectorRepr>::Distance,
575        num_trials: usize,
576        rng: &mut R,
577    ) where
578        R: Rng,
579    {
580        // Check that there is just one version.
581        let (&version, should_be_none) = computer.versions();
582        assert!(
583            should_be_none.is_none(),
584            "expected just one schema in test computer"
585        );
586        let invalid_version = version.wrapping_add(1);
587
588        for _ in 0..num_trials {
589            let code0 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
590            let expected0 = test_utils::generate_expected_vector(
591                &code0,
592                table.get_chunk_offsets(),
593                config.start_value,
594            );
595
596            let code1 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
597            let expected1 = test_utils::generate_expected_vector(
598                &code1,
599                table.get_chunk_offsets(),
600                config.start_value,
601            );
602
603            let expected = reference.evaluate_similarity(&expected0, &expected1);
604
605            // Test full-precision/quant.
606            let got = computer
607                .evaluate_similarity(&*expected0, &VersionedPQVector::new(code1.clone(), version))
608                .expect("evaluate_similarity should return Some");
609            assert_eq!(got, expected);
610
611            let got = computer
612                .evaluate_similarity(&*expected1, &VersionedPQVector::new(code0.clone(), version))
613                .expect("evaluate_similarity should return Some");
614            assert_eq!(got, expected);
615
616            // Test quant/quant.
617            let got = computer
618                .evaluate_similarity(
619                    &VersionedPQVector::new(code0.clone(), version),
620                    &VersionedPQVector::new(code1.clone(), version),
621                )
622                .expect("evaluate_similarity should return Some");
623            assert_eq!(got, expected);
624
625            // Check that version mismatches return `None`.
626            let got = computer.evaluate_similarity(
627                &*expected0,
628                &VersionedPQVector::new(code0.clone(), invalid_version),
629            );
630            assert!(got.is_none(), "version mismatches should return `None`");
631
632            let got = computer.evaluate_similarity(
633                &VersionedPQVector::new(code0.clone(), invalid_version),
634                &VersionedPQVector::new(code1.clone(), version),
635            );
636            assert!(got.is_none(), "version mismatches should return `None`");
637
638            let got = computer.evaluate_similarity(
639                &VersionedPQVector::new(code0.clone(), version),
640                &VersionedPQVector::new(code1.clone(), invalid_version),
641            );
642            assert!(got.is_none(), "version mismatches should return `None`");
643        }
644    }
645
646    #[rstest]
647    fn test_multi_distance_computer_one(
648        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
649    ) {
650        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
651
652        let config = test_utils::TableConfig {
653            dim: 17,
654            pq_chunks: 4,
655            num_pivots: 20,
656            start_value: 10.0,
657            use_opq: false,
658        };
659
660        let table = test_utils::seed_pivot_table(config);
661
662        let version: usize = 0x625b215f82f38008;
663
664        let multi_table = MultiTable::one(&table, version);
665        let (n, o) = multi_table.versions();
666        assert_eq!(*n, version);
667        assert!(o.is_none());
668
669        let computer = MultiDistanceComputer::new(multi_table, metric).unwrap();
670
671        test_distance_computer_multi_with_one(
672            &computer,
673            &table,
674            &config,
675            &f32::distance(metric, None),
676            100,
677            &mut rng,
678        );
679    }
680
681    ////////////////////////////////////
682    // Distance Computer - Two Tables //
683    ////////////////////////////////////
684
685    /// Test that the table works correctly when there are two inner PQ tables.
686    #[allow(clippy::too_many_arguments)]
687    fn test_distance_computer_multi_with_two<R>(
688        computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
689        new: &FixedChunkPQTable,
690        old: &FixedChunkPQTable,
691        new_config: &test_utils::TableConfig,
692        old_config: &test_utils::TableConfig,
693        reference: &<f32 as VectorRepr>::Distance,
694        num_trials: usize,
695        rng: &mut R,
696    ) where
697        R: Rng,
698    {
699        // Check that there are indeed two versions registered.
700        let (&new_version, old_version) = computer.versions();
701        let &old_version = old_version.expect("expected two schemas in test computer");
702
703        for _ in 0..num_trials {
704            // Generate a code for the old schema
705            let old_code =
706                test_utils::generate_random_code(old_config.num_pivots, old_config.pq_chunks, rng);
707            let old_expected = test_utils::generate_expected_vector(
708                &old_code,
709                old.get_chunk_offsets(),
710                old_config.start_value,
711            );
712
713            // Generate a code for the new schema
714            let new_code =
715                test_utils::generate_random_code(new_config.num_pivots, new_config.pq_chunks, rng);
716            let new_expected = test_utils::generate_expected_vector(
717                &new_code,
718                new.get_chunk_offsets(),
719                new_config.start_value,
720            );
721
722            // Generate reference results.
723            let oo = reference.evaluate_similarity(&old_expected, &old_expected);
724            let nn = reference.evaluate_similarity(&new_expected, &new_expected);
725            let on = reference.evaluate_similarity(&old_expected, &new_expected);
726
727            // Quant + Quant
728            {
729                let got_oo_qq = computer.evaluate_similarity(
730                    &VersionedPQVector::new(old_code.clone(), old_version),
731                    &VersionedPQVector::new(old_code.clone(), old_version),
732                );
733                assert_eq!(got_oo_qq.unwrap(), oo);
734
735                let got_on_qq = computer.evaluate_similarity(
736                    &VersionedPQVector::new(old_code.clone(), old_version),
737                    &VersionedPQVector::new(new_code.clone(), new_version),
738                );
739                assert_eq!(got_on_qq.unwrap(), on);
740
741                let got_no_qq = computer.evaluate_similarity(
742                    &VersionedPQVector::new(new_code.clone(), new_version),
743                    &VersionedPQVector::new(old_code.clone(), old_version),
744                );
745                assert_eq!(got_no_qq.unwrap(), on);
746
747                let got_nn_qq = computer.evaluate_similarity(
748                    &VersionedPQVector::new(new_code.clone(), new_version),
749                    &VersionedPQVector::new(new_code.clone(), new_version),
750                );
751                assert_eq!(got_nn_qq.unwrap(), nn);
752            }
753
754            // Full Precision + Quant
755            {
756                let got_oo_qq = computer.evaluate_similarity(
757                    &*old_expected,
758                    &VersionedPQVector::new(old_code.clone(), old_version),
759                );
760                assert_eq!(got_oo_qq.unwrap(), oo);
761
762                let got_on_qq = computer.evaluate_similarity(
763                    &*old_expected,
764                    &VersionedPQVector::new(new_code.clone(), new_version),
765                );
766                assert_eq!(got_on_qq.unwrap(), on);
767
768                let got_no_qq = computer.evaluate_similarity(
769                    &*new_expected,
770                    &VersionedPQVector::new(old_code.clone(), old_version),
771                );
772                assert_eq!(got_no_qq.unwrap(), on);
773
774                let got_nn_qq = computer.evaluate_similarity(
775                    &*new_expected,
776                    &VersionedPQVector::new(new_code.clone(), new_version),
777                );
778                assert_eq!(got_nn_qq.unwrap(), nn);
779            }
780
781            // Ensure that version mismatches return `None` for all combinations.
782            let mut bad_version = old_version.wrapping_add(1);
783            if bad_version == new_version {
784                bad_version = bad_version.wrapping_add(1);
785            }
786
787            // mismatch for first argument.
788            let got = computer.evaluate_similarity(
789                VersionedPQVectorRef::new(&old_code, bad_version),
790                VersionedPQVectorRef::new(&new_code, new_version),
791            );
792            assert!(got.is_none());
793
794            // mismatch for second argument.
795            let got = computer.evaluate_similarity(
796                &VersionedPQVector::new(new_code.clone(), new_version),
797                &VersionedPQVector::new(old_code.clone(), bad_version),
798            );
799            assert!(got.is_none());
800
801            // mismatch for full precision.
802            let got = computer.evaluate_similarity(
803                &*new_expected,
804                &VersionedPQVector::new(old_code.clone(), bad_version),
805            );
806            assert!(got.is_none());
807        }
808    }
809
810    #[rstest]
811    fn test_multi_distance_computer_two(
812        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
813    ) {
814        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
815
816        let old_config = test_utils::TableConfig {
817            dim: 17,
818            pq_chunks: 4,
819            num_pivots: 20,
820            start_value: 10.0,
821            use_opq: false,
822        };
823
824        let new_config = test_utils::TableConfig {
825            dim: 17,
826            pq_chunks: 5,
827            num_pivots: 16,
828            start_value: 1.0,
829            use_opq: false,
830        };
831
832        let new = test_utils::seed_pivot_table(new_config);
833        let old = test_utils::seed_pivot_table(old_config);
834
835        let new_version: usize = 0x5a2b92a731766613;
836        let old_version: usize = 0x2fab58c9c8b73841;
837
838        let multi_table = MultiTable::two(&new, &old, new_version, old_version).unwrap();
839        let (n, o) = multi_table.versions();
840        assert_eq!(*n, new_version);
841        assert_eq!(*o.unwrap(), old_version);
842
843        let computer = MultiDistanceComputer::new(multi_table.clone(), metric).unwrap();
844        test_distance_computer_multi_with_two(
845            &computer,
846            &new,
847            &old,
848            &new_config,
849            &old_config,
850            &f32::distance(metric, None),
851            100,
852            &mut rng,
853        );
854    }
855
856    ///////////////////////////////////////////
857    // Distance Computer Construction Errors //
858    ///////////////////////////////////////////
859
860    #[rstest]
861    fn test_multi_distance_computer_opq_error(
862        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
863    ) {
864        let config_with_opq = test_utils::TableConfig {
865            dim: 17,
866            pq_chunks: 4,
867            num_pivots: 20,
868            start_value: 10.0,
869            use_opq: true,
870        };
871
872        let config = test_utils::TableConfig {
873            dim: 17,
874            pq_chunks: 4,
875            num_pivots: 20,
876            start_value: 10.0,
877            use_opq: false,
878        };
879
880        let expected_err = (DistanceComputerConstructionError::OPQNotSupported).to_string();
881        let table_with_opq = test_utils::seed_pivot_table(config_with_opq);
882        let table = test_utils::seed_pivot_table(config);
883
884        let schema = MultiTable::one(&table_with_opq, 0);
885        let result = MultiDistanceComputer::new(schema, metric);
886        assert!(result.is_err(), "expected OPQ to not be supported");
887        assert_eq!(result.unwrap_err().to_string(), expected_err);
888
889        // Try all combinations of tables with OPQ.
890        let schema = MultiTable::two(&table_with_opq, &table, 0, 1).unwrap();
891        let result = MultiDistanceComputer::new(schema, metric);
892        assert!(result.is_err(), "expected OPQ to not be supported");
893        assert_eq!(result.unwrap_err().to_string(), expected_err);
894
895        let schema = MultiTable::two(&table, &table_with_opq, 0, 1).unwrap();
896        let result = MultiDistanceComputer::new(schema, metric);
897        assert!(result.is_err(), "expected OPQ to not be supported");
898        assert_eq!(result.unwrap_err().to_string(), expected_err);
899
900        let schema = MultiTable::two(&table_with_opq, &table_with_opq, 0, 1).unwrap();
901        let result = MultiDistanceComputer::new(schema, metric);
902        assert!(result.is_err(), "expected OPQ to not be supported");
903        assert_eq!(result.unwrap_err().to_string(), expected_err);
904    }
905
906    ////////////////////////////////
907    // Query Computer - One Table //
908    ////////////////////////////////
909
910    #[allow(clippy::too_many_arguments)]
911    fn check_query_computer<R: Rng>(
912        computer: &MultiQueryComputer<&'_ FixedChunkPQTable, usize>,
913        table: &FixedChunkPQTable,
914        config: &test_utils::TableConfig,
915        query: &[f32],
916        version: usize,
917        rng: &mut R,
918        reference: &<f32 as VectorRepr>::Distance,
919        errors: test_utils::RelativeAndAbsolute,
920    ) {
921        // Generate a code for the old table.
922        let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
923        let expected_vector = test_utils::generate_expected_vector(
924            &code,
925            table.get_chunk_offsets(),
926            config.start_value,
927        );
928        let got = computer
929            .evaluate_similarity(&VersionedPQVector {
930                data: code,
931                version,
932            })
933            .unwrap();
934        let expected = reference.evaluate_similarity(query, &expected_vector);
935        assert_relative_eq!(
936            got,
937            expected,
938            epsilon = errors.absolute,
939            max_relative = errors.relative
940        );
941    }
942
943    fn test_query_computer_multi_with_one<'a, T, R>(
944        mut create: impl FnMut(usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
945        table: &'a FixedChunkPQTable,
946        config: &test_utils::TableConfig,
947        reference: &<f32 as VectorRepr>::Distance,
948        num_trials: usize,
949        rng: &mut R,
950        errors: test_utils::RelativeAndAbsolute,
951    ) where
952        T: Into<f32> + TestDistribution,
953        R: Rng,
954    {
955        let standard = rand::distr::StandardUniform {};
956        for _ in 0..num_trials {
957            let input: Vec<T> = T::generate(config.dim, rng);
958            let input_f32 = to_f32(&input);
959
960            let version: u64 = standard.sample(rng);
961            let version: usize = version.into_usize();
962            let invalid_version = version.wrapping_add(1);
963
964            let computer = create(version, &input);
965
966            assert_eq!(
967                computer.versions(),
968                (&version, None),
969                "expected the computer to only have one version"
970            );
971
972            for _ in 0..num_trials {
973                check_query_computer(
974                    &computer, table, config, &input_f32, version, rng, reference, errors,
975                );
976            }
977
978            // Check the error path on mismatched versions.
979            let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
980            let got =
981                computer.evaluate_similarity(VersionedPQVectorRef::new(&code, invalid_version));
982            assert!(got.is_none(), "Expected `None` for unmatched versions");
983        }
984    }
985
986    #[rstest]
987    fn test_query_computer_one<T>(
988        #[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<u8>, PhantomData::<i8>)]
989        _datatype: PhantomData<T>,
990        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
991    ) where
992        T: Into<f32> + TestDistribution,
993    {
994        let mut rng = rand::rngs::StdRng::seed_from_u64(0x6b53bef1bc26571e);
995
996        let config = test_utils::TableConfig {
997            dim: 17,
998            pq_chunks: 4,
999            num_pivots: 20,
1000            start_value: 10.0,
1001            use_opq: false,
1002        };
1003
1004        let table = test_utils::seed_pivot_table(config);
1005        let num_trials = 20;
1006
1007        let errors = test_utils::RelativeAndAbsolute {
1008            relative: 5.0e-5,
1009            absolute: 0.0,
1010        };
1011
1012        let create = |version: usize, query: &[T]| {
1013            let schema = MultiTable::one(&table, version);
1014            MultiQueryComputer::new(schema, metric, query).unwrap()
1015        };
1016        test_query_computer_multi_with_one(
1017            create,
1018            &table,
1019            &config,
1020            &f32::distance(metric, None),
1021            num_trials,
1022            &mut rng,
1023            errors,
1024        );
1025    }
1026
1027    /////////////////////////////////
1028    // Query Computer - Two Tables //
1029    /////////////////////////////////
1030
1031    #[allow(clippy::too_many_arguments)]
1032    fn test_query_computer_multi_with_two<'a, T, R>(
1033        create: impl Fn(usize, usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
1034        new: &'a FixedChunkPQTable,
1035        old: &'a FixedChunkPQTable,
1036        new_config: &test_utils::TableConfig,
1037        old_config: &test_utils::TableConfig,
1038        reference: &<f32 as VectorRepr>::Distance,
1039        num_trials: usize,
1040        rng: &mut R,
1041        errors: test_utils::RelativeAndAbsolute,
1042    ) where
1043        T: Into<f32> + TestDistribution,
1044        R: Rng,
1045    {
1046        let standard = rand::distr::StandardUniform {};
1047        for _ in 0..num_trials {
1048            let input: Vec<T> = T::generate(old_config.dim, rng);
1049            let input_f32: Vec<f32> = to_f32(&input);
1050
1051            // Create a computer with two random versions.
1052            let old_version: u64 = standard.sample(rng);
1053            let mut new_version: u64 = standard.sample(rng);
1054            while new_version == old_version {
1055                new_version = standard.sample(rng);
1056            }
1057
1058            let mut invalid_version: u64 = standard.sample(rng);
1059            while invalid_version == old_version || invalid_version == new_version {
1060                invalid_version = standard.sample(rng);
1061            }
1062
1063            let old_version = old_version.into_usize();
1064            let new_version = new_version.into_usize();
1065            let invalid_version = invalid_version.into_usize();
1066
1067            let computer = create(new_version, old_version, &input);
1068
1069            assert_eq!(
1070                computer.versions(),
1071                (&new_version, Some(&old_version)),
1072                "versions were not propagated successfully",
1073            );
1074
1075            for _ in 0..num_trials {
1076                check_query_computer(
1077                    &computer,
1078                    old,
1079                    old_config,
1080                    &input_f32,
1081                    old_version,
1082                    rng,
1083                    reference,
1084                    errors,
1085                );
1086
1087                check_query_computer(
1088                    &computer,
1089                    new,
1090                    new_config,
1091                    &input_f32,
1092                    new_version,
1093                    rng,
1094                    reference,
1095                    errors,
1096                );
1097
1098                let code = test_utils::generate_random_code(
1099                    old_config.num_pivots,
1100                    old_config.pq_chunks,
1101                    rng,
1102                );
1103                let got = computer.evaluate_similarity(&VersionedPQVector {
1104                    data: code,
1105                    version: invalid_version,
1106                });
1107                assert!(
1108                    got.is_none(),
1109                    "expected a distance computation with an invalid version to return None"
1110                );
1111            }
1112        }
1113    }
1114
1115    #[rstest]
1116    fn test_query_computer_two<T>(
1117        #[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<u8>, PhantomData::<i8>)]
1118        _datatype: PhantomData<T>,
1119        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
1120    ) where
1121        T: Into<f32> + TestDistribution,
1122    {
1123        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
1124
1125        let old_config = test_utils::TableConfig {
1126            dim: 17,
1127            pq_chunks: 4,
1128            num_pivots: 20,
1129            start_value: 10.0,
1130            use_opq: false,
1131        };
1132
1133        let new_config = test_utils::TableConfig {
1134            dim: 17,
1135            pq_chunks: 5,
1136            num_pivots: 16,
1137            start_value: 1.0,
1138            use_opq: false,
1139        };
1140
1141        let old = test_utils::seed_pivot_table(old_config);
1142        let new = test_utils::seed_pivot_table(new_config);
1143        let num_trials = 20;
1144
1145        let create = |new_version: usize, old_version: usize, query: &[T]| {
1146            let schema = MultiTable::two(&new, &old, new_version, old_version).unwrap();
1147            MultiQueryComputer::new(schema, metric, query).unwrap()
1148        };
1149
1150        let errors = test_utils::RelativeAndAbsolute {
1151            relative: 5.0e-5,
1152            absolute: 0.0,
1153        };
1154
1155        test_query_computer_multi_with_two(
1156            create,
1157            &new,
1158            &old,
1159            &new_config,
1160            &old_config,
1161            &f32::distance(metric, None),
1162            num_trials,
1163            &mut rng,
1164            errors,
1165        );
1166    }
1167}