Skip to main content

diskann_providers/model/pq/distance/
multi.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::ops::Deref;
7
8use diskann::ANNResult;
9use diskann_utils::Reborrow;
10use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
11use thiserror::Error;
12
13use super::{QueryComputer, dynamic::VTable};
14use crate::model::FixedChunkPQTable;
15
16pub trait PQVersion: Eq + Copy {}
17impl<T> PQVersion for T where T: Eq + Copy {}
18
19/// A PQ vector with an associated version.
20#[derive(Debug, Clone, PartialEq)]
21pub struct VersionedPQVector<I: PQVersion> {
22    data: Vec<u8>,
23    version: I,
24}
25
26impl<I> VersionedPQVector<I>
27where
28    I: PQVersion,
29{
30    /// Construct a new `VersionedPQVector` taking ownership of the provided data and version.
31    pub fn new(data: Vec<u8>, version: I) -> Self {
32        Self { data, version }
33    }
34
35    /// Return a `VersionedPQVectorRef` over the data owned by this vector.
36    pub fn as_ref(&self) -> VersionedPQVectorRef<'_, I> {
37        VersionedPQVectorRef::new(&self.data, self.version)
38    }
39
40    /// Return the version associated with this vector.
41    pub fn version(&self) -> &I {
42        &self.version
43    }
44
45    /// Return the raw underlying data.
46    pub fn data(&self) -> &[u8] {
47        &self.data
48    }
49
50    /// Return the components of the vector. This is a low-level API.
51    pub fn raw_mut(&mut self) -> (&mut Vec<u8>, &mut I) {
52        (&mut self.data, &mut self.version)
53    }
54}
55
56impl<'a, I> Reborrow<'a> for VersionedPQVector<I>
57where
58    I: PQVersion,
59{
60    type Target = VersionedPQVectorRef<'a, I>;
61    fn reborrow(&'a self) -> Self::Target {
62        self.as_ref()
63    }
64}
65
66/// A reference version of `VersionedPQVector`.
67#[derive(Debug, Clone, Copy)]
68pub struct VersionedPQVectorRef<'a, I: PQVersion> {
69    data: &'a [u8],
70    version: I,
71}
72
73impl<'a, I: PQVersion> VersionedPQVectorRef<'a, I> {
74    /// Construct a new `VersionedPQVectorRef` around the provided data.
75    pub fn new(data: &'a [u8], version: I) -> Self {
76        Self { data, version }
77    }
78
79    /// Return the version associated with this vector.
80    pub fn version(&self) -> &I {
81        &self.version
82    }
83
84    /// Return the raw underlying data.
85    pub fn data(&self) -> &[u8] {
86        self.data
87    }
88}
89
90/// A wrapper for `FixedChunkPQTable` that contains either one or two inner
91/// `FixedChunkPQTables` with associated versions.
92#[derive(Debug, Clone)]
93pub enum MultiTable<T, I>
94where
95    T: Deref<Target = FixedChunkPQTable>,
96    I: PQVersion,
97{
98    /// Only one table is present with an associated version.
99    One { table: T, version: I },
100    /// Two tables are present, an incoming "new" table and an outgoing "old" table.
101    /// The versions of these tables are recorded respectively in `new_version` and
102    /// `old_version`.
103    Two {
104        new: T,
105        old: T,
106        new_version: I,
107        old_version: I,
108    },
109}
110
111#[derive(Debug, Error)]
112#[error("provided versions must not be equal")]
113pub struct EqualVersionsError;
114
115impl<T, I> MultiTable<T, I>
116where
117    T: Deref<Target = FixedChunkPQTable>,
118    I: PQVersion,
119{
120    /// Construct a new `MultiTable` containing a single `FixedChunkPQTable`.
121    pub fn one(table: T, version: I) -> Self {
122        Self::One { table, version }
123    }
124
125    /// Construct a new `MultiTable` with two `FixedChunkPQTable`s.
126    ///
127    /// Returns an `Err` if the two provided versions are equal.
128    pub fn two(new: T, old: T, new_version: I, old_version: I) -> Result<Self, EqualVersionsError> {
129        if new_version == old_version {
130            Err(EqualVersionsError)
131        } else {
132            Ok(Self::Two {
133                new,
134                old,
135                new_version,
136                old_version,
137            })
138        }
139    }
140
141    /// Return the versions associated with the tables in this schema.
142    ///
143    /// The returned tuple depends on whether this table has one or two registerd schemas.
144    ///
145    /// * If there is only one schema, return the `(version, None)` where `version` is the
146    ///   version of the only schema.
147    /// * If there are two schema, return `(new_version, Some(old_version))` where
148    ///   `new_version` is the version of the most recently registered schema while
149    ///   `old_version` is the old version.
150    pub fn versions(&self) -> (&I, Option<&I>) {
151        match &self {
152            Self::One { version, .. } => (version, None),
153            Self::Two {
154                new_version,
155                old_version,
156                ..
157            } => (new_version, Some(old_version)),
158        }
159    }
160}
161
162/// A distance computer implementing
163///
164/// * `DistanceFunction<&[f32], &VersionedPQVector, Option<f32>`
165/// * `DistanceFunction<&VersionedPQVector, &VersionedPQVector, Option<f32>`
166///
167/// That can contain either one or two PQ schemas, disambiguating which schema to use based
168/// on the version numbers contained in the `VersionedPQVectors`.
169///
170/// Since this struct stores at most two PQ tables, that means there is the possibility
171/// that a PQ vector is provided that does not match either of the tables.
172///
173/// Returns `None` for distance computations when the version of the PQ vector does not
174/// match with any version in the local table.
175#[derive(Debug, Clone)]
176pub struct MultiDistanceComputer<T, I>
177where
178    T: Deref<Target = FixedChunkPQTable>,
179    I: PQVersion,
180{
181    table: MultiTable<T, I>,
182    vtable: VTable,
183}
184
185impl<T, I> MultiDistanceComputer<T, I>
186where
187    T: Deref<Target = FixedChunkPQTable>,
188    I: PQVersion,
189{
190    /// Construct a `MultiDistanceComputer` from the provided table implementing the
191    /// requested metric.
192    pub fn new(table: MultiTable<T, I>, metric: Metric) -> Self {
193        Self {
194            table,
195            vtable: VTable::new(metric),
196        }
197    }
198
199    /// Return the versions associated with the tables in this schema.
200    ///
201    /// The returned tuple depends on whether this table has one or two registerd schemas.
202    ///
203    /// * If there is only one schema, return the `(version, None)` where `version` is the
204    ///   version of the only schema.
205    /// * If there are two schema, return `(new_version, Some(old_version))` where
206    ///   `new_version` is the version of the most recently registered schema while
207    ///   `old_version` is the old version.
208    pub fn versions(&self) -> (&I, Option<&I>) {
209        self.table.versions()
210    }
211}
212
213impl<T, I> DistanceFunction<&[f32], &VersionedPQVector<I>, Option<f32>>
214    for MultiDistanceComputer<T, I>
215where
216    T: Deref<Target = FixedChunkPQTable>,
217    I: PQVersion,
218{
219    #[inline(always)]
220    fn evaluate_similarity(&self, x: &[f32], y: &VersionedPQVector<I>) -> Option<f32> {
221        self.evaluate_similarity(x, y.reborrow())
222    }
223}
224
225impl<T, I> DistanceFunction<&[f32], VersionedPQVectorRef<'_, I>, Option<f32>>
226    for MultiDistanceComputer<T, I>
227where
228    T: Deref<Target = FixedChunkPQTable>,
229    I: PQVersion,
230{
231    fn evaluate_similarity(&self, x: &[f32], y: VersionedPQVectorRef<'_, I>) -> Option<f32> {
232        match &self.table {
233            MultiTable::One { table, version } => {
234                if version != &y.version {
235                    None
236                } else {
237                    Some((self.vtable.distance_fn)(table, x, y.data))
238                }
239            }
240            MultiTable::Two {
241                old,
242                new,
243                old_version,
244                new_version,
245            } => {
246                if old_version == &y.version {
247                    Some((self.vtable.distance_fn)(old, x, y.data))
248                } else if new_version == &y.version {
249                    Some((self.vtable.distance_fn)(new, x, y.data))
250                } else {
251                    None
252                }
253            }
254        }
255    }
256}
257
258impl<T, I> DistanceFunction<&VersionedPQVector<I>, &VersionedPQVector<I>, Option<f32>>
259    for MultiDistanceComputer<T, I>
260where
261    T: Deref<Target = FixedChunkPQTable>,
262    I: PQVersion,
263{
264    #[inline(always)]
265    fn evaluate_similarity(
266        &self,
267        x: &VersionedPQVector<I>,
268        y: &VersionedPQVector<I>,
269    ) -> Option<f32> {
270        self.evaluate_similarity(x.reborrow(), y.reborrow())
271    }
272}
273
274/// Compute the distance between two versioned quantized vectors.
275///
276/// If one schema is currently being used and at least one of the versions of the argument
277/// vectors does not match, return `None`.
278///
279/// If two schemas are used and at least one of the versions of the argument vectors is not
280/// recognized, then return `None`.
281impl<T, I> DistanceFunction<VersionedPQVectorRef<'_, I>, VersionedPQVectorRef<'_, I>, Option<f32>>
282    for MultiDistanceComputer<T, I>
283where
284    T: Deref<Target = FixedChunkPQTable>,
285    I: PQVersion,
286{
287    fn evaluate_similarity(
288        &self,
289        x: VersionedPQVectorRef<'_, I>,
290        y: VersionedPQVectorRef<'_, I>,
291    ) -> Option<f32> {
292        match &self.table {
293            MultiTable::One { table, version } => {
294                if (&x.version != version) || (&y.version != version) {
295                    None
296                } else {
297                    Some((self.vtable.distance_fn_qq)(table, x.data, y.data))
298                }
299            }
300            MultiTable::Two {
301                new,
302                old,
303                new_version,
304                old_version,
305            } => {
306                let x_new = &x.version == new_version;
307                let x_old = &x.version == old_version;
308
309                let y_new = &y.version == new_version;
310                let y_old = &y.version == old_version;
311
312                if x_old {
313                    if y_old {
314                        // Both Old
315                        Some((self.vtable.distance_fn_qq)(old, x.data, y.data))
316                    } else if y_new {
317                        let x_full = old.inflate_vector(x.data);
318                        // X Old, Y New
319                        Some((self.vtable.distance_fn)(new, &x_full, y.data))
320                    } else {
321                        None
322                    }
323                } else if x_new {
324                    if y_old {
325                        let y_full = old.inflate_vector(y.data);
326                        // X New, Y Old
327                        Some((self.vtable.distance_fn)(new, &y_full, x.data))
328                    } else if y_new {
329                        // Both new
330                        Some((self.vtable.distance_fn_qq)(new, x.data, y.data))
331                    } else {
332                        None
333                    }
334                } else {
335                    None
336                }
337            }
338        }
339    }
340}
341
342////////////////////
343// Query Computer //
344////////////////////
345
346/// A `PreprocessedDistanceFunction` containing either one or two PQ schemas, capable of
347/// performing distance computations with either. Upon a version mismatch with a query,
348/// `None` is returned.
349#[derive(Debug)]
350pub enum MultiQueryComputer<T, I>
351where
352    T: Deref<Target = FixedChunkPQTable>,
353    I: PQVersion,
354{
355    One {
356        computer: QueryComputer<T>,
357        version: I,
358    },
359    Two {
360        new: QueryComputer<T>,
361        old: QueryComputer<T>,
362        new_version: I,
363        old_version: I,
364    },
365}
366
367impl<T, I> MultiQueryComputer<T, I>
368where
369    T: Deref<Target = FixedChunkPQTable>,
370    I: PQVersion,
371{
372    /// Construct a new `MultiQueryComputer` with the requested metric and query.
373    pub fn new(table: MultiTable<T, I>, metric: Metric, query: &[f32]) -> ANNResult<Self> {
374        let s = match table {
375            MultiTable::One { table, version } => Self::One {
376                computer: { QueryComputer::new(table, metric, query, None)? },
377                version,
378            },
379            MultiTable::Two {
380                new,
381                old,
382                new_version,
383                old_version,
384            } => Self::Two {
385                new: { QueryComputer::new(new, metric, query, None)? },
386                old: { QueryComputer::new(old, metric, query, None)? },
387                new_version,
388                old_version,
389            },
390        };
391        Ok(s)
392    }
393
394    /// Return a tuple that is either:
395    /// 1. (The only table version, None)
396    /// 2. (New Table Version, Old Table Version)
397    pub fn versions(&self) -> (&I, Option<&I>) {
398        match &self {
399            Self::One { version, .. } => (version, None),
400            Self::Two {
401                new_version,
402                old_version,
403                ..
404            } => (new_version, Some(old_version)),
405        }
406    }
407}
408
409impl<T, I> PreprocessedDistanceFunction<&VersionedPQVector<I>, Option<f32>>
410    for MultiQueryComputer<T, I>
411where
412    T: Deref<Target = FixedChunkPQTable>,
413    I: PQVersion,
414{
415    #[inline(always)]
416    fn evaluate_similarity(&self, x: &VersionedPQVector<I>) -> Option<f32> {
417        self.evaluate_similarity(x.reborrow())
418    }
419}
420
421impl<T, I> PreprocessedDistanceFunction<VersionedPQVectorRef<'_, I>, Option<f32>>
422    for MultiQueryComputer<T, I>
423where
424    T: Deref<Target = FixedChunkPQTable>,
425    I: PQVersion,
426{
427    fn evaluate_similarity(&self, x: VersionedPQVectorRef<'_, I>) -> Option<f32> {
428        match &self {
429            Self::One { computer, version } => {
430                if version != &x.version {
431                    None
432                } else {
433                    Some(computer.evaluate_similarity(x.data))
434                }
435            }
436            Self::Two {
437                new,
438                old,
439                new_version,
440                old_version,
441            } => {
442                if old_version == &x.version {
443                    Some(old.evaluate_similarity(x.data))
444                } else if new_version == &x.version {
445                    Some(new.evaluate_similarity(x.data))
446                } else {
447                    None
448                }
449            }
450        }
451    }
452}
453
454/// # Testing Strategies.
455///
456/// ## Distance Computations
457///
458/// At this point, we assume that the lower level distance functions are more-or-less
459/// accurate. That is, a given `QueryComputer` or VTable based distance work correctly.
460///
461/// The testing functions at this level are more designed for testing that versioned vectors
462/// get sent to the right location and that the error handling is correct.
463#[cfg(test)]
464mod tests {
465    use approx::assert_relative_eq;
466    use diskann::utils::{IntoUsize, VectorRepr};
467    use diskann_vector::PreprocessedDistanceFunction;
468    use rand::{Rng, SeedableRng, distr::Distribution};
469    use rstest::rstest;
470
471    use super::{
472        super::test_utils::{self, TestDistribution},
473        *,
474    };
475
476    /////////////////////////
477    // Versioned PQ Vector //
478    /////////////////////////
479
480    #[test]
481    fn test_versioned_pq_vector() {
482        let vec = vec![1, 2, 3];
483        let ptr = vec.as_ptr();
484        let pq = VersionedPQVector::<usize>::new(vec, 10);
485        assert_eq!(*pq.version(), 10);
486        assert_eq!(pq.data().len(), 3);
487
488        let data_ptr = pq.data().as_ptr();
489        let pq_ref = pq.as_ref();
490        assert_eq!(pq_ref.version(), pq.version());
491        assert_eq!(data_ptr, ptr);
492        assert_eq!(
493            pq_ref.data().as_ptr(),
494            data_ptr,
495            "expected VersionedPQVectorRef to have the same underlying data as the \
496             original VersionedPQVector"
497        );
498
499        let pq_ref = pq.reborrow();
500        assert_eq!(pq_ref.version(), pq.version());
501        assert_eq!(data_ptr, ptr);
502        assert_eq!(
503            pq_ref.data().as_ptr(),
504            data_ptr,
505            "expected VersionedPQVectorRef to have the same underlying data as the \
506             original VersionedPQVector"
507        );
508    }
509
510    ////////////////
511    // MultiTable //
512    ////////////////
513
514    #[test]
515    fn test_table_error() {
516        let config = test_utils::TableConfig {
517            dim: 17,
518            pq_chunks: 4,
519            num_pivots: 20,
520            start_value: 10.0,
521        };
522
523        let new = test_utils::seed_pivot_table(config);
524        let old = test_utils::seed_pivot_table(config);
525
526        let result = MultiTable::two(&new, &old, 0, 0);
527        assert!(
528            matches!(result, Err(EqualVersionsError)),
529            "MultiTable should now allow construction of the Two variant with equal versions"
530        );
531    }
532
533    ///////////////////////////////////
534    // Distance Computer - One Table //
535    ///////////////////////////////////
536
537    /// Test that the table works correctl where there is one inner PQ table.
538    fn test_distance_computer_multi_with_one<R>(
539        computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
540        table: &FixedChunkPQTable,
541        config: &test_utils::TableConfig,
542        reference: &<f32 as VectorRepr>::Distance,
543        num_trials: usize,
544        rng: &mut R,
545    ) where
546        R: Rng,
547    {
548        // Check that there is just one version.
549        let (&version, should_be_none) = computer.versions();
550        assert!(
551            should_be_none.is_none(),
552            "expected just one schema in test computer"
553        );
554        let invalid_version = version.wrapping_add(1);
555
556        for _ in 0..num_trials {
557            let code0 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
558            let expected0 = test_utils::generate_expected_vector(
559                &code0,
560                table.get_chunk_offsets(),
561                config.start_value,
562            );
563
564            let code1 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
565            let expected1 = test_utils::generate_expected_vector(
566                &code1,
567                table.get_chunk_offsets(),
568                config.start_value,
569            );
570
571            let expected = reference.evaluate_similarity(&*expected0, &*expected1);
572
573            // Test full-precision/quant.
574            let got = computer
575                .evaluate_similarity(&*expected0, &VersionedPQVector::new(code1.clone(), version))
576                .expect("evaluate_similarity should return Some");
577            assert_eq!(got, expected);
578
579            let got = computer
580                .evaluate_similarity(&*expected1, &VersionedPQVector::new(code0.clone(), version))
581                .expect("evaluate_similarity should return Some");
582            assert_eq!(got, expected);
583
584            // Test quant/quant.
585            let got = computer
586                .evaluate_similarity(
587                    &VersionedPQVector::new(code0.clone(), version),
588                    &VersionedPQVector::new(code1.clone(), version),
589                )
590                .expect("evaluate_similarity should return Some");
591            assert_eq!(got, expected);
592
593            // Check that version mismatches return `None`.
594            let got = computer.evaluate_similarity(
595                &*expected0,
596                &VersionedPQVector::new(code0.clone(), invalid_version),
597            );
598            assert!(got.is_none(), "version mismatches should return `None`");
599
600            let got = computer.evaluate_similarity(
601                &VersionedPQVector::new(code0.clone(), invalid_version),
602                &VersionedPQVector::new(code1.clone(), version),
603            );
604            assert!(got.is_none(), "version mismatches should return `None`");
605
606            let got = computer.evaluate_similarity(
607                &VersionedPQVector::new(code0.clone(), version),
608                &VersionedPQVector::new(code1.clone(), invalid_version),
609            );
610            assert!(got.is_none(), "version mismatches should return `None`");
611        }
612    }
613
614    #[rstest]
615    fn test_multi_distance_computer_one(
616        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
617    ) {
618        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
619
620        let config = test_utils::TableConfig {
621            dim: 17,
622            pq_chunks: 4,
623            num_pivots: 20,
624            start_value: 10.0,
625        };
626
627        let table = test_utils::seed_pivot_table(config);
628
629        let version: usize = 0x625b215f82f38008;
630
631        let multi_table = MultiTable::one(&table, version);
632        let (n, o) = multi_table.versions();
633        assert_eq!(*n, version);
634        assert!(o.is_none());
635
636        let computer = MultiDistanceComputer::new(multi_table, metric);
637
638        test_distance_computer_multi_with_one(
639            &computer,
640            &table,
641            &config,
642            &f32::distance(metric, None),
643            100,
644            &mut rng,
645        );
646    }
647
648    ////////////////////////////////////
649    // Distance Computer - Two Tables //
650    ////////////////////////////////////
651
652    /// Test that the table works correctly when there are two inner PQ tables.
653    #[allow(clippy::too_many_arguments)]
654    fn test_distance_computer_multi_with_two<R>(
655        computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
656        new: &FixedChunkPQTable,
657        old: &FixedChunkPQTable,
658        new_config: &test_utils::TableConfig,
659        old_config: &test_utils::TableConfig,
660        reference: &<f32 as VectorRepr>::Distance,
661        num_trials: usize,
662        rng: &mut R,
663    ) where
664        R: Rng,
665    {
666        // Check that there are indeed two versions registered.
667        let (&new_version, old_version) = computer.versions();
668        let &old_version = old_version.expect("expected two schemas in test computer");
669
670        for _ in 0..num_trials {
671            // Generate a code for the old schema
672            let old_code =
673                test_utils::generate_random_code(old_config.num_pivots, old_config.pq_chunks, rng);
674            let old_expected = test_utils::generate_expected_vector(
675                &old_code,
676                old.get_chunk_offsets(),
677                old_config.start_value,
678            );
679
680            // Generate a code for the new schema
681            let new_code =
682                test_utils::generate_random_code(new_config.num_pivots, new_config.pq_chunks, rng);
683            let new_expected = test_utils::generate_expected_vector(
684                &new_code,
685                new.get_chunk_offsets(),
686                new_config.start_value,
687            );
688
689            // Generate reference results.
690            let oo = reference.evaluate_similarity(&*old_expected, &*old_expected);
691            let nn = reference.evaluate_similarity(&*new_expected, &*new_expected);
692            let on = reference.evaluate_similarity(&*old_expected, &*new_expected);
693
694            // Quant + Quant
695            {
696                let got_oo_qq = computer.evaluate_similarity(
697                    &VersionedPQVector::new(old_code.clone(), old_version),
698                    &VersionedPQVector::new(old_code.clone(), old_version),
699                );
700                assert_eq!(got_oo_qq.unwrap(), oo);
701
702                let got_on_qq = computer.evaluate_similarity(
703                    &VersionedPQVector::new(old_code.clone(), old_version),
704                    &VersionedPQVector::new(new_code.clone(), new_version),
705                );
706                assert_eq!(got_on_qq.unwrap(), on);
707
708                let got_no_qq = computer.evaluate_similarity(
709                    &VersionedPQVector::new(new_code.clone(), new_version),
710                    &VersionedPQVector::new(old_code.clone(), old_version),
711                );
712                assert_eq!(got_no_qq.unwrap(), on);
713
714                let got_nn_qq = computer.evaluate_similarity(
715                    &VersionedPQVector::new(new_code.clone(), new_version),
716                    &VersionedPQVector::new(new_code.clone(), new_version),
717                );
718                assert_eq!(got_nn_qq.unwrap(), nn);
719            }
720
721            // Full Precision + Quant
722            {
723                let got_oo_qq = computer.evaluate_similarity(
724                    &*old_expected,
725                    &VersionedPQVector::new(old_code.clone(), old_version),
726                );
727                assert_eq!(got_oo_qq.unwrap(), oo);
728
729                let got_on_qq = computer.evaluate_similarity(
730                    &*old_expected,
731                    &VersionedPQVector::new(new_code.clone(), new_version),
732                );
733                assert_eq!(got_on_qq.unwrap(), on);
734
735                let got_no_qq = computer.evaluate_similarity(
736                    &*new_expected,
737                    &VersionedPQVector::new(old_code.clone(), old_version),
738                );
739                assert_eq!(got_no_qq.unwrap(), on);
740
741                let got_nn_qq = computer.evaluate_similarity(
742                    &*new_expected,
743                    &VersionedPQVector::new(new_code.clone(), new_version),
744                );
745                assert_eq!(got_nn_qq.unwrap(), nn);
746            }
747
748            // Ensure that version mismatches return `None` for all combinations.
749            let mut bad_version = old_version.wrapping_add(1);
750            if bad_version == new_version {
751                bad_version = bad_version.wrapping_add(1);
752            }
753
754            // mismatch for first argument.
755            let got = computer.evaluate_similarity(
756                VersionedPQVectorRef::new(&old_code, bad_version),
757                VersionedPQVectorRef::new(&new_code, new_version),
758            );
759            assert!(got.is_none());
760
761            // mismatch for second argument.
762            let got = computer.evaluate_similarity(
763                &VersionedPQVector::new(new_code.clone(), new_version),
764                &VersionedPQVector::new(old_code.clone(), bad_version),
765            );
766            assert!(got.is_none());
767
768            // mismatch for full precision.
769            let got = computer.evaluate_similarity(
770                &*new_expected,
771                &VersionedPQVector::new(old_code.clone(), bad_version),
772            );
773            assert!(got.is_none());
774        }
775    }
776
777    #[rstest]
778    fn test_multi_distance_computer_two(
779        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
780    ) {
781        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
782
783        let old_config = test_utils::TableConfig {
784            dim: 17,
785            pq_chunks: 4,
786            num_pivots: 20,
787            start_value: 10.0,
788        };
789
790        let new_config = test_utils::TableConfig {
791            dim: 17,
792            pq_chunks: 5,
793            num_pivots: 16,
794            start_value: 1.0,
795        };
796
797        let new = test_utils::seed_pivot_table(new_config);
798        let old = test_utils::seed_pivot_table(old_config);
799
800        let new_version: usize = 0x5a2b92a731766613;
801        let old_version: usize = 0x2fab58c9c8b73841;
802
803        let multi_table = MultiTable::two(&new, &old, new_version, old_version).unwrap();
804        let (n, o) = multi_table.versions();
805        assert_eq!(*n, new_version);
806        assert_eq!(*o.unwrap(), old_version);
807
808        let computer = MultiDistanceComputer::new(multi_table.clone(), metric);
809        test_distance_computer_multi_with_two(
810            &computer,
811            &new,
812            &old,
813            &new_config,
814            &old_config,
815            &f32::distance(metric, None),
816            100,
817            &mut rng,
818        );
819    }
820
821    ////////////////////////////////
822    // Query Computer - One Table //
823    ////////////////////////////////
824
825    #[allow(clippy::too_many_arguments)]
826    fn check_query_computer<R: Rng>(
827        computer: &MultiQueryComputer<&'_ FixedChunkPQTable, usize>,
828        table: &FixedChunkPQTable,
829        config: &test_utils::TableConfig,
830        query: &[f32],
831        version: usize,
832        rng: &mut R,
833        reference: &<f32 as VectorRepr>::Distance,
834        errors: test_utils::RelativeAndAbsolute,
835    ) {
836        // Generate a code for the old table.
837        let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
838        let expected_vector = test_utils::generate_expected_vector(
839            &code,
840            table.get_chunk_offsets(),
841            config.start_value,
842        );
843        let got = computer
844            .evaluate_similarity(&VersionedPQVector {
845                data: code,
846                version,
847            })
848            .unwrap();
849        let expected = reference.evaluate_similarity(query, &expected_vector);
850        assert_relative_eq!(
851            got,
852            expected,
853            epsilon = errors.absolute,
854            max_relative = errors.relative
855        );
856    }
857
858    fn test_query_computer_multi_with_one<'a, R>(
859        mut create: impl FnMut(usize, &[f32]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
860        table: &'a FixedChunkPQTable,
861        config: &test_utils::TableConfig,
862        reference: &<f32 as VectorRepr>::Distance,
863        num_trials: usize,
864        rng: &mut R,
865        errors: test_utils::RelativeAndAbsolute,
866    ) where
867        R: Rng,
868    {
869        let standard = rand::distr::StandardUniform {};
870        for _ in 0..num_trials {
871            let input_f32: Vec<f32> = f32::generate(config.dim, rng);
872
873            let version: u64 = standard.sample(rng);
874            let version: usize = version.into_usize();
875            let invalid_version = version.wrapping_add(1);
876
877            let computer = create(version, &input_f32);
878
879            assert_eq!(
880                computer.versions(),
881                (&version, None),
882                "expected the computer to only have one version"
883            );
884
885            for _ in 0..num_trials {
886                check_query_computer(
887                    &computer, table, config, &input_f32, version, rng, reference, errors,
888                );
889            }
890
891            // Check the error path on mismatched versions.
892            let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
893            let got =
894                computer.evaluate_similarity(VersionedPQVectorRef::new(&code, invalid_version));
895            assert!(got.is_none(), "Expected `None` for unmatched versions");
896        }
897    }
898
899    #[rstest]
900    fn test_query_computer_one(
901        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
902    ) {
903        let mut rng = rand::rngs::StdRng::seed_from_u64(0x6b53bef1bc26571e);
904
905        let config = test_utils::TableConfig {
906            dim: 17,
907            pq_chunks: 4,
908            num_pivots: 20,
909            start_value: 10.0,
910        };
911
912        let table = test_utils::seed_pivot_table(config);
913        let num_trials = 20;
914
915        let errors = test_utils::RelativeAndAbsolute {
916            relative: 5.0e-5,
917            absolute: 0.0,
918        };
919
920        let create = |version: usize, query: &[f32]| {
921            let schema = MultiTable::one(&table, version);
922            MultiQueryComputer::new(schema, metric, query).unwrap()
923        };
924        test_query_computer_multi_with_one(
925            create,
926            &table,
927            &config,
928            &f32::distance(metric, None),
929            num_trials,
930            &mut rng,
931            errors,
932        );
933    }
934
935    /////////////////////////////////
936    // Query Computer - Two Tables //
937    /////////////////////////////////
938
939    #[allow(clippy::too_many_arguments)]
940    fn test_query_computer_multi_with_two<'a, R>(
941        create: impl Fn(usize, usize, &[f32]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
942        new: &'a FixedChunkPQTable,
943        old: &'a FixedChunkPQTable,
944        new_config: &test_utils::TableConfig,
945        old_config: &test_utils::TableConfig,
946        reference: &<f32 as VectorRepr>::Distance,
947        num_trials: usize,
948        rng: &mut R,
949        errors: test_utils::RelativeAndAbsolute,
950    ) where
951        R: Rng,
952    {
953        let standard = rand::distr::StandardUniform {};
954        for _ in 0..num_trials {
955            let input_f32: Vec<f32> = f32::generate(old_config.dim, rng);
956
957            // Create a computer with two random versions.
958            let old_version: u64 = standard.sample(rng);
959            let mut new_version: u64 = standard.sample(rng);
960            while new_version == old_version {
961                new_version = standard.sample(rng);
962            }
963
964            let mut invalid_version: u64 = standard.sample(rng);
965            while invalid_version == old_version || invalid_version == new_version {
966                invalid_version = standard.sample(rng);
967            }
968
969            let old_version = old_version.into_usize();
970            let new_version = new_version.into_usize();
971            let invalid_version = invalid_version.into_usize();
972
973            let computer = create(new_version, old_version, &input_f32);
974
975            assert_eq!(
976                computer.versions(),
977                (&new_version, Some(&old_version)),
978                "versions were not propagated successfully",
979            );
980
981            for _ in 0..num_trials {
982                check_query_computer(
983                    &computer,
984                    old,
985                    old_config,
986                    &input_f32,
987                    old_version,
988                    rng,
989                    reference,
990                    errors,
991                );
992
993                check_query_computer(
994                    &computer,
995                    new,
996                    new_config,
997                    &input_f32,
998                    new_version,
999                    rng,
1000                    reference,
1001                    errors,
1002                );
1003
1004                let code = test_utils::generate_random_code(
1005                    old_config.num_pivots,
1006                    old_config.pq_chunks,
1007                    rng,
1008                );
1009                let got = computer.evaluate_similarity(&VersionedPQVector {
1010                    data: code,
1011                    version: invalid_version,
1012                });
1013                assert!(
1014                    got.is_none(),
1015                    "expected a distance computation with an invalid version to return None"
1016                );
1017            }
1018        }
1019    }
1020
1021    #[rstest]
1022    fn test_query_computer_two(
1023        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
1024    ) {
1025        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
1026
1027        let old_config = test_utils::TableConfig {
1028            dim: 17,
1029            pq_chunks: 4,
1030            num_pivots: 20,
1031            start_value: 10.0,
1032        };
1033
1034        let new_config = test_utils::TableConfig {
1035            dim: 17,
1036            pq_chunks: 5,
1037            num_pivots: 16,
1038            start_value: 1.0,
1039        };
1040
1041        let old = test_utils::seed_pivot_table(old_config);
1042        let new = test_utils::seed_pivot_table(new_config);
1043        let num_trials = 20;
1044
1045        let create = |new_version: usize, old_version: usize, query: &[f32]| {
1046            let schema = MultiTable::two(&new, &old, new_version, old_version).unwrap();
1047            MultiQueryComputer::new(schema, metric, query).unwrap()
1048        };
1049
1050        let errors = test_utils::RelativeAndAbsolute {
1051            relative: 5.0e-5,
1052            absolute: 0.0,
1053        };
1054
1055        test_query_computer_multi_with_two(
1056            create,
1057            &new,
1058            &old,
1059            &new_config,
1060            &old_config,
1061            &f32::distance(metric, None),
1062            num_trials,
1063            &mut rng,
1064            errors,
1065        );
1066    }
1067}