Skip to main content

diskann_providers/model/pq/distance/
multi.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::ops::Deref;
7
8use diskann::ANNResult;
9use diskann_utils::Reborrow;
10use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
11use thiserror::Error;
12
13use super::{QueryComputer, dynamic::VTable};
14use crate::model::FixedChunkPQTable;
15
16pub trait PQVersion: Eq + Copy {}
17impl<T> PQVersion for T where T: Eq + Copy {}
18
19/// A PQ vector with an associated version.
20#[derive(Debug, Clone, PartialEq)]
21pub struct VersionedPQVector<I: PQVersion> {
22    data: Vec<u8>,
23    version: I,
24}
25
26impl<I> VersionedPQVector<I>
27where
28    I: PQVersion,
29{
30    /// Construct a new `VersionedPQVector` taking ownership of the provided data and version.
31    pub fn new(data: Vec<u8>, version: I) -> Self {
32        Self { data, version }
33    }
34
35    /// Return a `VersionedPQVectorRef` over the data owned by this vector.
36    pub fn as_ref(&self) -> VersionedPQVectorRef<'_, I> {
37        VersionedPQVectorRef::new(&self.data, self.version)
38    }
39
40    /// Return the version associated with this vector.
41    pub fn version(&self) -> &I {
42        &self.version
43    }
44
45    /// Return the raw underlying data.
46    pub fn data(&self) -> &[u8] {
47        &self.data
48    }
49
50    /// Return the components of the vector. This is a low-level API.
51    pub fn raw_mut(&mut self) -> (&mut Vec<u8>, &mut I) {
52        (&mut self.data, &mut self.version)
53    }
54}
55
56impl<'a, I> Reborrow<'a> for VersionedPQVector<I>
57where
58    I: PQVersion,
59{
60    type Target = VersionedPQVectorRef<'a, I>;
61    fn reborrow(&'a self) -> Self::Target {
62        self.as_ref()
63    }
64}
65
66/// A reference version of `VersionedPQVector`.
67#[derive(Debug, Clone, Copy)]
68pub struct VersionedPQVectorRef<'a, I: PQVersion> {
69    data: &'a [u8],
70    version: I,
71}
72
73impl<'a, I: PQVersion> VersionedPQVectorRef<'a, I> {
74    /// Construct a new `VersionedPQVectorRef` around the provided data.
75    pub fn new(data: &'a [u8], version: I) -> Self {
76        Self { data, version }
77    }
78
79    /// Return the version associated with this vector.
80    pub fn version(&self) -> &I {
81        &self.version
82    }
83
84    /// Return the raw underlying data.
85    pub fn data(&self) -> &[u8] {
86        self.data
87    }
88}
89
90/// A wrapper for `FixedChunkPQTable` that contains either one or two inner
91/// `FixedChunkPQTables` with associated versions.
92#[derive(Debug, Clone)]
93pub enum MultiTable<T, I>
94where
95    T: Deref<Target = FixedChunkPQTable>,
96    I: PQVersion,
97{
98    /// Only one table is present with an associated version.
99    One { table: T, version: I },
100    /// Two tables are present, an incoming "new" table and an outgoing "old" table.
101    /// The versions of these tables are recorded respectively in `new_version` and
102    /// `old_version`.
103    Two {
104        new: T,
105        old: T,
106        new_version: I,
107        old_version: I,
108    },
109}
110
111#[derive(Debug, Error)]
112#[error("provided versions must not be equal")]
113pub struct EqualVersionsError;
114
115impl<T, I> MultiTable<T, I>
116where
117    T: Deref<Target = FixedChunkPQTable>,
118    I: PQVersion,
119{
120    /// Construct a new `MultiTable` containing a single `FixedChunkPQTable`.
121    pub fn one(table: T, version: I) -> Self {
122        Self::One { table, version }
123    }
124
125    /// Construct a new `MultiTable` with two `FixedChunkPQTable`s.
126    ///
127    /// Returns an `Err` if the two provided versions are equal.
128    pub fn two(new: T, old: T, new_version: I, old_version: I) -> Result<Self, EqualVersionsError> {
129        if new_version == old_version {
130            Err(EqualVersionsError)
131        } else {
132            Ok(Self::Two {
133                new,
134                old,
135                new_version,
136                old_version,
137            })
138        }
139    }
140
141    /// Return the versions associated with the tables in this schema.
142    ///
143    /// The returned tuple depends on whether this table has one or two registerd schemas.
144    ///
145    /// * If there is only one schema, return the `(version, None)` where `version` is the
146    ///   version of the only schema.
147    /// * If there are two schema, return `(new_version, Some(old_version))` where
148    ///   `new_version` is the version of the most recently registered schema while
149    ///   `old_version` is the old version.
150    pub fn versions(&self) -> (&I, Option<&I>) {
151        match &self {
152            Self::One { version, .. } => (version, None),
153            Self::Two {
154                new_version,
155                old_version,
156                ..
157            } => (new_version, Some(old_version)),
158        }
159    }
160}
161
162/// A distance computer implementing
163///
164/// * `DistanceFunction<&[f32], &VersionedPQVector, Option<f32>`
165/// * `DistanceFunction<&VersionedPQVector, &VersionedPQVector, Option<f32>`
166///
167/// That can contain either one or two PQ schemas, disambiguating which schema to use based
168/// on the version numbers contained in the `VersionedPQVectors`.
169///
170/// Since this struct stores at most two PQ tables, that means there is the possibility
171/// that a PQ vector is provided that does not match either of the tables.
172///
173/// Returns `None` for distance computations when the version of the PQ vector does not
174/// match with any version in the local table.
175#[derive(Debug, Clone)]
176pub struct MultiDistanceComputer<T, I>
177where
178    T: Deref<Target = FixedChunkPQTable>,
179    I: PQVersion,
180{
181    table: MultiTable<T, I>,
182    vtable: VTable,
183}
184
185impl<T, I> MultiDistanceComputer<T, I>
186where
187    T: Deref<Target = FixedChunkPQTable>,
188    I: PQVersion,
189{
190    /// Construct a `MultiDistanceComputer` from the provided table implementing the
191    /// requested metric.
192    pub fn new(table: MultiTable<T, I>, metric: Metric) -> Self {
193        Self {
194            table,
195            vtable: VTable::new(metric),
196        }
197    }
198
199    /// Return the versions associated with the tables in this schema.
200    ///
201    /// The returned tuple depends on whether this table has one or two registerd schemas.
202    ///
203    /// * If there is only one schema, return the `(version, None)` where `version` is the
204    ///   version of the only schema.
205    /// * If there are two schema, return `(new_version, Some(old_version))` where
206    ///   `new_version` is the version of the most recently registered schema while
207    ///   `old_version` is the old version.
208    pub fn versions(&self) -> (&I, Option<&I>) {
209        self.table.versions()
210    }
211}
212
213impl<T, I> DistanceFunction<&[f32], &VersionedPQVector<I>, Option<f32>>
214    for MultiDistanceComputer<T, I>
215where
216    T: Deref<Target = FixedChunkPQTable>,
217    I: PQVersion,
218{
219    #[inline(always)]
220    fn evaluate_similarity(&self, x: &[f32], y: &VersionedPQVector<I>) -> Option<f32> {
221        self.evaluate_similarity(x, y.reborrow())
222    }
223}
224
225impl<T, I> DistanceFunction<&[f32], VersionedPQVectorRef<'_, I>, Option<f32>>
226    for MultiDistanceComputer<T, I>
227where
228    T: Deref<Target = FixedChunkPQTable>,
229    I: PQVersion,
230{
231    fn evaluate_similarity(&self, x: &[f32], y: VersionedPQVectorRef<'_, I>) -> Option<f32> {
232        match &self.table {
233            MultiTable::One { table, version } => {
234                if version != &y.version {
235                    None
236                } else {
237                    Some((self.vtable.distance_fn)(table, x, y.data))
238                }
239            }
240            MultiTable::Two {
241                old,
242                new,
243                old_version,
244                new_version,
245            } => {
246                if old_version == &y.version {
247                    Some((self.vtable.distance_fn)(old, x, y.data))
248                } else if new_version == &y.version {
249                    Some((self.vtable.distance_fn)(new, x, y.data))
250                } else {
251                    None
252                }
253            }
254        }
255    }
256}
257
258impl<T, I> DistanceFunction<&VersionedPQVector<I>, &VersionedPQVector<I>, Option<f32>>
259    for MultiDistanceComputer<T, I>
260where
261    T: Deref<Target = FixedChunkPQTable>,
262    I: PQVersion,
263{
264    #[inline(always)]
265    fn evaluate_similarity(
266        &self,
267        x: &VersionedPQVector<I>,
268        y: &VersionedPQVector<I>,
269    ) -> Option<f32> {
270        self.evaluate_similarity(x.reborrow(), y.reborrow())
271    }
272}
273
274/// Compute the distance between two versioned quantized vectors.
275///
276/// If one schema is currently being used and at least one of the versions of the argument
277/// vectors does not match, return `None`.
278///
279/// If two schemas are used and at least one of the versions of the argument vectors is not
280/// recognized, then return `None`.
281impl<T, I> DistanceFunction<VersionedPQVectorRef<'_, I>, VersionedPQVectorRef<'_, I>, Option<f32>>
282    for MultiDistanceComputer<T, I>
283where
284    T: Deref<Target = FixedChunkPQTable>,
285    I: PQVersion,
286{
287    fn evaluate_similarity(
288        &self,
289        x: VersionedPQVectorRef<'_, I>,
290        y: VersionedPQVectorRef<'_, I>,
291    ) -> Option<f32> {
292        match &self.table {
293            MultiTable::One { table, version } => {
294                if (&x.version != version) || (&y.version != version) {
295                    None
296                } else {
297                    Some((self.vtable.distance_fn_qq)(table, x.data, y.data))
298                }
299            }
300            MultiTable::Two {
301                new,
302                old,
303                new_version,
304                old_version,
305            } => {
306                let x_new = &x.version == new_version;
307                let x_old = &x.version == old_version;
308
309                let y_new = &y.version == new_version;
310                let y_old = &y.version == old_version;
311
312                if x_old {
313                    if y_old {
314                        // Both Old
315                        Some((self.vtable.distance_fn_qq)(old, x.data, y.data))
316                    } else if y_new {
317                        let x_full = old.inflate_vector(x.data);
318                        // X Old, Y New
319                        Some((self.vtable.distance_fn)(new, &x_full, y.data))
320                    } else {
321                        None
322                    }
323                } else if x_new {
324                    if y_old {
325                        let y_full = old.inflate_vector(y.data);
326                        // X New, Y Old
327                        Some((self.vtable.distance_fn)(new, &y_full, x.data))
328                    } else if y_new {
329                        // Both new
330                        Some((self.vtable.distance_fn_qq)(new, x.data, y.data))
331                    } else {
332                        None
333                    }
334                } else {
335                    None
336                }
337            }
338        }
339    }
340}
341
342////////////////////
343// Query Computer //
344////////////////////
345
346/// A `PreprocessedDistanceFunction` containing either one or two PQ schemas, capable of
347/// performing distance computations with either. Upon a version mismatch with a query,
348/// `None` is returned.
349#[derive(Debug)]
350pub enum MultiQueryComputer<T, I>
351where
352    T: Deref<Target = FixedChunkPQTable>,
353    I: PQVersion,
354{
355    One {
356        computer: QueryComputer<T>,
357        version: I,
358    },
359    Two {
360        new: QueryComputer<T>,
361        old: QueryComputer<T>,
362        new_version: I,
363        old_version: I,
364    },
365}
366
367impl<T, I> MultiQueryComputer<T, I>
368where
369    T: Deref<Target = FixedChunkPQTable>,
370    I: PQVersion,
371{
372    /// Construct a new `MultiQueryComputer` with the requested metric and query.
373    pub fn new<U>(table: MultiTable<T, I>, metric: Metric, query: &[U]) -> ANNResult<Self>
374    where
375        U: Into<f32> + Copy,
376    {
377        let s = match table {
378            MultiTable::One { table, version } => Self::One {
379                computer: { QueryComputer::new(table, metric, query, None)? },
380                version,
381            },
382            MultiTable::Two {
383                new,
384                old,
385                new_version,
386                old_version,
387            } => Self::Two {
388                new: { QueryComputer::new(new, metric, query, None)? },
389                old: { QueryComputer::new(old, metric, query, None)? },
390                new_version,
391                old_version,
392            },
393        };
394        Ok(s)
395    }
396
397    /// Return a tuple that is either:
398    /// 1. (The only table version, None)
399    /// 2. (New Table Version, Old Table Version)
400    pub fn versions(&self) -> (&I, Option<&I>) {
401        match &self {
402            Self::One { version, .. } => (version, None),
403            Self::Two {
404                new_version,
405                old_version,
406                ..
407            } => (new_version, Some(old_version)),
408        }
409    }
410}
411
412impl<T, I> PreprocessedDistanceFunction<&VersionedPQVector<I>, Option<f32>>
413    for MultiQueryComputer<T, I>
414where
415    T: Deref<Target = FixedChunkPQTable>,
416    I: PQVersion,
417{
418    #[inline(always)]
419    fn evaluate_similarity(&self, x: &VersionedPQVector<I>) -> Option<f32> {
420        self.evaluate_similarity(x.reborrow())
421    }
422}
423
424impl<T, I> PreprocessedDistanceFunction<VersionedPQVectorRef<'_, I>, Option<f32>>
425    for MultiQueryComputer<T, I>
426where
427    T: Deref<Target = FixedChunkPQTable>,
428    I: PQVersion,
429{
430    fn evaluate_similarity(&self, x: VersionedPQVectorRef<'_, I>) -> Option<f32> {
431        match &self {
432            Self::One { computer, version } => {
433                if version != &x.version {
434                    None
435                } else {
436                    Some(computer.evaluate_similarity(x.data))
437                }
438            }
439            Self::Two {
440                new,
441                old,
442                new_version,
443                old_version,
444            } => {
445                if old_version == &x.version {
446                    Some(old.evaluate_similarity(x.data))
447                } else if new_version == &x.version {
448                    Some(new.evaluate_similarity(x.data))
449                } else {
450                    None
451                }
452            }
453        }
454    }
455}
456
457/// # Testing Strategies.
458///
459/// ## Distance Computations
460///
461/// At this point, we assume that the lower level distance functions are more-or-less
462/// accurate. That is, a given `QueryComputer` or VTable based distance work correctly.
463///
464/// The testing functions at this level are more designed for testing that versioned vectors
465/// get sent to the right location and that the error handling is correct.
466#[cfg(test)]
467mod tests {
468    use std::marker::PhantomData;
469
470    use approx::assert_relative_eq;
471    use diskann::utils::{IntoUsize, VectorRepr};
472    use diskann_vector::{Half, PreprocessedDistanceFunction};
473    use rand::{Rng, SeedableRng, distr::Distribution};
474    use rstest::rstest;
475
476    use super::{
477        super::test_utils::{self, TestDistribution},
478        *,
479    };
480
481    fn to_f32<T>(x: &[T]) -> Vec<f32>
482    where
483        T: Into<f32> + Copy,
484    {
485        x.iter().map(|i| (*i).into()).collect()
486    }
487
488    /////////////////////////
489    // Versioned PQ Vector //
490    /////////////////////////
491
492    #[test]
493    fn test_versioned_pq_vector() {
494        let vec = vec![1, 2, 3];
495        let ptr = vec.as_ptr();
496        let pq = VersionedPQVector::<usize>::new(vec, 10);
497        assert_eq!(*pq.version(), 10);
498        assert_eq!(pq.data().len(), 3);
499
500        let data_ptr = pq.data().as_ptr();
501        let pq_ref = pq.as_ref();
502        assert_eq!(pq_ref.version(), pq.version());
503        assert_eq!(data_ptr, ptr);
504        assert_eq!(
505            pq_ref.data().as_ptr(),
506            data_ptr,
507            "expected VersionedPQVectorRef to have the same underlying data as the \
508             original VersionedPQVector"
509        );
510
511        let pq_ref = pq.reborrow();
512        assert_eq!(pq_ref.version(), pq.version());
513        assert_eq!(data_ptr, ptr);
514        assert_eq!(
515            pq_ref.data().as_ptr(),
516            data_ptr,
517            "expected VersionedPQVectorRef to have the same underlying data as the \
518             original VersionedPQVector"
519        );
520    }
521
522    ////////////////
523    // MultiTable //
524    ////////////////
525
526    #[test]
527    fn test_table_error() {
528        let config = test_utils::TableConfig {
529            dim: 17,
530            pq_chunks: 4,
531            num_pivots: 20,
532            start_value: 10.0,
533        };
534
535        let new = test_utils::seed_pivot_table(config);
536        let old = test_utils::seed_pivot_table(config);
537
538        let result = MultiTable::two(&new, &old, 0, 0);
539        assert!(
540            matches!(result, Err(EqualVersionsError)),
541            "MultiTable should now allow construction of the Two variant with equal versions"
542        );
543    }
544
545    ///////////////////////////////////
546    // Distance Computer - One Table //
547    ///////////////////////////////////
548
549    /// Test that the table works correctl where there is one inner PQ table.
550    fn test_distance_computer_multi_with_one<R>(
551        computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
552        table: &FixedChunkPQTable,
553        config: &test_utils::TableConfig,
554        reference: &<f32 as VectorRepr>::Distance,
555        num_trials: usize,
556        rng: &mut R,
557    ) where
558        R: Rng,
559    {
560        // Check that there is just one version.
561        let (&version, should_be_none) = computer.versions();
562        assert!(
563            should_be_none.is_none(),
564            "expected just one schema in test computer"
565        );
566        let invalid_version = version.wrapping_add(1);
567
568        for _ in 0..num_trials {
569            let code0 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
570            let expected0 = test_utils::generate_expected_vector(
571                &code0,
572                table.get_chunk_offsets(),
573                config.start_value,
574            );
575
576            let code1 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
577            let expected1 = test_utils::generate_expected_vector(
578                &code1,
579                table.get_chunk_offsets(),
580                config.start_value,
581            );
582
583            let expected = reference.evaluate_similarity(&expected0, &expected1);
584
585            // Test full-precision/quant.
586            let got = computer
587                .evaluate_similarity(&*expected0, &VersionedPQVector::new(code1.clone(), version))
588                .expect("evaluate_similarity should return Some");
589            assert_eq!(got, expected);
590
591            let got = computer
592                .evaluate_similarity(&*expected1, &VersionedPQVector::new(code0.clone(), version))
593                .expect("evaluate_similarity should return Some");
594            assert_eq!(got, expected);
595
596            // Test quant/quant.
597            let got = computer
598                .evaluate_similarity(
599                    &VersionedPQVector::new(code0.clone(), version),
600                    &VersionedPQVector::new(code1.clone(), version),
601                )
602                .expect("evaluate_similarity should return Some");
603            assert_eq!(got, expected);
604
605            // Check that version mismatches return `None`.
606            let got = computer.evaluate_similarity(
607                &*expected0,
608                &VersionedPQVector::new(code0.clone(), invalid_version),
609            );
610            assert!(got.is_none(), "version mismatches should return `None`");
611
612            let got = computer.evaluate_similarity(
613                &VersionedPQVector::new(code0.clone(), invalid_version),
614                &VersionedPQVector::new(code1.clone(), version),
615            );
616            assert!(got.is_none(), "version mismatches should return `None`");
617
618            let got = computer.evaluate_similarity(
619                &VersionedPQVector::new(code0.clone(), version),
620                &VersionedPQVector::new(code1.clone(), invalid_version),
621            );
622            assert!(got.is_none(), "version mismatches should return `None`");
623        }
624    }
625
626    #[rstest]
627    fn test_multi_distance_computer_one(
628        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
629    ) {
630        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
631
632        let config = test_utils::TableConfig {
633            dim: 17,
634            pq_chunks: 4,
635            num_pivots: 20,
636            start_value: 10.0,
637        };
638
639        let table = test_utils::seed_pivot_table(config);
640
641        let version: usize = 0x625b215f82f38008;
642
643        let multi_table = MultiTable::one(&table, version);
644        let (n, o) = multi_table.versions();
645        assert_eq!(*n, version);
646        assert!(o.is_none());
647
648        let computer = MultiDistanceComputer::new(multi_table, metric);
649
650        test_distance_computer_multi_with_one(
651            &computer,
652            &table,
653            &config,
654            &f32::distance(metric, None),
655            100,
656            &mut rng,
657        );
658    }
659
660    ////////////////////////////////////
661    // Distance Computer - Two Tables //
662    ////////////////////////////////////
663
664    /// Test that the table works correctly when there are two inner PQ tables.
665    #[allow(clippy::too_many_arguments)]
666    fn test_distance_computer_multi_with_two<R>(
667        computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
668        new: &FixedChunkPQTable,
669        old: &FixedChunkPQTable,
670        new_config: &test_utils::TableConfig,
671        old_config: &test_utils::TableConfig,
672        reference: &<f32 as VectorRepr>::Distance,
673        num_trials: usize,
674        rng: &mut R,
675    ) where
676        R: Rng,
677    {
678        // Check that there are indeed two versions registered.
679        let (&new_version, old_version) = computer.versions();
680        let &old_version = old_version.expect("expected two schemas in test computer");
681
682        for _ in 0..num_trials {
683            // Generate a code for the old schema
684            let old_code =
685                test_utils::generate_random_code(old_config.num_pivots, old_config.pq_chunks, rng);
686            let old_expected = test_utils::generate_expected_vector(
687                &old_code,
688                old.get_chunk_offsets(),
689                old_config.start_value,
690            );
691
692            // Generate a code for the new schema
693            let new_code =
694                test_utils::generate_random_code(new_config.num_pivots, new_config.pq_chunks, rng);
695            let new_expected = test_utils::generate_expected_vector(
696                &new_code,
697                new.get_chunk_offsets(),
698                new_config.start_value,
699            );
700
701            // Generate reference results.
702            let oo = reference.evaluate_similarity(&old_expected, &old_expected);
703            let nn = reference.evaluate_similarity(&new_expected, &new_expected);
704            let on = reference.evaluate_similarity(&old_expected, &new_expected);
705
706            // Quant + Quant
707            {
708                let got_oo_qq = computer.evaluate_similarity(
709                    &VersionedPQVector::new(old_code.clone(), old_version),
710                    &VersionedPQVector::new(old_code.clone(), old_version),
711                );
712                assert_eq!(got_oo_qq.unwrap(), oo);
713
714                let got_on_qq = computer.evaluate_similarity(
715                    &VersionedPQVector::new(old_code.clone(), old_version),
716                    &VersionedPQVector::new(new_code.clone(), new_version),
717                );
718                assert_eq!(got_on_qq.unwrap(), on);
719
720                let got_no_qq = computer.evaluate_similarity(
721                    &VersionedPQVector::new(new_code.clone(), new_version),
722                    &VersionedPQVector::new(old_code.clone(), old_version),
723                );
724                assert_eq!(got_no_qq.unwrap(), on);
725
726                let got_nn_qq = computer.evaluate_similarity(
727                    &VersionedPQVector::new(new_code.clone(), new_version),
728                    &VersionedPQVector::new(new_code.clone(), new_version),
729                );
730                assert_eq!(got_nn_qq.unwrap(), nn);
731            }
732
733            // Full Precision + Quant
734            {
735                let got_oo_qq = computer.evaluate_similarity(
736                    &*old_expected,
737                    &VersionedPQVector::new(old_code.clone(), old_version),
738                );
739                assert_eq!(got_oo_qq.unwrap(), oo);
740
741                let got_on_qq = computer.evaluate_similarity(
742                    &*old_expected,
743                    &VersionedPQVector::new(new_code.clone(), new_version),
744                );
745                assert_eq!(got_on_qq.unwrap(), on);
746
747                let got_no_qq = computer.evaluate_similarity(
748                    &*new_expected,
749                    &VersionedPQVector::new(old_code.clone(), old_version),
750                );
751                assert_eq!(got_no_qq.unwrap(), on);
752
753                let got_nn_qq = computer.evaluate_similarity(
754                    &*new_expected,
755                    &VersionedPQVector::new(new_code.clone(), new_version),
756                );
757                assert_eq!(got_nn_qq.unwrap(), nn);
758            }
759
760            // Ensure that version mismatches return `None` for all combinations.
761            let mut bad_version = old_version.wrapping_add(1);
762            if bad_version == new_version {
763                bad_version = bad_version.wrapping_add(1);
764            }
765
766            // mismatch for first argument.
767            let got = computer.evaluate_similarity(
768                VersionedPQVectorRef::new(&old_code, bad_version),
769                VersionedPQVectorRef::new(&new_code, new_version),
770            );
771            assert!(got.is_none());
772
773            // mismatch for second argument.
774            let got = computer.evaluate_similarity(
775                &VersionedPQVector::new(new_code.clone(), new_version),
776                &VersionedPQVector::new(old_code.clone(), bad_version),
777            );
778            assert!(got.is_none());
779
780            // mismatch for full precision.
781            let got = computer.evaluate_similarity(
782                &*new_expected,
783                &VersionedPQVector::new(old_code.clone(), bad_version),
784            );
785            assert!(got.is_none());
786        }
787    }
788
789    #[rstest]
790    fn test_multi_distance_computer_two(
791        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
792    ) {
793        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
794
795        let old_config = test_utils::TableConfig {
796            dim: 17,
797            pq_chunks: 4,
798            num_pivots: 20,
799            start_value: 10.0,
800        };
801
802        let new_config = test_utils::TableConfig {
803            dim: 17,
804            pq_chunks: 5,
805            num_pivots: 16,
806            start_value: 1.0,
807        };
808
809        let new = test_utils::seed_pivot_table(new_config);
810        let old = test_utils::seed_pivot_table(old_config);
811
812        let new_version: usize = 0x5a2b92a731766613;
813        let old_version: usize = 0x2fab58c9c8b73841;
814
815        let multi_table = MultiTable::two(&new, &old, new_version, old_version).unwrap();
816        let (n, o) = multi_table.versions();
817        assert_eq!(*n, new_version);
818        assert_eq!(*o.unwrap(), old_version);
819
820        let computer = MultiDistanceComputer::new(multi_table.clone(), metric);
821        test_distance_computer_multi_with_two(
822            &computer,
823            &new,
824            &old,
825            &new_config,
826            &old_config,
827            &f32::distance(metric, None),
828            100,
829            &mut rng,
830        );
831    }
832
833    ////////////////////////////////
834    // Query Computer - One Table //
835    ////////////////////////////////
836
837    #[allow(clippy::too_many_arguments)]
838    fn check_query_computer<R: Rng>(
839        computer: &MultiQueryComputer<&'_ FixedChunkPQTable, usize>,
840        table: &FixedChunkPQTable,
841        config: &test_utils::TableConfig,
842        query: &[f32],
843        version: usize,
844        rng: &mut R,
845        reference: &<f32 as VectorRepr>::Distance,
846        errors: test_utils::RelativeAndAbsolute,
847    ) {
848        // Generate a code for the old table.
849        let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
850        let expected_vector = test_utils::generate_expected_vector(
851            &code,
852            table.get_chunk_offsets(),
853            config.start_value,
854        );
855        let got = computer
856            .evaluate_similarity(&VersionedPQVector {
857                data: code,
858                version,
859            })
860            .unwrap();
861        let expected = reference.evaluate_similarity(query, &expected_vector);
862        assert_relative_eq!(
863            got,
864            expected,
865            epsilon = errors.absolute,
866            max_relative = errors.relative
867        );
868    }
869
870    fn test_query_computer_multi_with_one<'a, T, R>(
871        mut create: impl FnMut(usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
872        table: &'a FixedChunkPQTable,
873        config: &test_utils::TableConfig,
874        reference: &<f32 as VectorRepr>::Distance,
875        num_trials: usize,
876        rng: &mut R,
877        errors: test_utils::RelativeAndAbsolute,
878    ) where
879        T: Into<f32> + TestDistribution,
880        R: Rng,
881    {
882        let standard = rand::distr::StandardUniform {};
883        for _ in 0..num_trials {
884            let input: Vec<T> = T::generate(config.dim, rng);
885            let input_f32 = to_f32(&input);
886
887            let version: u64 = standard.sample(rng);
888            let version: usize = version.into_usize();
889            let invalid_version = version.wrapping_add(1);
890
891            let computer = create(version, &input);
892
893            assert_eq!(
894                computer.versions(),
895                (&version, None),
896                "expected the computer to only have one version"
897            );
898
899            for _ in 0..num_trials {
900                check_query_computer(
901                    &computer, table, config, &input_f32, version, rng, reference, errors,
902                );
903            }
904
905            // Check the error path on mismatched versions.
906            let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
907            let got =
908                computer.evaluate_similarity(VersionedPQVectorRef::new(&code, invalid_version));
909            assert!(got.is_none(), "Expected `None` for unmatched versions");
910        }
911    }
912
913    #[rstest]
914    fn test_query_computer_one<T>(
915        #[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<u8>, PhantomData::<i8>)]
916        _datatype: PhantomData<T>,
917        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
918    ) where
919        T: Into<f32> + TestDistribution,
920    {
921        let mut rng = rand::rngs::StdRng::seed_from_u64(0x6b53bef1bc26571e);
922
923        let config = test_utils::TableConfig {
924            dim: 17,
925            pq_chunks: 4,
926            num_pivots: 20,
927            start_value: 10.0,
928        };
929
930        let table = test_utils::seed_pivot_table(config);
931        let num_trials = 20;
932
933        let errors = test_utils::RelativeAndAbsolute {
934            relative: 5.0e-5,
935            absolute: 0.0,
936        };
937
938        let create = |version: usize, query: &[T]| {
939            let schema = MultiTable::one(&table, version);
940            MultiQueryComputer::new(schema, metric, query).unwrap()
941        };
942        test_query_computer_multi_with_one(
943            create,
944            &table,
945            &config,
946            &f32::distance(metric, None),
947            num_trials,
948            &mut rng,
949            errors,
950        );
951    }
952
953    /////////////////////////////////
954    // Query Computer - Two Tables //
955    /////////////////////////////////
956
957    #[allow(clippy::too_many_arguments)]
958    fn test_query_computer_multi_with_two<'a, T, R>(
959        create: impl Fn(usize, usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
960        new: &'a FixedChunkPQTable,
961        old: &'a FixedChunkPQTable,
962        new_config: &test_utils::TableConfig,
963        old_config: &test_utils::TableConfig,
964        reference: &<f32 as VectorRepr>::Distance,
965        num_trials: usize,
966        rng: &mut R,
967        errors: test_utils::RelativeAndAbsolute,
968    ) where
969        T: Into<f32> + TestDistribution,
970        R: Rng,
971    {
972        let standard = rand::distr::StandardUniform {};
973        for _ in 0..num_trials {
974            let input: Vec<T> = T::generate(old_config.dim, rng);
975            let input_f32: Vec<f32> = to_f32(&input);
976
977            // Create a computer with two random versions.
978            let old_version: u64 = standard.sample(rng);
979            let mut new_version: u64 = standard.sample(rng);
980            while new_version == old_version {
981                new_version = standard.sample(rng);
982            }
983
984            let mut invalid_version: u64 = standard.sample(rng);
985            while invalid_version == old_version || invalid_version == new_version {
986                invalid_version = standard.sample(rng);
987            }
988
989            let old_version = old_version.into_usize();
990            let new_version = new_version.into_usize();
991            let invalid_version = invalid_version.into_usize();
992
993            let computer = create(new_version, old_version, &input);
994
995            assert_eq!(
996                computer.versions(),
997                (&new_version, Some(&old_version)),
998                "versions were not propagated successfully",
999            );
1000
1001            for _ in 0..num_trials {
1002                check_query_computer(
1003                    &computer,
1004                    old,
1005                    old_config,
1006                    &input_f32,
1007                    old_version,
1008                    rng,
1009                    reference,
1010                    errors,
1011                );
1012
1013                check_query_computer(
1014                    &computer,
1015                    new,
1016                    new_config,
1017                    &input_f32,
1018                    new_version,
1019                    rng,
1020                    reference,
1021                    errors,
1022                );
1023
1024                let code = test_utils::generate_random_code(
1025                    old_config.num_pivots,
1026                    old_config.pq_chunks,
1027                    rng,
1028                );
1029                let got = computer.evaluate_similarity(&VersionedPQVector {
1030                    data: code,
1031                    version: invalid_version,
1032                });
1033                assert!(
1034                    got.is_none(),
1035                    "expected a distance computation with an invalid version to return None"
1036                );
1037            }
1038        }
1039    }
1040
1041    #[rstest]
1042    fn test_query_computer_two<T>(
1043        #[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<u8>, PhantomData::<i8>)]
1044        _datatype: PhantomData<T>,
1045        #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
1046    ) where
1047        T: Into<f32> + TestDistribution,
1048    {
1049        let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
1050
1051        let old_config = test_utils::TableConfig {
1052            dim: 17,
1053            pq_chunks: 4,
1054            num_pivots: 20,
1055            start_value: 10.0,
1056        };
1057
1058        let new_config = test_utils::TableConfig {
1059            dim: 17,
1060            pq_chunks: 5,
1061            num_pivots: 16,
1062            start_value: 1.0,
1063        };
1064
1065        let old = test_utils::seed_pivot_table(old_config);
1066        let new = test_utils::seed_pivot_table(new_config);
1067        let num_trials = 20;
1068
1069        let create = |new_version: usize, old_version: usize, query: &[T]| {
1070            let schema = MultiTable::two(&new, &old, new_version, old_version).unwrap();
1071            MultiQueryComputer::new(schema, metric, query).unwrap()
1072        };
1073
1074        let errors = test_utils::RelativeAndAbsolute {
1075            relative: 5.0e-5,
1076            absolute: 0.0,
1077        };
1078
1079        test_query_computer_multi_with_two(
1080            create,
1081            &new,
1082            &old,
1083            &new_config,
1084            &old_config,
1085            &f32::distance(metric, None),
1086            num_trials,
1087            &mut rng,
1088            errors,
1089        );
1090    }
1091}