1use std::marker::PhantomData;
137
138use diskann_utils::{Reborrow, ReborrowMut};
139use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
140use diskann_wide::{
141 Architecture,
142 arch::{Scalar, Target1, Target2},
143};
144#[cfg(feature = "flatbuffers")]
145use flatbuffers::FlatBufferBuilder;
146use thiserror::Error;
147
148#[cfg(target_arch = "x86_64")]
149use diskann_wide::arch::x86_64::{V3, V4};
150
151use super::{
152 CompensatedCosine, CompensatedIP, CompensatedSquaredL2, Data, DataMut, DataRef, FullQuery,
153 FullQueryMut, FullQueryRef, Query, QueryMut, QueryRef, SphericalQuantizer, SupportedMetric,
154 quantizer,
155};
156use crate::{
157 AsFunctor, CompressIntoWith,
158 alloc::{
159 Allocator, AllocatorCore, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone,
160 },
161 bits::{self, Representation, Unsigned},
162 distances::{self, UnequalLengths},
163 error::InlineError,
164 meta,
165 num::PowerOfTwo,
166 poly,
167};
168#[cfg(feature = "flatbuffers")]
169use crate::{alloc::CompoundError, flatbuffers as fb};
170
171type Rf32 = distances::Result<f32>;
173
174#[derive(Debug, Clone)]
180pub struct QueryBufferDescription {
181 size: usize,
182 align: PowerOfTwo,
183}
184
185impl QueryBufferDescription {
186 pub fn new(size: usize, align: PowerOfTwo) -> Self {
188 Self { size, align }
189 }
190
191 pub fn bytes(&self) -> usize {
193 self.size
194 }
195
196 pub fn align(&self) -> PowerOfTwo {
198 self.align
199 }
200}
201
202pub trait Quantizer<A = GlobalAllocator>: Send + Sync
224where
225 A: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
226{
227 fn nbits(&self) -> usize;
229
230 fn bytes(&self) -> usize;
232
233 fn dim(&self) -> usize;
235
236 fn full_dim(&self) -> usize;
238
239 fn distance_computer(&self, allocator: A) -> Result<DistanceComputer<A>, AllocatorError>;
247
248 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer;
253
254 fn query_computer(
264 &self,
265 layout: QueryLayout,
266 allocator: A,
267 ) -> Result<DistanceComputer<A>, DistanceComputerError>;
268
269 fn query_buffer_description(
274 &self,
275 layout: QueryLayout,
276 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>;
277
278 fn compress_query(
285 &self,
286 x: &[f32],
287 layout: QueryLayout,
288 allow_rescale: bool,
289 buffer: OpaqueMut<'_>,
290 scratch: ScopedAllocator<'_>,
291 ) -> Result<(), QueryCompressionError>;
292
293 fn fused_query_computer(
300 &self,
301 x: &[f32],
302 layout: QueryLayout,
303 allow_rescale: bool,
304 allocator: A,
305 scratch: ScopedAllocator<'_>,
306 ) -> Result<QueryComputer<A>, QueryComputerError>;
307
308 fn is_supported(&self, layout: QueryLayout) -> bool;
310
311 fn compress(
318 &self,
319 x: &[f32],
320 into: OpaqueMut<'_>,
321 scratch: ScopedAllocator<'_>,
322 ) -> Result<(), CompressionError>;
323
324 fn metric(&self) -> SupportedMetric;
326
327 fn try_clone_into(&self, allocator: A) -> Result<Poly<dyn Quantizer<A>, A>, AllocatorError>;
329
330 crate::utils::features! {
331 #![feature = "flatbuffers"]
332 fn serialize(&self, allocator: A) -> Result<Poly<[u8], A>, AllocatorError>;
335 }
336}
337
338#[derive(Debug, Error)]
339#[error("Layout {layout} is not supported for {desc}")]
340pub struct UnsupportedQueryLayout {
341 layout: QueryLayout,
342 desc: &'static str,
343}
344
345impl UnsupportedQueryLayout {
346 fn new(layout: QueryLayout, desc: &'static str) -> Self {
347 Self { layout, desc }
348 }
349}
350
351#[derive(Debug, Error)]
352#[non_exhaustive]
353pub enum DistanceComputerError {
354 #[error(transparent)]
355 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
356 #[error(transparent)]
357 AllocatorError(#[from] AllocatorError),
358}
359
360#[derive(Debug, Error)]
361#[non_exhaustive]
362pub enum QueryCompressionError {
363 #[error(transparent)]
364 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
365 #[error(transparent)]
366 CompressionError(#[from] CompressionError),
367 #[error(transparent)]
368 NotCanonical(#[from] NotCanonical),
369 #[error(transparent)]
370 AllocatorError(#[from] AllocatorError),
371}
372
373#[derive(Debug, Error)]
374#[non_exhaustive]
375pub enum QueryComputerError {
376 #[error(transparent)]
377 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
378 #[error(transparent)]
379 CompressionError(#[from] CompressionError),
380 #[error(transparent)]
381 AllocatorError(#[from] AllocatorError),
382}
383
384#[derive(Debug, Error)]
386#[error("Error occured during query compression")]
387pub enum CompressionError {
388 NotCanonical(#[source] InlineError<16>),
390
391 CompressionError(#[source] quantizer::CompressionError),
395}
396
397impl CompressionError {
398 fn not_canonical<E>(error: E) -> Self
399 where
400 E: std::error::Error + Send + Sync + 'static,
401 {
402 Self::NotCanonical(InlineError::new(error))
403 }
404}
405
406#[derive(Debug, Error)]
407#[error("An opaque argument did not have the required alignment or length")]
408pub struct NotCanonical {
409 source: Box<dyn std::error::Error + Send + Sync>,
410}
411
412impl NotCanonical {
413 fn new<E>(err: E) -> Self
414 where
415 E: std::error::Error + Send + Sync + 'static,
416 {
417 Self {
418 source: Box::new(err),
419 }
420 }
421}
422
423#[derive(Debug, Clone, Copy)]
430#[repr(transparent)]
431pub struct Opaque<'a>(&'a [u8]);
432
433impl<'a> Opaque<'a> {
434 pub fn new(slice: &'a [u8]) -> Self {
436 Self(slice)
437 }
438
439 pub fn into_inner(self) -> &'a [u8] {
441 self.0
442 }
443}
444
445impl std::ops::Deref for Opaque<'_> {
446 type Target = [u8];
447 fn deref(&self) -> &[u8] {
448 self.0
449 }
450}
451impl<'short> Reborrow<'short> for Opaque<'_> {
452 type Target = Opaque<'short>;
453 fn reborrow(&'short self) -> Self::Target {
454 *self
455 }
456}
457
458#[derive(Debug)]
461#[repr(transparent)]
462pub struct OpaqueMut<'a>(&'a mut [u8]);
463
464impl<'a> OpaqueMut<'a> {
465 pub fn new(slice: &'a mut [u8]) -> Self {
467 Self(slice)
468 }
469
470 pub fn inspect(&mut self) -> &mut [u8] {
472 self.0
473 }
474}
475
476impl std::ops::Deref for OpaqueMut<'_> {
477 type Target = [u8];
478 fn deref(&self) -> &[u8] {
479 self.0
480 }
481}
482
483impl std::ops::DerefMut for OpaqueMut<'_> {
484 fn deref_mut(&mut self) -> &mut [u8] {
485 self.0
486 }
487}
488
489#[derive(Debug, Clone, Copy, PartialEq, Eq)]
495pub enum QueryLayout {
496 SameAsData,
500
501 FourBitTransposed,
503
504 ScalarQuantized,
507
508 FullPrecision,
510}
511
512impl QueryLayout {
513 #[cfg(test)]
514 fn all() -> [Self; 4] {
515 [
516 Self::SameAsData,
517 Self::FourBitTransposed,
518 Self::ScalarQuantized,
519 Self::FullPrecision,
520 ]
521 }
522}
523
524impl std::fmt::Display for QueryLayout {
525 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 <Self as std::fmt::Debug>::fmt(self, fmt)
527 }
528}
529
530trait ReportQueryLayout {
540 fn report_query_layout(&self) -> QueryLayout;
541}
542
543impl<T, M, L, R> ReportQueryLayout for Reify<T, M, L, R>
544where
545 T: ReportQueryLayout,
546{
547 fn report_query_layout(&self) -> QueryLayout {
548 self.inner.report_query_layout()
549 }
550}
551
552impl<D, Q> ReportQueryLayout for Curried<D, Q>
553where
554 Q: ReportQueryLayout,
555{
556 fn report_query_layout(&self) -> QueryLayout {
557 self.query.report_query_layout()
558 }
559}
560
561impl<const NBITS: usize, A> ReportQueryLayout for Data<NBITS, A>
562where
563 Unsigned: Representation<NBITS>,
564 A: AllocatorCore,
565{
566 fn report_query_layout(&self) -> QueryLayout {
567 QueryLayout::SameAsData
568 }
569}
570
571impl<const NBITS: usize, A> ReportQueryLayout for Query<NBITS, bits::Dense, A>
572where
573 Unsigned: Representation<NBITS>,
574 A: AllocatorCore,
575{
576 fn report_query_layout(&self) -> QueryLayout {
577 QueryLayout::ScalarQuantized
578 }
579}
580
581impl<A> ReportQueryLayout for Query<4, bits::BitTranspose, A>
582where
583 A: AllocatorCore,
584{
585 fn report_query_layout(&self) -> QueryLayout {
586 QueryLayout::FourBitTransposed
587 }
588}
589
590impl<A> ReportQueryLayout for FullQuery<A>
591where
592 A: AllocatorCore,
593{
594 fn report_query_layout(&self) -> QueryLayout {
595 QueryLayout::FullPrecision
596 }
597}
598
599trait FromOpaque: 'static + Send + Sync {
608 type Target<'a>;
609 type Error: std::error::Error + Send + Sync + 'static;
610
611 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error>;
612}
613
614#[derive(Debug, Default)]
616pub(super) struct AsFull;
617
618#[derive(Debug, Default)]
620pub(super) struct AsData<const NBITS: usize>;
621
622#[derive(Debug)]
624pub(super) struct AsQuery<const NBITS: usize, Perm = bits::Dense> {
625 _marker: PhantomData<Perm>,
626}
627
628impl<const NBITS: usize, Perm> Default for AsQuery<NBITS, Perm> {
630 fn default() -> Self {
631 Self {
632 _marker: PhantomData,
633 }
634 }
635}
636
637impl FromOpaque for AsFull {
638 type Target<'a> = FullQueryRef<'a>;
639 type Error = meta::slice::NotCanonical;
640
641 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
642 Self::Target::from_canonical(query.into_inner(), dim)
643 }
644}
645
646impl ReportQueryLayout for AsFull {
647 fn report_query_layout(&self) -> QueryLayout {
648 QueryLayout::FullPrecision
649 }
650}
651
652impl<const NBITS: usize> FromOpaque for AsData<NBITS>
653where
654 Unsigned: Representation<NBITS>,
655{
656 type Target<'a> = DataRef<'a, NBITS>;
657 type Error = meta::NotCanonical;
658
659 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
660 Self::Target::from_canonical_back(query.into_inner(), dim)
661 }
662}
663
664impl<const NBITS: usize> ReportQueryLayout for AsData<NBITS> {
665 fn report_query_layout(&self) -> QueryLayout {
666 QueryLayout::SameAsData
667 }
668}
669
670impl<const NBITS: usize, Perm> FromOpaque for AsQuery<NBITS, Perm>
671where
672 Unsigned: Representation<NBITS>,
673 Perm: bits::PermutationStrategy<NBITS> + Send + Sync + 'static,
674{
675 type Target<'a> = QueryRef<'a, NBITS, Perm>;
676 type Error = meta::NotCanonical;
677
678 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
679 Self::Target::from_canonical_back(query.into_inner(), dim)
680 }
681}
682
683impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::Dense> {
684 fn report_query_layout(&self) -> QueryLayout {
685 QueryLayout::ScalarQuantized
686 }
687}
688
689impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::BitTranspose> {
690 fn report_query_layout(&self) -> QueryLayout {
691 QueryLayout::FourBitTransposed
692 }
693}
694
695pub(super) struct Reify<T, M, L, R> {
701 inner: T,
702 dim: usize,
703 arch: M,
704 _markers: PhantomData<(L, R)>,
705}
706
707impl<T, M, L, R> Reify<T, M, L, R> {
708 pub(super) fn new(inner: T, dim: usize, arch: M) -> Self {
709 Self {
710 inner,
711 dim,
712 arch,
713 _markers: PhantomData,
714 }
715 }
716}
717
718impl<M, T, R> DynQueryComputer for Reify<T, M, (), R>
719where
720 M: Architecture,
721 R: FromOpaque,
722 T: ReportQueryLayout + Send + Sync,
723 for<'a> &'a T: Target1<M, Rf32, R::Target<'a>>,
724{
725 fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
726 self.arch.run2(
727 |this: &Self, x| {
728 let x = R::from_opaque(x, this.dim)
729 .map_err(|err| QueryDistanceError::XReify(InlineError::new(err)))?;
730 this.arch
731 .run1(&this.inner, x)
732 .map_err(QueryDistanceError::UnequalLengths)
733 },
734 self,
735 x,
736 )
737 }
738
739 fn layout(&self) -> QueryLayout {
740 self.inner.report_query_layout()
741 }
742}
743
744impl<T, M, Q, R> DynDistanceComputer for Reify<T, M, Q, R>
745where
746 M: Architecture,
747 Q: FromOpaque + Default + ReportQueryLayout,
748 R: FromOpaque,
749 T: for<'a> Target2<M, Rf32, Q::Target<'a>, R::Target<'a>> + Copy + Send + Sync,
750{
751 fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
752 self.arch.run3(
753 |this: &Self, query, x| {
754 let query = Q::from_opaque(query, this.dim)
755 .map_err(|err| DistanceError::QueryReify(InlineError::<24>::new(err)))?;
756
757 let x = R::from_opaque(x, this.dim)
758 .map_err(|err| DistanceError::XReify(InlineError::<16>::new(err)))?;
759
760 this.arch
761 .run2_inline(this.inner, query, x)
762 .map_err(DistanceError::UnequalLengths)
763 },
764 self,
765 query,
766 x,
767 )
768 }
769
770 fn layout(&self) -> QueryLayout {
771 Q::default().report_query_layout()
772 }
773}
774
775#[derive(Debug, Error)]
781pub enum QueryDistanceError {
782 #[error("trouble trying to reify the argument")]
784 XReify(#[source] InlineError<16>),
785
786 #[error("encountered while trying to compute distances")]
788 UnequalLengths(#[source] UnequalLengths),
789}
790
791pub trait DynQueryComputer: Send + Sync {
792 fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError>;
793 fn layout(&self) -> QueryLayout;
794}
795
796pub struct QueryComputer<A = GlobalAllocator>
805where
806 A: AllocatorCore,
807{
808 inner: Poly<dyn DynQueryComputer, A>,
809}
810
811impl<A> QueryComputer<A>
812where
813 A: AllocatorCore,
814{
815 fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
816 where
817 T: DynQueryComputer + 'static,
818 {
819 let inner = Poly::new(inner, allocator)?;
820 Ok(Self {
821 inner: poly!(DynQueryComputer, inner),
822 })
823 }
824
825 pub fn layout(&self) -> QueryLayout {
827 self.inner.layout()
828 }
829
830 pub fn into_inner(self) -> Poly<dyn DynQueryComputer, A> {
832 self.inner
833 }
834}
835
836impl<A> std::fmt::Debug for QueryComputer<A>
837where
838 A: AllocatorCore,
839{
840 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
841 write!(
842 f,
843 "dynamic fused query computer with layout \"{}\"",
844 self.layout()
845 )
846 }
847}
848
849impl<A> PreprocessedDistanceFunction<Opaque<'_>, Result<f32, QueryDistanceError>>
850 for QueryComputer<A>
851where
852 A: AllocatorCore,
853{
854 fn evaluate_similarity(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
855 self.inner.evaluate(x)
856 }
857}
858
859pub(super) struct Curried<D, Q> {
866 inner: D,
867 query: Q,
868}
869
870impl<D, Q> Curried<D, Q> {
871 pub(super) fn new(inner: D, query: Q) -> Self {
872 Self { inner, query }
873 }
874}
875
876impl<A, D, Q, T, R> Target1<A, R, T> for &Curried<D, Q>
877where
878 A: Architecture,
879 Q: for<'a> Reborrow<'a>,
880 D: for<'a> Target2<A, R, <Q as Reborrow<'a>>::Target, T> + Copy,
881{
882 fn run(self, arch: A, x: T) -> R {
883 self.inner.run(arch, self.query.reborrow(), x)
884 }
885}
886
887#[derive(Debug, Error)]
893pub enum DistanceError {
894 #[error("trouble trying to reify the left-hand argument")]
896 QueryReify(InlineError<24>),
897
898 #[error("trouble trying to reify the right-hand argument")]
900 XReify(InlineError<16>),
901
902 #[error("encountered while trying to compute distances")]
906 UnequalLengths(UnequalLengths),
907}
908
909pub trait DynDistanceComputer: Send + Sync {
910 fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError>;
911 fn layout(&self) -> QueryLayout;
912}
913
914pub struct DistanceComputer<A = GlobalAllocator>
925where
926 A: AllocatorCore,
927{
928 inner: Poly<dyn DynDistanceComputer, A>,
929}
930
931impl<A> DistanceComputer<A>
932where
933 A: AllocatorCore,
934{
935 pub(super) fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
936 where
937 T: DynDistanceComputer + 'static,
938 {
939 let inner = Poly::new(inner, allocator)?;
940 Ok(Self {
941 inner: poly!(DynDistanceComputer, inner),
942 })
943 }
944
945 pub fn layout(&self) -> QueryLayout {
947 self.inner.layout()
948 }
949
950 pub fn into_inner(self) -> Poly<dyn DynDistanceComputer, A> {
951 self.inner
952 }
953}
954
955impl<A> std::fmt::Debug for DistanceComputer<A>
956where
957 A: AllocatorCore,
958{
959 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
960 write!(
961 f,
962 "dynamic distance computer with layout \"{}\"",
963 self.layout()
964 )
965 }
966}
967
968impl<A> DistanceFunction<Opaque<'_>, Opaque<'_>, Result<f32, DistanceError>> for DistanceComputer<A>
969where
970 A: AllocatorCore,
971{
972 fn evaluate_similarity(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
973 self.inner.evaluate(query, x)
974 }
975}
976
977#[cfg(all(not(test), feature = "flatbuffers"))]
983const DEFAULT_SERIALIZED_BYTES: usize = 1024;
984
985#[cfg(all(test, feature = "flatbuffers"))]
987const DEFAULT_SERIALIZED_BYTES: usize = 1;
988
989pub struct Impl<const NBITS: usize, A = GlobalAllocator>
992where
993 A: Allocator,
994{
995 quantizer: SphericalQuantizer<A>,
996 distance: Poly<dyn DynDistanceComputer, A>,
997}
998
999pub trait Constructible<A = GlobalAllocator>
1002where
1003 A: Allocator,
1004{
1005 fn dispatch_distance(
1006 quantizer: &SphericalQuantizer<A>,
1007 ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError>;
1008}
1009
1010impl<const NBITS: usize, A: Allocator> Constructible<A> for Impl<NBITS, A>
1011where
1012 A: Allocator,
1013 AsData<NBITS>: FromOpaque,
1014 SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1015{
1016 fn dispatch_distance(
1017 quantizer: &SphericalQuantizer<A>,
1018 ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError> {
1019 diskann_wide::arch::dispatch2_no_features(
1020 ComputerDispatcher::<AsData<NBITS>, NBITS>::new(),
1021 quantizer,
1022 quantizer.allocator().clone(),
1023 )
1024 .map(|obj| obj.inner)
1025 }
1026}
1027
1028impl<const NBITS: usize, A> TryClone for Impl<NBITS, A>
1029where
1030 A: Allocator,
1031 AsData<NBITS>: FromOpaque,
1032 SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1033{
1034 fn try_clone(&self) -> Result<Self, AllocatorError> {
1035 Self::new(self.quantizer.try_clone()?)
1036 }
1037}
1038
1039impl<const NBITS: usize, A: Allocator> Impl<NBITS, A> {
1040 pub fn new(quantizer: SphericalQuantizer<A>) -> Result<Self, AllocatorError>
1042 where
1043 Self: Constructible<A>,
1044 {
1045 let distance = Self::dispatch_distance(&quantizer)?;
1046 Ok(Self {
1047 quantizer,
1048 distance,
1049 })
1050 }
1051
1052 pub fn quantizer(&self) -> &SphericalQuantizer<A> {
1054 &self.quantizer
1055 }
1056
1057 pub fn supports(layout: QueryLayout) -> bool {
1061 if const { NBITS == 1 } {
1062 [
1063 QueryLayout::SameAsData,
1064 QueryLayout::FourBitTransposed,
1065 QueryLayout::FullPrecision,
1066 ]
1067 .contains(&layout)
1068 } else {
1069 [
1070 QueryLayout::SameAsData,
1071 QueryLayout::ScalarQuantized,
1072 QueryLayout::FullPrecision,
1073 ]
1074 .contains(&layout)
1075 }
1076 }
1077
1078 fn query_computer<Q, B>(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError>
1081 where
1082 Q: FromOpaque,
1083 B: AllocatorCore,
1084 SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1085 {
1086 diskann_wide::arch::dispatch2_no_features(
1087 ComputerDispatcher::<Q, NBITS>::new(),
1088 &self.quantizer,
1089 allocator,
1090 )
1091 }
1092
1093 fn compress_query<'a, T>(
1094 &self,
1095 query: &'a [f32],
1096 storage: T,
1097 scratch: ScopedAllocator<'a>,
1098 ) -> Result<(), QueryCompressionError>
1099 where
1100 SphericalQuantizer<A>: CompressIntoWith<&'a [f32], T, ScopedAllocator<'a>, Error = quantizer::CompressionError>,
1101 {
1102 self.quantizer
1103 .compress_into_with(query, storage, scratch)
1104 .map_err(|err| CompressionError::CompressionError(err).into())
1105 }
1106
1107 fn fused_query_computer<Q, T, B>(
1110 &self,
1111 query: &[f32],
1112 mut storage: T,
1113 allocator: B,
1114 scratch: ScopedAllocator<'_>,
1115 ) -> Result<QueryComputer<B>, QueryComputerError>
1116 where
1117 Q: FromOpaque,
1118 T: for<'a> ReborrowMut<'a>
1119 + for<'a> Reborrow<'a, Target = Q::Target<'a>>
1120 + ReportQueryLayout
1121 + Send
1122 + Sync
1123 + 'static,
1124 B: AllocatorCore,
1125 SphericalQuantizer<A>: for<'a> CompressIntoWith<
1126 &'a [f32],
1127 <T as ReborrowMut<'a>>::Target,
1128 ScopedAllocator<'a>,
1129 Error = quantizer::CompressionError,
1130 >,
1131 SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1132 {
1133 if let Err(err) = self
1134 .quantizer
1135 .compress_into_with(query, storage.reborrow_mut(), scratch)
1136 {
1137 return Err(CompressionError::CompressionError(err).into());
1138 }
1139
1140 diskann_wide::arch::dispatch3_no_features(
1141 ComputerDispatcher::<Q, NBITS>::new(),
1142 &self.quantizer,
1143 storage,
1144 allocator,
1145 )
1146 .map_err(|e| e.into())
1147 }
1148
1149 #[cfg(feature = "flatbuffers")]
1150 fn serialize<B>(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError>
1151 where
1152 B: Allocator + std::panic::UnwindSafe,
1153 A: std::panic::RefUnwindSafe,
1154 {
1155 let mut buf = FlatBufferBuilder::new_in(Poly::broadcast(
1156 0u8,
1157 DEFAULT_SERIALIZED_BYTES,
1158 allocator.clone(),
1159 )?);
1160
1161 let quantizer = &self.quantizer;
1162
1163 let (root, mut buf) = match std::panic::catch_unwind(move || {
1164 let offset = quantizer.pack(&mut buf);
1165
1166 let root = fb::spherical::Quantizer::create(
1167 &mut buf,
1168 &fb::spherical::QuantizerArgs {
1169 quantizer: Some(offset),
1170 nbits: NBITS as u32,
1171 },
1172 );
1173 (root, buf)
1174 }) {
1175 Ok(ret) => ret,
1176 Err(err) => match err.downcast_ref::<String>() {
1177 Some(msg) => {
1178 if msg.contains("AllocatorError") {
1179 return Err(AllocatorError);
1180 } else {
1181 std::panic::resume_unwind(err);
1182 }
1183 }
1184 None => std::panic::resume_unwind(err),
1185 },
1186 };
1187
1188 fb::spherical::finish_quantizer_buffer(&mut buf, root);
1190 Poly::from_iter(buf.finished_data().iter().copied(), allocator)
1191 }
1192}
1193
1194trait BuildComputer<M, Q, const N: usize>
1211where
1212 M: Architecture,
1213 Q: FromOpaque,
1214{
1215 fn build_computer<A>(
1220 &self,
1221 arch: M,
1222 allocator: A,
1223 ) -> Result<DistanceComputer<A>, AllocatorError>
1224 where
1225 A: AllocatorCore;
1226
1227 fn build_fused_computer<R, A>(
1232 &self,
1233 arch: M,
1234 query: R,
1235 allocator: A,
1236 ) -> Result<QueryComputer<A>, AllocatorError>
1237 where
1238 R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1239 A: AllocatorCore;
1240}
1241
1242fn identity<T>(x: T) -> T {
1243 x
1244}
1245
1246macro_rules! dispatch_map {
1247 ($N:literal, $Q:ty, $arch:ty) => {
1248 dispatch_map!($N, $Q, $arch, identity);
1249 };
1250 ($N:literal, $Q:ty, $arch:ty, $op:ident) => {
1251 impl<A> BuildComputer<$arch, $Q, $N> for SphericalQuantizer<A>
1252 where
1253 A: Allocator,
1254 {
1255 fn build_computer<B>(
1256 &self,
1257 input_arch: $arch,
1258 allocator: B,
1259 ) -> Result<DistanceComputer<B>, AllocatorError>
1260 where
1261 B: AllocatorCore,
1262 {
1263 type D = AsData<$N>;
1264
1265 let arch = ($op)(input_arch);
1267 let dim = self.output_dim();
1268 match self.metric() {
1269 SupportedMetric::SquaredL2 => {
1270 let reify = Reify::<CompensatedSquaredL2, _, $Q, D>::new(
1271 self.as_functor(),
1272 dim,
1273 arch,
1274 );
1275 DistanceComputer::new(reify, allocator)
1276 }
1277 SupportedMetric::InnerProduct => {
1278 let reify =
1279 Reify::<CompensatedIP, _, $Q, D>::new(self.as_functor(), dim, arch);
1280 DistanceComputer::new(reify, allocator)
1281 }
1282 SupportedMetric::Cosine => {
1283 let reify =
1284 Reify::<CompensatedCosine, _, $Q, D>::new(self.as_functor(), dim, arch);
1285 DistanceComputer::new(reify, allocator)
1286 }
1287 }
1288 }
1289
1290 fn build_fused_computer<R, B>(
1291 &self,
1292 input_arch: $arch,
1293 query: R,
1294 allocator: B,
1295 ) -> Result<QueryComputer<B>, AllocatorError>
1296 where
1297 R: ReportQueryLayout
1298 + for<'a> Reborrow<'a, Target = <$Q as FromOpaque>::Target<'a>>
1299 + Send
1300 + Sync
1301 + 'static,
1302 B: AllocatorCore,
1303 {
1304 type D = AsData<$N>;
1305 let arch = ($op)(input_arch);
1306 let dim = self.output_dim();
1307 match self.metric() {
1308 SupportedMetric::SquaredL2 => {
1309 let computer: CompensatedSquaredL2 = self.as_functor();
1310 let curried = Curried::new(computer, query);
1311 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1312 Ok(QueryComputer::new(reify, allocator)?)
1313 }
1314 SupportedMetric::InnerProduct => {
1315 let computer: CompensatedIP = self.as_functor();
1316 let curried = Curried::new(computer, query);
1317 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1318 Ok(QueryComputer::new(reify, allocator)?)
1319 }
1320 SupportedMetric::Cosine => {
1321 let computer: CompensatedCosine = self.as_functor();
1322 let curried = Curried::new(computer, query);
1323 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1324 Ok(QueryComputer::new(reify, allocator)?)
1325 }
1326 }
1327 }
1328 }
1329 };
1330}
1331
1332dispatch_map!(1, AsFull, Scalar);
1333dispatch_map!(2, AsFull, Scalar);
1334dispatch_map!(4, AsFull, Scalar);
1335dispatch_map!(8, AsFull, Scalar);
1336
1337dispatch_map!(1, AsData<1>, Scalar);
1338dispatch_map!(2, AsData<2>, Scalar);
1339dispatch_map!(4, AsData<4>, Scalar);
1340dispatch_map!(8, AsData<8>, Scalar);
1341
1342dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Scalar);
1344dispatch_map!(2, AsQuery<2>, Scalar);
1345dispatch_map!(4, AsQuery<4>, Scalar);
1346dispatch_map!(8, AsQuery<8>, Scalar);
1347
1348cfg_if::cfg_if! {
1349 if #[cfg(target_arch = "x86_64")] {
1350 fn downcast_to_v3(arch: V4) -> V3 {
1351 arch.into()
1352 }
1353
1354 dispatch_map!(1, AsFull, V3);
1356 dispatch_map!(2, AsFull, V3);
1357 dispatch_map!(4, AsFull, V3);
1358 dispatch_map!(8, AsFull, V3);
1359
1360 dispatch_map!(1, AsData<1>, V3);
1361 dispatch_map!(2, AsData<2>, V3);
1362 dispatch_map!(4, AsData<4>, V3);
1363 dispatch_map!(8, AsData<8>, V3);
1364
1365 dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V3);
1366 dispatch_map!(2, AsQuery<2>, V3);
1367 dispatch_map!(4, AsQuery<4>, V3);
1368 dispatch_map!(8, AsQuery<8>, V3);
1369
1370 dispatch_map!(1, AsFull, V4, downcast_to_v3);
1372 dispatch_map!(2, AsFull, V4, downcast_to_v3);
1373 dispatch_map!(4, AsFull, V4, downcast_to_v3);
1374 dispatch_map!(8, AsFull, V4, downcast_to_v3);
1375
1376 dispatch_map!(1, AsData<1>, V4, downcast_to_v3);
1377 dispatch_map!(2, AsData<2>, V4); dispatch_map!(4, AsData<4>, V4, downcast_to_v3);
1379 dispatch_map!(8, AsData<8>, V4, downcast_to_v3);
1380
1381 dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3);
1382 dispatch_map!(2, AsQuery<2>, V4); dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3);
1384 dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3);
1385 }
1386}
1387
1388#[derive(Debug, Clone, Copy)]
1400struct ComputerDispatcher<Q, const N: usize> {
1401 _query_type: std::marker::PhantomData<Q>,
1402}
1403
1404impl<Q, const N: usize> ComputerDispatcher<Q, N> {
1405 fn new() -> Self {
1406 Self {
1407 _query_type: std::marker::PhantomData,
1408 }
1409 }
1410}
1411
1412impl<M, const N: usize, A, B, Q>
1413 diskann_wide::arch::Target2<
1414 M,
1415 Result<DistanceComputer<B>, AllocatorError>,
1416 &SphericalQuantizer<A>,
1417 B,
1418 > for ComputerDispatcher<Q, N>
1419where
1420 M: Architecture,
1421 A: Allocator,
1422 B: AllocatorCore,
1423 Q: FromOpaque,
1424 SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1425{
1426 fn run(
1427 self,
1428 arch: M,
1429 quantizer: &SphericalQuantizer<A>,
1430 allocator: B,
1431 ) -> Result<DistanceComputer<B>, AllocatorError> {
1432 quantizer.build_computer(arch, allocator)
1433 }
1434}
1435
1436impl<M, const N: usize, A, R, B, Q>
1437 diskann_wide::arch::Target3<
1438 M,
1439 Result<QueryComputer<B>, AllocatorError>,
1440 &SphericalQuantizer<A>,
1441 R,
1442 B,
1443 > for ComputerDispatcher<Q, N>
1444where
1445 M: Architecture,
1446 A: Allocator,
1447 B: AllocatorCore,
1448 Q: FromOpaque,
1449 R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1450 SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1451{
1452 fn run(
1453 self,
1454 arch: M,
1455 quantizer: &SphericalQuantizer<A>,
1456 query: R,
1457 allocator: B,
1458 ) -> Result<QueryComputer<B>, AllocatorError> {
1459 quantizer.build_fused_computer(arch, query, allocator)
1460 }
1461}
1462
1463#[cfg(target_arch = "x86_64")]
1464trait Dispatchable<Q, const N: usize>:
1465 BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>
1466where
1467 Q: FromOpaque,
1468{
1469}
1470
1471#[cfg(target_arch = "x86_64")]
1472impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1473where
1474 Q: FromOpaque,
1475 T: BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>,
1476{
1477}
1478
1479#[cfg(not(target_arch = "x86_64"))]
1480trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N>
1481where
1482 Q: FromOpaque,
1483{
1484}
1485
1486#[cfg(not(target_arch = "x86_64"))]
1487impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1488where
1489 Q: FromOpaque,
1490 T: BuildComputer<Scalar, Q, N>,
1491{
1492}
1493
1494impl<A, B> Quantizer<B> for Impl<1, A>
1499where
1500 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1501 B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1502{
1503 fn nbits(&self) -> usize {
1504 1
1505 }
1506
1507 fn dim(&self) -> usize {
1508 self.quantizer.output_dim()
1509 }
1510
1511 fn full_dim(&self) -> usize {
1512 self.quantizer.input_dim()
1513 }
1514
1515 fn bytes(&self) -> usize {
1516 DataRef::<1>::canonical_bytes(self.quantizer.output_dim())
1517 }
1518
1519 fn distance_computer(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError> {
1520 self.query_computer::<AsData<1>, _>(allocator)
1521 }
1522
1523 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1524 &*self.distance
1525 }
1526
1527 fn query_computer(
1528 &self,
1529 layout: QueryLayout,
1530 allocator: B,
1531 ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1532 match layout {
1533 QueryLayout::SameAsData => Ok(self.query_computer::<AsData<1>, _>(allocator)?),
1534 QueryLayout::FourBitTransposed => {
1535 Ok(self.query_computer::<AsQuery<4, bits::BitTranspose>, _>(allocator)?)
1536 }
1537 QueryLayout::ScalarQuantized => {
1538 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1539 }
1540 QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1541 }
1542 }
1543
1544 fn query_buffer_description(
1545 &self,
1546 layout: QueryLayout,
1547 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout> {
1548 let dim = <Self as Quantizer<B>>::dim(self);
1549 match layout {
1550 QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1551 DataRef::<1>::canonical_bytes(dim),
1552 PowerOfTwo::alignment_of::<u8>(),
1553 )),
1554 QueryLayout::FourBitTransposed => Ok(QueryBufferDescription::new(
1555 QueryRef::<4, bits::BitTranspose>::canonical_bytes(dim),
1556 PowerOfTwo::alignment_of::<u8>(),
1557 )),
1558 QueryLayout::ScalarQuantized => {
1559 Err(UnsupportedQueryLayout::new(layout, "1-bit compression"))
1560 }
1561 QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1562 FullQueryRef::canonical_bytes(dim),
1563 FullQueryRef::canonical_align(),
1564 )),
1565 }
1566 }
1567
1568 fn compress_query(
1569 &self,
1570 x: &[f32],
1571 layout: QueryLayout,
1572 allow_rescale: bool,
1573 mut buffer: OpaqueMut<'_>,
1574 scratch: ScopedAllocator<'_>,
1575 ) -> Result<(), QueryCompressionError> {
1576 let dim = <Self as Quantizer<B>>::dim(self);
1577 let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1578 match layout {
1579 QueryLayout::SameAsData => self.compress_query(
1580 v,
1581 DataMut::<1>::from_canonical_back_mut(&mut buffer, dim)
1582 .map_err(NotCanonical::new)?,
1583 scratch,
1584 ),
1585 QueryLayout::FourBitTransposed => self.compress_query(
1586 v,
1587 QueryMut::<4, bits::BitTranspose>::from_canonical_back_mut(&mut buffer, dim)
1588 .map_err(NotCanonical::new)?,
1589 scratch,
1590 ),
1591 QueryLayout::ScalarQuantized => {
1592 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1593 }
1594 QueryLayout::FullPrecision => self.compress_query(
1595 v,
1596 FullQueryMut::from_canonical_mut(&mut buffer, dim)
1597 .map_err(NotCanonical::new)?,
1598 scratch,
1599 ),
1600 }
1601 };
1602
1603 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1604 let mut copy = x.to_owned();
1605 self.quantizer.rescale(&mut copy);
1606 finish(©)
1607 } else {
1608 finish(x)
1609 }
1610 }
1611
1612 fn fused_query_computer(
1613 &self,
1614 x: &[f32],
1615 layout: QueryLayout,
1616 allow_rescale: bool,
1617 allocator: B,
1618 scratch: ScopedAllocator<'_>,
1619 ) -> Result<QueryComputer<B>, QueryComputerError> {
1620 let dim = <Self as Quantizer<B>>::dim(self);
1621 let finish = |v: &[f32], allocator: B| -> Result<QueryComputer<B>, QueryComputerError> {
1622 match layout {
1623 QueryLayout::SameAsData => self.fused_query_computer::<AsData<1>, Data<1, _>, _>(
1624 v,
1625 Data::new_in(dim, allocator.clone())?,
1626 allocator,
1627 scratch,
1628 ),
1629 QueryLayout::FourBitTransposed => self
1630 .fused_query_computer::<AsQuery<4, bits::BitTranspose>, Query<4, bits::BitTranspose, _>, _>(
1631 v,
1632 Query::new_in(dim, allocator.clone())?,
1633 allocator,
1634 scratch,
1635 ),
1636 QueryLayout::ScalarQuantized => {
1637 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1638 }
1639 QueryLayout::FullPrecision => self.fused_query_computer::<AsFull, FullQuery<_>, _>(
1640 v,
1641 FullQuery::empty(dim, allocator.clone())?,
1642 allocator,
1643 scratch,
1644 ),
1645 }
1646 };
1647
1648 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1649 let mut copy = x.to_owned();
1650 self.quantizer.rescale(&mut copy);
1651 finish(©, allocator)
1652 } else {
1653 finish(x, allocator)
1654 }
1655 }
1656
1657 fn is_supported(&self, layout: QueryLayout) -> bool {
1658 Self::supports(layout)
1659 }
1660
1661 fn compress(
1662 &self,
1663 x: &[f32],
1664 mut into: OpaqueMut<'_>,
1665 scratch: ScopedAllocator<'_>,
1666 ) -> Result<(), CompressionError> {
1667 let dim = <Self as Quantizer<B>>::dim(self);
1668 let into = DataMut::<1>::from_canonical_back_mut(into.inspect(), dim)
1669 .map_err(CompressionError::not_canonical)?;
1670 self.quantizer
1671 .compress_into_with(x, into, scratch)
1672 .map_err(CompressionError::CompressionError)
1673 }
1674
1675 fn metric(&self) -> SupportedMetric {
1676 self.quantizer.metric()
1677 }
1678
1679 fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1680 let clone = (*self).try_clone()?;
1681 poly!({ Quantizer<B> }, clone, allocator)
1682 }
1683
1684 #[cfg(feature = "flatbuffers")]
1685 fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1686 Impl::<1, A>::serialize(self, allocator)
1687 }
1688}
1689
1690macro_rules! plan {
1691 ($N:literal) => {
1692 impl<A, B> Quantizer<B> for Impl<$N, A>
1693 where
1694 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1695 B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1696 {
1697 fn nbits(&self) -> usize {
1698 $N
1699 }
1700
1701 fn dim(&self) -> usize {
1702 self.quantizer.output_dim()
1703 }
1704
1705 fn full_dim(&self) -> usize {
1706 self.quantizer.input_dim()
1707 }
1708
1709 fn bytes(&self) -> usize {
1710 DataRef::<$N>::canonical_bytes(<Self as Quantizer<B>>::dim(self))
1711 }
1712
1713 fn distance_computer(
1714 &self,
1715 allocator: B
1716 ) -> Result<DistanceComputer<B>, AllocatorError> {
1717 self.query_computer::<AsData<$N>, _>(allocator)
1718 }
1719
1720 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1721 &*self.distance
1722 }
1723
1724 fn query_computer(
1725 &self,
1726 layout: QueryLayout,
1727 allocator: B,
1728 ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1729 match layout {
1730 QueryLayout::SameAsData => Ok(self.query_computer::<AsData<$N>, _>(allocator)?)
1731 ,
1732 QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout::new(
1733 layout,
1734 concat!($N, "-bit compression"),
1735 ).into()),
1736 QueryLayout::ScalarQuantized => {
1737 Ok(self.query_computer::<AsQuery<$N, bits::Dense>, _>(allocator)?)
1738 },
1739 QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1740
1741 }
1742 }
1743
1744 fn query_buffer_description(
1745 &self,
1746 layout: QueryLayout
1747 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>
1748 {
1749 let dim = <Self as Quantizer<B>>::dim(self);
1750 match layout {
1751 QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1752 DataRef::<$N>::canonical_bytes(dim),
1753 PowerOfTwo::alignment_of::<u8>(),
1754 )),
1755 QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout {
1756 layout,
1757 desc: concat!($N, "-bit compression"),
1758 }),
1759 QueryLayout::ScalarQuantized => Ok(QueryBufferDescription::new(
1760 QueryRef::<$N, bits::Dense>::canonical_bytes(dim),
1761 PowerOfTwo::alignment_of::<u8>(),
1762 )),
1763 QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1764 FullQueryRef::canonical_bytes(dim),
1765 FullQueryRef::canonical_align(),
1766 )),
1767 }
1768 }
1769
1770 fn compress_query(
1771 &self,
1772 x: &[f32],
1773 layout: QueryLayout,
1774 allow_rescale: bool,
1775 mut buffer: OpaqueMut<'_>,
1776 scratch: ScopedAllocator<'_>,
1777 ) -> Result<(), QueryCompressionError> {
1778 let dim = <Self as Quantizer<B>>::dim(self);
1779 let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1780 match layout {
1781 QueryLayout::SameAsData => self.compress_query(
1782 v,
1783 DataMut::<$N>::from_canonical_back_mut(
1784 &mut buffer,
1785 dim,
1786 ).map_err(NotCanonical::new)?,
1787 scratch,
1788 ),
1789 QueryLayout::FourBitTransposed => {
1790 Err(UnsupportedQueryLayout::new(
1791 layout,
1792 concat!($N, "-bit compression"),
1793 ).into())
1794 },
1795 QueryLayout::ScalarQuantized => self.compress_query(
1796 v,
1797 QueryMut::<$N, bits::Dense>::from_canonical_back_mut(
1798 &mut buffer,
1799 dim,
1800 ).map_err(NotCanonical::new)?,
1801 scratch,
1802 ),
1803 QueryLayout::FullPrecision => self.compress_query(
1804 v,
1805 FullQueryMut::from_canonical_mut(
1806 &mut buffer,
1807 dim,
1808 ).map_err(NotCanonical::new)?,
1809 scratch,
1810 ),
1811 }
1812 };
1813
1814 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1815 let mut copy = x.to_owned();
1816 self.quantizer.rescale(&mut copy);
1817 finish(©)
1818 } else {
1819 finish(x)
1820 }
1821 }
1822
1823 fn fused_query_computer(
1824 &self,
1825 x: &[f32],
1826 layout: QueryLayout,
1827 allow_rescale: bool,
1828 allocator: B,
1829 scratch: ScopedAllocator<'_>,
1830 ) -> Result<QueryComputer<B>, QueryComputerError>
1831 {
1832 let dim = <Self as Quantizer<B>>::dim(self);
1833 let finish = |v: &[f32]| -> Result<QueryComputer<B>, QueryComputerError> {
1834 match layout {
1835 QueryLayout::SameAsData => {
1836 self.fused_query_computer::<AsData<$N>, Data<$N, _>, B>(
1837 v,
1838 Data::new_in(dim, allocator.clone())?,
1839 allocator,
1840 scratch,
1841 )
1842 },
1843 QueryLayout::FourBitTransposed => {
1844 Err(UnsupportedQueryLayout::new(
1845 layout,
1846 concat!($N, "-bit compression"),
1847 ).into())
1848 },
1849 QueryLayout::ScalarQuantized => {
1850 self.fused_query_computer::<AsQuery<$N, bits::Dense>, Query<$N, bits::Dense, _>, B>(
1851 v,
1852 Query::new_in(dim, allocator.clone())?,
1853 allocator,
1854 scratch,
1855 )
1856 },
1857 QueryLayout::FullPrecision => {
1858 self.fused_query_computer::<AsFull, FullQuery<_>, B>(
1859 v,
1860 FullQuery::empty(dim, allocator.clone())?,
1861 allocator,
1862 scratch,
1863 )
1864 },
1865 }
1866 };
1867
1868 let metric = <Self as Quantizer<B>>::metric(self);
1869 if allow_rescale && metric == SupportedMetric::InnerProduct {
1870 let mut copy = x.to_owned();
1871 self.quantizer.rescale(&mut copy);
1872 finish(©)
1873 } else {
1874 finish(x)
1875 }
1876 }
1877
1878 fn is_supported(&self, layout: QueryLayout) -> bool {
1879 Self::supports(layout)
1880 }
1881
1882 fn compress(
1883 &self,
1884 x: &[f32],
1885 mut into: OpaqueMut<'_>,
1886 scratch: ScopedAllocator<'_>,
1887 ) -> Result<(), CompressionError> {
1888 let dim = <Self as Quantizer<B>>::dim(self);
1889 let into = DataMut::<$N>::from_canonical_back_mut(into.inspect(), dim)
1890 .map_err(CompressionError::not_canonical)?;
1891
1892 self.quantizer.compress_into_with(x, into, scratch)
1893 .map_err(CompressionError::CompressionError)
1894 }
1895
1896 fn metric(&self) -> SupportedMetric {
1897 self.quantizer.metric()
1898 }
1899
1900 fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1901 let clone = (&*self).try_clone()?;
1902 poly!({ Quantizer<B> }, clone, allocator)
1903 }
1904
1905 #[cfg(feature = "flatbuffers")]
1906 fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1907 Impl::<$N, A>::serialize(self, allocator)
1908 }
1909 }
1910 };
1911 ($N:literal, $($Ns:literal),*) => {
1912 plan!($N);
1913 $(plan!($Ns);)*
1914 }
1915}
1916
1917plan!(2, 4, 8);
1918
1919#[cfg(feature = "flatbuffers")]
1924#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1925#[derive(Debug, Clone, Error)]
1926#[non_exhaustive]
1927pub enum DeserializationError {
1928 #[error("unhandled file identifier in flatbuffer")]
1929 InvalidIdentifier,
1930
1931 #[error("unsupported number of bits ({0})")]
1932 UnsupportedBitWidth(u32),
1933
1934 #[error(transparent)]
1935 InvalidQuantizer(#[from] super::quantizer::DeserializationError),
1936
1937 #[error(transparent)]
1938 InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
1939
1940 #[error(transparent)]
1941 AllocatorError(#[from] AllocatorError),
1942}
1943
1944#[cfg(feature = "flatbuffers")]
1950#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1951pub fn try_deserialize<O, A>(
1952 data: &[u8],
1953 alloc: A,
1954) -> Result<Poly<dyn Quantizer<O>, A>, DeserializationError>
1955where
1956 O: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1957 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1958{
1959 fn unpack_bits<'a, const NBITS: usize, O, A>(
1964 proto: fb::spherical::SphericalQuantizer<'_>,
1965 alloc: A,
1966 ) -> Result<Poly<dyn Quantizer<O> + 'a, A>, DeserializationError>
1967 where
1968 O: Allocator + Send + Sync + std::panic::UnwindSafe + 'static,
1969 A: Allocator + Send + Sync + 'a,
1970 Impl<NBITS, A>: Quantizer<O> + Constructible<A>,
1971 {
1972 let imp = match Poly::new_with(
1973 #[inline(never)]
1974 |alloc| -> Result<_, super::quantizer::DeserializationError> {
1975 let quantizer = SphericalQuantizer::try_unpack(alloc, proto)?;
1976 Ok(Impl::new(quantizer)?)
1977 },
1978 alloc,
1979 ) {
1980 Ok(imp) => imp,
1981 Err(CompoundError::Allocator(err)) => {
1982 return Err(err.into());
1983 }
1984 Err(CompoundError::Constructor(err)) => {
1985 return Err(err.into());
1986 }
1987 };
1988 Ok(poly!({ Quantizer<O> }, imp))
1989 }
1990
1991 if !fb::spherical::quantizer_buffer_has_identifier(data) {
1993 return Err(DeserializationError::InvalidIdentifier);
1994 }
1995
1996 let root = fb::spherical::root_as_quantizer(data)?;
2000 let nbits = root.nbits();
2001 let proto = root.quantizer();
2002
2003 match nbits {
2004 1 => unpack_bits::<1, _, _>(proto, alloc),
2005 2 => unpack_bits::<2, _, _>(proto, alloc),
2006 4 => unpack_bits::<4, _, _>(proto, alloc),
2007 8 => unpack_bits::<8, _, _>(proto, alloc),
2008 n => Err(DeserializationError::UnsupportedBitWidth(n)),
2009 }
2010}
2011
2012#[cfg(test)]
2017mod tests {
2018 use diskann_utils::views::{Matrix, MatrixView};
2019 use rand::{SeedableRng, rngs::StdRng};
2020
2021 use super::*;
2022 use crate::{
2023 algorithms::{TransformKind, transforms::TargetDim},
2024 alloc::{AlignedAllocator, GlobalAllocator, Poly},
2025 num::PowerOfTwo,
2026 spherical::PreScale,
2027 };
2028
2029 fn test_plan_1_bit(plan: &dyn Quantizer) {
2034 assert_eq!(
2035 plan.nbits(),
2036 1,
2037 "this test only applies to 1-bit quantization"
2038 );
2039
2040 for layout in QueryLayout::all() {
2042 match layout {
2043 QueryLayout::SameAsData
2044 | QueryLayout::FourBitTransposed
2045 | QueryLayout::FullPrecision => assert!(
2046 plan.is_supported(layout),
2047 "expected {} to be supported",
2048 layout
2049 ),
2050 QueryLayout::ScalarQuantized => assert!(
2051 !plan.is_supported(layout),
2052 "expected {} to not be supported",
2053 layout
2054 ),
2055 }
2056 }
2057 }
2058
2059 fn test_plan_n_bit(plan: &dyn Quantizer, nbits: usize) {
2060 assert_ne!(nbits, 1, "there is another test for 1-bit quantizers");
2061 assert_eq!(
2062 plan.nbits(),
2063 nbits,
2064 "this test only applies to 1-bit quantization"
2065 );
2066
2067 for layout in QueryLayout::all() {
2069 match layout {
2070 QueryLayout::SameAsData
2071 | QueryLayout::ScalarQuantized
2072 | QueryLayout::FullPrecision => assert!(
2073 plan.is_supported(layout),
2074 "expected {} to be supported",
2075 layout
2076 ),
2077 QueryLayout::FourBitTransposed => assert!(
2078 !plan.is_supported(layout),
2079 "expected {} to not be supported",
2080 layout
2081 ),
2082 }
2083 }
2084 }
2085
2086 #[inline(never)]
2087 fn test_plan(plan: &dyn Quantizer, nbits: usize, dataset: MatrixView<f32>) {
2088 if nbits == 1 {
2090 test_plan_1_bit(plan);
2091 } else {
2092 test_plan_n_bit(plan, nbits);
2093 }
2094
2095 assert_eq!(plan.full_dim(), dataset.ncols());
2097
2098 let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2100 let mut a = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2101 let mut b = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2102 let scoped_global = ScopedAllocator::global();
2103
2104 plan.compress(dataset.row(0), OpaqueMut::new(&mut a), scoped_global)
2105 .unwrap();
2106 plan.compress(dataset.row(1), OpaqueMut::new(&mut b), scoped_global)
2107 .unwrap();
2108
2109 let f = plan.distance_computer(GlobalAllocator).unwrap();
2110 let _: f32 = f
2111 .evaluate_similarity(Opaque::new(&a), Opaque::new(&b))
2112 .unwrap();
2113
2114 let test_errors = |f: &dyn DynDistanceComputer| {
2115 let err = f
2117 .evaluate(Opaque::new(&a[..a.len() - 1]), Opaque::new(&b))
2118 .unwrap_err();
2119 assert!(matches!(err, DistanceError::QueryReify(_)));
2120
2121 let err = f
2123 .evaluate(Opaque::new(&vec![0u8; a.len() + 1]), Opaque::new(&b))
2124 .unwrap_err();
2125 assert!(matches!(err, DistanceError::QueryReify(_)));
2126
2127 let err = f
2129 .evaluate(Opaque::new(&a), Opaque::new(&b[..b.len() - 1]))
2130 .unwrap_err();
2131 assert!(matches!(err, DistanceError::XReify(_)));
2132
2133 let err = f
2135 .evaluate(Opaque::new(&a), Opaque::new(&vec![0u8; b.len() + 1]))
2136 .unwrap_err();
2137 assert!(matches!(err, DistanceError::XReify(_)));
2138 };
2139
2140 test_errors(&*f.inner);
2141
2142 let f = plan.distance_computer_ref();
2143 let _: f32 = f.evaluate(Opaque::new(&a), Opaque::new(&b)).unwrap();
2144 test_errors(f);
2145
2146 for layout in QueryLayout::all() {
2148 if !plan.is_supported(layout) {
2149 let check_message = |msg: &str| {
2150 assert!(
2151 msg.contains(&(layout.to_string())),
2152 "error message ({}) should contain the layout \"{}\"",
2153 msg,
2154 layout
2155 );
2156 assert!(
2157 msg.contains(&format!("{}", nbits)),
2158 "error message ({}) should contain the number of bits \"{}\"",
2159 msg,
2160 nbits
2161 );
2162 };
2163
2164 {
2166 let err = plan
2167 .fused_query_computer(
2168 dataset.row(1),
2169 layout,
2170 false,
2171 GlobalAllocator,
2172 scoped_global,
2173 )
2174 .unwrap_err();
2175
2176 let msg = err.to_string();
2177 check_message(&msg);
2178 }
2179
2180 {
2182 let err = plan.query_buffer_description(layout).unwrap_err();
2183 let msg = err.to_string();
2184 check_message(&msg);
2185 }
2186
2187 {
2189 let buffer = &mut [];
2190 let err = plan
2191 .compress_query(
2192 dataset.row(1),
2193 layout,
2194 true,
2195 OpaqueMut::new(buffer),
2196 scoped_global,
2197 )
2198 .unwrap_err();
2199 let msg = err.to_string();
2200 check_message(&msg);
2201 }
2202
2203 {
2205 let err = plan.query_computer(layout, GlobalAllocator).unwrap_err();
2206 let msg = err.to_string();
2207 check_message(&msg);
2208 }
2209
2210 continue;
2211 }
2212
2213 let g = plan
2214 .fused_query_computer(
2215 dataset.row(1),
2216 layout,
2217 false,
2218 GlobalAllocator,
2219 scoped_global,
2220 )
2221 .unwrap();
2222 assert_eq!(
2223 g.layout(),
2224 layout,
2225 "the query computer should faithfully preserve the requested layout"
2226 );
2227
2228 let direct: f32 = g.evaluate_similarity(Opaque(&a)).unwrap();
2229
2230 {
2232 let err = g
2233 .evaluate_similarity(Opaque::new(&a[..a.len() - 1]))
2234 .unwrap_err();
2235 assert!(matches!(err, QueryDistanceError::XReify(_)));
2236
2237 let err = g
2238 .evaluate_similarity(Opaque::new(&vec![0u8; a.len() + 1]))
2239 .unwrap_err();
2240 assert!(matches!(err, QueryDistanceError::XReify(_)));
2241 }
2242
2243 let sizes = plan.query_buffer_description(layout).unwrap();
2244 let mut buf =
2245 Poly::broadcast(0u8, sizes.bytes(), AlignedAllocator::new(sizes.align())).unwrap();
2246
2247 plan.compress_query(
2248 dataset.row(1),
2249 layout,
2250 false,
2251 OpaqueMut::new(&mut buf),
2252 scoped_global,
2253 )
2254 .unwrap();
2255
2256 let standalone = plan.query_computer(layout, GlobalAllocator).unwrap();
2257
2258 assert_eq!(
2259 standalone.layout(),
2260 layout,
2261 "the standalone computer did not preserve the requested layout",
2262 );
2263
2264 let indirect: f32 = standalone
2265 .evaluate_similarity(Opaque(&buf), Opaque(&a))
2266 .unwrap();
2267
2268 assert_eq!(
2269 direct, indirect,
2270 "the two different query computation APIs did not return the same result"
2271 );
2272
2273 let too_small = &dataset.row(0)[..dataset.ncols() - 1];
2275 assert!(
2276 plan.fused_query_computer(too_small, layout, false, GlobalAllocator, scoped_global)
2277 .is_err()
2278 );
2279 }
2280
2281 {
2283 let mut too_small = vec![u8::default(); plan.bytes() - 1];
2284 assert!(
2285 plan.compress(dataset.row(0), OpaqueMut(&mut too_small), scoped_global)
2286 .is_err()
2287 );
2288
2289 let mut too_big = vec![u8::default(); plan.bytes() + 1];
2290 assert!(
2291 plan.compress(dataset.row(0), OpaqueMut(&mut too_big), scoped_global)
2292 .is_err()
2293 );
2294
2295 let mut just_right = vec![u8::default(); plan.bytes()];
2296 assert!(
2297 plan.compress(
2298 &dataset.row(0)[..dataset.ncols() - 1],
2299 OpaqueMut(&mut just_right),
2300 scoped_global
2301 )
2302 .is_err()
2303 );
2304 }
2305 }
2306
2307 fn make_impl<const NBITS: usize>(metric: SupportedMetric) -> (Impl<NBITS>, Matrix<f32>)
2308 where
2309 Impl<NBITS>: Constructible,
2310 {
2311 let data = test_dataset();
2312 let mut rng = StdRng::seed_from_u64(0x7d535118722ff197);
2313
2314 let quantizer = SphericalQuantizer::train(
2315 data.as_view(),
2316 TransformKind::PaddingHadamard {
2317 target_dim: TargetDim::Natural,
2318 },
2319 metric,
2320 PreScale::None,
2321 &mut rng,
2322 GlobalAllocator,
2323 )
2324 .unwrap();
2325
2326 (Impl::<NBITS>::new(quantizer).unwrap(), data)
2327 }
2328
2329 #[test]
2330 fn test_plan_1bit_l2() {
2331 let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2332 test_plan(&plan, 1, data.as_view());
2333 }
2334
2335 #[test]
2336 fn test_plan_1bit_ip() {
2337 let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2338 test_plan(&plan, 1, data.as_view());
2339 }
2340
2341 #[test]
2342 fn test_plan_1bit_cosine() {
2343 let (plan, data) = make_impl::<1>(SupportedMetric::Cosine);
2344 test_plan(&plan, 1, data.as_view());
2345 }
2346
2347 #[test]
2348 fn test_plan_2bit_l2() {
2349 let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2350 test_plan(&plan, 2, data.as_view());
2351 }
2352
2353 #[test]
2354 fn test_plan_2bit_ip() {
2355 let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2356 test_plan(&plan, 2, data.as_view());
2357 }
2358
2359 #[test]
2360 fn test_plan_2bit_cosine() {
2361 let (plan, data) = make_impl::<2>(SupportedMetric::Cosine);
2362 test_plan(&plan, 2, data.as_view());
2363 }
2364
2365 #[test]
2366 fn test_plan_4bit_l2() {
2367 let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2368 test_plan(&plan, 4, data.as_view());
2369 }
2370
2371 #[test]
2372 fn test_plan_4bit_ip() {
2373 let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2374 test_plan(&plan, 4, data.as_view());
2375 }
2376
2377 #[test]
2378 fn test_plan_4bit_cosine() {
2379 let (plan, data) = make_impl::<4>(SupportedMetric::Cosine);
2380 test_plan(&plan, 4, data.as_view());
2381 }
2382
2383 #[test]
2384 fn test_plan_8bit_l2() {
2385 let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2386 test_plan(&plan, 8, data.as_view());
2387 }
2388
2389 #[test]
2390 fn test_plan_8bit_ip() {
2391 let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2392 test_plan(&plan, 8, data.as_view());
2393 }
2394
2395 #[test]
2396 fn test_plan_8bit_cosine() {
2397 let (plan, data) = make_impl::<8>(SupportedMetric::Cosine);
2398 test_plan(&plan, 8, data.as_view());
2399 }
2400
2401 fn test_dataset() -> Matrix<f32> {
2402 let data = vec![
2403 0.28657,
2404 -0.0318168,
2405 0.0666847,
2406 0.0329265,
2407 -0.00829283,
2408 0.168735,
2409 -0.000846311,
2410 -0.360779, -0.0968938,
2412 0.161921,
2413 -0.0979579,
2414 0.102228,
2415 -0.259928,
2416 -0.139634,
2417 0.165384,
2418 -0.293443, 0.130205,
2420 0.265737,
2421 0.401816,
2422 -0.407552,
2423 0.13012,
2424 -0.0475244,
2425 0.511723,
2426 -0.4372, -0.0979126,
2428 0.135861,
2429 -0.0154144,
2430 -0.14047,
2431 -0.0250029,
2432 -0.190279,
2433 0.407283,
2434 -0.389184, -0.264153,
2436 0.0696822,
2437 -0.145585,
2438 0.370284,
2439 0.186825,
2440 -0.140736,
2441 0.274703,
2442 -0.334563, 0.247613,
2444 0.513165,
2445 -0.0845867,
2446 0.0532264,
2447 -0.00480601,
2448 -0.122408,
2449 0.47227,
2450 -0.268301, 0.103198,
2452 0.30756,
2453 -0.316293,
2454 -0.0686877,
2455 -0.330729,
2456 -0.461997,
2457 0.550857,
2458 -0.240851, 0.128258,
2460 0.786291,
2461 -0.0268103,
2462 0.111763,
2463 -0.308962,
2464 -0.17407,
2465 0.437154,
2466 -0.159879, 0.00374063,
2468 0.490301,
2469 0.0327826,
2470 -0.0340962,
2471 -0.118605,
2472 0.163879,
2473 0.2737,
2474 -0.299942, -0.284077,
2476 0.249377,
2477 -0.0307734,
2478 -0.0661631,
2479 0.233854,
2480 0.427987,
2481 0.614132,
2482 -0.288649, -0.109492,
2484 0.203939,
2485 -0.73956,
2486 -0.130748,
2487 0.22072,
2488 0.0647836,
2489 0.328726,
2490 -0.374602, -0.223114,
2492 0.0243489,
2493 0.109195,
2494 -0.416914,
2495 0.0201052,
2496 -0.0190542,
2497 0.947078,
2498 -0.333229, -0.165869,
2500 -0.00296729,
2501 -0.414378,
2502 0.231321,
2503 0.205365,
2504 0.161761,
2505 0.148608,
2506 -0.395063, -0.0498255,
2508 0.193279,
2509 -0.110946,
2510 -0.181174,
2511 -0.274578,
2512 -0.227511,
2513 0.190208,
2514 -0.256174, -0.188106,
2516 -0.0292958,
2517 0.0930939,
2518 0.0558456,
2519 0.257437,
2520 0.685481,
2521 0.307922,
2522 -0.320006, 0.250035,
2524 0.275942,
2525 -0.0856306,
2526 -0.352027,
2527 -0.103509,
2528 -0.00890859,
2529 0.276121,
2530 -0.324718, ];
2532
2533 Matrix::try_from(data.into(), 16, 8).unwrap()
2534 }
2535
2536 #[cfg(feature = "flatbuffers")]
2537 mod serialization {
2538 use std::sync::{
2539 Arc,
2540 atomic::{AtomicBool, Ordering},
2541 };
2542
2543 use super::*;
2544 use crate::alloc::{BumpAllocator, GlobalAllocator};
2545
2546 #[inline(never)]
2547 fn test_plan_serialization(
2548 quantizer: &dyn Quantizer,
2549 nbits: usize,
2550 dataset: MatrixView<f32>,
2551 ) {
2552 assert_eq!(quantizer.full_dim(), dataset.ncols());
2554 let scoped_global = ScopedAllocator::global();
2555
2556 let serialized = quantizer.serialize(GlobalAllocator).unwrap();
2557 let deserialized =
2558 try_deserialize::<GlobalAllocator, _>(&serialized, GlobalAllocator).unwrap();
2559
2560 assert_eq!(deserialized.nbits(), nbits);
2561 assert_eq!(deserialized.bytes(), quantizer.bytes());
2562 assert_eq!(deserialized.dim(), quantizer.dim());
2563 assert_eq!(deserialized.full_dim(), quantizer.full_dim());
2564 assert_eq!(deserialized.metric(), quantizer.metric());
2565
2566 for layout in QueryLayout::all() {
2567 assert_eq!(
2568 deserialized.is_supported(layout),
2569 quantizer.is_supported(layout)
2570 );
2571 }
2572
2573 let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2575 {
2576 let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2577 let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2578
2579 for row in dataset.row_iter() {
2580 quantizer
2581 .compress(row, OpaqueMut::new(&mut a), scoped_global)
2582 .unwrap();
2583 deserialized
2584 .compress(row, OpaqueMut::new(&mut b), scoped_global)
2585 .unwrap();
2586
2587 assert_eq!(a, b);
2589 }
2590 }
2591
2592 {
2594 let mut a0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2595 let mut a1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2596 let mut b0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2597 let mut b1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2598
2599 let q_computer = quantizer.distance_computer(GlobalAllocator).unwrap();
2600 let q_computer_ref = quantizer.distance_computer_ref();
2601 let d_computer = deserialized.distance_computer(GlobalAllocator).unwrap();
2602 let d_computer_ref = deserialized.distance_computer_ref();
2603
2604 for r0 in dataset.row_iter() {
2605 quantizer
2606 .compress(r0, OpaqueMut::new(&mut a0), scoped_global)
2607 .unwrap();
2608 deserialized
2609 .compress(r0, OpaqueMut::new(&mut b0), scoped_global)
2610 .unwrap();
2611 for r1 in dataset.row_iter() {
2612 quantizer
2613 .compress(r1, OpaqueMut::new(&mut a1), scoped_global)
2614 .unwrap();
2615 deserialized
2616 .compress(r1, OpaqueMut::new(&mut b1), scoped_global)
2617 .unwrap();
2618
2619 let a0 = Opaque::new(&a0);
2620 let a1 = Opaque::new(&a1);
2621
2622 let q_computer_dist = q_computer.evaluate_similarity(a0, a1).unwrap();
2623 let d_computer_dist = d_computer.evaluate_similarity(a0, a1).unwrap();
2624
2625 assert_eq!(q_computer_dist, d_computer_dist);
2626
2627 let q_computer_ref_dist = q_computer_ref.evaluate(a0, a1).unwrap();
2628
2629 assert_eq!(q_computer_dist, q_computer_ref_dist);
2630
2631 let d_computer_ref_dist = d_computer_ref.evaluate(a0, a1).unwrap();
2632 assert_eq!(d_computer_dist, d_computer_ref_dist);
2633 }
2634 }
2635 }
2636
2637 {
2639 let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2640 let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2641
2642 for layout in QueryLayout::all() {
2643 if !quantizer.is_supported(layout) {
2644 continue;
2645 }
2646
2647 for r in dataset.row_iter() {
2648 let q_computer = quantizer
2649 .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2650 .unwrap();
2651 let d_computer = deserialized
2652 .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2653 .unwrap();
2654
2655 for u in dataset.row_iter() {
2656 quantizer
2657 .compress(u, OpaqueMut::new(&mut a), scoped_global)
2658 .unwrap();
2659 deserialized
2660 .compress(u, OpaqueMut::new(&mut b), scoped_global)
2661 .unwrap();
2662
2663 assert_eq!(
2664 q_computer.evaluate_similarity(Opaque::new(&a)).unwrap(),
2665 d_computer.evaluate_similarity(Opaque::new(&b)).unwrap(),
2666 );
2667 }
2668 }
2669 }
2670 }
2671 }
2672
2673 #[derive(Debug, Clone)]
2675 struct FlakyAllocator {
2676 have_allocated: Arc<AtomicBool>,
2677 }
2678
2679 impl FlakyAllocator {
2680 fn new(have_allocated: Arc<AtomicBool>) -> Self {
2681 Self { have_allocated }
2682 }
2683 }
2684
2685 unsafe impl AllocatorCore for FlakyAllocator {
2687 fn allocate(
2688 &self,
2689 layout: std::alloc::Layout,
2690 ) -> Result<std::ptr::NonNull<[u8]>, AllocatorError> {
2691 if self.have_allocated.swap(true, Ordering::Relaxed) {
2692 Err(AllocatorError)
2693 } else {
2694 GlobalAllocator.allocate(layout)
2695 }
2696 }
2697
2698 unsafe fn deallocate(&self, ptr: std::ptr::NonNull<[u8]>, layout: std::alloc::Layout) {
2699 unsafe { GlobalAllocator.deallocate(ptr, layout) }
2701 }
2702 }
2703
2704 fn test_plan_panic_boundary<const NBITS: usize>(v: &Impl<NBITS>)
2705 where
2706 Impl<NBITS>: Quantizer,
2707 {
2708 let have_allocated = Arc::new(AtomicBool::new(false));
2710 let _: AllocatorError = v
2711 .serialize(FlakyAllocator::new(have_allocated.clone()))
2712 .unwrap_err();
2713 assert!(have_allocated.load(Ordering::Relaxed));
2714 }
2715
2716 #[test]
2717 fn test_plan_1bit_l2() {
2718 let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2719 test_plan_panic_boundary(&plan);
2720 test_plan_serialization(&plan, 1, data.as_view());
2721 }
2722
2723 #[test]
2724 fn test_plan_1bit_ip() {
2725 let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2726 test_plan_panic_boundary(&plan);
2727 test_plan_serialization(&plan, 1, data.as_view());
2728 }
2729
2730 #[test]
2731 fn test_plan_2bit_l2() {
2732 let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2733 test_plan_panic_boundary(&plan);
2734 test_plan_serialization(&plan, 2, data.as_view());
2735 }
2736
2737 #[test]
2738 fn test_plan_2bit_ip() {
2739 let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2740 test_plan_panic_boundary(&plan);
2741 test_plan_serialization(&plan, 2, data.as_view());
2742 }
2743
2744 #[test]
2745 fn test_plan_4bit_l2() {
2746 let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2747 test_plan_panic_boundary(&plan);
2748 test_plan_serialization(&plan, 4, data.as_view());
2749 }
2750
2751 #[test]
2752 fn test_plan_4bit_ip() {
2753 let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2754 test_plan_panic_boundary(&plan);
2755 test_plan_serialization(&plan, 4, data.as_view());
2756 }
2757
2758 #[test]
2759 fn test_plan_8bit_l2() {
2760 let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2761 test_plan_panic_boundary(&plan);
2762 test_plan_serialization(&plan, 8, data.as_view());
2763 }
2764
2765 #[test]
2766 fn test_plan_8bit_ip() {
2767 let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2768 test_plan_panic_boundary(&plan);
2769 test_plan_serialization(&plan, 8, data.as_view());
2770 }
2771
2772 #[test]
2773 fn test_allocation_order() {
2774 let (plan, _) = make_impl::<1>(SupportedMetric::SquaredL2);
2775 let buf = plan.serialize(GlobalAllocator).unwrap();
2776
2777 let allocator = BumpAllocator::new(8192, PowerOfTwo::new(64).unwrap()).unwrap();
2778 let deserialized =
2779 try_deserialize::<GlobalAllocator, _>(&buf, allocator.clone()).unwrap();
2780 assert_eq!(
2781 Poly::as_ptr(&deserialized).cast::<u8>(),
2782 allocator.as_ptr(),
2783 "expected the returned box to be allocated first",
2784 );
2785 }
2786 }
2787}