1use std::ops::Deref;
7
8use diskann::ANNResult;
9use diskann_utils::Reborrow;
10use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric};
11use thiserror::Error;
12
13use super::{
14 QueryComputer,
15 dynamic::{DistanceComputerConstructionError, VTable},
16};
17use crate::model::FixedChunkPQTable;
18
19pub trait PQVersion: Eq + Copy {}
20impl<T> PQVersion for T where T: Eq + Copy {}
21
22#[derive(Debug, Clone, PartialEq)]
24pub struct VersionedPQVector<I: PQVersion> {
25 data: Vec<u8>,
26 version: I,
27}
28
29impl<I> VersionedPQVector<I>
30where
31 I: PQVersion,
32{
33 pub fn new(data: Vec<u8>, version: I) -> Self {
35 Self { data, version }
36 }
37
38 pub fn as_ref(&self) -> VersionedPQVectorRef<'_, I> {
40 VersionedPQVectorRef::new(&self.data, self.version)
41 }
42
43 pub fn version(&self) -> &I {
45 &self.version
46 }
47
48 pub fn data(&self) -> &[u8] {
50 &self.data
51 }
52
53 pub fn raw_mut(&mut self) -> (&mut Vec<u8>, &mut I) {
55 (&mut self.data, &mut self.version)
56 }
57}
58
59impl<'a, I> Reborrow<'a> for VersionedPQVector<I>
60where
61 I: PQVersion,
62{
63 type Target = VersionedPQVectorRef<'a, I>;
64 fn reborrow(&'a self) -> Self::Target {
65 self.as_ref()
66 }
67}
68
69#[derive(Debug, Clone, Copy)]
71pub struct VersionedPQVectorRef<'a, I: PQVersion> {
72 data: &'a [u8],
73 version: I,
74}
75
76impl<'a, I: PQVersion> VersionedPQVectorRef<'a, I> {
77 pub fn new(data: &'a [u8], version: I) -> Self {
79 Self { data, version }
80 }
81
82 pub fn version(&self) -> &I {
84 &self.version
85 }
86
87 pub fn data(&self) -> &[u8] {
89 self.data
90 }
91}
92
93#[derive(Debug, Clone)]
96pub enum MultiTable<T, I>
97where
98 T: Deref<Target = FixedChunkPQTable>,
99 I: PQVersion,
100{
101 One { table: T, version: I },
103 Two {
107 new: T,
108 old: T,
109 new_version: I,
110 old_version: I,
111 },
112}
113
114#[derive(Debug, Error)]
115#[error("provided versions must not be equal")]
116pub struct EqualVersionsError;
117
118impl<T, I> MultiTable<T, I>
119where
120 T: Deref<Target = FixedChunkPQTable>,
121 I: PQVersion,
122{
123 pub fn one(table: T, version: I) -> Self {
125 Self::One { table, version }
126 }
127
128 pub fn two(new: T, old: T, new_version: I, old_version: I) -> Result<Self, EqualVersionsError> {
132 if new_version == old_version {
133 Err(EqualVersionsError)
134 } else {
135 Ok(Self::Two {
136 new,
137 old,
138 new_version,
139 old_version,
140 })
141 }
142 }
143
144 pub fn versions(&self) -> (&I, Option<&I>) {
154 match &self {
155 Self::One { version, .. } => (version, None),
156 Self::Two {
157 new_version,
158 old_version,
159 ..
160 } => (new_version, Some(old_version)),
161 }
162 }
163}
164
165#[derive(Debug, Clone)]
179pub struct MultiDistanceComputer<T, I>
180where
181 T: Deref<Target = FixedChunkPQTable>,
182 I: PQVersion,
183{
184 table: MultiTable<T, I>,
185 vtable: VTable,
186}
187
188impl<T, I> MultiDistanceComputer<T, I>
189where
190 T: Deref<Target = FixedChunkPQTable>,
191 I: PQVersion,
192{
193 pub fn new(
196 table: MultiTable<T, I>,
197 metric: Metric,
198 ) -> Result<Self, DistanceComputerConstructionError> {
199 match &table {
201 MultiTable::One { table, .. } => {
202 if table.has_opq() {
203 return Err(DistanceComputerConstructionError::OPQNotSupported);
204 }
205 }
206 MultiTable::Two { new, old, .. } => {
207 if new.has_opq() || old.has_opq() {
208 return Err(DistanceComputerConstructionError::OPQNotSupported);
209 }
210 }
211 };
212 Ok(Self {
213 table,
214 vtable: VTable::new(metric),
215 })
216 }
217
218 pub fn versions(&self) -> (&I, Option<&I>) {
228 self.table.versions()
229 }
230}
231
232impl<T, I> DistanceFunction<&[f32], &VersionedPQVector<I>, Option<f32>>
233 for MultiDistanceComputer<T, I>
234where
235 T: Deref<Target = FixedChunkPQTable>,
236 I: PQVersion,
237{
238 #[inline(always)]
239 fn evaluate_similarity(&self, x: &[f32], y: &VersionedPQVector<I>) -> Option<f32> {
240 self.evaluate_similarity(x, y.reborrow())
241 }
242}
243
244impl<T, I> DistanceFunction<&[f32], VersionedPQVectorRef<'_, I>, Option<f32>>
245 for MultiDistanceComputer<T, I>
246where
247 T: Deref<Target = FixedChunkPQTable>,
248 I: PQVersion,
249{
250 fn evaluate_similarity(&self, x: &[f32], y: VersionedPQVectorRef<'_, I>) -> Option<f32> {
251 match &self.table {
252 MultiTable::One { table, version } => {
253 if version != &y.version {
254 None
255 } else {
256 Some((self.vtable.distance_fn)(table, x, y.data))
257 }
258 }
259 MultiTable::Two {
260 old,
261 new,
262 old_version,
263 new_version,
264 } => {
265 if old_version == &y.version {
266 Some((self.vtable.distance_fn)(old, x, y.data))
267 } else if new_version == &y.version {
268 Some((self.vtable.distance_fn)(new, x, y.data))
269 } else {
270 None
271 }
272 }
273 }
274 }
275}
276
277impl<T, I> DistanceFunction<&VersionedPQVector<I>, &VersionedPQVector<I>, Option<f32>>
278 for MultiDistanceComputer<T, I>
279where
280 T: Deref<Target = FixedChunkPQTable>,
281 I: PQVersion,
282{
283 #[inline(always)]
284 fn evaluate_similarity(
285 &self,
286 x: &VersionedPQVector<I>,
287 y: &VersionedPQVector<I>,
288 ) -> Option<f32> {
289 self.evaluate_similarity(x.reborrow(), y.reborrow())
290 }
291}
292
293impl<T, I> DistanceFunction<VersionedPQVectorRef<'_, I>, VersionedPQVectorRef<'_, I>, Option<f32>>
301 for MultiDistanceComputer<T, I>
302where
303 T: Deref<Target = FixedChunkPQTable>,
304 I: PQVersion,
305{
306 fn evaluate_similarity(
307 &self,
308 x: VersionedPQVectorRef<'_, I>,
309 y: VersionedPQVectorRef<'_, I>,
310 ) -> Option<f32> {
311 match &self.table {
312 MultiTable::One { table, version } => {
313 if (&x.version != version) || (&y.version != version) {
314 None
315 } else {
316 Some((self.vtable.distance_fn_qq)(table, x.data, y.data))
317 }
318 }
319 MultiTable::Two {
320 new,
321 old,
322 new_version,
323 old_version,
324 } => {
325 let x_new = &x.version == new_version;
326 let x_old = &x.version == old_version;
327
328 let y_new = &y.version == new_version;
329 let y_old = &y.version == old_version;
330
331 if x_old {
332 if y_old {
333 Some((self.vtable.distance_fn_qq)(old, x.data, y.data))
335 } else if y_new {
336 let x_full = old.inflate_vector(x.data);
337 Some((self.vtable.distance_fn)(new, &x_full, y.data))
339 } else {
340 None
341 }
342 } else if x_new {
343 if y_old {
344 let y_full = old.inflate_vector(y.data);
345 Some((self.vtable.distance_fn)(new, &y_full, x.data))
347 } else if y_new {
348 Some((self.vtable.distance_fn_qq)(new, x.data, y.data))
350 } else {
351 None
352 }
353 } else {
354 None
355 }
356 }
357 }
358 }
359}
360
361#[derive(Debug)]
369pub enum MultiQueryComputer<T, I>
370where
371 T: Deref<Target = FixedChunkPQTable>,
372 I: PQVersion,
373{
374 One {
375 computer: QueryComputer<T>,
376 version: I,
377 },
378 Two {
379 new: QueryComputer<T>,
380 old: QueryComputer<T>,
381 new_version: I,
382 old_version: I,
383 },
384}
385
386impl<T, I> MultiQueryComputer<T, I>
387where
388 T: Deref<Target = FixedChunkPQTable>,
389 I: PQVersion,
390{
391 pub fn new<U>(table: MultiTable<T, I>, metric: Metric, query: &[U]) -> ANNResult<Self>
393 where
394 U: Into<f32> + Copy,
395 {
396 let s = match table {
397 MultiTable::One { table, version } => Self::One {
398 computer: { QueryComputer::new(table, metric, query, None)? },
399 version,
400 },
401 MultiTable::Two {
402 new,
403 old,
404 new_version,
405 old_version,
406 } => Self::Two {
407 new: { QueryComputer::new(new, metric, query, None)? },
408 old: { QueryComputer::new(old, metric, query, None)? },
409 new_version,
410 old_version,
411 },
412 };
413 Ok(s)
414 }
415
416 pub fn versions(&self) -> (&I, Option<&I>) {
420 match &self {
421 Self::One { version, .. } => (version, None),
422 Self::Two {
423 new_version,
424 old_version,
425 ..
426 } => (new_version, Some(old_version)),
427 }
428 }
429}
430
431impl<T, I> PreprocessedDistanceFunction<&VersionedPQVector<I>, Option<f32>>
432 for MultiQueryComputer<T, I>
433where
434 T: Deref<Target = FixedChunkPQTable>,
435 I: PQVersion,
436{
437 #[inline(always)]
438 fn evaluate_similarity(&self, x: &VersionedPQVector<I>) -> Option<f32> {
439 self.evaluate_similarity(x.reborrow())
440 }
441}
442
443impl<T, I> PreprocessedDistanceFunction<VersionedPQVectorRef<'_, I>, Option<f32>>
444 for MultiQueryComputer<T, I>
445where
446 T: Deref<Target = FixedChunkPQTable>,
447 I: PQVersion,
448{
449 fn evaluate_similarity(&self, x: VersionedPQVectorRef<'_, I>) -> Option<f32> {
450 match &self {
451 Self::One { computer, version } => {
452 if version != &x.version {
453 None
454 } else {
455 Some(computer.evaluate_similarity(x.data))
456 }
457 }
458 Self::Two {
459 new,
460 old,
461 new_version,
462 old_version,
463 } => {
464 if old_version == &x.version {
465 Some(old.evaluate_similarity(x.data))
466 } else if new_version == &x.version {
467 Some(new.evaluate_similarity(x.data))
468 } else {
469 None
470 }
471 }
472 }
473 }
474}
475
476#[cfg(test)]
486mod tests {
487 use std::marker::PhantomData;
488
489 use approx::assert_relative_eq;
490 use diskann::utils::{IntoUsize, VectorRepr};
491 use diskann_vector::{Half, PreprocessedDistanceFunction};
492 use rand::{Rng, SeedableRng, distr::Distribution};
493 use rstest::rstest;
494
495 use super::{
496 super::test_utils::{self, TestDistribution},
497 *,
498 };
499
500 fn to_f32<T>(x: &[T]) -> Vec<f32>
501 where
502 T: Into<f32> + Copy,
503 {
504 x.iter().map(|i| (*i).into()).collect()
505 }
506
507 #[test]
512 fn test_versioned_pq_vector() {
513 let vec = vec![1, 2, 3];
514 let ptr = vec.as_ptr();
515 let pq = VersionedPQVector::<usize>::new(vec, 10);
516 assert_eq!(*pq.version(), 10);
517 assert_eq!(pq.data().len(), 3);
518
519 let data_ptr = pq.data().as_ptr();
520 let pq_ref = pq.as_ref();
521 assert_eq!(pq_ref.version(), pq.version());
522 assert_eq!(data_ptr, ptr);
523 assert_eq!(
524 pq_ref.data().as_ptr(),
525 data_ptr,
526 "expected VersionedPQVectorRef to have the same underlying data as the \
527 original VersionedPQVector"
528 );
529
530 let pq_ref = pq.reborrow();
531 assert_eq!(pq_ref.version(), pq.version());
532 assert_eq!(data_ptr, ptr);
533 assert_eq!(
534 pq_ref.data().as_ptr(),
535 data_ptr,
536 "expected VersionedPQVectorRef to have the same underlying data as the \
537 original VersionedPQVector"
538 );
539 }
540
541 #[test]
546 fn test_table_error() {
547 let config = test_utils::TableConfig {
548 dim: 17,
549 pq_chunks: 4,
550 num_pivots: 20,
551 start_value: 10.0,
552 use_opq: false,
553 };
554
555 let new = test_utils::seed_pivot_table(config);
556 let old = test_utils::seed_pivot_table(config);
557
558 let result = MultiTable::two(&new, &old, 0, 0);
559 assert!(
560 matches!(result, Err(EqualVersionsError)),
561 "MultiTable should now allow construction of the Two variant with equal versions"
562 );
563 }
564
565 fn test_distance_computer_multi_with_one<R>(
571 computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
572 table: &FixedChunkPQTable,
573 config: &test_utils::TableConfig,
574 reference: &<f32 as VectorRepr>::Distance,
575 num_trials: usize,
576 rng: &mut R,
577 ) where
578 R: Rng,
579 {
580 let (&version, should_be_none) = computer.versions();
582 assert!(
583 should_be_none.is_none(),
584 "expected just one schema in test computer"
585 );
586 let invalid_version = version.wrapping_add(1);
587
588 for _ in 0..num_trials {
589 let code0 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
590 let expected0 = test_utils::generate_expected_vector(
591 &code0,
592 table.get_chunk_offsets(),
593 config.start_value,
594 );
595
596 let code1 = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
597 let expected1 = test_utils::generate_expected_vector(
598 &code1,
599 table.get_chunk_offsets(),
600 config.start_value,
601 );
602
603 let expected = reference.evaluate_similarity(&expected0, &expected1);
604
605 let got = computer
607 .evaluate_similarity(&*expected0, &VersionedPQVector::new(code1.clone(), version))
608 .expect("evaluate_similarity should return Some");
609 assert_eq!(got, expected);
610
611 let got = computer
612 .evaluate_similarity(&*expected1, &VersionedPQVector::new(code0.clone(), version))
613 .expect("evaluate_similarity should return Some");
614 assert_eq!(got, expected);
615
616 let got = computer
618 .evaluate_similarity(
619 &VersionedPQVector::new(code0.clone(), version),
620 &VersionedPQVector::new(code1.clone(), version),
621 )
622 .expect("evaluate_similarity should return Some");
623 assert_eq!(got, expected);
624
625 let got = computer.evaluate_similarity(
627 &*expected0,
628 &VersionedPQVector::new(code0.clone(), invalid_version),
629 );
630 assert!(got.is_none(), "version mismatches should return `None`");
631
632 let got = computer.evaluate_similarity(
633 &VersionedPQVector::new(code0.clone(), invalid_version),
634 &VersionedPQVector::new(code1.clone(), version),
635 );
636 assert!(got.is_none(), "version mismatches should return `None`");
637
638 let got = computer.evaluate_similarity(
639 &VersionedPQVector::new(code0.clone(), version),
640 &VersionedPQVector::new(code1.clone(), invalid_version),
641 );
642 assert!(got.is_none(), "version mismatches should return `None`");
643 }
644 }
645
646 #[rstest]
647 fn test_multi_distance_computer_one(
648 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
649 ) {
650 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
651
652 let config = test_utils::TableConfig {
653 dim: 17,
654 pq_chunks: 4,
655 num_pivots: 20,
656 start_value: 10.0,
657 use_opq: false,
658 };
659
660 let table = test_utils::seed_pivot_table(config);
661
662 let version: usize = 0x625b215f82f38008;
663
664 let multi_table = MultiTable::one(&table, version);
665 let (n, o) = multi_table.versions();
666 assert_eq!(*n, version);
667 assert!(o.is_none());
668
669 let computer = MultiDistanceComputer::new(multi_table, metric).unwrap();
670
671 test_distance_computer_multi_with_one(
672 &computer,
673 &table,
674 &config,
675 &f32::distance(metric, None),
676 100,
677 &mut rng,
678 );
679 }
680
681 #[allow(clippy::too_many_arguments)]
687 fn test_distance_computer_multi_with_two<R>(
688 computer: &MultiDistanceComputer<&'_ FixedChunkPQTable, usize>,
689 new: &FixedChunkPQTable,
690 old: &FixedChunkPQTable,
691 new_config: &test_utils::TableConfig,
692 old_config: &test_utils::TableConfig,
693 reference: &<f32 as VectorRepr>::Distance,
694 num_trials: usize,
695 rng: &mut R,
696 ) where
697 R: Rng,
698 {
699 let (&new_version, old_version) = computer.versions();
701 let &old_version = old_version.expect("expected two schemas in test computer");
702
703 for _ in 0..num_trials {
704 let old_code =
706 test_utils::generate_random_code(old_config.num_pivots, old_config.pq_chunks, rng);
707 let old_expected = test_utils::generate_expected_vector(
708 &old_code,
709 old.get_chunk_offsets(),
710 old_config.start_value,
711 );
712
713 let new_code =
715 test_utils::generate_random_code(new_config.num_pivots, new_config.pq_chunks, rng);
716 let new_expected = test_utils::generate_expected_vector(
717 &new_code,
718 new.get_chunk_offsets(),
719 new_config.start_value,
720 );
721
722 let oo = reference.evaluate_similarity(&old_expected, &old_expected);
724 let nn = reference.evaluate_similarity(&new_expected, &new_expected);
725 let on = reference.evaluate_similarity(&old_expected, &new_expected);
726
727 {
729 let got_oo_qq = computer.evaluate_similarity(
730 &VersionedPQVector::new(old_code.clone(), old_version),
731 &VersionedPQVector::new(old_code.clone(), old_version),
732 );
733 assert_eq!(got_oo_qq.unwrap(), oo);
734
735 let got_on_qq = computer.evaluate_similarity(
736 &VersionedPQVector::new(old_code.clone(), old_version),
737 &VersionedPQVector::new(new_code.clone(), new_version),
738 );
739 assert_eq!(got_on_qq.unwrap(), on);
740
741 let got_no_qq = computer.evaluate_similarity(
742 &VersionedPQVector::new(new_code.clone(), new_version),
743 &VersionedPQVector::new(old_code.clone(), old_version),
744 );
745 assert_eq!(got_no_qq.unwrap(), on);
746
747 let got_nn_qq = computer.evaluate_similarity(
748 &VersionedPQVector::new(new_code.clone(), new_version),
749 &VersionedPQVector::new(new_code.clone(), new_version),
750 );
751 assert_eq!(got_nn_qq.unwrap(), nn);
752 }
753
754 {
756 let got_oo_qq = computer.evaluate_similarity(
757 &*old_expected,
758 &VersionedPQVector::new(old_code.clone(), old_version),
759 );
760 assert_eq!(got_oo_qq.unwrap(), oo);
761
762 let got_on_qq = computer.evaluate_similarity(
763 &*old_expected,
764 &VersionedPQVector::new(new_code.clone(), new_version),
765 );
766 assert_eq!(got_on_qq.unwrap(), on);
767
768 let got_no_qq = computer.evaluate_similarity(
769 &*new_expected,
770 &VersionedPQVector::new(old_code.clone(), old_version),
771 );
772 assert_eq!(got_no_qq.unwrap(), on);
773
774 let got_nn_qq = computer.evaluate_similarity(
775 &*new_expected,
776 &VersionedPQVector::new(new_code.clone(), new_version),
777 );
778 assert_eq!(got_nn_qq.unwrap(), nn);
779 }
780
781 let mut bad_version = old_version.wrapping_add(1);
783 if bad_version == new_version {
784 bad_version = bad_version.wrapping_add(1);
785 }
786
787 let got = computer.evaluate_similarity(
789 VersionedPQVectorRef::new(&old_code, bad_version),
790 VersionedPQVectorRef::new(&new_code, new_version),
791 );
792 assert!(got.is_none());
793
794 let got = computer.evaluate_similarity(
796 &VersionedPQVector::new(new_code.clone(), new_version),
797 &VersionedPQVector::new(old_code.clone(), bad_version),
798 );
799 assert!(got.is_none());
800
801 let got = computer.evaluate_similarity(
803 &*new_expected,
804 &VersionedPQVector::new(old_code.clone(), bad_version),
805 );
806 assert!(got.is_none());
807 }
808 }
809
810 #[rstest]
811 fn test_multi_distance_computer_two(
812 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
813 ) {
814 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
815
816 let old_config = test_utils::TableConfig {
817 dim: 17,
818 pq_chunks: 4,
819 num_pivots: 20,
820 start_value: 10.0,
821 use_opq: false,
822 };
823
824 let new_config = test_utils::TableConfig {
825 dim: 17,
826 pq_chunks: 5,
827 num_pivots: 16,
828 start_value: 1.0,
829 use_opq: false,
830 };
831
832 let new = test_utils::seed_pivot_table(new_config);
833 let old = test_utils::seed_pivot_table(old_config);
834
835 let new_version: usize = 0x5a2b92a731766613;
836 let old_version: usize = 0x2fab58c9c8b73841;
837
838 let multi_table = MultiTable::two(&new, &old, new_version, old_version).unwrap();
839 let (n, o) = multi_table.versions();
840 assert_eq!(*n, new_version);
841 assert_eq!(*o.unwrap(), old_version);
842
843 let computer = MultiDistanceComputer::new(multi_table.clone(), metric).unwrap();
844 test_distance_computer_multi_with_two(
845 &computer,
846 &new,
847 &old,
848 &new_config,
849 &old_config,
850 &f32::distance(metric, None),
851 100,
852 &mut rng,
853 );
854 }
855
856 #[rstest]
861 fn test_multi_distance_computer_opq_error(
862 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
863 ) {
864 let config_with_opq = test_utils::TableConfig {
865 dim: 17,
866 pq_chunks: 4,
867 num_pivots: 20,
868 start_value: 10.0,
869 use_opq: true,
870 };
871
872 let config = test_utils::TableConfig {
873 dim: 17,
874 pq_chunks: 4,
875 num_pivots: 20,
876 start_value: 10.0,
877 use_opq: false,
878 };
879
880 let expected_err = (DistanceComputerConstructionError::OPQNotSupported).to_string();
881 let table_with_opq = test_utils::seed_pivot_table(config_with_opq);
882 let table = test_utils::seed_pivot_table(config);
883
884 let schema = MultiTable::one(&table_with_opq, 0);
885 let result = MultiDistanceComputer::new(schema, metric);
886 assert!(result.is_err(), "expected OPQ to not be supported");
887 assert_eq!(result.unwrap_err().to_string(), expected_err);
888
889 let schema = MultiTable::two(&table_with_opq, &table, 0, 1).unwrap();
891 let result = MultiDistanceComputer::new(schema, metric);
892 assert!(result.is_err(), "expected OPQ to not be supported");
893 assert_eq!(result.unwrap_err().to_string(), expected_err);
894
895 let schema = MultiTable::two(&table, &table_with_opq, 0, 1).unwrap();
896 let result = MultiDistanceComputer::new(schema, metric);
897 assert!(result.is_err(), "expected OPQ to not be supported");
898 assert_eq!(result.unwrap_err().to_string(), expected_err);
899
900 let schema = MultiTable::two(&table_with_opq, &table_with_opq, 0, 1).unwrap();
901 let result = MultiDistanceComputer::new(schema, metric);
902 assert!(result.is_err(), "expected OPQ to not be supported");
903 assert_eq!(result.unwrap_err().to_string(), expected_err);
904 }
905
906 #[allow(clippy::too_many_arguments)]
911 fn check_query_computer<R: Rng>(
912 computer: &MultiQueryComputer<&'_ FixedChunkPQTable, usize>,
913 table: &FixedChunkPQTable,
914 config: &test_utils::TableConfig,
915 query: &[f32],
916 version: usize,
917 rng: &mut R,
918 reference: &<f32 as VectorRepr>::Distance,
919 errors: test_utils::RelativeAndAbsolute,
920 ) {
921 let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
923 let expected_vector = test_utils::generate_expected_vector(
924 &code,
925 table.get_chunk_offsets(),
926 config.start_value,
927 );
928 let got = computer
929 .evaluate_similarity(&VersionedPQVector {
930 data: code,
931 version,
932 })
933 .unwrap();
934 let expected = reference.evaluate_similarity(query, &expected_vector);
935 assert_relative_eq!(
936 got,
937 expected,
938 epsilon = errors.absolute,
939 max_relative = errors.relative
940 );
941 }
942
943 fn test_query_computer_multi_with_one<'a, T, R>(
944 mut create: impl FnMut(usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
945 table: &'a FixedChunkPQTable,
946 config: &test_utils::TableConfig,
947 reference: &<f32 as VectorRepr>::Distance,
948 num_trials: usize,
949 rng: &mut R,
950 errors: test_utils::RelativeAndAbsolute,
951 ) where
952 T: Into<f32> + TestDistribution,
953 R: Rng,
954 {
955 let standard = rand::distr::StandardUniform {};
956 for _ in 0..num_trials {
957 let input: Vec<T> = T::generate(config.dim, rng);
958 let input_f32 = to_f32(&input);
959
960 let version: u64 = standard.sample(rng);
961 let version: usize = version.into_usize();
962 let invalid_version = version.wrapping_add(1);
963
964 let computer = create(version, &input);
965
966 assert_eq!(
967 computer.versions(),
968 (&version, None),
969 "expected the computer to only have one version"
970 );
971
972 for _ in 0..num_trials {
973 check_query_computer(
974 &computer, table, config, &input_f32, version, rng, reference, errors,
975 );
976 }
977
978 let code = test_utils::generate_random_code(config.num_pivots, config.pq_chunks, rng);
980 let got =
981 computer.evaluate_similarity(VersionedPQVectorRef::new(&code, invalid_version));
982 assert!(got.is_none(), "Expected `None` for unmatched versions");
983 }
984 }
985
986 #[rstest]
987 fn test_query_computer_one<T>(
988 #[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<u8>, PhantomData::<i8>)]
989 _datatype: PhantomData<T>,
990 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
991 ) where
992 T: Into<f32> + TestDistribution,
993 {
994 let mut rng = rand::rngs::StdRng::seed_from_u64(0x6b53bef1bc26571e);
995
996 let config = test_utils::TableConfig {
997 dim: 17,
998 pq_chunks: 4,
999 num_pivots: 20,
1000 start_value: 10.0,
1001 use_opq: false,
1002 };
1003
1004 let table = test_utils::seed_pivot_table(config);
1005 let num_trials = 20;
1006
1007 let errors = test_utils::RelativeAndAbsolute {
1008 relative: 5.0e-5,
1009 absolute: 0.0,
1010 };
1011
1012 let create = |version: usize, query: &[T]| {
1013 let schema = MultiTable::one(&table, version);
1014 MultiQueryComputer::new(schema, metric, query).unwrap()
1015 };
1016 test_query_computer_multi_with_one(
1017 create,
1018 &table,
1019 &config,
1020 &f32::distance(metric, None),
1021 num_trials,
1022 &mut rng,
1023 errors,
1024 );
1025 }
1026
1027 #[allow(clippy::too_many_arguments)]
1032 fn test_query_computer_multi_with_two<'a, T, R>(
1033 create: impl Fn(usize, usize, &[T]) -> MultiQueryComputer<&'a FixedChunkPQTable, usize>,
1034 new: &'a FixedChunkPQTable,
1035 old: &'a FixedChunkPQTable,
1036 new_config: &test_utils::TableConfig,
1037 old_config: &test_utils::TableConfig,
1038 reference: &<f32 as VectorRepr>::Distance,
1039 num_trials: usize,
1040 rng: &mut R,
1041 errors: test_utils::RelativeAndAbsolute,
1042 ) where
1043 T: Into<f32> + TestDistribution,
1044 R: Rng,
1045 {
1046 let standard = rand::distr::StandardUniform {};
1047 for _ in 0..num_trials {
1048 let input: Vec<T> = T::generate(old_config.dim, rng);
1049 let input_f32: Vec<f32> = to_f32(&input);
1050
1051 let old_version: u64 = standard.sample(rng);
1053 let mut new_version: u64 = standard.sample(rng);
1054 while new_version == old_version {
1055 new_version = standard.sample(rng);
1056 }
1057
1058 let mut invalid_version: u64 = standard.sample(rng);
1059 while invalid_version == old_version || invalid_version == new_version {
1060 invalid_version = standard.sample(rng);
1061 }
1062
1063 let old_version = old_version.into_usize();
1064 let new_version = new_version.into_usize();
1065 let invalid_version = invalid_version.into_usize();
1066
1067 let computer = create(new_version, old_version, &input);
1068
1069 assert_eq!(
1070 computer.versions(),
1071 (&new_version, Some(&old_version)),
1072 "versions were not propagated successfully",
1073 );
1074
1075 for _ in 0..num_trials {
1076 check_query_computer(
1077 &computer,
1078 old,
1079 old_config,
1080 &input_f32,
1081 old_version,
1082 rng,
1083 reference,
1084 errors,
1085 );
1086
1087 check_query_computer(
1088 &computer,
1089 new,
1090 new_config,
1091 &input_f32,
1092 new_version,
1093 rng,
1094 reference,
1095 errors,
1096 );
1097
1098 let code = test_utils::generate_random_code(
1099 old_config.num_pivots,
1100 old_config.pq_chunks,
1101 rng,
1102 );
1103 let got = computer.evaluate_similarity(&VersionedPQVector {
1104 data: code,
1105 version: invalid_version,
1106 });
1107 assert!(
1108 got.is_none(),
1109 "expected a distance computation with an invalid version to return None"
1110 );
1111 }
1112 }
1113 }
1114
1115 #[rstest]
1116 fn test_query_computer_two<T>(
1117 #[values(PhantomData::<f32>, PhantomData::<Half>, PhantomData::<u8>, PhantomData::<i8>)]
1118 _datatype: PhantomData<T>,
1119 #[values(Metric::L2, Metric::InnerProduct, Metric::Cosine)] metric: Metric,
1120 ) where
1121 T: Into<f32> + TestDistribution,
1122 {
1123 let mut rng = rand::rngs::StdRng::seed_from_u64(0xc8da1164a88cef0f);
1124
1125 let old_config = test_utils::TableConfig {
1126 dim: 17,
1127 pq_chunks: 4,
1128 num_pivots: 20,
1129 start_value: 10.0,
1130 use_opq: false,
1131 };
1132
1133 let new_config = test_utils::TableConfig {
1134 dim: 17,
1135 pq_chunks: 5,
1136 num_pivots: 16,
1137 start_value: 1.0,
1138 use_opq: false,
1139 };
1140
1141 let old = test_utils::seed_pivot_table(old_config);
1142 let new = test_utils::seed_pivot_table(new_config);
1143 let num_trials = 20;
1144
1145 let create = |new_version: usize, old_version: usize, query: &[T]| {
1146 let schema = MultiTable::two(&new, &old, new_version, old_version).unwrap();
1147 MultiQueryComputer::new(schema, metric, query).unwrap()
1148 };
1149
1150 let errors = test_utils::RelativeAndAbsolute {
1151 relative: 5.0e-5,
1152 absolute: 0.0,
1153 };
1154
1155 test_query_computer_multi_with_two(
1156 create,
1157 &new,
1158 &old,
1159 &new_config,
1160 &old_config,
1161 &f32::distance(metric, None),
1162 num_trials,
1163 &mut rng,
1164 errors,
1165 );
1166 }
1167}