1use std::ops::Deref;
7
8use diskann::ANNResult;
9use diskann_utils::Reborrow;
10use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
11use thiserror::Error;
12
13use super::{QueryComputer, dynamic::VTable};
14use crate::model::FixedChunkPQTable;
15
16pub trait PQVersion: Eq + Copy {}
17impl<T> PQVersion for T where T: Eq + Copy {}
18
19#[derive(Debug, Clone, PartialEq)]
21pub struct VersionedPQVector<I: PQVersion> {
22 data: Vec<u8>,
23 version: I,
24}
25
26impl<I> VersionedPQVector<I>
27where
28 I: PQVersion,
29{
30 pub fn new(data: Vec<u8>, version: I) -> Self {
32 Self { data, version }
33 }
34
35 pub fn as_ref(&self) -> VersionedPQVectorRef<'_, I> {
37 VersionedPQVectorRef::new(&self.data, self.version)
38 }
39
40 pub fn version(&self) -> &I {
42 &self.version
43 }
44
45 pub fn data(&self) -> &[u8] {
47 &self.data
48 }
49
50 pub fn raw_mut(&mut self) -> (&mut Vec<u8>, &mut I) {
52 (&mut self.data, &mut self.version)
53 }
54}
55
56impl<'a, I> Reborrow<'a> for VersionedPQVector<I>
57where
58 I: PQVersion,
59{
60 type Target = VersionedPQVectorRef<'a, I>;
61 fn reborrow(&'a self) -> Self::Target {
62 self.as_ref()
63 }
64}
65
66#[derive(Debug, Clone, Copy)]
68pub struct VersionedPQVectorRef<'a, I: PQVersion> {
69 data: &'a [u8],
70 version: I,
71}
72
73impl<'a, I: PQVersion> VersionedPQVectorRef<'a, I> {
74 pub fn new(data: &'a [u8], version: I) -> Self {
76 Self { data, version }
77 }
78
79 pub fn version(&self) -> &I {
81 &self.version
82 }
83
84 pub fn data(&self) -> &[u8] {
86 self.data
87 }
88}
89
90#[derive(Debug, Clone)]
93pub enum MultiTable<T, I>
94where
95 T: Deref<Target = FixedChunkPQTable>,
96 I: PQVersion,
97{
98 One { table: T, version: I },
100 Two {
104 new: T,
105 old: T,
106 new_version: I,
107 old_version: I,
108 },
109}
110
111#[derive(Debug, Error)]
112#[error("provided versions must not be equal")]
113pub struct EqualVersionsError;
114
115impl<T, I> MultiTable<T, I>
116where
117 T: Deref<Target = FixedChunkPQTable>,
118 I: PQVersion,
119{
120 pub fn one(table: T, version: I) -> Self {
122 Self::One { table, version }
123 }
124
125 pub fn two(new: T, old: T, new_version: I, old_version: I) -> Result<Self, EqualVersionsError> {
129 if new_version == old_version {
130 Err(EqualVersionsError)
131 } else {
132 Ok(Self::Two {
133 new,
134 old,
135 new_version,
136 old_version,
137 })
138 }
139 }
140
141 pub fn versions(&self) -> (&I, Option<&I>) {
151 match &self {
152 Self::One { version, .. } => (version, None),
153 Self::Two {
154 new_version,
155 old_version,
156 ..
157 } => (new_version, Some(old_version)),
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
176pub struct MultiDistanceComputer<T, I>
177where
178 T: Deref<Target = FixedChunkPQTable>,
179 I: PQVersion,
180{
181 table: MultiTable<T, I>,
182 vtable: VTable,
183}
184
185impl<T, I> MultiDistanceComputer<T, I>
186where
187 T: Deref<Target = FixedChunkPQTable>,
188 I: PQVersion,
189{
190 pub fn new(table: MultiTable<T, I>, metric: Metric) -> Self {
193 Self {
194 table,
195 vtable: VTable::new(metric),
196 }
197 }
198
199 pub fn versions(&self) -> (&I, Option<&I>) {
209 self.table.versions()
210 }
211}
212
213impl<T, I> DistanceFunction<&[f32], &VersionedPQVector<I>, Option<f32>>
214 for MultiDistanceComputer<T, I>
215where
216 T: Deref<Target = FixedChunkPQTable>,
217 I: PQVersion,
218{
219 #[inline(always)]
220 fn evaluate_similarity(&self, x: &[f32], y: &VersionedPQVector<I>) -> Option<f32> {
221 self.evaluate_similarity(x, y.reborrow())
222 }
223}
224
225impl<T, I> DistanceFunction<&[f32], VersionedPQVectorRef<'_, I>, Option<f32>>
226 for MultiDistanceComputer<T, I>
227where
228 T: Deref<Target = FixedChunkPQTable>,
229 I: PQVersion,
230{
231 fn evaluate_similarity(&self, x: &[f32], y: VersionedPQVectorRef<'_, I>) -> Option<f32> {
232 match &self.table {
233 MultiTable::One { table, version } => {
234 if version != &y.version {
235 None
236 } else {
237 Some((self.vtable.distance_fn)(table, x, y.data))
238 }
239 }
240 MultiTable::Two {
241 old,
242 new,
243 old_version,
244 new_version,
245 } => {
246 if old_version == &y.version {
247 Some((self.vtable.distance_fn)(old, x, y.data))
248 } else if new_version == &y.version {
249 Some((self.vtable.distance_fn)(new, x, y.data))
250 } else {
251 None
252 }
253 }
254 }
255 }
256}
257
258impl<T, I> DistanceFunction<&VersionedPQVector<I>, &VersionedPQVector<I>, Option<f32>>
259 for MultiDistanceComputer<T, I>
260where
261 T: Deref<Target = FixedChunkPQTable>,
262 I: PQVersion,
263{
264 #[inline(always)]
265 fn evaluate_similarity(
266 &self,
267 x: &VersionedPQVector<I>,
268 y: &VersionedPQVector<I>,
269 ) -> Option<f32> {
270 self.evaluate_similarity(x.reborrow(), y.reborrow())
271 }
272}
273
274impl<T, I> DistanceFunction<VersionedPQVectorRef<'_, I>, VersionedPQVectorRef<'_, I>, Option<f32>>
282 for MultiDistanceComputer<T, I>
283where
284 T: Deref<Target = FixedChunkPQTable>,
285 I: PQVersion,
286{
287 fn evaluate_similarity(
288 &self,
289 x: VersionedPQVectorRef<'_, I>,
290 y: VersionedPQVectorRef<'_, I>,
291 ) -> Option<f32> {
292 match &self.table {
293 MultiTable::One { table, version } => {
294 if (&x.version != version) || (&y.version != version) {
295 None
296 } else {
297 Some((self.vtable.distance_fn_qq)(table, x.data, y.data))
298 }
299 }
300 MultiTable::Two {
301 new,
302 old,
303 new_version,
304 old_version,
305 } => {
306 let x_new = &x.version == new_version;
307 let x_old = &x.version == old_version;
308
309 let y_new = &y.version == new_version;
310 let y_old = &y.version == old_version;
311
312 if x_old {
313 if y_old {
314 Some((self.vtable.distance_fn_qq)(old, x.data, y.data))
316 } else if y_new {
317 let x_full = old.inflate_vector(x.data);
318 Some((self.vtable.distance_fn)(new, &x_full, y.data))
320 } else {
321 None
322 }
323 } else if x_new {
324 if y_old {
325 let y_full = old.inflate_vector(y.data);
326 Some((self.vtable.distance_fn)(new, &y_full, x.data))
328 } else if y_new {
329 Some((self.vtable.distance_fn_qq)(new, x.data, y.data))
331 } else {
332 None
333 }
334 } else {
335 None
336 }
337 }
338 }
339 }
340}
341
342#[derive(Debug)]
350pub enum MultiQueryComputer<T, I>
351where
352 T: Deref<Target = FixedChunkPQTable>,
353 I: PQVersion,
354{
355 One {
356 computer: QueryComputer<T>,
357 version: I,
358 },
359 Two {
360 new: QueryComputer<T>,
361 old: QueryComputer<T>,
362 new_version: I,
363 old_version: I,
364 },
365}
366
367impl<T, I> MultiQueryComputer<T, I>
368where
369 T: Deref<Target = FixedChunkPQTable>,
370 I: PQVersion,
371{
372 pub fn new<U>(table: MultiTable<T, I>, metric: Metric, query: &[U]) -> ANNResult<Self>
374 where
375 U: Into<f32> + Copy,
376 {
377 let s = match table {
378 MultiTable::One { table, version } => Self::One {
379 computer: { QueryComputer::new(table, metric, query, None)? },
380 version,
381 },
382 MultiTable::Two {
383 new,
384 old,
385 new_version,
386 old_version,
387 } => Self::Two {
388 new: { QueryComputer::new(new, metric, query, None)? },
389 old: { QueryComputer::new(old, metric, query, None)? },
390 new_version,
391 old_version,
392 },
393 };
394 Ok(s)
395 }
396
397 pub fn versions(&self) -> (&I, Option<&I>) {
401 match &self {
402 Self::One { version, .. } => (version, None),
403 Self::Two {
404 new_version,
405 old_version,
406 ..
407 } => (new_version, Some(old_version)),
408 }
409 }
410}
411
412impl<T, I> PreprocessedDistanceFunction<&VersionedPQVector<I>, Option<f32>>
413 for MultiQueryComputer<T, I>
414where
415 T: Deref<Target = FixedChunkPQTable>,
416 I: PQVersion,
417{
418 #[inline(always)]
419 fn evaluate_similarity(&self, x: &VersionedPQVector<I>) -> Option<f32> {
420 self.evaluate_similarity(x.reborrow())
421 }
422}
423
424impl<T, I> PreprocessedDistanceFunction<VersionedPQVectorRef<'_, I>, Option<f32>>
425 for MultiQueryComputer<T, I>
426where
427 T: Deref<Target = FixedChunkPQTable>,
428 I: PQVersion,
429{
430 fn evaluate_similarity(&self, x: VersionedPQVectorRef<'_, I>) -> Option<f32> {
431 match &self {
432 Self::One { computer, version } => {
433 if version != &x.version {
434 None
435 } else {
436 Some(computer.evaluate_similarity(x.data))
437 }
438 }
439 Self::Two {
440 new,
441 old,
442 new_version,
443 old_version,
444 } => {
445 if old_version == &x.version {
446 Some(old.evaluate_similarity(x.data))
447 } else if new_version == &x.version {
448 Some(new.evaluate_similarity(x.data))
449 } else {
450 None
451 }
452 }
453 }
454 }
455}
456
457#[cfg(test)]
467mod tests {
468 use std::marker::PhantomData;
469
470 use approx::assert_relative_eq;
471 use diskann::utils::{IntoUsize, VectorRepr};
472 use diskann_vector::{Half, PreprocessedDistanceFunction};
473 use rand::{Rng, SeedableRng, distr::Distribution};
474 use rstest::rstest;
475
476 use super::{
477 super::test_utils::{self, TestDistribution},
478 *,
479 };
480
481 fn to_f32<T>(x: &[T]) -> Vec<f32>
482 where
483 T: Into<f32> + Copy,
484 {
485 x.iter().map(|i| (*i).into()).collect()
486 }
487
488 #[test]
493 fn test_versioned_pq_vector() {
494 let vec = vec![1, 2, 3];
495 let ptr = vec.as_ptr();
496 let pq = VersionedPQVector::<usize>::new(vec, 10);
497 assert_eq!(*pq.version(), 10);
498 assert_eq!(pq.data().len(), 3);
499
500 let data_ptr = pq.data().as_ptr();
501 let pq_ref = pq.as_ref();
502 assert_eq!(pq_ref.version(), pq.version());
503 assert_eq!(data_ptr, ptr);
504 assert_eq!(
505 pq_ref.data().as_ptr(),
506 data_ptr,
507 "expected VersionedPQVectorRef to have the same underlying data as the \
508 original VersionedPQVector"
509 );
510
511 let pq_ref = pq.reborrow();
512 assert_eq!(pq_ref.version(), pq.version());
513 assert_eq!(data_ptr, ptr);
514 assert_eq!(
515 pq_ref.data().as_ptr(),
516 data_ptr,
517 "expected VersionedPQVectorRef to have the same underlying data as the \
518 original VersionedPQVector"
519 );
520 }
521
522 #[test]
527 fn test_table_error() {
528 let config = test_utils::TableConfig {
529 dim: 17,
530 pq_chunks: 4,
531 num_pivots: 20,
532 start_value: 10.0,
533 };
534
535 let new = test_utils::seed_pivot_table(config);
536 let old = test_utils::seed_pivot_table(config);
537
538 let result = MultiTable::two(&new, &old, 0, 0);
539 assert!(
540 matches!(result, Err(EqualVersionsError)),
541 "MultiTable should now allow construction of the Two variant with equal versions"
542 );
543 }
544
545 fn test_distance_computer_multi_with_one<R>(
551 computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
552 table: &FixedChunkPQTable,
553 config: &test_utils::TableConfig,
554 reference: &<f32 as VectorRepr>::Distance,
555 num_trials: usize,
556 rng: &mut R,
557 ) where
558 R: Rng,
559 {
560 let (&version, should_be_none) = computer.versions();
562 assert!(
563 should_be_none.is_none(),
564 "expected just one schema in test computer"
565 );
566 let invalid_version = version.wrapping_add(1);
567
568 for _ in 0..num_trials {
569 let code0 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
570 let expected0 = test_utils::generate_expected_vector(
571 &code0,
572 table.get_chunk_offsets(),
573 config.start_value,
574 );
575
576 let code1 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
577 let expected1 = test_utils::generate_expected_vector(
578 &code1,
579 table.get_chunk_offsets(),
580 config.start_value,
581 );
582
583 let expected = reference.evaluate_similarity(&expected0, &expected1);
584
585 let got = computer
587 .evaluate_similarity(&*expected0, &VersionedPQVector::new(code1.clone(), version))
588 .expect("evaluate_similarity should return Some");
589 assert_eq!(got, expected);
590
591 let got = computer
592 .evaluate_similarity(&*expected1, &VersionedPQVector::new(code0.clone(), version))
593 .expect("evaluate_similarity should return Some");
594 assert_eq!(got, expected);
595
596 let got = computer
598 .evaluate_similarity(
599 &VersionedPQVector::new(code0.clone(), version),
600 &VersionedPQVector::new(code1.clone(), version),
601 )
602 .expect("evaluate_similarity should return Some");
603 assert_eq!(got, expected);
604
605 let got = computer.evaluate_similarity(
607 &*expected0,
608 &VersionedPQVector::new(code0.clone(), invalid_version),
609 );
610 assert!(got.is_none(), "version mismatches should return `None`");
611
612 let got = computer.evaluate_similarity(
613 &VersionedPQVector::new(code0.clone(), invalid_version),
614 &VersionedPQVector::new(code1.clone(), version),
615 );
616 assert!(got.is_none(), "version mismatches should return `None`");
617
618 let got = computer.evaluate_similarity(
619 &VersionedPQVector::new(code0.clone(), version),
620 &VersionedPQVector::new(code1.clone(), invalid_version),
621 );
622 assert!(got.is_none(), "version mismatches should return `None`");
623 }
624 }
625
626 #[rstest]
627 fn test_multi_distance_computer_one(
628 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
629 ) {
630 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
631
632 let config = test_utils::TableConfig {
633 dim: 17,
634 pq_chunks: 4,
635 num_pivots: 20,
636 start_value: 10.0,
637 };
638
639 let table = test_utils::seed_pivot_table(config);
640
641 let version: usize = 0x625b215f82f38008;
642
643 let multi_table = MultiTable::one(&table, version);
644 let (n, o) = multi_table.versions();
645 assert_eq!(*n, version);
646 assert!(o.is_none());
647
648 let computer = MultiDistanceComputer::new(multi_table, metric);
649
650 test_distance_computer_multi_with_one(
651 &computer,
652 &table,
653 &config,
654 &f32::distance(metric, None),
655 100,
656 &mut rng,
657 );
658 }
659
660 #[allow(clippy::too_many_arguments)]
666 fn test_distance_computer_multi_with_two<R>(
667 computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
668 new: &FixedChunkPQTable,
669 old: &FixedChunkPQTable,
670 new_config: &test_utils::TableConfig,
671 old_config: &test_utils::TableConfig,
672 reference: &<f32 as VectorRepr>::Distance,
673 num_trials: usize,
674 rng: &mut R,
675 ) where
676 R: Rng,
677 {
678 let (&new_version, old_version) = computer.versions();
680 let &old_version = old_version.expect("expected two schemas in test computer");
681
682 for _ in 0..num_trials {
683 let old_code =
685 test_utils::generate_random_code(old_config.num_pivots, old_config.pq_chunks, rng);
686 let old_expected = test_utils::generate_expected_vector(
687 &old_code,
688 old.get_chunk_offsets(),
689 old_config.start_value,
690 );
691
692 let new_code =
694 test_utils::generate_random_code(new_config.num_pivots, new_config.pq_chunks, rng);
695 let new_expected = test_utils::generate_expected_vector(
696 &new_code,
697 new.get_chunk_offsets(),
698 new_config.start_value,
699 );
700
701 let oo = reference.evaluate_similarity(&old_expected, &old_expected);
703 let nn = reference.evaluate_similarity(&new_expected, &new_expected);
704 let on = reference.evaluate_similarity(&old_expected, &new_expected);
705
706 {
708 let got_oo_qq = computer.evaluate_similarity(
709 &VersionedPQVector::new(old_code.clone(), old_version),
710 &VersionedPQVector::new(old_code.clone(), old_version),
711 );
712 assert_eq!(got_oo_qq.unwrap(), oo);
713
714 let got_on_qq = computer.evaluate_similarity(
715 &VersionedPQVector::new(old_code.clone(), old_version),
716 &VersionedPQVector::new(new_code.clone(), new_version),
717 );
718 assert_eq!(got_on_qq.unwrap(), on);
719
720 let got_no_qq = computer.evaluate_similarity(
721 &VersionedPQVector::new(new_code.clone(), new_version),
722 &VersionedPQVector::new(old_code.clone(), old_version),
723 );
724 assert_eq!(got_no_qq.unwrap(), on);
725
726 let got_nn_qq = computer.evaluate_similarity(
727 &VersionedPQVector::new(new_code.clone(), new_version),
728 &VersionedPQVector::new(new_code.clone(), new_version),
729 );
730 assert_eq!(got_nn_qq.unwrap(), nn);
731 }
732
733 {
735 let got_oo_qq = computer.evaluate_similarity(
736 &*old_expected,
737 &VersionedPQVector::new(old_code.clone(), old_version),
738 );
739 assert_eq!(got_oo_qq.unwrap(), oo);
740
741 let got_on_qq = computer.evaluate_similarity(
742 &*old_expected,
743 &VersionedPQVector::new(new_code.clone(), new_version),
744 );
745 assert_eq!(got_on_qq.unwrap(), on);
746
747 let got_no_qq = computer.evaluate_similarity(
748 &*new_expected,
749 &VersionedPQVector::new(old_code.clone(), old_version),
750 );
751 assert_eq!(got_no_qq.unwrap(), on);
752
753 let got_nn_qq = computer.evaluate_similarity(
754 &*new_expected,
755 &VersionedPQVector::new(new_code.clone(), new_version),
756 );
757 assert_eq!(got_nn_qq.unwrap(), nn);
758 }
759
760 let mut bad_version = old_version.wrapping_add(1);
762 if bad_version == new_version {
763 bad_version = bad_version.wrapping_add(1);
764 }
765
766 let got = computer.evaluate_similarity(
768 VersionedPQVectorRef::new(&old_code, bad_version),
769 VersionedPQVectorRef::new(&new_code, new_version),
770 );
771 assert!(got.is_none());
772
773 let got = computer.evaluate_similarity(
775 &VersionedPQVector::new(new_code.clone(), new_version),
776 &VersionedPQVector::new(old_code.clone(), bad_version),
777 );
778 assert!(got.is_none());
779
780 let got = computer.evaluate_similarity(
782 &*new_expected,
783 &VersionedPQVector::new(old_code.clone(), bad_version),
784 );
785 assert!(got.is_none());
786 }
787 }
788
789 #[rstest]
790 fn test_multi_distance_computer_two(
791 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
792 ) {
793 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
794
795 let old_config = test_utils::TableConfig {
796 dim: 17,
797 pq_chunks: 4,
798 num_pivots: 20,
799 start_value: 10.0,
800 };
801
802 let new_config = test_utils::TableConfig {
803 dim: 17,
804 pq_chunks: 5,
805 num_pivots: 16,
806 start_value: 1.0,
807 };
808
809 let new = test_utils::seed_pivot_table(new_config);
810 let old = test_utils::seed_pivot_table(old_config);
811
812 let new_version: usize = 0x5a2b92a731766613;
813 let old_version: usize = 0x2fab58c9c8b73841;
814
815 let multi_table = MultiTable::two(&new, &old, new_version, old_version).unwrap();
816 let (n, o) = multi_table.versions();
817 assert_eq!(*n, new_version);
818 assert_eq!(*o.unwrap(), old_version);
819
820 let computer = MultiDistanceComputer::new(multi_table.clone(), metric);
821 test_distance_computer_multi_with_two(
822 &computer,
823 &new,
824 &old,
825 &new_config,
826 &old_config,
827 &f32::distance(metric, None),
828 100,
829 &mut rng,
830 );
831 }
832
833 #[allow(clippy::too_many_arguments)]
838 fn check_query_computer<R: Rng>(
839 computer: &MultiQueryComputer<&'_ FixedChunkPQTable, usize>,
840 table: &FixedChunkPQTable,
841 config: &test_utils::TableConfig,
842 query: &[f32],
843 version: usize,
844 rng: &mut R,
845 reference: &<f32 as VectorRepr>::Distance,
846 errors: test_utils::RelativeAndAbsolute,
847 ) {
848 let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
850 let expected_vector = test_utils::generate_expected_vector(
851 &code,
852 table.get_chunk_offsets(),
853 config.start_value,
854 );
855 let got = computer
856 .evaluate_similarity(&VersionedPQVector {
857 data: code,
858 version,
859 })
860 .unwrap();
861 let expected = reference.evaluate_similarity(query, &expected_vector);
862 assert_relative_eq!(
863 got,
864 expected,
865 epsilon = errors.absolute,
866 max_relative = errors.relative
867 );
868 }
869
870 fn test_query_computer_multi_with_one<'a, T, R>(
871 mut create: impl FnMut(usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
872 table: &'a FixedChunkPQTable,
873 config: &test_utils::TableConfig,
874 reference: &<f32 as VectorRepr>::Distance,
875 num_trials: usize,
876 rng: &mut R,
877 errors: test_utils::RelativeAndAbsolute,
878 ) where
879 T: Into<f32> + TestDistribution,
880 R: Rng,
881 {
882 let standard = rand::distr::StandardUniform {};
883 for _ in 0..num_trials {
884 let input: Vec<T> = T::generate(config.dim, rng);
885 let input_f32 = to_f32(&input);
886
887 let version: u64 = standard.sample(rng);
888 let version: usize = version.into_usize();
889 let invalid_version = version.wrapping_add(1);
890
891 let computer = create(version, &input);
892
893 assert_eq!(
894 computer.versions(),
895 (&version, None),
896 "expected the computer to only have one version"
897 );
898
899 for _ in 0..num_trials {
900 check_query_computer(
901 &computer, table, config, &input_f32, version, rng, reference, errors,
902 );
903 }
904
905 let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
907 let got =
908 computer.evaluate_similarity(VersionedPQVectorRef::new(&code, invalid_version));
909 assert!(got.is_none(), "Expected `None` for unmatched versions");
910 }
911 }
912
913 #[rstest]
914 fn test_query_computer_one<T>(
915 #[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<u8>, PhantomData::<i8>)]
916 _datatype: PhantomData<T>,
917 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
918 ) where
919 T: Into<f32> + TestDistribution,
920 {
921 let mut rng = rand::rngs::StdRng::seed_from_u64(0x6b53bef1bc26571e);
922
923 let config = test_utils::TableConfig {
924 dim: 17,
925 pq_chunks: 4,
926 num_pivots: 20,
927 start_value: 10.0,
928 };
929
930 let table = test_utils::seed_pivot_table(config);
931 let num_trials = 20;
932
933 let errors = test_utils::RelativeAndAbsolute {
934 relative: 5.0e-5,
935 absolute: 0.0,
936 };
937
938 let create = |version: usize, query: &[T]| {
939 let schema = MultiTable::one(&table, version);
940 MultiQueryComputer::new(schema, metric, query).unwrap()
941 };
942 test_query_computer_multi_with_one(
943 create,
944 &table,
945 &config,
946 &f32::distance(metric, None),
947 num_trials,
948 &mut rng,
949 errors,
950 );
951 }
952
953 #[allow(clippy::too_many_arguments)]
958 fn test_query_computer_multi_with_two<'a, T, R>(
959 create: impl Fn(usize, usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
960 new: &'a FixedChunkPQTable,
961 old: &'a FixedChunkPQTable,
962 new_config: &test_utils::TableConfig,
963 old_config: &test_utils::TableConfig,
964 reference: &<f32 as VectorRepr>::Distance,
965 num_trials: usize,
966 rng: &mut R,
967 errors: test_utils::RelativeAndAbsolute,
968 ) where
969 T: Into<f32> + TestDistribution,
970 R: Rng,
971 {
972 let standard = rand::distr::StandardUniform {};
973 for _ in 0..num_trials {
974 let input: Vec<T> = T::generate(old_config.dim, rng);
975 let input_f32: Vec<f32> = to_f32(&input);
976
977 let old_version: u64 = standard.sample(rng);
979 let mut new_version: u64 = standard.sample(rng);
980 while new_version == old_version {
981 new_version = standard.sample(rng);
982 }
983
984 let mut invalid_version: u64 = standard.sample(rng);
985 while invalid_version == old_version || invalid_version == new_version {
986 invalid_version = standard.sample(rng);
987 }
988
989 let old_version = old_version.into_usize();
990 let new_version = new_version.into_usize();
991 let invalid_version = invalid_version.into_usize();
992
993 let computer = create(new_version, old_version, &input);
994
995 assert_eq!(
996 computer.versions(),
997 (&new_version, Some(&old_version)),
998 "versions were not propagated successfully",
999 );
1000
1001 for _ in 0..num_trials {
1002 check_query_computer(
1003 &computer,
1004 old,
1005 old_config,
1006 &input_f32,
1007 old_version,
1008 rng,
1009 reference,
1010 errors,
1011 );
1012
1013 check_query_computer(
1014 &computer,
1015 new,
1016 new_config,
1017 &input_f32,
1018 new_version,
1019 rng,
1020 reference,
1021 errors,
1022 );
1023
1024 let code = test_utils::generate_random_code(
1025 old_config.num_pivots,
1026 old_config.pq_chunks,
1027 rng,
1028 );
1029 let got = computer.evaluate_similarity(&VersionedPQVector {
1030 data: code,
1031 version: invalid_version,
1032 });
1033 assert!(
1034 got.is_none(),
1035 "expected a distance computation with an invalid version to return None"
1036 );
1037 }
1038 }
1039 }
1040
1041 #[rstest]
1042 fn test_query_computer_two<T>(
1043 #[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<u8>, PhantomData::<i8>)]
1044 _datatype: PhantomData<T>,
1045 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
1046 ) where
1047 T: Into<f32> + TestDistribution,
1048 {
1049 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
1050
1051 let old_config = test_utils::TableConfig {
1052 dim: 17,
1053 pq_chunks: 4,
1054 num_pivots: 20,
1055 start_value: 10.0,
1056 };
1057
1058 let new_config = test_utils::TableConfig {
1059 dim: 17,
1060 pq_chunks: 5,
1061 num_pivots: 16,
1062 start_value: 1.0,
1063 };
1064
1065 let old = test_utils::seed_pivot_table(old_config);
1066 let new = test_utils::seed_pivot_table(new_config);
1067 let num_trials = 20;
1068
1069 let create = |new_version: usize, old_version: usize, query: &[T]| {
1070 let schema = MultiTable::two(&new, &old, new_version, old_version).unwrap();
1071 MultiQueryComputer::new(schema, metric, query).unwrap()
1072 };
1073
1074 let errors = test_utils::RelativeAndAbsolute {
1075 relative: 5.0e-5,
1076 absolute: 0.0,
1077 };
1078
1079 test_query_computer_multi_with_two(
1080 create,
1081 &new,
1082 &old,
1083 &new_config,
1084 &old_config,
1085 &f32::distance(metric, None),
1086 num_trials,
1087 &mut rng,
1088 errors,
1089 );
1090 }
1091}