1use std::marker::PhantomData;
137
138use diskann_utils::{Reborrow, ReborrowMut};
139use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
140use diskann_wide::{
141 Architecture,
142 arch::{Scalar, Target1, Target2},
143};
144#[cfg(feature = "flatbuffers")]
145use flatbuffers::FlatBufferBuilder;
146use thiserror::Error;
147
148#[cfg(target_arch = "x86_64")]
149use diskann_wide::arch::x86_64::{V3, V4};
150
151#[cfg(target_arch = "aarch64")]
152use diskann_wide::arch::aarch64::Neon;
153
154use super::{
155 CompensatedCosine, CompensatedIP, CompensatedSquaredL2, Data, DataMut, DataRef, FullQuery,
156 FullQueryMut, FullQueryRef, Query, QueryMut, QueryRef, SphericalQuantizer, SupportedMetric,
157 quantizer,
158};
159use crate::{
160 AsFunctor, CompressIntoWith,
161 alloc::{
162 Allocator, AllocatorCore, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone,
163 },
164 bits::{self, Representation, Unsigned},
165 distances::{self, UnequalLengths},
166 error::InlineError,
167 meta,
168 num::PowerOfTwo,
169 poly,
170};
171#[cfg(feature = "flatbuffers")]
172use crate::{alloc::CompoundError, flatbuffers as fb};
173
174type Rf32 = distances::Result<f32>;
176
177#[derive(Debug, Clone)]
183pub struct QueryBufferDescription {
184 size: usize,
185 align: PowerOfTwo,
186}
187
188impl QueryBufferDescription {
189 pub fn new(size: usize, align: PowerOfTwo) -> Self {
191 Self { size, align }
192 }
193
194 pub fn bytes(&self) -> usize {
196 self.size
197 }
198
199 pub fn align(&self) -> PowerOfTwo {
201 self.align
202 }
203}
204
205pub trait Quantizer<A = GlobalAllocator>: Send + Sync
227where
228 A: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
229{
230 fn nbits(&self) -> usize;
232
233 fn bytes(&self) -> usize;
235
236 fn dim(&self) -> usize;
238
239 fn full_dim(&self) -> usize;
241
242 fn distance_computer(&self, allocator: A) -> Result<DistanceComputer<A>, AllocatorError>;
250
251 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer;
256
257 fn query_computer(
267 &self,
268 layout: QueryLayout,
269 allocator: A,
270 ) -> Result<DistanceComputer<A>, DistanceComputerError>;
271
272 fn query_buffer_description(
277 &self,
278 layout: QueryLayout,
279 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>;
280
281 fn compress_query(
288 &self,
289 x: &[f32],
290 layout: QueryLayout,
291 allow_rescale: bool,
292 buffer: OpaqueMut<'_>,
293 scratch: ScopedAllocator<'_>,
294 ) -> Result<(), QueryCompressionError>;
295
296 fn fused_query_computer(
303 &self,
304 x: &[f32],
305 layout: QueryLayout,
306 allow_rescale: bool,
307 allocator: A,
308 scratch: ScopedAllocator<'_>,
309 ) -> Result<QueryComputer<A>, QueryComputerError>;
310
311 fn is_supported(&self, layout: QueryLayout) -> bool;
313
314 fn compress(
321 &self,
322 x: &[f32],
323 into: OpaqueMut<'_>,
324 scratch: ScopedAllocator<'_>,
325 ) -> Result<(), CompressionError>;
326
327 fn metric(&self) -> SupportedMetric;
329
330 fn try_clone_into(&self, allocator: A) -> Result<Poly<dyn Quantizer<A>, A>, AllocatorError>;
332
333 crate::utils::features! {
334 #![feature = "flatbuffers"]
335 fn serialize(&self, allocator: A) -> Result<Poly<[u8], A>, AllocatorError>;
338 }
339}
340
341#[derive(Debug, Error)]
342#[error("Layout {layout} is not supported for {desc}")]
343pub struct UnsupportedQueryLayout {
344 layout: QueryLayout,
345 desc: &'static str,
346}
347
348impl UnsupportedQueryLayout {
349 fn new(layout: QueryLayout, desc: &'static str) -> Self {
350 Self { layout, desc }
351 }
352}
353
354#[derive(Debug, Error)]
355#[non_exhaustive]
356pub enum DistanceComputerError {
357 #[error(transparent)]
358 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
359 #[error(transparent)]
360 AllocatorError(#[from] AllocatorError),
361}
362
363#[derive(Debug, Error)]
364#[non_exhaustive]
365pub enum QueryCompressionError {
366 #[error(transparent)]
367 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
368 #[error(transparent)]
369 CompressionError(#[from] CompressionError),
370 #[error(transparent)]
371 NotCanonical(#[from] NotCanonical),
372 #[error(transparent)]
373 AllocatorError(#[from] AllocatorError),
374}
375
376#[derive(Debug, Error)]
377#[non_exhaustive]
378pub enum QueryComputerError {
379 #[error(transparent)]
380 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
381 #[error(transparent)]
382 CompressionError(#[from] CompressionError),
383 #[error(transparent)]
384 AllocatorError(#[from] AllocatorError),
385}
386
387#[derive(Debug, Error)]
389#[error("Error occured during query compression")]
390pub enum CompressionError {
391 NotCanonical(#[source] InlineError<16>),
393
394 CompressionError(#[source] quantizer::CompressionError),
398}
399
400impl CompressionError {
401 fn not_canonical<E>(error: E) -> Self
402 where
403 E: std::error::Error + Send + Sync + 'static,
404 {
405 Self::NotCanonical(InlineError::new(error))
406 }
407}
408
409#[derive(Debug, Error)]
410#[error("An opaque argument did not have the required alignment or length")]
411pub struct NotCanonical {
412 source: Box<dyn std::error::Error + Send + Sync>,
413}
414
415impl NotCanonical {
416 fn new<E>(err: E) -> Self
417 where
418 E: std::error::Error + Send + Sync + 'static,
419 {
420 Self {
421 source: Box::new(err),
422 }
423 }
424}
425
426#[derive(Debug, Clone, Copy)]
433#[repr(transparent)]
434pub struct Opaque<'a>(&'a [u8]);
435
436impl<'a> Opaque<'a> {
437 pub fn new(slice: &'a [u8]) -> Self {
439 Self(slice)
440 }
441
442 pub fn into_inner(self) -> &'a [u8] {
444 self.0
445 }
446}
447
448impl std::ops::Deref for Opaque<'_> {
449 type Target = [u8];
450 fn deref(&self) -> &[u8] {
451 self.0
452 }
453}
454impl<'short> Reborrow<'short> for Opaque<'_> {
455 type Target = Opaque<'short>;
456 fn reborrow(&'short self) -> Self::Target {
457 *self
458 }
459}
460
461#[derive(Debug)]
464#[repr(transparent)]
465pub struct OpaqueMut<'a>(&'a mut [u8]);
466
467impl<'a> OpaqueMut<'a> {
468 pub fn new(slice: &'a mut [u8]) -> Self {
470 Self(slice)
471 }
472
473 pub fn inspect(&mut self) -> &mut [u8] {
475 self.0
476 }
477}
478
479impl std::ops::Deref for OpaqueMut<'_> {
480 type Target = [u8];
481 fn deref(&self) -> &[u8] {
482 self.0
483 }
484}
485
486impl std::ops::DerefMut for OpaqueMut<'_> {
487 fn deref_mut(&mut self) -> &mut [u8] {
488 self.0
489 }
490}
491
492#[derive(Debug, Clone, Copy, PartialEq, Eq)]
498pub enum QueryLayout {
499 SameAsData,
503
504 FourBitTransposed,
506
507 ScalarQuantized,
510
511 FullPrecision,
513}
514
515impl QueryLayout {
516 #[cfg(test)]
517 fn all() -> [Self; 4] {
518 [
519 Self::SameAsData,
520 Self::FourBitTransposed,
521 Self::ScalarQuantized,
522 Self::FullPrecision,
523 ]
524 }
525}
526
527impl std::fmt::Display for QueryLayout {
528 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529 <Self as std::fmt::Debug>::fmt(self, fmt)
530 }
531}
532
533trait ReportQueryLayout {
543 fn report_query_layout(&self) -> QueryLayout;
544}
545
546impl<T, M, L, R> ReportQueryLayout for Reify<T, M, L, R>
547where
548 T: ReportQueryLayout,
549{
550 fn report_query_layout(&self) -> QueryLayout {
551 self.inner.report_query_layout()
552 }
553}
554
555impl<D, Q> ReportQueryLayout for Curried<D, Q>
556where
557 Q: ReportQueryLayout,
558{
559 fn report_query_layout(&self) -> QueryLayout {
560 self.query.report_query_layout()
561 }
562}
563
564impl<const NBITS: usize, A> ReportQueryLayout for Data<NBITS, A>
565where
566 Unsigned: Representation<NBITS>,
567 A: AllocatorCore,
568{
569 fn report_query_layout(&self) -> QueryLayout {
570 QueryLayout::SameAsData
571 }
572}
573
574impl<const NBITS: usize, A> ReportQueryLayout for Query<NBITS, bits::Dense, A>
575where
576 Unsigned: Representation<NBITS>,
577 A: AllocatorCore,
578{
579 fn report_query_layout(&self) -> QueryLayout {
580 QueryLayout::ScalarQuantized
581 }
582}
583
584impl<A> ReportQueryLayout for Query<4, bits::BitTranspose, A>
585where
586 A: AllocatorCore,
587{
588 fn report_query_layout(&self) -> QueryLayout {
589 QueryLayout::FourBitTransposed
590 }
591}
592
593impl<A> ReportQueryLayout for FullQuery<A>
594where
595 A: AllocatorCore,
596{
597 fn report_query_layout(&self) -> QueryLayout {
598 QueryLayout::FullPrecision
599 }
600}
601
602trait FromOpaque: 'static + Send + Sync {
611 type Target<'a>;
612 type Error: std::error::Error + Send + Sync + 'static;
613
614 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error>;
615}
616
617#[derive(Debug, Default)]
619pub(super) struct AsFull;
620
621#[derive(Debug, Default)]
623pub(super) struct AsData<const NBITS: usize>;
624
625#[derive(Debug)]
627pub(super) struct AsQuery<const NBITS: usize, Perm = bits::Dense> {
628 _marker: PhantomData<Perm>,
629}
630
631impl<const NBITS: usize, Perm> Default for AsQuery<NBITS, Perm> {
633 fn default() -> Self {
634 Self {
635 _marker: PhantomData,
636 }
637 }
638}
639
640impl FromOpaque for AsFull {
641 type Target<'a> = FullQueryRef<'a>;
642 type Error = meta::slice::NotCanonical;
643
644 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
645 Self::Target::from_canonical(query.into_inner(), dim)
646 }
647}
648
649impl ReportQueryLayout for AsFull {
650 fn report_query_layout(&self) -> QueryLayout {
651 QueryLayout::FullPrecision
652 }
653}
654
655impl<const NBITS: usize> FromOpaque for AsData<NBITS>
656where
657 Unsigned: Representation<NBITS>,
658{
659 type Target<'a> = DataRef<'a, NBITS>;
660 type Error = meta::NotCanonical;
661
662 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
663 Self::Target::from_canonical_back(query.into_inner(), dim)
664 }
665}
666
667impl<const NBITS: usize> ReportQueryLayout for AsData<NBITS> {
668 fn report_query_layout(&self) -> QueryLayout {
669 QueryLayout::SameAsData
670 }
671}
672
673impl<const NBITS: usize, Perm> FromOpaque for AsQuery<NBITS, Perm>
674where
675 Unsigned: Representation<NBITS>,
676 Perm: bits::PermutationStrategy<NBITS> + Send + Sync + 'static,
677{
678 type Target<'a> = QueryRef<'a, NBITS, Perm>;
679 type Error = meta::NotCanonical;
680
681 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
682 Self::Target::from_canonical_back(query.into_inner(), dim)
683 }
684}
685
686impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::Dense> {
687 fn report_query_layout(&self) -> QueryLayout {
688 QueryLayout::ScalarQuantized
689 }
690}
691
692impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::BitTranspose> {
693 fn report_query_layout(&self) -> QueryLayout {
694 QueryLayout::FourBitTransposed
695 }
696}
697
698pub(super) struct Reify<T, M, L, R> {
704 inner: T,
705 dim: usize,
706 arch: M,
707 _markers: PhantomData<(L, R)>,
708}
709
710impl<T, M, L, R> Reify<T, M, L, R> {
711 pub(super) fn new(inner: T, dim: usize, arch: M) -> Self {
712 Self {
713 inner,
714 dim,
715 arch,
716 _markers: PhantomData,
717 }
718 }
719}
720
721impl<M, T, R> DynQueryComputer for Reify<T, M, (), R>
722where
723 M: Architecture,
724 R: FromOpaque,
725 T: ReportQueryLayout + Send + Sync,
726 for<'a> &'a T: Target1<M, Rf32, R::Target<'a>>,
727{
728 fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
729 self.arch.run2(
730 |this: &Self, x| {
731 let x = R::from_opaque(x, this.dim)
732 .map_err(|err| QueryDistanceError::XReify(InlineError::new(err)))?;
733 this.arch
734 .run1(&this.inner, x)
735 .map_err(QueryDistanceError::UnequalLengths)
736 },
737 self,
738 x,
739 )
740 }
741
742 fn layout(&self) -> QueryLayout {
743 self.inner.report_query_layout()
744 }
745}
746
747impl<T, M, Q, R> DynDistanceComputer for Reify<T, M, Q, R>
748where
749 M: Architecture,
750 Q: FromOpaque + Default + ReportQueryLayout,
751 R: FromOpaque,
752 T: for<'a> Target2<M, Rf32, Q::Target<'a>, R::Target<'a>> + Copy + Send + Sync,
753{
754 fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
755 self.arch.run3(
756 |this: &Self, query, x| {
757 let query = Q::from_opaque(query, this.dim)
758 .map_err(|err| DistanceError::QueryReify(InlineError::<24>::new(err)))?;
759
760 let x = R::from_opaque(x, this.dim)
761 .map_err(|err| DistanceError::XReify(InlineError::<16>::new(err)))?;
762
763 this.arch
764 .run2_inline(this.inner, query, x)
765 .map_err(DistanceError::UnequalLengths)
766 },
767 self,
768 query,
769 x,
770 )
771 }
772
773 fn layout(&self) -> QueryLayout {
774 Q::default().report_query_layout()
775 }
776}
777
778#[derive(Debug, Error)]
784pub enum QueryDistanceError {
785 #[error("trouble trying to reify the argument")]
787 XReify(#[source] InlineError<16>),
788
789 #[error("encountered while trying to compute distances")]
791 UnequalLengths(#[source] UnequalLengths),
792}
793
794pub trait DynQueryComputer: Send + Sync {
795 fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError>;
796 fn layout(&self) -> QueryLayout;
797}
798
799pub struct QueryComputer<A = GlobalAllocator>
808where
809 A: AllocatorCore,
810{
811 inner: Poly<dyn DynQueryComputer, A>,
812}
813
814impl<A> QueryComputer<A>
815where
816 A: AllocatorCore,
817{
818 fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
819 where
820 T: DynQueryComputer + 'static,
821 {
822 let inner = Poly::new(inner, allocator)?;
823 Ok(Self {
824 inner: poly!(DynQueryComputer, inner),
825 })
826 }
827
828 pub fn layout(&self) -> QueryLayout {
830 self.inner.layout()
831 }
832
833 pub fn into_inner(self) -> Poly<dyn DynQueryComputer, A> {
835 self.inner
836 }
837}
838
839impl<A> std::fmt::Debug for QueryComputer<A>
840where
841 A: AllocatorCore,
842{
843 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844 write!(
845 f,
846 "dynamic fused query computer with layout \"{}\"",
847 self.layout()
848 )
849 }
850}
851
852impl<A> PreprocessedDistanceFunction<Opaque<'_>, Result<f32, QueryDistanceError>>
853 for QueryComputer<A>
854where
855 A: AllocatorCore,
856{
857 fn evaluate_similarity(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
858 self.inner.evaluate(x)
859 }
860}
861
862pub(super) struct Curried<D, Q> {
869 inner: D,
870 query: Q,
871}
872
873impl<D, Q> Curried<D, Q> {
874 pub(super) fn new(inner: D, query: Q) -> Self {
875 Self { inner, query }
876 }
877}
878
879impl<A, D, Q, T, R> Target1<A, R, T> for &Curried<D, Q>
880where
881 A: Architecture,
882 Q: for<'a> Reborrow<'a>,
883 D: for<'a> Target2<A, R, <Q as Reborrow<'a>>::Target, T> + Copy,
884{
885 fn run(self, arch: A, x: T) -> R {
886 self.inner.run(arch, self.query.reborrow(), x)
887 }
888}
889
890#[derive(Debug, Error)]
896pub enum DistanceError {
897 #[error("trouble trying to reify the left-hand argument")]
899 QueryReify(InlineError<24>),
900
901 #[error("trouble trying to reify the right-hand argument")]
903 XReify(InlineError<16>),
904
905 #[error("encountered while trying to compute distances")]
909 UnequalLengths(UnequalLengths),
910}
911
912pub trait DynDistanceComputer: Send + Sync {
913 fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError>;
914 fn layout(&self) -> QueryLayout;
915}
916
917pub struct DistanceComputer<A = GlobalAllocator>
928where
929 A: AllocatorCore,
930{
931 inner: Poly<dyn DynDistanceComputer, A>,
932}
933
934impl<A> DistanceComputer<A>
935where
936 A: AllocatorCore,
937{
938 pub(super) fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
939 where
940 T: DynDistanceComputer + 'static,
941 {
942 let inner = Poly::new(inner, allocator)?;
943 Ok(Self {
944 inner: poly!(DynDistanceComputer, inner),
945 })
946 }
947
948 pub fn layout(&self) -> QueryLayout {
950 self.inner.layout()
951 }
952
953 pub fn into_inner(self) -> Poly<dyn DynDistanceComputer, A> {
954 self.inner
955 }
956}
957
958impl<A> std::fmt::Debug for DistanceComputer<A>
959where
960 A: AllocatorCore,
961{
962 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
963 write!(
964 f,
965 "dynamic distance computer with layout \"{}\"",
966 self.layout()
967 )
968 }
969}
970
971impl<A> DistanceFunction<Opaque<'_>, Opaque<'_>, Result<f32, DistanceError>> for DistanceComputer<A>
972where
973 A: AllocatorCore,
974{
975 fn evaluate_similarity(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
976 self.inner.evaluate(query, x)
977 }
978}
979
980#[cfg(all(not(test), feature = "flatbuffers"))]
986const DEFAULT_SERIALIZED_BYTES: usize = 1024;
987
988#[cfg(all(test, feature = "flatbuffers"))]
990const DEFAULT_SERIALIZED_BYTES: usize = 1;
991
992pub struct Impl<const NBITS: usize, A = GlobalAllocator>
995where
996 A: Allocator,
997{
998 quantizer: SphericalQuantizer<A>,
999 distance: Poly<dyn DynDistanceComputer, A>,
1000}
1001
1002pub trait Constructible<A = GlobalAllocator>
1005where
1006 A: Allocator,
1007{
1008 fn dispatch_distance(
1009 quantizer: &SphericalQuantizer<A>,
1010 ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError>;
1011}
1012
1013impl<const NBITS: usize, A: Allocator> Constructible<A> for Impl<NBITS, A>
1014where
1015 A: Allocator,
1016 AsData<NBITS>: FromOpaque,
1017 SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1018{
1019 fn dispatch_distance(
1020 quantizer: &SphericalQuantizer<A>,
1021 ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError> {
1022 diskann_wide::arch::dispatch2_no_features(
1023 ComputerDispatcher::<AsData<NBITS>, NBITS>::new(),
1024 quantizer,
1025 quantizer.allocator().clone(),
1026 )
1027 .map(|obj| obj.inner)
1028 }
1029}
1030
1031impl<const NBITS: usize, A> TryClone for Impl<NBITS, A>
1032where
1033 A: Allocator,
1034 AsData<NBITS>: FromOpaque,
1035 SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1036{
1037 fn try_clone(&self) -> Result<Self, AllocatorError> {
1038 Self::new(self.quantizer.try_clone()?)
1039 }
1040}
1041
1042impl<const NBITS: usize, A: Allocator> Impl<NBITS, A> {
1043 pub fn new(quantizer: SphericalQuantizer<A>) -> Result<Self, AllocatorError>
1045 where
1046 Self: Constructible<A>,
1047 {
1048 let distance = Self::dispatch_distance(&quantizer)?;
1049 Ok(Self {
1050 quantizer,
1051 distance,
1052 })
1053 }
1054
1055 pub fn quantizer(&self) -> &SphericalQuantizer<A> {
1057 &self.quantizer
1058 }
1059
1060 pub fn supports(layout: QueryLayout) -> bool {
1064 if const { NBITS == 1 } {
1065 [
1066 QueryLayout::SameAsData,
1067 QueryLayout::FourBitTransposed,
1068 QueryLayout::FullPrecision,
1069 ]
1070 .contains(&layout)
1071 } else {
1072 [
1073 QueryLayout::SameAsData,
1074 QueryLayout::ScalarQuantized,
1075 QueryLayout::FullPrecision,
1076 ]
1077 .contains(&layout)
1078 }
1079 }
1080
1081 fn query_computer<Q, B>(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError>
1084 where
1085 Q: FromOpaque,
1086 B: AllocatorCore,
1087 SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1088 {
1089 diskann_wide::arch::dispatch2_no_features(
1090 ComputerDispatcher::<Q, NBITS>::new(),
1091 &self.quantizer,
1092 allocator,
1093 )
1094 }
1095
1096 fn compress_query<'a, T>(
1097 &self,
1098 query: &'a [f32],
1099 storage: T,
1100 scratch: ScopedAllocator<'a>,
1101 ) -> Result<(), QueryCompressionError>
1102 where
1103 SphericalQuantizer<A>: CompressIntoWith<&'a [f32], T, ScopedAllocator<'a>, Error = quantizer::CompressionError>,
1104 {
1105 self.quantizer
1106 .compress_into_with(query, storage, scratch)
1107 .map_err(|err| CompressionError::CompressionError(err).into())
1108 }
1109
1110 fn fused_query_computer<Q, T, B>(
1113 &self,
1114 query: &[f32],
1115 mut storage: T,
1116 allocator: B,
1117 scratch: ScopedAllocator<'_>,
1118 ) -> Result<QueryComputer<B>, QueryComputerError>
1119 where
1120 Q: FromOpaque,
1121 T: for<'a> ReborrowMut<'a>
1122 + for<'a> Reborrow<'a, Target = Q::Target<'a>>
1123 + ReportQueryLayout
1124 + Send
1125 + Sync
1126 + 'static,
1127 B: AllocatorCore,
1128 SphericalQuantizer<A>: for<'a> CompressIntoWith<
1129 &'a [f32],
1130 <T as ReborrowMut<'a>>::Target,
1131 ScopedAllocator<'a>,
1132 Error = quantizer::CompressionError,
1133 >,
1134 SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1135 {
1136 if let Err(err) = self
1137 .quantizer
1138 .compress_into_with(query, storage.reborrow_mut(), scratch)
1139 {
1140 return Err(CompressionError::CompressionError(err).into());
1141 }
1142
1143 diskann_wide::arch::dispatch3_no_features(
1144 ComputerDispatcher::<Q, NBITS>::new(),
1145 &self.quantizer,
1146 storage,
1147 allocator,
1148 )
1149 .map_err(|e| e.into())
1150 }
1151
1152 #[cfg(feature = "flatbuffers")]
1153 fn serialize<B>(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError>
1154 where
1155 B: Allocator + std::panic::UnwindSafe,
1156 A: std::panic::RefUnwindSafe,
1157 {
1158 let mut buf = FlatBufferBuilder::new_in(Poly::broadcast(
1159 0u8,
1160 DEFAULT_SERIALIZED_BYTES,
1161 allocator.clone(),
1162 )?);
1163
1164 let quantizer = &self.quantizer;
1165
1166 let (root, mut buf) = match std::panic::catch_unwind(move || {
1167 let offset = quantizer.pack(&mut buf);
1168
1169 let root = fb::spherical::Quantizer::create(
1170 &mut buf,
1171 &fb::spherical::QuantizerArgs {
1172 quantizer: Some(offset),
1173 nbits: NBITS as u32,
1174 },
1175 );
1176 (root, buf)
1177 }) {
1178 Ok(ret) => ret,
1179 Err(err) => match err.downcast_ref::<String>() {
1180 Some(msg) => {
1181 if msg.contains("AllocatorError") {
1182 return Err(AllocatorError);
1183 } else {
1184 std::panic::resume_unwind(err);
1185 }
1186 }
1187 None => std::panic::resume_unwind(err),
1188 },
1189 };
1190
1191 fb::spherical::finish_quantizer_buffer(&mut buf, root);
1193 Poly::from_iter(buf.finished_data().iter().copied(), allocator)
1194 }
1195}
1196
1197trait BuildComputer<M, Q, const N: usize>
1214where
1215 M: Architecture,
1216 Q: FromOpaque,
1217{
1218 fn build_computer<A>(
1223 &self,
1224 arch: M,
1225 allocator: A,
1226 ) -> Result<DistanceComputer<A>, AllocatorError>
1227 where
1228 A: AllocatorCore;
1229
1230 fn build_fused_computer<R, A>(
1235 &self,
1236 arch: M,
1237 query: R,
1238 allocator: A,
1239 ) -> Result<QueryComputer<A>, AllocatorError>
1240 where
1241 R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1242 A: AllocatorCore;
1243}
1244
1245fn identity<T>(x: T) -> T {
1246 x
1247}
1248
1249macro_rules! dispatch_map {
1250 ($N:literal, $Q:ty, $arch:ty) => {
1251 dispatch_map!($N, $Q, $arch, identity);
1252 };
1253 ($N:literal, $Q:ty, $arch:ty, $op:ident) => {
1254 impl<A> BuildComputer<$arch, $Q, $N> for SphericalQuantizer<A>
1255 where
1256 A: Allocator,
1257 {
1258 fn build_computer<B>(
1259 &self,
1260 input_arch: $arch,
1261 allocator: B,
1262 ) -> Result<DistanceComputer<B>, AllocatorError>
1263 where
1264 B: AllocatorCore,
1265 {
1266 type D = AsData<$N>;
1267
1268 let arch = ($op)(input_arch);
1270 let dim = self.output_dim();
1271 match self.metric() {
1272 SupportedMetric::SquaredL2 => {
1273 let reify = Reify::<CompensatedSquaredL2, _, $Q, D>::new(
1274 self.as_functor(),
1275 dim,
1276 arch,
1277 );
1278 DistanceComputer::new(reify, allocator)
1279 }
1280 SupportedMetric::InnerProduct => {
1281 let reify =
1282 Reify::<CompensatedIP, _, $Q, D>::new(self.as_functor(), dim, arch);
1283 DistanceComputer::new(reify, allocator)
1284 }
1285 SupportedMetric::Cosine => {
1286 let reify =
1287 Reify::<CompensatedCosine, _, $Q, D>::new(self.as_functor(), dim, arch);
1288 DistanceComputer::new(reify, allocator)
1289 }
1290 }
1291 }
1292
1293 fn build_fused_computer<R, B>(
1294 &self,
1295 input_arch: $arch,
1296 query: R,
1297 allocator: B,
1298 ) -> Result<QueryComputer<B>, AllocatorError>
1299 where
1300 R: ReportQueryLayout
1301 + for<'a> Reborrow<'a, Target = <$Q as FromOpaque>::Target<'a>>
1302 + Send
1303 + Sync
1304 + 'static,
1305 B: AllocatorCore,
1306 {
1307 type D = AsData<$N>;
1308 let arch = ($op)(input_arch);
1309 let dim = self.output_dim();
1310 match self.metric() {
1311 SupportedMetric::SquaredL2 => {
1312 let computer: CompensatedSquaredL2 = self.as_functor();
1313 let curried = Curried::new(computer, query);
1314 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1315 Ok(QueryComputer::new(reify, allocator)?)
1316 }
1317 SupportedMetric::InnerProduct => {
1318 let computer: CompensatedIP = self.as_functor();
1319 let curried = Curried::new(computer, query);
1320 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1321 Ok(QueryComputer::new(reify, allocator)?)
1322 }
1323 SupportedMetric::Cosine => {
1324 let computer: CompensatedCosine = self.as_functor();
1325 let curried = Curried::new(computer, query);
1326 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1327 Ok(QueryComputer::new(reify, allocator)?)
1328 }
1329 }
1330 }
1331 }
1332 };
1333}
1334
1335dispatch_map!(1, AsFull, Scalar);
1336dispatch_map!(2, AsFull, Scalar);
1337dispatch_map!(4, AsFull, Scalar);
1338dispatch_map!(8, AsFull, Scalar);
1339
1340dispatch_map!(1, AsData<1>, Scalar);
1341dispatch_map!(2, AsData<2>, Scalar);
1342dispatch_map!(4, AsData<4>, Scalar);
1343dispatch_map!(8, AsData<8>, Scalar);
1344
1345dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Scalar);
1347dispatch_map!(2, AsQuery<2>, Scalar);
1348dispatch_map!(4, AsQuery<4>, Scalar);
1349dispatch_map!(8, AsQuery<8>, Scalar);
1350
1351cfg_if::cfg_if! {
1352 if #[cfg(target_arch = "x86_64")] {
1353 fn downcast_to_v3(arch: V4) -> V3 {
1354 arch.into()
1355 }
1356
1357 dispatch_map!(1, AsFull, V3);
1359 dispatch_map!(2, AsFull, V3);
1360 dispatch_map!(4, AsFull, V3);
1361 dispatch_map!(8, AsFull, V3);
1362
1363 dispatch_map!(1, AsData<1>, V3);
1364 dispatch_map!(2, AsData<2>, V3);
1365 dispatch_map!(4, AsData<4>, V3);
1366 dispatch_map!(8, AsData<8>, V3);
1367
1368 dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V3);
1369 dispatch_map!(2, AsQuery<2>, V3);
1370 dispatch_map!(4, AsQuery<4>, V3);
1371 dispatch_map!(8, AsQuery<8>, V3);
1372
1373 dispatch_map!(1, AsFull, V4, downcast_to_v3);
1375 dispatch_map!(2, AsFull, V4, downcast_to_v3);
1376 dispatch_map!(4, AsFull, V4, downcast_to_v3);
1377 dispatch_map!(8, AsFull, V4, downcast_to_v3);
1378
1379 dispatch_map!(1, AsData<1>, V4, downcast_to_v3);
1380 dispatch_map!(2, AsData<2>, V4); dispatch_map!(4, AsData<4>, V4, downcast_to_v3);
1382 dispatch_map!(8, AsData<8>, V4, downcast_to_v3);
1383
1384 dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3);
1385 dispatch_map!(2, AsQuery<2>, V4); dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3);
1387 dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3);
1388 } else if #[cfg(target_arch = "aarch64")] {
1389 fn downcast(arch: Neon) -> Scalar {
1390 arch.retarget()
1391 }
1392
1393 dispatch_map!(1, AsFull, Neon, downcast);
1394 dispatch_map!(2, AsFull, Neon, downcast);
1395 dispatch_map!(4, AsFull, Neon, downcast);
1396 dispatch_map!(8, AsFull, Neon, downcast);
1397
1398 dispatch_map!(1, AsData<1>, Neon, downcast);
1399 dispatch_map!(2, AsData<2>, Neon, downcast);
1400 dispatch_map!(4, AsData<4>, Neon, downcast);
1401 dispatch_map!(8, AsData<8>, Neon, downcast);
1402
1403 dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Neon, downcast);
1404 dispatch_map!(2, AsQuery<2>, Neon, downcast);
1405 dispatch_map!(4, AsQuery<4>, Neon, downcast);
1406 dispatch_map!(8, AsQuery<8>, Neon, downcast);
1407 }
1408}
1409
1410#[derive(Debug, Clone, Copy)]
1422struct ComputerDispatcher<Q, const N: usize> {
1423 _query_type: std::marker::PhantomData<Q>,
1424}
1425
1426impl<Q, const N: usize> ComputerDispatcher<Q, N> {
1427 fn new() -> Self {
1428 Self {
1429 _query_type: std::marker::PhantomData,
1430 }
1431 }
1432}
1433
1434impl<M, const N: usize, A, B, Q>
1435 diskann_wide::arch::Target2<
1436 M,
1437 Result<DistanceComputer<B>, AllocatorError>,
1438 &SphericalQuantizer<A>,
1439 B,
1440 > for ComputerDispatcher<Q, N>
1441where
1442 M: Architecture,
1443 A: Allocator,
1444 B: AllocatorCore,
1445 Q: FromOpaque,
1446 SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1447{
1448 fn run(
1449 self,
1450 arch: M,
1451 quantizer: &SphericalQuantizer<A>,
1452 allocator: B,
1453 ) -> Result<DistanceComputer<B>, AllocatorError> {
1454 quantizer.build_computer(arch, allocator)
1455 }
1456}
1457
1458impl<M, const N: usize, A, R, B, Q>
1459 diskann_wide::arch::Target3<
1460 M,
1461 Result<QueryComputer<B>, AllocatorError>,
1462 &SphericalQuantizer<A>,
1463 R,
1464 B,
1465 > for ComputerDispatcher<Q, N>
1466where
1467 M: Architecture,
1468 A: Allocator,
1469 B: AllocatorCore,
1470 Q: FromOpaque,
1471 R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1472 SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1473{
1474 fn run(
1475 self,
1476 arch: M,
1477 quantizer: &SphericalQuantizer<A>,
1478 query: R,
1479 allocator: B,
1480 ) -> Result<QueryComputer<B>, AllocatorError> {
1481 quantizer.build_fused_computer(arch, query, allocator)
1482 }
1483}
1484
1485#[cfg(target_arch = "x86_64")]
1486trait Dispatchable<Q, const N: usize>:
1487 BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>
1488where
1489 Q: FromOpaque,
1490{
1491}
1492
1493#[cfg(target_arch = "x86_64")]
1494impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1495where
1496 Q: FromOpaque,
1497 T: BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>,
1498{
1499}
1500
1501#[cfg(target_arch = "aarch64")]
1502trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N> + BuildComputer<Neon, Q, N>
1503where
1504 Q: FromOpaque,
1505{
1506}
1507
1508#[cfg(target_arch = "aarch64")]
1509impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1510where
1511 Q: FromOpaque,
1512 T: BuildComputer<Scalar, Q, N> + BuildComputer<Neon, Q, N>,
1513{
1514}
1515
1516#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1517trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N>
1518where
1519 Q: FromOpaque,
1520{
1521}
1522
1523#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1524impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1525where
1526 Q: FromOpaque,
1527 T: BuildComputer<Scalar, Q, N>,
1528{
1529}
1530
1531impl<A, B> Quantizer<B> for Impl<1, A>
1536where
1537 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1538 B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1539{
1540 fn nbits(&self) -> usize {
1541 1
1542 }
1543
1544 fn dim(&self) -> usize {
1545 self.quantizer.output_dim()
1546 }
1547
1548 fn full_dim(&self) -> usize {
1549 self.quantizer.input_dim()
1550 }
1551
1552 fn bytes(&self) -> usize {
1553 DataRef::<1>::canonical_bytes(self.quantizer.output_dim())
1554 }
1555
1556 fn distance_computer(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError> {
1557 self.query_computer::<AsData<1>, _>(allocator)
1558 }
1559
1560 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1561 &*self.distance
1562 }
1563
1564 fn query_computer(
1565 &self,
1566 layout: QueryLayout,
1567 allocator: B,
1568 ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1569 match layout {
1570 QueryLayout::SameAsData => Ok(self.query_computer::<AsData<1>, _>(allocator)?),
1571 QueryLayout::FourBitTransposed => {
1572 Ok(self.query_computer::<AsQuery<4, bits::BitTranspose>, _>(allocator)?)
1573 }
1574 QueryLayout::ScalarQuantized => {
1575 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1576 }
1577 QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1578 }
1579 }
1580
1581 fn query_buffer_description(
1582 &self,
1583 layout: QueryLayout,
1584 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout> {
1585 let dim = <Self as Quantizer<B>>::dim(self);
1586 match layout {
1587 QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1588 DataRef::<1>::canonical_bytes(dim),
1589 PowerOfTwo::alignment_of::<u8>(),
1590 )),
1591 QueryLayout::FourBitTransposed => Ok(QueryBufferDescription::new(
1592 QueryRef::<4, bits::BitTranspose>::canonical_bytes(dim),
1593 PowerOfTwo::alignment_of::<u8>(),
1594 )),
1595 QueryLayout::ScalarQuantized => {
1596 Err(UnsupportedQueryLayout::new(layout, "1-bit compression"))
1597 }
1598 QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1599 FullQueryRef::canonical_bytes(dim),
1600 FullQueryRef::canonical_align(),
1601 )),
1602 }
1603 }
1604
1605 fn compress_query(
1606 &self,
1607 x: &[f32],
1608 layout: QueryLayout,
1609 allow_rescale: bool,
1610 mut buffer: OpaqueMut<'_>,
1611 scratch: ScopedAllocator<'_>,
1612 ) -> Result<(), QueryCompressionError> {
1613 let dim = <Self as Quantizer<B>>::dim(self);
1614 let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1615 match layout {
1616 QueryLayout::SameAsData => self.compress_query(
1617 v,
1618 DataMut::<1>::from_canonical_back_mut(&mut buffer, dim)
1619 .map_err(NotCanonical::new)?,
1620 scratch,
1621 ),
1622 QueryLayout::FourBitTransposed => self.compress_query(
1623 v,
1624 QueryMut::<4, bits::BitTranspose>::from_canonical_back_mut(&mut buffer, dim)
1625 .map_err(NotCanonical::new)?,
1626 scratch,
1627 ),
1628 QueryLayout::ScalarQuantized => {
1629 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1630 }
1631 QueryLayout::FullPrecision => self.compress_query(
1632 v,
1633 FullQueryMut::from_canonical_mut(&mut buffer, dim)
1634 .map_err(NotCanonical::new)?,
1635 scratch,
1636 ),
1637 }
1638 };
1639
1640 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1641 let mut copy = x.to_owned();
1642 self.quantizer.rescale(&mut copy);
1643 finish(©)
1644 } else {
1645 finish(x)
1646 }
1647 }
1648
1649 fn fused_query_computer(
1650 &self,
1651 x: &[f32],
1652 layout: QueryLayout,
1653 allow_rescale: bool,
1654 allocator: B,
1655 scratch: ScopedAllocator<'_>,
1656 ) -> Result<QueryComputer<B>, QueryComputerError> {
1657 let dim = <Self as Quantizer<B>>::dim(self);
1658 let finish = |v: &[f32], allocator: B| -> Result<QueryComputer<B>, QueryComputerError> {
1659 match layout {
1660 QueryLayout::SameAsData => self.fused_query_computer::<AsData<1>, Data<1, _>, _>(
1661 v,
1662 Data::new_in(dim, allocator.clone())?,
1663 allocator,
1664 scratch,
1665 ),
1666 QueryLayout::FourBitTransposed => self
1667 .fused_query_computer::<AsQuery<4, bits::BitTranspose>, Query<4, bits::BitTranspose, _>, _>(
1668 v,
1669 Query::new_in(dim, allocator.clone())?,
1670 allocator,
1671 scratch,
1672 ),
1673 QueryLayout::ScalarQuantized => {
1674 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1675 }
1676 QueryLayout::FullPrecision => self.fused_query_computer::<AsFull, FullQuery<_>, _>(
1677 v,
1678 FullQuery::empty(dim, allocator.clone())?,
1679 allocator,
1680 scratch,
1681 ),
1682 }
1683 };
1684
1685 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1686 let mut copy = x.to_owned();
1687 self.quantizer.rescale(&mut copy);
1688 finish(©, allocator)
1689 } else {
1690 finish(x, allocator)
1691 }
1692 }
1693
1694 fn is_supported(&self, layout: QueryLayout) -> bool {
1695 Self::supports(layout)
1696 }
1697
1698 fn compress(
1699 &self,
1700 x: &[f32],
1701 mut into: OpaqueMut<'_>,
1702 scratch: ScopedAllocator<'_>,
1703 ) -> Result<(), CompressionError> {
1704 let dim = <Self as Quantizer<B>>::dim(self);
1705 let into = DataMut::<1>::from_canonical_back_mut(into.inspect(), dim)
1706 .map_err(CompressionError::not_canonical)?;
1707 self.quantizer
1708 .compress_into_with(x, into, scratch)
1709 .map_err(CompressionError::CompressionError)
1710 }
1711
1712 fn metric(&self) -> SupportedMetric {
1713 self.quantizer.metric()
1714 }
1715
1716 fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1717 let clone = (*self).try_clone()?;
1718 poly!({ Quantizer<B> }, clone, allocator)
1719 }
1720
1721 #[cfg(feature = "flatbuffers")]
1722 fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1723 Impl::<1, A>::serialize(self, allocator)
1724 }
1725}
1726
1727macro_rules! plan {
1728 ($N:literal) => {
1729 impl<A, B> Quantizer<B> for Impl<$N, A>
1730 where
1731 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1732 B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1733 {
1734 fn nbits(&self) -> usize {
1735 $N
1736 }
1737
1738 fn dim(&self) -> usize {
1739 self.quantizer.output_dim()
1740 }
1741
1742 fn full_dim(&self) -> usize {
1743 self.quantizer.input_dim()
1744 }
1745
1746 fn bytes(&self) -> usize {
1747 DataRef::<$N>::canonical_bytes(<Self as Quantizer<B>>::dim(self))
1748 }
1749
1750 fn distance_computer(
1751 &self,
1752 allocator: B
1753 ) -> Result<DistanceComputer<B>, AllocatorError> {
1754 self.query_computer::<AsData<$N>, _>(allocator)
1755 }
1756
1757 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1758 &*self.distance
1759 }
1760
1761 fn query_computer(
1762 &self,
1763 layout: QueryLayout,
1764 allocator: B,
1765 ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1766 match layout {
1767 QueryLayout::SameAsData => Ok(self.query_computer::<AsData<$N>, _>(allocator)?)
1768 ,
1769 QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout::new(
1770 layout,
1771 concat!($N, "-bit compression"),
1772 ).into()),
1773 QueryLayout::ScalarQuantized => {
1774 Ok(self.query_computer::<AsQuery<$N, bits::Dense>, _>(allocator)?)
1775 },
1776 QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1777
1778 }
1779 }
1780
1781 fn query_buffer_description(
1782 &self,
1783 layout: QueryLayout
1784 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>
1785 {
1786 let dim = <Self as Quantizer<B>>::dim(self);
1787 match layout {
1788 QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1789 DataRef::<$N>::canonical_bytes(dim),
1790 PowerOfTwo::alignment_of::<u8>(),
1791 )),
1792 QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout {
1793 layout,
1794 desc: concat!($N, "-bit compression"),
1795 }),
1796 QueryLayout::ScalarQuantized => Ok(QueryBufferDescription::new(
1797 QueryRef::<$N, bits::Dense>::canonical_bytes(dim),
1798 PowerOfTwo::alignment_of::<u8>(),
1799 )),
1800 QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1801 FullQueryRef::canonical_bytes(dim),
1802 FullQueryRef::canonical_align(),
1803 )),
1804 }
1805 }
1806
1807 fn compress_query(
1808 &self,
1809 x: &[f32],
1810 layout: QueryLayout,
1811 allow_rescale: bool,
1812 mut buffer: OpaqueMut<'_>,
1813 scratch: ScopedAllocator<'_>,
1814 ) -> Result<(), QueryCompressionError> {
1815 let dim = <Self as Quantizer<B>>::dim(self);
1816 let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1817 match layout {
1818 QueryLayout::SameAsData => self.compress_query(
1819 v,
1820 DataMut::<$N>::from_canonical_back_mut(
1821 &mut buffer,
1822 dim,
1823 ).map_err(NotCanonical::new)?,
1824 scratch,
1825 ),
1826 QueryLayout::FourBitTransposed => {
1827 Err(UnsupportedQueryLayout::new(
1828 layout,
1829 concat!($N, "-bit compression"),
1830 ).into())
1831 },
1832 QueryLayout::ScalarQuantized => self.compress_query(
1833 v,
1834 QueryMut::<$N, bits::Dense>::from_canonical_back_mut(
1835 &mut buffer,
1836 dim,
1837 ).map_err(NotCanonical::new)?,
1838 scratch,
1839 ),
1840 QueryLayout::FullPrecision => self.compress_query(
1841 v,
1842 FullQueryMut::from_canonical_mut(
1843 &mut buffer,
1844 dim,
1845 ).map_err(NotCanonical::new)?,
1846 scratch,
1847 ),
1848 }
1849 };
1850
1851 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1852 let mut copy = x.to_owned();
1853 self.quantizer.rescale(&mut copy);
1854 finish(©)
1855 } else {
1856 finish(x)
1857 }
1858 }
1859
1860 fn fused_query_computer(
1861 &self,
1862 x: &[f32],
1863 layout: QueryLayout,
1864 allow_rescale: bool,
1865 allocator: B,
1866 scratch: ScopedAllocator<'_>,
1867 ) -> Result<QueryComputer<B>, QueryComputerError>
1868 {
1869 let dim = <Self as Quantizer<B>>::dim(self);
1870 let finish = |v: &[f32]| -> Result<QueryComputer<B>, QueryComputerError> {
1871 match layout {
1872 QueryLayout::SameAsData => {
1873 self.fused_query_computer::<AsData<$N>, Data<$N, _>, B>(
1874 v,
1875 Data::new_in(dim, allocator.clone())?,
1876 allocator,
1877 scratch,
1878 )
1879 },
1880 QueryLayout::FourBitTransposed => {
1881 Err(UnsupportedQueryLayout::new(
1882 layout,
1883 concat!($N, "-bit compression"),
1884 ).into())
1885 },
1886 QueryLayout::ScalarQuantized => {
1887 self.fused_query_computer::<AsQuery<$N, bits::Dense>, Query<$N, bits::Dense, _>, B>(
1888 v,
1889 Query::new_in(dim, allocator.clone())?,
1890 allocator,
1891 scratch,
1892 )
1893 },
1894 QueryLayout::FullPrecision => {
1895 self.fused_query_computer::<AsFull, FullQuery<_>, B>(
1896 v,
1897 FullQuery::empty(dim, allocator.clone())?,
1898 allocator,
1899 scratch,
1900 )
1901 },
1902 }
1903 };
1904
1905 let metric = <Self as Quantizer<B>>::metric(self);
1906 if allow_rescale && metric == SupportedMetric::InnerProduct {
1907 let mut copy = x.to_owned();
1908 self.quantizer.rescale(&mut copy);
1909 finish(©)
1910 } else {
1911 finish(x)
1912 }
1913 }
1914
1915 fn is_supported(&self, layout: QueryLayout) -> bool {
1916 Self::supports(layout)
1917 }
1918
1919 fn compress(
1920 &self,
1921 x: &[f32],
1922 mut into: OpaqueMut<'_>,
1923 scratch: ScopedAllocator<'_>,
1924 ) -> Result<(), CompressionError> {
1925 let dim = <Self as Quantizer<B>>::dim(self);
1926 let into = DataMut::<$N>::from_canonical_back_mut(into.inspect(), dim)
1927 .map_err(CompressionError::not_canonical)?;
1928
1929 self.quantizer.compress_into_with(x, into, scratch)
1930 .map_err(CompressionError::CompressionError)
1931 }
1932
1933 fn metric(&self) -> SupportedMetric {
1934 self.quantizer.metric()
1935 }
1936
1937 fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1938 let clone = (&*self).try_clone()?;
1939 poly!({ Quantizer<B> }, clone, allocator)
1940 }
1941
1942 #[cfg(feature = "flatbuffers")]
1943 fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1944 Impl::<$N, A>::serialize(self, allocator)
1945 }
1946 }
1947 };
1948 ($N:literal, $($Ns:literal),*) => {
1949 plan!($N);
1950 $(plan!($Ns);)*
1951 }
1952}
1953
1954plan!(2, 4, 8);
1955
1956#[cfg(feature = "flatbuffers")]
1961#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1962#[derive(Debug, Clone, Error)]
1963#[non_exhaustive]
1964pub enum DeserializationError {
1965 #[error("unhandled file identifier in flatbuffer")]
1966 InvalidIdentifier,
1967
1968 #[error("unsupported number of bits ({0})")]
1969 UnsupportedBitWidth(u32),
1970
1971 #[error(transparent)]
1972 InvalidQuantizer(#[from] super::quantizer::DeserializationError),
1973
1974 #[error(transparent)]
1975 InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
1976
1977 #[error(transparent)]
1978 AllocatorError(#[from] AllocatorError),
1979}
1980
1981#[cfg(feature = "flatbuffers")]
1987#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1988pub fn try_deserialize<O, A>(
1989 data: &[u8],
1990 alloc: A,
1991) -> Result<Poly<dyn Quantizer<O>, A>, DeserializationError>
1992where
1993 O: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1994 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1995{
1996 fn unpack_bits<'a, const NBITS: usize, O, A>(
2001 proto: fb::spherical::SphericalQuantizer<'_>,
2002 alloc: A,
2003 ) -> Result<Poly<dyn Quantizer<O> + 'a, A>, DeserializationError>
2004 where
2005 O: Allocator + Send + Sync + std::panic::UnwindSafe + 'static,
2006 A: Allocator + Send + Sync + 'a,
2007 Impl<NBITS, A>: Quantizer<O> + Constructible<A>,
2008 {
2009 let imp = match Poly::new_with(
2010 #[inline(never)]
2011 |alloc| -> Result<_, super::quantizer::DeserializationError> {
2012 let quantizer = SphericalQuantizer::try_unpack(alloc, proto)?;
2013 Ok(Impl::new(quantizer)?)
2014 },
2015 alloc,
2016 ) {
2017 Ok(imp) => imp,
2018 Err(CompoundError::Allocator(err)) => {
2019 return Err(err.into());
2020 }
2021 Err(CompoundError::Constructor(err)) => {
2022 return Err(err.into());
2023 }
2024 };
2025 Ok(poly!({ Quantizer<O> }, imp))
2026 }
2027
2028 if !fb::spherical::quantizer_buffer_has_identifier(data) {
2030 return Err(DeserializationError::InvalidIdentifier);
2031 }
2032
2033 let root = fb::spherical::root_as_quantizer(data)?;
2037 let nbits = root.nbits();
2038 let proto = root.quantizer();
2039
2040 match nbits {
2041 1 => unpack_bits::<1, _, _>(proto, alloc),
2042 2 => unpack_bits::<2, _, _>(proto, alloc),
2043 4 => unpack_bits::<4, _, _>(proto, alloc),
2044 8 => unpack_bits::<8, _, _>(proto, alloc),
2045 n => Err(DeserializationError::UnsupportedBitWidth(n)),
2046 }
2047}
2048
2049#[cfg(test)]
2054mod tests {
2055 use diskann_utils::views::{Matrix, MatrixView};
2056 use rand::{SeedableRng, rngs::StdRng};
2057
2058 use super::*;
2059 use crate::{
2060 algorithms::{TransformKind, transforms::TargetDim},
2061 alloc::{AlignedAllocator, GlobalAllocator, Poly},
2062 num::PowerOfTwo,
2063 spherical::PreScale,
2064 };
2065
2066 fn test_plan_1_bit(plan: &dyn Quantizer) {
2071 assert_eq!(
2072 plan.nbits(),
2073 1,
2074 "this test only applies to 1-bit quantization"
2075 );
2076
2077 for layout in QueryLayout::all() {
2079 match layout {
2080 QueryLayout::SameAsData
2081 | QueryLayout::FourBitTransposed
2082 | QueryLayout::FullPrecision => assert!(
2083 plan.is_supported(layout),
2084 "expected {} to be supported",
2085 layout
2086 ),
2087 QueryLayout::ScalarQuantized => assert!(
2088 !plan.is_supported(layout),
2089 "expected {} to not be supported",
2090 layout
2091 ),
2092 }
2093 }
2094 }
2095
2096 fn test_plan_n_bit(plan: &dyn Quantizer, nbits: usize) {
2097 assert_ne!(nbits, 1, "there is another test for 1-bit quantizers");
2098 assert_eq!(
2099 plan.nbits(),
2100 nbits,
2101 "this test only applies to 1-bit quantization"
2102 );
2103
2104 for layout in QueryLayout::all() {
2106 match layout {
2107 QueryLayout::SameAsData
2108 | QueryLayout::ScalarQuantized
2109 | QueryLayout::FullPrecision => assert!(
2110 plan.is_supported(layout),
2111 "expected {} to be supported",
2112 layout
2113 ),
2114 QueryLayout::FourBitTransposed => assert!(
2115 !plan.is_supported(layout),
2116 "expected {} to not be supported",
2117 layout
2118 ),
2119 }
2120 }
2121 }
2122
2123 #[inline(never)]
2124 fn test_plan(plan: &dyn Quantizer, nbits: usize, dataset: MatrixView<f32>) {
2125 if nbits == 1 {
2127 test_plan_1_bit(plan);
2128 } else {
2129 test_plan_n_bit(plan, nbits);
2130 }
2131
2132 assert_eq!(plan.full_dim(), dataset.ncols());
2134
2135 let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2137 let mut a = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2138 let mut b = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2139 let scoped_global = ScopedAllocator::global();
2140
2141 plan.compress(dataset.row(0), OpaqueMut::new(&mut a), scoped_global)
2142 .unwrap();
2143 plan.compress(dataset.row(1), OpaqueMut::new(&mut b), scoped_global)
2144 .unwrap();
2145
2146 let f = plan.distance_computer(GlobalAllocator).unwrap();
2147 let _: f32 = f
2148 .evaluate_similarity(Opaque::new(&a), Opaque::new(&b))
2149 .unwrap();
2150
2151 let test_errors = |f: &dyn DynDistanceComputer| {
2152 let err = f
2154 .evaluate(Opaque::new(&a[..a.len() - 1]), Opaque::new(&b))
2155 .unwrap_err();
2156 assert!(matches!(err, DistanceError::QueryReify(_)));
2157
2158 let err = f
2160 .evaluate(Opaque::new(&vec![0u8; a.len() + 1]), Opaque::new(&b))
2161 .unwrap_err();
2162 assert!(matches!(err, DistanceError::QueryReify(_)));
2163
2164 let err = f
2166 .evaluate(Opaque::new(&a), Opaque::new(&b[..b.len() - 1]))
2167 .unwrap_err();
2168 assert!(matches!(err, DistanceError::XReify(_)));
2169
2170 let err = f
2172 .evaluate(Opaque::new(&a), Opaque::new(&vec![0u8; b.len() + 1]))
2173 .unwrap_err();
2174 assert!(matches!(err, DistanceError::XReify(_)));
2175 };
2176
2177 test_errors(&*f.inner);
2178
2179 let f = plan.distance_computer_ref();
2180 let _: f32 = f.evaluate(Opaque::new(&a), Opaque::new(&b)).unwrap();
2181 test_errors(f);
2182
2183 for layout in QueryLayout::all() {
2185 if !plan.is_supported(layout) {
2186 let check_message = |msg: &str| {
2187 assert!(
2188 msg.contains(&(layout.to_string())),
2189 "error message ({}) should contain the layout \"{}\"",
2190 msg,
2191 layout
2192 );
2193 assert!(
2194 msg.contains(&format!("{}", nbits)),
2195 "error message ({}) should contain the number of bits \"{}\"",
2196 msg,
2197 nbits
2198 );
2199 };
2200
2201 {
2203 let err = plan
2204 .fused_query_computer(
2205 dataset.row(1),
2206 layout,
2207 false,
2208 GlobalAllocator,
2209 scoped_global,
2210 )
2211 .unwrap_err();
2212
2213 let msg = err.to_string();
2214 check_message(&msg);
2215 }
2216
2217 {
2219 let err = plan.query_buffer_description(layout).unwrap_err();
2220 let msg = err.to_string();
2221 check_message(&msg);
2222 }
2223
2224 {
2226 let buffer = &mut [];
2227 let err = plan
2228 .compress_query(
2229 dataset.row(1),
2230 layout,
2231 true,
2232 OpaqueMut::new(buffer),
2233 scoped_global,
2234 )
2235 .unwrap_err();
2236 let msg = err.to_string();
2237 check_message(&msg);
2238 }
2239
2240 {
2242 let err = plan.query_computer(layout, GlobalAllocator).unwrap_err();
2243 let msg = err.to_string();
2244 check_message(&msg);
2245 }
2246
2247 continue;
2248 }
2249
2250 let g = plan
2251 .fused_query_computer(
2252 dataset.row(1),
2253 layout,
2254 false,
2255 GlobalAllocator,
2256 scoped_global,
2257 )
2258 .unwrap();
2259 assert_eq!(
2260 g.layout(),
2261 layout,
2262 "the query computer should faithfully preserve the requested layout"
2263 );
2264
2265 let direct: f32 = g.evaluate_similarity(Opaque(&a)).unwrap();
2266
2267 {
2269 let err = g
2270 .evaluate_similarity(Opaque::new(&a[..a.len() - 1]))
2271 .unwrap_err();
2272 assert!(matches!(err, QueryDistanceError::XReify(_)));
2273
2274 let err = g
2275 .evaluate_similarity(Opaque::new(&vec![0u8; a.len() + 1]))
2276 .unwrap_err();
2277 assert!(matches!(err, QueryDistanceError::XReify(_)));
2278 }
2279
2280 let sizes = plan.query_buffer_description(layout).unwrap();
2281 let mut buf =
2282 Poly::broadcast(0u8, sizes.bytes(), AlignedAllocator::new(sizes.align())).unwrap();
2283
2284 plan.compress_query(
2285 dataset.row(1),
2286 layout,
2287 false,
2288 OpaqueMut::new(&mut buf),
2289 scoped_global,
2290 )
2291 .unwrap();
2292
2293 let standalone = plan.query_computer(layout, GlobalAllocator).unwrap();
2294
2295 assert_eq!(
2296 standalone.layout(),
2297 layout,
2298 "the standalone computer did not preserve the requested layout",
2299 );
2300
2301 let indirect: f32 = standalone
2302 .evaluate_similarity(Opaque(&buf), Opaque(&a))
2303 .unwrap();
2304
2305 assert_eq!(
2306 direct, indirect,
2307 "the two different query computation APIs did not return the same result"
2308 );
2309
2310 let too_small = &dataset.row(0)[..dataset.ncols() - 1];
2312 assert!(
2313 plan.fused_query_computer(too_small, layout, false, GlobalAllocator, scoped_global)
2314 .is_err()
2315 );
2316 }
2317
2318 {
2320 let mut too_small = vec![u8::default(); plan.bytes() - 1];
2321 assert!(
2322 plan.compress(dataset.row(0), OpaqueMut(&mut too_small), scoped_global)
2323 .is_err()
2324 );
2325
2326 let mut too_big = vec![u8::default(); plan.bytes() + 1];
2327 assert!(
2328 plan.compress(dataset.row(0), OpaqueMut(&mut too_big), scoped_global)
2329 .is_err()
2330 );
2331
2332 let mut just_right = vec![u8::default(); plan.bytes()];
2333 assert!(
2334 plan.compress(
2335 &dataset.row(0)[..dataset.ncols() - 1],
2336 OpaqueMut(&mut just_right),
2337 scoped_global
2338 )
2339 .is_err()
2340 );
2341 }
2342 }
2343
2344 fn make_impl<const NBITS: usize>(metric: SupportedMetric) -> (Impl<NBITS>, Matrix<f32>)
2345 where
2346 Impl<NBITS>: Constructible,
2347 {
2348 let data = test_dataset();
2349 let mut rng = StdRng::seed_from_u64(0x7d535118722ff197);
2350
2351 let quantizer = SphericalQuantizer::train(
2352 data.as_view(),
2353 TransformKind::PaddingHadamard {
2354 target_dim: TargetDim::Natural,
2355 },
2356 metric,
2357 PreScale::None,
2358 &mut rng,
2359 GlobalAllocator,
2360 )
2361 .unwrap();
2362
2363 (Impl::<NBITS>::new(quantizer).unwrap(), data)
2364 }
2365
2366 #[test]
2367 fn test_plan_1bit_l2() {
2368 let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2369 test_plan(&plan, 1, data.as_view());
2370 }
2371
2372 #[test]
2373 fn test_plan_1bit_ip() {
2374 let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2375 test_plan(&plan, 1, data.as_view());
2376 }
2377
2378 #[test]
2379 fn test_plan_1bit_cosine() {
2380 let (plan, data) = make_impl::<1>(SupportedMetric::Cosine);
2381 test_plan(&plan, 1, data.as_view());
2382 }
2383
2384 #[test]
2385 fn test_plan_2bit_l2() {
2386 let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2387 test_plan(&plan, 2, data.as_view());
2388 }
2389
2390 #[test]
2391 fn test_plan_2bit_ip() {
2392 let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2393 test_plan(&plan, 2, data.as_view());
2394 }
2395
2396 #[test]
2397 fn test_plan_2bit_cosine() {
2398 let (plan, data) = make_impl::<2>(SupportedMetric::Cosine);
2399 test_plan(&plan, 2, data.as_view());
2400 }
2401
2402 #[test]
2403 fn test_plan_4bit_l2() {
2404 let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2405 test_plan(&plan, 4, data.as_view());
2406 }
2407
2408 #[test]
2409 fn test_plan_4bit_ip() {
2410 let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2411 test_plan(&plan, 4, data.as_view());
2412 }
2413
2414 #[test]
2415 fn test_plan_4bit_cosine() {
2416 let (plan, data) = make_impl::<4>(SupportedMetric::Cosine);
2417 test_plan(&plan, 4, data.as_view());
2418 }
2419
2420 #[test]
2421 fn test_plan_8bit_l2() {
2422 let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2423 test_plan(&plan, 8, data.as_view());
2424 }
2425
2426 #[test]
2427 fn test_plan_8bit_ip() {
2428 let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2429 test_plan(&plan, 8, data.as_view());
2430 }
2431
2432 #[test]
2433 fn test_plan_8bit_cosine() {
2434 let (plan, data) = make_impl::<8>(SupportedMetric::Cosine);
2435 test_plan(&plan, 8, data.as_view());
2436 }
2437
2438 fn test_dataset() -> Matrix<f32> {
2439 let data = vec![
2440 0.28657,
2441 -0.0318168,
2442 0.0666847,
2443 0.0329265,
2444 -0.00829283,
2445 0.168735,
2446 -0.000846311,
2447 -0.360779, -0.0968938,
2449 0.161921,
2450 -0.0979579,
2451 0.102228,
2452 -0.259928,
2453 -0.139634,
2454 0.165384,
2455 -0.293443, 0.130205,
2457 0.265737,
2458 0.401816,
2459 -0.407552,
2460 0.13012,
2461 -0.0475244,
2462 0.511723,
2463 -0.4372, -0.0979126,
2465 0.135861,
2466 -0.0154144,
2467 -0.14047,
2468 -0.0250029,
2469 -0.190279,
2470 0.407283,
2471 -0.389184, -0.264153,
2473 0.0696822,
2474 -0.145585,
2475 0.370284,
2476 0.186825,
2477 -0.140736,
2478 0.274703,
2479 -0.334563, 0.247613,
2481 0.513165,
2482 -0.0845867,
2483 0.0532264,
2484 -0.00480601,
2485 -0.122408,
2486 0.47227,
2487 -0.268301, 0.103198,
2489 0.30756,
2490 -0.316293,
2491 -0.0686877,
2492 -0.330729,
2493 -0.461997,
2494 0.550857,
2495 -0.240851, 0.128258,
2497 0.786291,
2498 -0.0268103,
2499 0.111763,
2500 -0.308962,
2501 -0.17407,
2502 0.437154,
2503 -0.159879, 0.00374063,
2505 0.490301,
2506 0.0327826,
2507 -0.0340962,
2508 -0.118605,
2509 0.163879,
2510 0.2737,
2511 -0.299942, -0.284077,
2513 0.249377,
2514 -0.0307734,
2515 -0.0661631,
2516 0.233854,
2517 0.427987,
2518 0.614132,
2519 -0.288649, -0.109492,
2521 0.203939,
2522 -0.73956,
2523 -0.130748,
2524 0.22072,
2525 0.0647836,
2526 0.328726,
2527 -0.374602, -0.223114,
2529 0.0243489,
2530 0.109195,
2531 -0.416914,
2532 0.0201052,
2533 -0.0190542,
2534 0.947078,
2535 -0.333229, -0.165869,
2537 -0.00296729,
2538 -0.414378,
2539 0.231321,
2540 0.205365,
2541 0.161761,
2542 0.148608,
2543 -0.395063, -0.0498255,
2545 0.193279,
2546 -0.110946,
2547 -0.181174,
2548 -0.274578,
2549 -0.227511,
2550 0.190208,
2551 -0.256174, -0.188106,
2553 -0.0292958,
2554 0.0930939,
2555 0.0558456,
2556 0.257437,
2557 0.685481,
2558 0.307922,
2559 -0.320006, 0.250035,
2561 0.275942,
2562 -0.0856306,
2563 -0.352027,
2564 -0.103509,
2565 -0.00890859,
2566 0.276121,
2567 -0.324718, ];
2569
2570 Matrix::try_from(data.into(), 16, 8).unwrap()
2571 }
2572
2573 #[cfg(feature = "flatbuffers")]
2574 mod serialization {
2575 use std::sync::{
2576 Arc,
2577 atomic::{AtomicBool, Ordering},
2578 };
2579
2580 use super::*;
2581 use crate::alloc::{BumpAllocator, GlobalAllocator};
2582
2583 #[inline(never)]
2584 fn test_plan_serialization(
2585 quantizer: &dyn Quantizer,
2586 nbits: usize,
2587 dataset: MatrixView<f32>,
2588 ) {
2589 assert_eq!(quantizer.full_dim(), dataset.ncols());
2591 let scoped_global = ScopedAllocator::global();
2592
2593 let serialized = quantizer.serialize(GlobalAllocator).unwrap();
2594 let deserialized =
2595 try_deserialize::<GlobalAllocator, _>(&serialized, GlobalAllocator).unwrap();
2596
2597 assert_eq!(deserialized.nbits(), nbits);
2598 assert_eq!(deserialized.bytes(), quantizer.bytes());
2599 assert_eq!(deserialized.dim(), quantizer.dim());
2600 assert_eq!(deserialized.full_dim(), quantizer.full_dim());
2601 assert_eq!(deserialized.metric(), quantizer.metric());
2602
2603 for layout in QueryLayout::all() {
2604 assert_eq!(
2605 deserialized.is_supported(layout),
2606 quantizer.is_supported(layout)
2607 );
2608 }
2609
2610 let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2612 {
2613 let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2614 let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2615
2616 for row in dataset.row_iter() {
2617 quantizer
2618 .compress(row, OpaqueMut::new(&mut a), scoped_global)
2619 .unwrap();
2620 deserialized
2621 .compress(row, OpaqueMut::new(&mut b), scoped_global)
2622 .unwrap();
2623
2624 assert_eq!(a, b);
2626 }
2627 }
2628
2629 {
2631 let mut a0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2632 let mut a1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2633 let mut b0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2634 let mut b1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2635
2636 let q_computer = quantizer.distance_computer(GlobalAllocator).unwrap();
2637 let q_computer_ref = quantizer.distance_computer_ref();
2638 let d_computer = deserialized.distance_computer(GlobalAllocator).unwrap();
2639 let d_computer_ref = deserialized.distance_computer_ref();
2640
2641 for r0 in dataset.row_iter() {
2642 quantizer
2643 .compress(r0, OpaqueMut::new(&mut a0), scoped_global)
2644 .unwrap();
2645 deserialized
2646 .compress(r0, OpaqueMut::new(&mut b0), scoped_global)
2647 .unwrap();
2648 for r1 in dataset.row_iter() {
2649 quantizer
2650 .compress(r1, OpaqueMut::new(&mut a1), scoped_global)
2651 .unwrap();
2652 deserialized
2653 .compress(r1, OpaqueMut::new(&mut b1), scoped_global)
2654 .unwrap();
2655
2656 let a0 = Opaque::new(&a0);
2657 let a1 = Opaque::new(&a1);
2658
2659 let q_computer_dist = q_computer.evaluate_similarity(a0, a1).unwrap();
2660 let d_computer_dist = d_computer.evaluate_similarity(a0, a1).unwrap();
2661
2662 assert_eq!(q_computer_dist, d_computer_dist);
2663
2664 let q_computer_ref_dist = q_computer_ref.evaluate(a0, a1).unwrap();
2665
2666 assert_eq!(q_computer_dist, q_computer_ref_dist);
2667
2668 let d_computer_ref_dist = d_computer_ref.evaluate(a0, a1).unwrap();
2669 assert_eq!(d_computer_dist, d_computer_ref_dist);
2670 }
2671 }
2672 }
2673
2674 {
2676 let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2677 let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2678
2679 for layout in QueryLayout::all() {
2680 if !quantizer.is_supported(layout) {
2681 continue;
2682 }
2683
2684 for r in dataset.row_iter() {
2685 let q_computer = quantizer
2686 .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2687 .unwrap();
2688 let d_computer = deserialized
2689 .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2690 .unwrap();
2691
2692 for u in dataset.row_iter() {
2693 quantizer
2694 .compress(u, OpaqueMut::new(&mut a), scoped_global)
2695 .unwrap();
2696 deserialized
2697 .compress(u, OpaqueMut::new(&mut b), scoped_global)
2698 .unwrap();
2699
2700 assert_eq!(
2701 q_computer.evaluate_similarity(Opaque::new(&a)).unwrap(),
2702 d_computer.evaluate_similarity(Opaque::new(&b)).unwrap(),
2703 );
2704 }
2705 }
2706 }
2707 }
2708 }
2709
2710 #[derive(Debug, Clone)]
2712 struct FlakyAllocator {
2713 have_allocated: Arc<AtomicBool>,
2714 }
2715
2716 impl FlakyAllocator {
2717 fn new(have_allocated: Arc<AtomicBool>) -> Self {
2718 Self { have_allocated }
2719 }
2720 }
2721
2722 unsafe impl AllocatorCore for FlakyAllocator {
2724 fn allocate(
2725 &self,
2726 layout: std::alloc::Layout,
2727 ) -> Result<std::ptr::NonNull<[u8]>, AllocatorError> {
2728 if self.have_allocated.swap(true, Ordering::Relaxed) {
2729 Err(AllocatorError)
2730 } else {
2731 GlobalAllocator.allocate(layout)
2732 }
2733 }
2734
2735 unsafe fn deallocate(&self, ptr: std::ptr::NonNull<[u8]>, layout: std::alloc::Layout) {
2736 unsafe { GlobalAllocator.deallocate(ptr, layout) }
2738 }
2739 }
2740
2741 fn test_plan_panic_boundary<const NBITS: usize>(v: &Impl<NBITS>)
2742 where
2743 Impl<NBITS>: Quantizer,
2744 {
2745 let have_allocated = Arc::new(AtomicBool::new(false));
2747 let _: AllocatorError = v
2748 .serialize(FlakyAllocator::new(have_allocated.clone()))
2749 .unwrap_err();
2750 assert!(have_allocated.load(Ordering::Relaxed));
2751 }
2752
2753 #[test]
2754 fn test_plan_1bit_l2() {
2755 let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2756 test_plan_panic_boundary(&plan);
2757 test_plan_serialization(&plan, 1, data.as_view());
2758 }
2759
2760 #[test]
2761 fn test_plan_1bit_ip() {
2762 let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2763 test_plan_panic_boundary(&plan);
2764 test_plan_serialization(&plan, 1, data.as_view());
2765 }
2766
2767 #[test]
2768 fn test_plan_2bit_l2() {
2769 let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2770 test_plan_panic_boundary(&plan);
2771 test_plan_serialization(&plan, 2, data.as_view());
2772 }
2773
2774 #[test]
2775 fn test_plan_2bit_ip() {
2776 let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2777 test_plan_panic_boundary(&plan);
2778 test_plan_serialization(&plan, 2, data.as_view());
2779 }
2780
2781 #[test]
2782 fn test_plan_4bit_l2() {
2783 let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2784 test_plan_panic_boundary(&plan);
2785 test_plan_serialization(&plan, 4, data.as_view());
2786 }
2787
2788 #[test]
2789 fn test_plan_4bit_ip() {
2790 let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2791 test_plan_panic_boundary(&plan);
2792 test_plan_serialization(&plan, 4, data.as_view());
2793 }
2794
2795 #[test]
2796 fn test_plan_8bit_l2() {
2797 let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2798 test_plan_panic_boundary(&plan);
2799 test_plan_serialization(&plan, 8, data.as_view());
2800 }
2801
2802 #[test]
2803 fn test_plan_8bit_ip() {
2804 let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2805 test_plan_panic_boundary(&plan);
2806 test_plan_serialization(&plan, 8, data.as_view());
2807 }
2808
2809 #[test]
2810 fn test_allocation_order() {
2811 let (plan, _) = make_impl::<1>(SupportedMetric::SquaredL2);
2812 let buf = plan.serialize(GlobalAllocator).unwrap();
2813
2814 let allocator = BumpAllocator::new(8192, PowerOfTwo::new(64).unwrap()).unwrap();
2815 let deserialized =
2816 try_deserialize::<GlobalAllocator, _>(&buf, allocator.clone()).unwrap();
2817 assert_eq!(
2818 Poly::as_ptr(&deserialized).cast::<u8>(),
2819 allocator.as_ptr(),
2820 "expected the returned box to be allocated first",
2821 );
2822 }
2823 }
2824}