1use std::ops::Deref;
7
8use diskann::ANNResult;
9use diskann_utils::Reborrow;
10use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
11use thiserror::Error;
12
13use super::{QueryComputer, dynamic::VTable};
14use crate::model::FixedChunkPQTable;
15
16pub trait PQVersion: Eq + Copy {}
17impl<T> PQVersion for T where T: Eq + Copy {}
18
19#[derive(Debug, Clone, PartialEq)]
21pub struct VersionedPQVector<I: PQVersion> {
22 data: Vec<u8>,
23 version: I,
24}
25
26impl<I> VersionedPQVector<I>
27where
28 I: PQVersion,
29{
30 pub fn new(data: Vec<u8>, version: I) -> Self {
32 Self { data, version }
33 }
34
35 pub fn as_ref(&self) -> VersionedPQVectorRef<'_, I> {
37 VersionedPQVectorRef::new(&self.data, self.version)
38 }
39
40 pub fn version(&self) -> &I {
42 &self.version
43 }
44
45 pub fn data(&self) -> &[u8] {
47 &self.data
48 }
49
50 pub fn raw_mut(&mut self) -> (&mut Vec<u8>, &mut I) {
52 (&mut self.data, &mut self.version)
53 }
54}
55
56impl<'a, I> Reborrow<'a> for VersionedPQVector<I>
57where
58 I: PQVersion,
59{
60 type Target = VersionedPQVectorRef<'a, I>;
61 fn reborrow(&'a self) -> Self::Target {
62 self.as_ref()
63 }
64}
65
66#[derive(Debug, Clone, Copy)]
68pub struct VersionedPQVectorRef<'a, I: PQVersion> {
69 data: &'a [u8],
70 version: I,
71}
72
73impl<'a, I: PQVersion> VersionedPQVectorRef<'a, I> {
74 pub fn new(data: &'a [u8], version: I) -> Self {
76 Self { data, version }
77 }
78
79 pub fn version(&self) -> &I {
81 &self.version
82 }
83
84 pub fn data(&self) -> &[u8] {
86 self.data
87 }
88}
89
90#[derive(Debug, Clone)]
93pub enum MultiTable<T, I>
94where
95 T: Deref<Target = FixedChunkPQTable>,
96 I: PQVersion,
97{
98 One { table: T, version: I },
100 Two {
104 new: T,
105 old: T,
106 new_version: I,
107 old_version: I,
108 },
109}
110
111#[derive(Debug, Error)]
112#[error("provided versions must not be equal")]
113pub struct EqualVersionsError;
114
115impl<T, I> MultiTable<T, I>
116where
117 T: Deref<Target = FixedChunkPQTable>,
118 I: PQVersion,
119{
120 pub fn one(table: T, version: I) -> Self {
122 Self::One { table, version }
123 }
124
125 pub fn two(new: T, old: T, new_version: I, old_version: I) -> Result<Self, EqualVersionsError> {
129 if new_version == old_version {
130 Err(EqualVersionsError)
131 } else {
132 Ok(Self::Two {
133 new,
134 old,
135 new_version,
136 old_version,
137 })
138 }
139 }
140
141 pub fn versions(&self) -> (&I, Option<&I>) {
151 match &self {
152 Self::One { version, .. } => (version, None),
153 Self::Two {
154 new_version,
155 old_version,
156 ..
157 } => (new_version, Some(old_version)),
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
176pub struct MultiDistanceComputer<T, I>
177where
178 T: Deref<Target = FixedChunkPQTable>,
179 I: PQVersion,
180{
181 table: MultiTable<T, I>,
182 vtable: VTable,
183}
184
185impl<T, I> MultiDistanceComputer<T, I>
186where
187 T: Deref<Target = FixedChunkPQTable>,
188 I: PQVersion,
189{
190 pub fn new(table: MultiTable<T, I>, metric: Metric) -> Self {
193 Self {
194 table,
195 vtable: VTable::new(metric),
196 }
197 }
198
199 pub fn versions(&self) -> (&I, Option<&I>) {
209 self.table.versions()
210 }
211}
212
213impl<T, I> DistanceFunction<&[f32], &VersionedPQVector<I>, Option<f32>>
214 for MultiDistanceComputer<T, I>
215where
216 T: Deref<Target = FixedChunkPQTable>,
217 I: PQVersion,
218{
219 #[inline(always)]
220 fn evaluate_similarity(&self, x: &[f32], y: &VersionedPQVector<I>) -> Option<f32> {
221 self.evaluate_similarity(x, y.reborrow())
222 }
223}
224
225impl<T, I> DistanceFunction<&[f32], VersionedPQVectorRef<'_, I>, Option<f32>>
226 for MultiDistanceComputer<T, I>
227where
228 T: Deref<Target = FixedChunkPQTable>,
229 I: PQVersion,
230{
231 fn evaluate_similarity(&self, x: &[f32], y: VersionedPQVectorRef<'_, I>) -> Option<f32> {
232 match &self.table {
233 MultiTable::One { table, version } => {
234 if version != &y.version {
235 None
236 } else {
237 Some((self.vtable.distance_fn)(table, x, y.data))
238 }
239 }
240 MultiTable::Two {
241 old,
242 new,
243 old_version,
244 new_version,
245 } => {
246 if old_version == &y.version {
247 Some((self.vtable.distance_fn)(old, x, y.data))
248 } else if new_version == &y.version {
249 Some((self.vtable.distance_fn)(new, x, y.data))
250 } else {
251 None
252 }
253 }
254 }
255 }
256}
257
258impl<T, I> DistanceFunction<&VersionedPQVector<I>, &VersionedPQVector<I>, Option<f32>>
259 for MultiDistanceComputer<T, I>
260where
261 T: Deref<Target = FixedChunkPQTable>,
262 I: PQVersion,
263{
264 #[inline(always)]
265 fn evaluate_similarity(
266 &self,
267 x: &VersionedPQVector<I>,
268 y: &VersionedPQVector<I>,
269 ) -> Option<f32> {
270 self.evaluate_similarity(x.reborrow(), y.reborrow())
271 }
272}
273
274impl<T, I> DistanceFunction<VersionedPQVectorRef<'_, I>, VersionedPQVectorRef<'_, I>, Option<f32>>
282 for MultiDistanceComputer<T, I>
283where
284 T: Deref<Target = FixedChunkPQTable>,
285 I: PQVersion,
286{
287 fn evaluate_similarity(
288 &self,
289 x: VersionedPQVectorRef<'_, I>,
290 y: VersionedPQVectorRef<'_, I>,
291 ) -> Option<f32> {
292 match &self.table {
293 MultiTable::One { table, version } => {
294 if (&x.version != version) || (&y.version != version) {
295 None
296 } else {
297 Some((self.vtable.distance_fn_qq)(table, x.data, y.data))
298 }
299 }
300 MultiTable::Two {
301 new,
302 old,
303 new_version,
304 old_version,
305 } => {
306 let x_new = &x.version == new_version;
307 let x_old = &x.version == old_version;
308
309 let y_new = &y.version == new_version;
310 let y_old = &y.version == old_version;
311
312 if x_old {
313 if y_old {
314 Some((self.vtable.distance_fn_qq)(old, x.data, y.data))
316 } else if y_new {
317 let x_full = old.inflate_vector(x.data);
318 Some((self.vtable.distance_fn)(new, &x_full, y.data))
320 } else {
321 None
322 }
323 } else if x_new {
324 if y_old {
325 let y_full = old.inflate_vector(y.data);
326 Some((self.vtable.distance_fn)(new, &y_full, x.data))
328 } else if y_new {
329 Some((self.vtable.distance_fn_qq)(new, x.data, y.data))
331 } else {
332 None
333 }
334 } else {
335 None
336 }
337 }
338 }
339 }
340}
341
342#[derive(Debug)]
350pub enum MultiQueryComputer<T, I>
351where
352 T: Deref<Target = FixedChunkPQTable>,
353 I: PQVersion,
354{
355 One {
356 computer: QueryComputer<T>,
357 version: I,
358 },
359 Two {
360 new: QueryComputer<T>,
361 old: QueryComputer<T>,
362 new_version: I,
363 old_version: I,
364 },
365}
366
367impl<T, I> MultiQueryComputer<T, I>
368where
369 T: Deref<Target = FixedChunkPQTable>,
370 I: PQVersion,
371{
372 pub fn new(table: MultiTable<T, I>, metric: Metric, query: &[f32]) -> ANNResult<Self> {
374 let s = match table {
375 MultiTable::One { table, version } => Self::One {
376 computer: { QueryComputer::new(table, metric, query, None)? },
377 version,
378 },
379 MultiTable::Two {
380 new,
381 old,
382 new_version,
383 old_version,
384 } => Self::Two {
385 new: { QueryComputer::new(new, metric, query, None)? },
386 old: { QueryComputer::new(old, metric, query, None)? },
387 new_version,
388 old_version,
389 },
390 };
391 Ok(s)
392 }
393
394 pub fn versions(&self) -> (&I, Option<&I>) {
398 match &self {
399 Self::One { version, .. } => (version, None),
400 Self::Two {
401 new_version,
402 old_version,
403 ..
404 } => (new_version, Some(old_version)),
405 }
406 }
407}
408
409impl<T, I> PreprocessedDistanceFunction<&VersionedPQVector<I>, Option<f32>>
410 for MultiQueryComputer<T, I>
411where
412 T: Deref<Target = FixedChunkPQTable>,
413 I: PQVersion,
414{
415 #[inline(always)]
416 fn evaluate_similarity(&self, x: &VersionedPQVector<I>) -> Option<f32> {
417 self.evaluate_similarity(x.reborrow())
418 }
419}
420
421impl<T, I> PreprocessedDistanceFunction<VersionedPQVectorRef<'_, I>, Option<f32>>
422 for MultiQueryComputer<T, I>
423where
424 T: Deref<Target = FixedChunkPQTable>,
425 I: PQVersion,
426{
427 fn evaluate_similarity(&self, x: VersionedPQVectorRef<'_, I>) -> Option<f32> {
428 match &self {
429 Self::One { computer, version } => {
430 if version != &x.version {
431 None
432 } else {
433 Some(computer.evaluate_similarity(x.data))
434 }
435 }
436 Self::Two {
437 new,
438 old,
439 new_version,
440 old_version,
441 } => {
442 if old_version == &x.version {
443 Some(old.evaluate_similarity(x.data))
444 } else if new_version == &x.version {
445 Some(new.evaluate_similarity(x.data))
446 } else {
447 None
448 }
449 }
450 }
451 }
452}
453
454#[cfg(test)]
464mod tests {
465 use approx::assert_relative_eq;
466 use diskann::utils::{IntoUsize, VectorRepr};
467 use diskann_vector::PreprocessedDistanceFunction;
468 use rand::{Rng, SeedableRng, distr::Distribution};
469 use rstest::rstest;
470
471 use super::{
472 super::test_utils::{self, TestDistribution},
473 *,
474 };
475
476 #[test]
481 fn test_versioned_pq_vector() {
482 let vec = vec![1, 2, 3];
483 let ptr = vec.as_ptr();
484 let pq = VersionedPQVector::<usize>::new(vec, 10);
485 assert_eq!(*pq.version(), 10);
486 assert_eq!(pq.data().len(), 3);
487
488 let data_ptr = pq.data().as_ptr();
489 let pq_ref = pq.as_ref();
490 assert_eq!(pq_ref.version(), pq.version());
491 assert_eq!(data_ptr, ptr);
492 assert_eq!(
493 pq_ref.data().as_ptr(),
494 data_ptr,
495 "expected VersionedPQVectorRef to have the same underlying data as the \
496 original VersionedPQVector"
497 );
498
499 let pq_ref = pq.reborrow();
500 assert_eq!(pq_ref.version(), pq.version());
501 assert_eq!(data_ptr, ptr);
502 assert_eq!(
503 pq_ref.data().as_ptr(),
504 data_ptr,
505 "expected VersionedPQVectorRef to have the same underlying data as the \
506 original VersionedPQVector"
507 );
508 }
509
510 #[test]
515 fn test_table_error() {
516 let config = test_utils::TableConfig {
517 dim: 17,
518 pq_chunks: 4,
519 num_pivots: 20,
520 start_value: 10.0,
521 };
522
523 let new = test_utils::seed_pivot_table(config);
524 let old = test_utils::seed_pivot_table(config);
525
526 let result = MultiTable::two(&new, &old, 0, 0);
527 assert!(
528 matches!(result, Err(EqualVersionsError)),
529 "MultiTable should now allow construction of the Two variant with equal versions"
530 );
531 }
532
533 fn test_distance_computer_multi_with_one<R>(
539 computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
540 table: &FixedChunkPQTable,
541 config: &test_utils::TableConfig,
542 reference: &<f32 as VectorRepr>::Distance,
543 num_trials: usize,
544 rng: &mut R,
545 ) where
546 R: Rng,
547 {
548 let (&version, should_be_none) = computer.versions();
550 assert!(
551 should_be_none.is_none(),
552 "expected just one schema in test computer"
553 );
554 let invalid_version = version.wrapping_add(1);
555
556 for _ in 0..num_trials {
557 let code0 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
558 let expected0 = test_utils::generate_expected_vector(
559 &code0,
560 table.get_chunk_offsets(),
561 config.start_value,
562 );
563
564 let code1 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
565 let expected1 = test_utils::generate_expected_vector(
566 &code1,
567 table.get_chunk_offsets(),
568 config.start_value,
569 );
570
571 let expected = reference.evaluate_similarity(&*expected0, &*expected1);
572
573 let got = computer
575 .evaluate_similarity(&*expected0, &VersionedPQVector::new(code1.clone(), version))
576 .expect("evaluate_similarity should return Some");
577 assert_eq!(got, expected);
578
579 let got = computer
580 .evaluate_similarity(&*expected1, &VersionedPQVector::new(code0.clone(), version))
581 .expect("evaluate_similarity should return Some");
582 assert_eq!(got, expected);
583
584 let got = computer
586 .evaluate_similarity(
587 &VersionedPQVector::new(code0.clone(), version),
588 &VersionedPQVector::new(code1.clone(), version),
589 )
590 .expect("evaluate_similarity should return Some");
591 assert_eq!(got, expected);
592
593 let got = computer.evaluate_similarity(
595 &*expected0,
596 &VersionedPQVector::new(code0.clone(), invalid_version),
597 );
598 assert!(got.is_none(), "version mismatches should return `None`");
599
600 let got = computer.evaluate_similarity(
601 &VersionedPQVector::new(code0.clone(), invalid_version),
602 &VersionedPQVector::new(code1.clone(), version),
603 );
604 assert!(got.is_none(), "version mismatches should return `None`");
605
606 let got = computer.evaluate_similarity(
607 &VersionedPQVector::new(code0.clone(), version),
608 &VersionedPQVector::new(code1.clone(), invalid_version),
609 );
610 assert!(got.is_none(), "version mismatches should return `None`");
611 }
612 }
613
614 #[rstest]
615 fn test_multi_distance_computer_one(
616 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
617 ) {
618 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
619
620 let config = test_utils::TableConfig {
621 dim: 17,
622 pq_chunks: 4,
623 num_pivots: 20,
624 start_value: 10.0,
625 };
626
627 let table = test_utils::seed_pivot_table(config);
628
629 let version: usize = 0x625b215f82f38008;
630
631 let multi_table = MultiTable::one(&table, version);
632 let (n, o) = multi_table.versions();
633 assert_eq!(*n, version);
634 assert!(o.is_none());
635
636 let computer = MultiDistanceComputer::new(multi_table, metric);
637
638 test_distance_computer_multi_with_one(
639 &computer,
640 &table,
641 &config,
642 &f32::distance(metric, None),
643 100,
644 &mut rng,
645 );
646 }
647
648 #[allow(clippy::too_many_arguments)]
654 fn test_distance_computer_multi_with_two<R>(
655 computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
656 new: &FixedChunkPQTable,
657 old: &FixedChunkPQTable,
658 new_config: &test_utils::TableConfig,
659 old_config: &test_utils::TableConfig,
660 reference: &<f32 as VectorRepr>::Distance,
661 num_trials: usize,
662 rng: &mut R,
663 ) where
664 R: Rng,
665 {
666 let (&new_version, old_version) = computer.versions();
668 let &old_version = old_version.expect("expected two schemas in test computer");
669
670 for _ in 0..num_trials {
671 let old_code =
673 test_utils::generate_random_code(old_config.num_pivots, old_config.pq_chunks, rng);
674 let old_expected = test_utils::generate_expected_vector(
675 &old_code,
676 old.get_chunk_offsets(),
677 old_config.start_value,
678 );
679
680 let new_code =
682 test_utils::generate_random_code(new_config.num_pivots, new_config.pq_chunks, rng);
683 let new_expected = test_utils::generate_expected_vector(
684 &new_code,
685 new.get_chunk_offsets(),
686 new_config.start_value,
687 );
688
689 let oo = reference.evaluate_similarity(&*old_expected, &*old_expected);
691 let nn = reference.evaluate_similarity(&*new_expected, &*new_expected);
692 let on = reference.evaluate_similarity(&*old_expected, &*new_expected);
693
694 {
696 let got_oo_qq = computer.evaluate_similarity(
697 &VersionedPQVector::new(old_code.clone(), old_version),
698 &VersionedPQVector::new(old_code.clone(), old_version),
699 );
700 assert_eq!(got_oo_qq.unwrap(), oo);
701
702 let got_on_qq = computer.evaluate_similarity(
703 &VersionedPQVector::new(old_code.clone(), old_version),
704 &VersionedPQVector::new(new_code.clone(), new_version),
705 );
706 assert_eq!(got_on_qq.unwrap(), on);
707
708 let got_no_qq = computer.evaluate_similarity(
709 &VersionedPQVector::new(new_code.clone(), new_version),
710 &VersionedPQVector::new(old_code.clone(), old_version),
711 );
712 assert_eq!(got_no_qq.unwrap(), on);
713
714 let got_nn_qq = computer.evaluate_similarity(
715 &VersionedPQVector::new(new_code.clone(), new_version),
716 &VersionedPQVector::new(new_code.clone(), new_version),
717 );
718 assert_eq!(got_nn_qq.unwrap(), nn);
719 }
720
721 {
723 let got_oo_qq = computer.evaluate_similarity(
724 &*old_expected,
725 &VersionedPQVector::new(old_code.clone(), old_version),
726 );
727 assert_eq!(got_oo_qq.unwrap(), oo);
728
729 let got_on_qq = computer.evaluate_similarity(
730 &*old_expected,
731 &VersionedPQVector::new(new_code.clone(), new_version),
732 );
733 assert_eq!(got_on_qq.unwrap(), on);
734
735 let got_no_qq = computer.evaluate_similarity(
736 &*new_expected,
737 &VersionedPQVector::new(old_code.clone(), old_version),
738 );
739 assert_eq!(got_no_qq.unwrap(), on);
740
741 let got_nn_qq = computer.evaluate_similarity(
742 &*new_expected,
743 &VersionedPQVector::new(new_code.clone(), new_version),
744 );
745 assert_eq!(got_nn_qq.unwrap(), nn);
746 }
747
748 let mut bad_version = old_version.wrapping_add(1);
750 if bad_version == new_version {
751 bad_version = bad_version.wrapping_add(1);
752 }
753
754 let got = computer.evaluate_similarity(
756 VersionedPQVectorRef::new(&old_code, bad_version),
757 VersionedPQVectorRef::new(&new_code, new_version),
758 );
759 assert!(got.is_none());
760
761 let got = computer.evaluate_similarity(
763 &VersionedPQVector::new(new_code.clone(), new_version),
764 &VersionedPQVector::new(old_code.clone(), bad_version),
765 );
766 assert!(got.is_none());
767
768 let got = computer.evaluate_similarity(
770 &*new_expected,
771 &VersionedPQVector::new(old_code.clone(), bad_version),
772 );
773 assert!(got.is_none());
774 }
775 }
776
777 #[rstest]
778 fn test_multi_distance_computer_two(
779 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
780 ) {
781 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
782
783 let old_config = test_utils::TableConfig {
784 dim: 17,
785 pq_chunks: 4,
786 num_pivots: 20,
787 start_value: 10.0,
788 };
789
790 let new_config = test_utils::TableConfig {
791 dim: 17,
792 pq_chunks: 5,
793 num_pivots: 16,
794 start_value: 1.0,
795 };
796
797 let new = test_utils::seed_pivot_table(new_config);
798 let old = test_utils::seed_pivot_table(old_config);
799
800 let new_version: usize = 0x5a2b92a731766613;
801 let old_version: usize = 0x2fab58c9c8b73841;
802
803 let multi_table = MultiTable::two(&new, &old, new_version, old_version).unwrap();
804 let (n, o) = multi_table.versions();
805 assert_eq!(*n, new_version);
806 assert_eq!(*o.unwrap(), old_version);
807
808 let computer = MultiDistanceComputer::new(multi_table.clone(), metric);
809 test_distance_computer_multi_with_two(
810 &computer,
811 &new,
812 &old,
813 &new_config,
814 &old_config,
815 &f32::distance(metric, None),
816 100,
817 &mut rng,
818 );
819 }
820
821 #[allow(clippy::too_many_arguments)]
826 fn check_query_computer<R: Rng>(
827 computer: &MultiQueryComputer<&'_ FixedChunkPQTable, usize>,
828 table: &FixedChunkPQTable,
829 config: &test_utils::TableConfig,
830 query: &[f32],
831 version: usize,
832 rng: &mut R,
833 reference: &<f32 as VectorRepr>::Distance,
834 errors: test_utils::RelativeAndAbsolute,
835 ) {
836 let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
838 let expected_vector = test_utils::generate_expected_vector(
839 &code,
840 table.get_chunk_offsets(),
841 config.start_value,
842 );
843 let got = computer
844 .evaluate_similarity(&VersionedPQVector {
845 data: code,
846 version,
847 })
848 .unwrap();
849 let expected = reference.evaluate_similarity(query, &expected_vector);
850 assert_relative_eq!(
851 got,
852 expected,
853 epsilon = errors.absolute,
854 max_relative = errors.relative
855 );
856 }
857
858 fn test_query_computer_multi_with_one<'a, R>(
859 mut create: impl FnMut(usize, &[f32]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
860 table: &'a FixedChunkPQTable,
861 config: &test_utils::TableConfig,
862 reference: &<f32 as VectorRepr>::Distance,
863 num_trials: usize,
864 rng: &mut R,
865 errors: test_utils::RelativeAndAbsolute,
866 ) where
867 R: Rng,
868 {
869 let standard = rand::distr::StandardUniform {};
870 for _ in 0..num_trials {
871 let input_f32: Vec<f32> = f32::generate(config.dim, rng);
872
873 let version: u64 = standard.sample(rng);
874 let version: usize = version.into_usize();
875 let invalid_version = version.wrapping_add(1);
876
877 let computer = create(version, &input_f32);
878
879 assert_eq!(
880 computer.versions(),
881 (&version, None),
882 "expected the computer to only have one version"
883 );
884
885 for _ in 0..num_trials {
886 check_query_computer(
887 &computer, table, config, &input_f32, version, rng, reference, errors,
888 );
889 }
890
891 let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
893 let got =
894 computer.evaluate_similarity(VersionedPQVectorRef::new(&code, invalid_version));
895 assert!(got.is_none(), "Expected `None` for unmatched versions");
896 }
897 }
898
899 #[rstest]
900 fn test_query_computer_one(
901 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
902 ) {
903 let mut rng = rand::rngs::StdRng::seed_from_u64(0x6b53bef1bc26571e);
904
905 let config = test_utils::TableConfig {
906 dim: 17,
907 pq_chunks: 4,
908 num_pivots: 20,
909 start_value: 10.0,
910 };
911
912 let table = test_utils::seed_pivot_table(config);
913 let num_trials = 20;
914
915 let errors = test_utils::RelativeAndAbsolute {
916 relative: 5.0e-5,
917 absolute: 0.0,
918 };
919
920 let create = |version: usize, query: &[f32]| {
921 let schema = MultiTable::one(&table, version);
922 MultiQueryComputer::new(schema, metric, query).unwrap()
923 };
924 test_query_computer_multi_with_one(
925 create,
926 &table,
927 &config,
928 &f32::distance(metric, None),
929 num_trials,
930 &mut rng,
931 errors,
932 );
933 }
934
935 #[allow(clippy::too_many_arguments)]
940 fn test_query_computer_multi_with_two<'a, R>(
941 create: impl Fn(usize, usize, &[f32]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
942 new: &'a FixedChunkPQTable,
943 old: &'a FixedChunkPQTable,
944 new_config: &test_utils::TableConfig,
945 old_config: &test_utils::TableConfig,
946 reference: &<f32 as VectorRepr>::Distance,
947 num_trials: usize,
948 rng: &mut R,
949 errors: test_utils::RelativeAndAbsolute,
950 ) where
951 R: Rng,
952 {
953 let standard = rand::distr::StandardUniform {};
954 for _ in 0..num_trials {
955 let input_f32: Vec<f32> = f32::generate(old_config.dim, rng);
956
957 let old_version: u64 = standard.sample(rng);
959 let mut new_version: u64 = standard.sample(rng);
960 while new_version == old_version {
961 new_version = standard.sample(rng);
962 }
963
964 let mut invalid_version: u64 = standard.sample(rng);
965 while invalid_version == old_version || invalid_version == new_version {
966 invalid_version = standard.sample(rng);
967 }
968
969 let old_version = old_version.into_usize();
970 let new_version = new_version.into_usize();
971 let invalid_version = invalid_version.into_usize();
972
973 let computer = create(new_version, old_version, &input_f32);
974
975 assert_eq!(
976 computer.versions(),
977 (&new_version, Some(&old_version)),
978 "versions were not propagated successfully",
979 );
980
981 for _ in 0..num_trials {
982 check_query_computer(
983 &computer,
984 old,
985 old_config,
986 &input_f32,
987 old_version,
988 rng,
989 reference,
990 errors,
991 );
992
993 check_query_computer(
994 &computer,
995 new,
996 new_config,
997 &input_f32,
998 new_version,
999 rng,
1000 reference,
1001 errors,
1002 );
1003
1004 let code = test_utils::generate_random_code(
1005 old_config.num_pivots,
1006 old_config.pq_chunks,
1007 rng,
1008 );
1009 let got = computer.evaluate_similarity(&VersionedPQVector {
1010 data: code,
1011 version: invalid_version,
1012 });
1013 assert!(
1014 got.is_none(),
1015 "expected a distance computation with an invalid version to return None"
1016 );
1017 }
1018 }
1019 }
1020
1021 #[rstest]
1022 fn test_query_computer_two(
1023 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
1024 ) {
1025 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
1026
1027 let old_config = test_utils::TableConfig {
1028 dim: 17,
1029 pq_chunks: 4,
1030 num_pivots: 20,
1031 start_value: 10.0,
1032 };
1033
1034 let new_config = test_utils::TableConfig {
1035 dim: 17,
1036 pq_chunks: 5,
1037 num_pivots: 16,
1038 start_value: 1.0,
1039 };
1040
1041 let old = test_utils::seed_pivot_table(old_config);
1042 let new = test_utils::seed_pivot_table(new_config);
1043 let num_trials = 20;
1044
1045 let create = |new_version: usize, old_version: usize, query: &[f32]| {
1046 let schema = MultiTable::two(&new, &old, new_version, old_version).unwrap();
1047 MultiQueryComputer::new(schema, metric, query).unwrap()
1048 };
1049
1050 let errors = test_utils::RelativeAndAbsolute {
1051 relative: 5.0e-5,
1052 absolute: 0.0,
1053 };
1054
1055 test_query_computer_multi_with_two(
1056 create,
1057 &new,
1058 &old,
1059 &new_config,
1060 &old_config,
1061 &f32::distance(metric, None),
1062 num_trials,
1063 &mut rng,
1064 errors,
1065 );
1066 }
1067}