1use std::marker::PhantomData;
137
138use diskann_utils::{Reborrow, ReborrowMut};
139use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
140use diskann_wide::{
141 arch::{Scalar, Target1, Target2},
142 Architecture,
143};
144#[cfg(feature = "flatbuffers")]
145use flatbuffers::FlatBufferBuilder;
146use thiserror::Error;
147
148#[cfg(target_arch = "x86_64")]
149use diskann_wide::arch::x86_64::{V3, V4};
150
151use super::{
152 quantizer, CompensatedCosine, CompensatedIP, CompensatedSquaredL2, Data, DataMut, DataRef,
153 FullQuery, FullQueryMut, FullQueryRef, Query, QueryMut, QueryRef, SphericalQuantizer,
154 SupportedMetric,
155};
156#[cfg(feature = "flatbuffers")]
157use crate::{alloc::CompoundError, flatbuffers as fb};
158use crate::{
159 alloc::{
160 Allocator, AllocatorCore, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone,
161 },
162 bits::{self, Representation, Unsigned},
163 distances::{self, UnequalLengths},
164 error::InlineError,
165 meta,
166 num::PowerOfTwo,
167 poly, AsFunctor, CompressIntoWith,
168};
169
170type Rf32 = distances::Result<f32>;
172
173#[derive(Debug, Clone)]
179pub struct QueryBufferDescription {
180 size: usize,
181 align: PowerOfTwo,
182}
183
184impl QueryBufferDescription {
185 pub fn new(size: usize, align: PowerOfTwo) -> Self {
187 Self { size, align }
188 }
189
190 pub fn bytes(&self) -> usize {
192 self.size
193 }
194
195 pub fn align(&self) -> PowerOfTwo {
197 self.align
198 }
199}
200
201pub trait Quantizer<A = GlobalAllocator>: Send + Sync
223where
224 A: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
225{
226 fn nbits(&self) -> usize;
228
229 fn bytes(&self) -> usize;
231
232 fn dim(&self) -> usize;
234
235 fn full_dim(&self) -> usize;
237
238 fn distance_computer(&self, allocator: A) -> Result<DistanceComputer<A>, AllocatorError>;
246
247 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer;
252
253 fn query_computer(
263 &self,
264 layout: QueryLayout,
265 allocator: A,
266 ) -> Result<DistanceComputer<A>, DistanceComputerError>;
267
268 fn query_buffer_description(
273 &self,
274 layout: QueryLayout,
275 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>;
276
277 fn compress_query(
284 &self,
285 x: &[f32],
286 layout: QueryLayout,
287 allow_rescale: bool,
288 buffer: OpaqueMut<'_>,
289 scratch: ScopedAllocator<'_>,
290 ) -> Result<(), QueryCompressionError>;
291
292 fn fused_query_computer(
299 &self,
300 x: &[f32],
301 layout: QueryLayout,
302 allow_rescale: bool,
303 allocator: A,
304 scratch: ScopedAllocator<'_>,
305 ) -> Result<QueryComputer<A>, QueryComputerError>;
306
307 fn is_supported(&self, layout: QueryLayout) -> bool;
309
310 fn compress(
317 &self,
318 x: &[f32],
319 into: OpaqueMut<'_>,
320 scratch: ScopedAllocator<'_>,
321 ) -> Result<(), CompressionError>;
322
323 fn metric(&self) -> SupportedMetric;
325
326 fn try_clone_into(&self, allocator: A) -> Result<Poly<dyn Quantizer<A>, A>, AllocatorError>;
328
329 crate::utils::features! {
330 #![feature = "flatbuffers"]
331 fn serialize(&self, allocator: A) -> Result<Poly<[u8], A>, AllocatorError>;
334 }
335}
336
337#[derive(Debug, Error)]
338#[error("Layout {layout} is not supported for {desc}")]
339pub struct UnsupportedQueryLayout {
340 layout: QueryLayout,
341 desc: &'static str,
342}
343
344impl UnsupportedQueryLayout {
345 fn new(layout: QueryLayout, desc: &'static str) -> Self {
346 Self { layout, desc }
347 }
348}
349
350#[derive(Debug, Error)]
351#[non_exhaustive]
352pub enum DistanceComputerError {
353 #[error(transparent)]
354 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
355 #[error(transparent)]
356 AllocatorError(#[from] AllocatorError),
357}
358
359#[derive(Debug, Error)]
360#[non_exhaustive]
361pub enum QueryCompressionError {
362 #[error(transparent)]
363 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
364 #[error(transparent)]
365 CompressionError(#[from] CompressionError),
366 #[error(transparent)]
367 NotCanonical(#[from] NotCanonical),
368 #[error(transparent)]
369 AllocatorError(#[from] AllocatorError),
370}
371
372#[derive(Debug, Error)]
373#[non_exhaustive]
374pub enum QueryComputerError {
375 #[error(transparent)]
376 UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
377 #[error(transparent)]
378 CompressionError(#[from] CompressionError),
379 #[error(transparent)]
380 AllocatorError(#[from] AllocatorError),
381}
382
383#[derive(Debug, Error)]
385#[error("Error occured during query compression")]
386pub enum CompressionError {
387 NotCanonical(#[source] InlineError<16>),
389
390 CompressionError(#[source] quantizer::CompressionError),
394}
395
396impl CompressionError {
397 fn not_canonical<E>(error: E) -> Self
398 where
399 E: std::error::Error + Send + Sync + 'static,
400 {
401 Self::NotCanonical(InlineError::new(error))
402 }
403}
404
405#[derive(Debug, Error)]
406#[error("An opaque argument did not have the required alignment or length")]
407pub struct NotCanonical {
408 source: Box<dyn std::error::Error + Send + Sync>,
409}
410
411impl NotCanonical {
412 fn new<E>(err: E) -> Self
413 where
414 E: std::error::Error + Send + Sync + 'static,
415 {
416 Self {
417 source: Box::new(err),
418 }
419 }
420}
421
422#[derive(Debug, Clone, Copy)]
429#[repr(transparent)]
430pub struct Opaque<'a>(&'a [u8]);
431
432impl<'a> Opaque<'a> {
433 pub fn new(slice: &'a [u8]) -> Self {
435 Self(slice)
436 }
437
438 pub fn into_inner(self) -> &'a [u8] {
440 self.0
441 }
442}
443
444impl std::ops::Deref for Opaque<'_> {
445 type Target = [u8];
446 fn deref(&self) -> &[u8] {
447 self.0
448 }
449}
450impl<'short> Reborrow<'short> for Opaque<'_> {
451 type Target = Opaque<'short>;
452 fn reborrow(&'short self) -> Self::Target {
453 *self
454 }
455}
456
457#[derive(Debug)]
460#[repr(transparent)]
461pub struct OpaqueMut<'a>(&'a mut [u8]);
462
463impl<'a> OpaqueMut<'a> {
464 pub fn new(slice: &'a mut [u8]) -> Self {
466 Self(slice)
467 }
468
469 pub fn inspect(&mut self) -> &mut [u8] {
471 self.0
472 }
473}
474
475impl std::ops::Deref for OpaqueMut<'_> {
476 type Target = [u8];
477 fn deref(&self) -> &[u8] {
478 self.0
479 }
480}
481
482impl std::ops::DerefMut for OpaqueMut<'_> {
483 fn deref_mut(&mut self) -> &mut [u8] {
484 self.0
485 }
486}
487
488#[derive(Debug, Clone, Copy, PartialEq, Eq)]
494pub enum QueryLayout {
495 SameAsData,
499
500 FourBitTransposed,
502
503 ScalarQuantized,
506
507 FullPrecision,
509}
510
511impl QueryLayout {
512 #[cfg(test)]
513 fn all() -> [Self; 4] {
514 [
515 Self::SameAsData,
516 Self::FourBitTransposed,
517 Self::ScalarQuantized,
518 Self::FullPrecision,
519 ]
520 }
521}
522
523impl std::fmt::Display for QueryLayout {
524 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525 <Self as std::fmt::Debug>::fmt(self, fmt)
526 }
527}
528
529trait ReportQueryLayout {
539 fn report_query_layout(&self) -> QueryLayout;
540}
541
542impl<T, M, L, R> ReportQueryLayout for Reify<T, M, L, R>
543where
544 T: ReportQueryLayout,
545{
546 fn report_query_layout(&self) -> QueryLayout {
547 self.inner.report_query_layout()
548 }
549}
550
551impl<D, Q> ReportQueryLayout for Curried<D, Q>
552where
553 Q: ReportQueryLayout,
554{
555 fn report_query_layout(&self) -> QueryLayout {
556 self.query.report_query_layout()
557 }
558}
559
560impl<const NBITS: usize, A> ReportQueryLayout for Data<NBITS, A>
561where
562 Unsigned: Representation<NBITS>,
563 A: AllocatorCore,
564{
565 fn report_query_layout(&self) -> QueryLayout {
566 QueryLayout::SameAsData
567 }
568}
569
570impl<const NBITS: usize, A> ReportQueryLayout for Query<NBITS, bits::Dense, A>
571where
572 Unsigned: Representation<NBITS>,
573 A: AllocatorCore,
574{
575 fn report_query_layout(&self) -> QueryLayout {
576 QueryLayout::ScalarQuantized
577 }
578}
579
580impl<A> ReportQueryLayout for Query<4, bits::BitTranspose, A>
581where
582 A: AllocatorCore,
583{
584 fn report_query_layout(&self) -> QueryLayout {
585 QueryLayout::FourBitTransposed
586 }
587}
588
589impl<A> ReportQueryLayout for FullQuery<A>
590where
591 A: AllocatorCore,
592{
593 fn report_query_layout(&self) -> QueryLayout {
594 QueryLayout::FullPrecision
595 }
596}
597
598trait FromOpaque: 'static + Send + Sync {
607 type Target<'a>;
608 type Error: std::error::Error + Send + Sync + 'static;
609
610 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error>;
611}
612
613#[derive(Debug, Default)]
615pub(super) struct AsFull;
616
617#[derive(Debug, Default)]
619pub(super) struct AsData<const NBITS: usize>;
620
621#[derive(Debug)]
623pub(super) struct AsQuery<const NBITS: usize, Perm = bits::Dense> {
624 _marker: PhantomData<Perm>,
625}
626
627impl<const NBITS: usize, Perm> Default for AsQuery<NBITS, Perm> {
629 fn default() -> Self {
630 Self {
631 _marker: PhantomData,
632 }
633 }
634}
635
636impl FromOpaque for AsFull {
637 type Target<'a> = FullQueryRef<'a>;
638 type Error = meta::slice::NotCanonical;
639
640 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
641 Self::Target::from_canonical(query.into_inner(), dim)
642 }
643}
644
645impl ReportQueryLayout for AsFull {
646 fn report_query_layout(&self) -> QueryLayout {
647 QueryLayout::FullPrecision
648 }
649}
650
651impl<const NBITS: usize> FromOpaque for AsData<NBITS>
652where
653 Unsigned: Representation<NBITS>,
654{
655 type Target<'a> = DataRef<'a, NBITS>;
656 type Error = meta::NotCanonical;
657
658 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
659 Self::Target::from_canonical_back(query.into_inner(), dim)
660 }
661}
662
663impl<const NBITS: usize> ReportQueryLayout for AsData<NBITS> {
664 fn report_query_layout(&self) -> QueryLayout {
665 QueryLayout::SameAsData
666 }
667}
668
669impl<const NBITS: usize, Perm> FromOpaque for AsQuery<NBITS, Perm>
670where
671 Unsigned: Representation<NBITS>,
672 Perm: bits::PermutationStrategy<NBITS> + Send + Sync + 'static,
673{
674 type Target<'a> = QueryRef<'a, NBITS, Perm>;
675 type Error = meta::NotCanonical;
676
677 fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
678 Self::Target::from_canonical_back(query.into_inner(), dim)
679 }
680}
681
682impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::Dense> {
683 fn report_query_layout(&self) -> QueryLayout {
684 QueryLayout::ScalarQuantized
685 }
686}
687
688impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::BitTranspose> {
689 fn report_query_layout(&self) -> QueryLayout {
690 QueryLayout::FourBitTransposed
691 }
692}
693
694pub(super) struct Reify<T, M, L, R> {
700 inner: T,
701 dim: usize,
702 arch: M,
703 _markers: PhantomData<(L, R)>,
704}
705
706impl<T, M, L, R> Reify<T, M, L, R> {
707 pub(super) fn new(inner: T, dim: usize, arch: M) -> Self {
708 Self {
709 inner,
710 dim,
711 arch,
712 _markers: PhantomData,
713 }
714 }
715}
716
717impl<M, T, R> DynQueryComputer for Reify<T, M, (), R>
718where
719 M: Architecture,
720 R: FromOpaque,
721 T: ReportQueryLayout + Send + Sync,
722 for<'a> &'a T: Target1<M, Rf32, R::Target<'a>>,
723{
724 fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
725 self.arch.run2(
726 |this: &Self, x| {
727 let x = R::from_opaque(x, this.dim)
728 .map_err(|err| QueryDistanceError::XReify(InlineError::new(err)))?;
729 this.arch
730 .run1(&this.inner, x)
731 .map_err(QueryDistanceError::UnequalLengths)
732 },
733 self,
734 x,
735 )
736 }
737
738 fn layout(&self) -> QueryLayout {
739 self.inner.report_query_layout()
740 }
741}
742
743impl<T, M, Q, R> DynDistanceComputer for Reify<T, M, Q, R>
744where
745 M: Architecture,
746 Q: FromOpaque + Default + ReportQueryLayout,
747 R: FromOpaque,
748 T: for<'a> Target2<M, Rf32, Q::Target<'a>, R::Target<'a>> + Copy + Send + Sync,
749{
750 fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
751 self.arch.run3(
752 |this: &Self, query, x| {
753 let query = Q::from_opaque(query, this.dim)
754 .map_err(|err| DistanceError::QueryReify(InlineError::<24>::new(err)))?;
755
756 let x = R::from_opaque(x, this.dim)
757 .map_err(|err| DistanceError::XReify(InlineError::<16>::new(err)))?;
758
759 this.arch
760 .run2_inline(this.inner, query, x)
761 .map_err(DistanceError::UnequalLengths)
762 },
763 self,
764 query,
765 x,
766 )
767 }
768
769 fn layout(&self) -> QueryLayout {
770 Q::default().report_query_layout()
771 }
772}
773
774#[derive(Debug, Error)]
780pub enum QueryDistanceError {
781 #[error("trouble trying to reify the argument")]
783 XReify(#[source] InlineError<16>),
784
785 #[error("encountered while trying to compute distances")]
787 UnequalLengths(#[source] UnequalLengths),
788}
789
790pub trait DynQueryComputer: Send + Sync {
791 fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError>;
792 fn layout(&self) -> QueryLayout;
793}
794
795pub struct QueryComputer<A = GlobalAllocator>
804where
805 A: AllocatorCore,
806{
807 inner: Poly<dyn DynQueryComputer, A>,
808}
809
810impl<A> QueryComputer<A>
811where
812 A: AllocatorCore,
813{
814 fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
815 where
816 T: DynQueryComputer + 'static,
817 {
818 let inner = Poly::new(inner, allocator)?;
819 Ok(Self {
820 inner: poly!(DynQueryComputer, inner),
821 })
822 }
823
824 pub fn layout(&self) -> QueryLayout {
826 self.inner.layout()
827 }
828
829 pub fn into_inner(self) -> Poly<dyn DynQueryComputer, A> {
831 self.inner
832 }
833}
834
835impl<A> std::fmt::Debug for QueryComputer<A>
836where
837 A: AllocatorCore,
838{
839 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
840 write!(
841 f,
842 "dynamic fused query computer with layout \"{}\"",
843 self.layout()
844 )
845 }
846}
847
848impl<A> PreprocessedDistanceFunction<Opaque<'_>, Result<f32, QueryDistanceError>>
849 for QueryComputer<A>
850where
851 A: AllocatorCore,
852{
853 fn evaluate_similarity(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
854 self.inner.evaluate(x)
855 }
856}
857
858pub(super) struct Curried<D, Q> {
865 inner: D,
866 query: Q,
867}
868
869impl<D, Q> Curried<D, Q> {
870 pub(super) fn new(inner: D, query: Q) -> Self {
871 Self { inner, query }
872 }
873}
874
875impl<A, D, Q, T, R> Target1<A, R, T> for &Curried<D, Q>
876where
877 A: Architecture,
878 Q: for<'a> Reborrow<'a>,
879 D: for<'a> Target2<A, R, <Q as Reborrow<'a>>::Target, T> + Copy,
880{
881 fn run(self, arch: A, x: T) -> R {
882 self.inner.run(arch, self.query.reborrow(), x)
883 }
884}
885
886#[derive(Debug, Error)]
892pub enum DistanceError {
893 #[error("trouble trying to reify the left-hand argument")]
895 QueryReify(InlineError<24>),
896
897 #[error("trouble trying to reify the right-hand argument")]
899 XReify(InlineError<16>),
900
901 #[error("encountered while trying to compute distances")]
905 UnequalLengths(UnequalLengths),
906}
907
908pub trait DynDistanceComputer: Send + Sync {
909 fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError>;
910 fn layout(&self) -> QueryLayout;
911}
912
913pub struct DistanceComputer<A = GlobalAllocator>
924where
925 A: AllocatorCore,
926{
927 inner: Poly<dyn DynDistanceComputer, A>,
928}
929
930impl<A> DistanceComputer<A>
931where
932 A: AllocatorCore,
933{
934 pub(super) fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
935 where
936 T: DynDistanceComputer + 'static,
937 {
938 let inner = Poly::new(inner, allocator)?;
939 Ok(Self {
940 inner: poly!(DynDistanceComputer, inner),
941 })
942 }
943
944 pub fn layout(&self) -> QueryLayout {
946 self.inner.layout()
947 }
948
949 pub fn into_inner(self) -> Poly<dyn DynDistanceComputer, A> {
950 self.inner
951 }
952}
953
954impl<A> std::fmt::Debug for DistanceComputer<A>
955where
956 A: AllocatorCore,
957{
958 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
959 write!(
960 f,
961 "dynamic distance computer with layout \"{}\"",
962 self.layout()
963 )
964 }
965}
966
967impl<A> DistanceFunction<Opaque<'_>, Opaque<'_>, Result<f32, DistanceError>> for DistanceComputer<A>
968where
969 A: AllocatorCore,
970{
971 fn evaluate_similarity(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
972 self.inner.evaluate(query, x)
973 }
974}
975
976#[cfg(all(not(test), feature = "flatbuffers"))]
982const DEFAULT_SERIALIZED_BYTES: usize = 1024;
983
984#[cfg(all(test, feature = "flatbuffers"))]
986const DEFAULT_SERIALIZED_BYTES: usize = 1;
987
988pub struct Impl<const NBITS: usize, A = GlobalAllocator>
991where
992 A: Allocator,
993{
994 quantizer: SphericalQuantizer<A>,
995 distance: Poly<dyn DynDistanceComputer, A>,
996}
997
998pub trait Constructible<A = GlobalAllocator>
1001where
1002 A: Allocator,
1003{
1004 fn dispatch_distance(
1005 quantizer: &SphericalQuantizer<A>,
1006 ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError>;
1007}
1008
1009impl<const NBITS: usize, A: Allocator> Constructible<A> for Impl<NBITS, A>
1010where
1011 A: Allocator,
1012 AsData<NBITS>: FromOpaque,
1013 SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1014{
1015 fn dispatch_distance(
1016 quantizer: &SphericalQuantizer<A>,
1017 ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError> {
1018 diskann_wide::arch::dispatch2_no_features(
1019 ComputerDispatcher::<AsData<NBITS>, NBITS>::new(),
1020 quantizer,
1021 quantizer.allocator().clone(),
1022 )
1023 .map(|obj| obj.inner)
1024 }
1025}
1026
1027impl<const NBITS: usize, A> TryClone for Impl<NBITS, A>
1028where
1029 A: Allocator,
1030 AsData<NBITS>: FromOpaque,
1031 SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1032{
1033 fn try_clone(&self) -> Result<Self, AllocatorError> {
1034 Self::new(self.quantizer.try_clone()?)
1035 }
1036}
1037
1038impl<const NBITS: usize, A: Allocator> Impl<NBITS, A> {
1039 pub fn new(quantizer: SphericalQuantizer<A>) -> Result<Self, AllocatorError>
1041 where
1042 Self: Constructible<A>,
1043 {
1044 let distance = Self::dispatch_distance(&quantizer)?;
1045 Ok(Self {
1046 quantizer,
1047 distance,
1048 })
1049 }
1050
1051 pub fn quantizer(&self) -> &SphericalQuantizer<A> {
1053 &self.quantizer
1054 }
1055
1056 pub fn supports(layout: QueryLayout) -> bool {
1060 if const { NBITS == 1 } {
1061 [
1062 QueryLayout::SameAsData,
1063 QueryLayout::FourBitTransposed,
1064 QueryLayout::FullPrecision,
1065 ]
1066 .contains(&layout)
1067 } else {
1068 [
1069 QueryLayout::SameAsData,
1070 QueryLayout::ScalarQuantized,
1071 QueryLayout::FullPrecision,
1072 ]
1073 .contains(&layout)
1074 }
1075 }
1076
1077 fn query_computer<Q, B>(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError>
1080 where
1081 Q: FromOpaque,
1082 B: AllocatorCore,
1083 SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1084 {
1085 diskann_wide::arch::dispatch2_no_features(
1086 ComputerDispatcher::<Q, NBITS>::new(),
1087 &self.quantizer,
1088 allocator,
1089 )
1090 }
1091
1092 fn compress_query<'a, T>(
1093 &self,
1094 query: &'a [f32],
1095 storage: T,
1096 scratch: ScopedAllocator<'a>,
1097 ) -> Result<(), QueryCompressionError>
1098 where
1099 SphericalQuantizer<A>: CompressIntoWith<
1100 &'a [f32],
1101 T,
1102 ScopedAllocator<'a>,
1103 Error = quantizer::CompressionError,
1104 >,
1105 {
1106 self.quantizer
1107 .compress_into_with(query, storage, scratch)
1108 .map_err(|err| CompressionError::CompressionError(err).into())
1109 }
1110
1111 fn fused_query_computer<Q, T, B>(
1114 &self,
1115 query: &[f32],
1116 mut storage: T,
1117 allocator: B,
1118 scratch: ScopedAllocator<'_>,
1119 ) -> Result<QueryComputer<B>, QueryComputerError>
1120 where
1121 Q: FromOpaque,
1122 T: for<'a> ReborrowMut<'a>
1123 + for<'a> Reborrow<'a, Target = Q::Target<'a>>
1124 + ReportQueryLayout
1125 + Send
1126 + Sync
1127 + 'static,
1128 B: AllocatorCore,
1129 SphericalQuantizer<A>: for<'a> CompressIntoWith<
1130 &'a [f32],
1131 <T as ReborrowMut<'a>>::Target,
1132 ScopedAllocator<'a>,
1133 Error = quantizer::CompressionError,
1134 >,
1135 SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1136 {
1137 if let Err(err) = self
1138 .quantizer
1139 .compress_into_with(query, storage.reborrow_mut(), scratch)
1140 {
1141 return Err(CompressionError::CompressionError(err).into());
1142 }
1143
1144 diskann_wide::arch::dispatch3_no_features(
1145 ComputerDispatcher::<Q, NBITS>::new(),
1146 &self.quantizer,
1147 storage,
1148 allocator,
1149 )
1150 .map_err(|e| e.into())
1151 }
1152
1153 #[cfg(feature = "flatbuffers")]
1154 fn serialize<B>(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError>
1155 where
1156 B: Allocator + std::panic::UnwindSafe,
1157 A: std::panic::RefUnwindSafe,
1158 {
1159 let mut buf = FlatBufferBuilder::new_in(Poly::broadcast(
1160 0u8,
1161 DEFAULT_SERIALIZED_BYTES,
1162 allocator.clone(),
1163 )?);
1164
1165 let quantizer = &self.quantizer;
1166
1167 let (root, mut buf) = match std::panic::catch_unwind(move || {
1168 let offset = quantizer.pack(&mut buf);
1169
1170 let root = fb::spherical::Quantizer::create(
1171 &mut buf,
1172 &fb::spherical::QuantizerArgs {
1173 quantizer: Some(offset),
1174 nbits: NBITS as u32,
1175 },
1176 );
1177 (root, buf)
1178 }) {
1179 Ok(ret) => ret,
1180 Err(err) => match err.downcast_ref::<String>() {
1181 Some(msg) => {
1182 if msg.contains("AllocatorError") {
1183 return Err(AllocatorError);
1184 } else {
1185 std::panic::resume_unwind(err);
1186 }
1187 }
1188 None => std::panic::resume_unwind(err),
1189 },
1190 };
1191
1192 fb::spherical::finish_quantizer_buffer(&mut buf, root);
1194 Poly::from_iter(buf.finished_data().iter().copied(), allocator)
1195 }
1196}
1197
1198trait BuildComputer<M, Q, const N: usize>
1215where
1216 M: Architecture,
1217 Q: FromOpaque,
1218{
1219 fn build_computer<A>(
1224 &self,
1225 arch: M,
1226 allocator: A,
1227 ) -> Result<DistanceComputer<A>, AllocatorError>
1228 where
1229 A: AllocatorCore;
1230
1231 fn build_fused_computer<R, A>(
1236 &self,
1237 arch: M,
1238 query: R,
1239 allocator: A,
1240 ) -> Result<QueryComputer<A>, AllocatorError>
1241 where
1242 R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1243 A: AllocatorCore;
1244}
1245
1246fn identity<T>(x: T) -> T {
1247 x
1248}
1249
1250macro_rules! dispatch_map {
1251 ($N:literal, $Q:ty, $arch:ty) => {
1252 dispatch_map!($N, $Q, $arch, identity);
1253 };
1254 ($N:literal, $Q:ty, $arch:ty, $op:ident) => {
1255 impl<A> BuildComputer<$arch, $Q, $N> for SphericalQuantizer<A>
1256 where
1257 A: Allocator,
1258 {
1259 fn build_computer<B>(
1260 &self,
1261 input_arch: $arch,
1262 allocator: B,
1263 ) -> Result<DistanceComputer<B>, AllocatorError>
1264 where
1265 B: AllocatorCore,
1266 {
1267 type D = AsData<$N>;
1268
1269 let arch = ($op)(input_arch);
1271 let dim = self.output_dim();
1272 match self.metric() {
1273 SupportedMetric::SquaredL2 => {
1274 let reify = Reify::<CompensatedSquaredL2, _, $Q, D>::new(
1275 self.as_functor(),
1276 dim,
1277 arch,
1278 );
1279 DistanceComputer::new(reify, allocator)
1280 }
1281 SupportedMetric::InnerProduct => {
1282 let reify =
1283 Reify::<CompensatedIP, _, $Q, D>::new(self.as_functor(), dim, arch);
1284 DistanceComputer::new(reify, allocator)
1285 }
1286 SupportedMetric::Cosine => {
1287 let reify =
1288 Reify::<CompensatedCosine, _, $Q, D>::new(self.as_functor(), dim, arch);
1289 DistanceComputer::new(reify, allocator)
1290 }
1291 }
1292 }
1293
1294 fn build_fused_computer<R, B>(
1295 &self,
1296 input_arch: $arch,
1297 query: R,
1298 allocator: B,
1299 ) -> Result<QueryComputer<B>, AllocatorError>
1300 where
1301 R: ReportQueryLayout
1302 + for<'a> Reborrow<'a, Target = <$Q as FromOpaque>::Target<'a>>
1303 + Send
1304 + Sync
1305 + 'static,
1306 B: AllocatorCore,
1307 {
1308 type D = AsData<$N>;
1309 let arch = ($op)(input_arch);
1310 let dim = self.output_dim();
1311 match self.metric() {
1312 SupportedMetric::SquaredL2 => {
1313 let computer: CompensatedSquaredL2 = self.as_functor();
1314 let curried = Curried::new(computer, query);
1315 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1316 Ok(QueryComputer::new(reify, allocator)?)
1317 }
1318 SupportedMetric::InnerProduct => {
1319 let computer: CompensatedIP = self.as_functor();
1320 let curried = Curried::new(computer, query);
1321 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1322 Ok(QueryComputer::new(reify, allocator)?)
1323 }
1324 SupportedMetric::Cosine => {
1325 let computer: CompensatedCosine = self.as_functor();
1326 let curried = Curried::new(computer, query);
1327 let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1328 Ok(QueryComputer::new(reify, allocator)?)
1329 }
1330 }
1331 }
1332 }
1333 };
1334}
1335
1336dispatch_map!(1, AsFull, Scalar);
1337dispatch_map!(2, AsFull, Scalar);
1338dispatch_map!(4, AsFull, Scalar);
1339dispatch_map!(8, AsFull, Scalar);
1340
1341dispatch_map!(1, AsData<1>, Scalar);
1342dispatch_map!(2, AsData<2>, Scalar);
1343dispatch_map!(4, AsData<4>, Scalar);
1344dispatch_map!(8, AsData<8>, Scalar);
1345
1346dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Scalar);
1348dispatch_map!(2, AsQuery<2>, Scalar);
1349dispatch_map!(4, AsQuery<4>, Scalar);
1350dispatch_map!(8, AsQuery<8>, Scalar);
1351
1352cfg_if::cfg_if! {
1353 if #[cfg(target_arch = "x86_64")] {
1354 fn downcast_to_v3(arch: V4) -> V3 {
1355 arch.into()
1356 }
1357
1358 dispatch_map!(1, AsFull, V3);
1360 dispatch_map!(2, AsFull, V3);
1361 dispatch_map!(4, AsFull, V3);
1362 dispatch_map!(8, AsFull, V3);
1363
1364 dispatch_map!(1, AsData<1>, V3);
1365 dispatch_map!(2, AsData<2>, V3);
1366 dispatch_map!(4, AsData<4>, V3);
1367 dispatch_map!(8, AsData<8>, V3);
1368
1369 dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V3);
1370 dispatch_map!(2, AsQuery<2>, V3);
1371 dispatch_map!(4, AsQuery<4>, V3);
1372 dispatch_map!(8, AsQuery<8>, V3);
1373
1374 dispatch_map!(1, AsFull, V4, downcast_to_v3);
1376 dispatch_map!(2, AsFull, V4, downcast_to_v3);
1377 dispatch_map!(4, AsFull, V4, downcast_to_v3);
1378 dispatch_map!(8, AsFull, V4, downcast_to_v3);
1379
1380 dispatch_map!(1, AsData<1>, V4, downcast_to_v3);
1381 dispatch_map!(2, AsData<2>, V4); dispatch_map!(4, AsData<4>, V4, downcast_to_v3);
1383 dispatch_map!(8, AsData<8>, V4, downcast_to_v3);
1384
1385 dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3);
1386 dispatch_map!(2, AsQuery<2>, V4); dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3);
1388 dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3);
1389 }
1390}
1391
1392#[derive(Debug, Clone, Copy)]
1404struct ComputerDispatcher<Q, const N: usize> {
1405 _query_type: std::marker::PhantomData<Q>,
1406}
1407
1408impl<Q, const N: usize> ComputerDispatcher<Q, N> {
1409 fn new() -> Self {
1410 Self {
1411 _query_type: std::marker::PhantomData,
1412 }
1413 }
1414}
1415
1416impl<M, const N: usize, A, B, Q>
1417 diskann_wide::arch::Target2<
1418 M,
1419 Result<DistanceComputer<B>, AllocatorError>,
1420 &SphericalQuantizer<A>,
1421 B,
1422 > for ComputerDispatcher<Q, N>
1423where
1424 M: Architecture,
1425 A: Allocator,
1426 B: AllocatorCore,
1427 Q: FromOpaque,
1428 SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1429{
1430 fn run(
1431 self,
1432 arch: M,
1433 quantizer: &SphericalQuantizer<A>,
1434 allocator: B,
1435 ) -> Result<DistanceComputer<B>, AllocatorError> {
1436 quantizer.build_computer(arch, allocator)
1437 }
1438}
1439
1440impl<M, const N: usize, A, R, B, Q>
1441 diskann_wide::arch::Target3<
1442 M,
1443 Result<QueryComputer<B>, AllocatorError>,
1444 &SphericalQuantizer<A>,
1445 R,
1446 B,
1447 > for ComputerDispatcher<Q, N>
1448where
1449 M: Architecture,
1450 A: Allocator,
1451 B: AllocatorCore,
1452 Q: FromOpaque,
1453 R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1454 SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1455{
1456 fn run(
1457 self,
1458 arch: M,
1459 quantizer: &SphericalQuantizer<A>,
1460 query: R,
1461 allocator: B,
1462 ) -> Result<QueryComputer<B>, AllocatorError> {
1463 quantizer.build_fused_computer(arch, query, allocator)
1464 }
1465}
1466
1467#[cfg(target_arch = "x86_64")]
1468trait Dispatchable<Q, const N: usize>:
1469 BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>
1470where
1471 Q: FromOpaque,
1472{
1473}
1474
1475#[cfg(target_arch = "x86_64")]
1476impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1477where
1478 Q: FromOpaque,
1479 T: BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>,
1480{
1481}
1482
1483#[cfg(not(target_arch = "x86_64"))]
1484trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N>
1485where
1486 Q: FromOpaque,
1487{
1488}
1489
1490#[cfg(not(target_arch = "x86_64"))]
1491impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1492where
1493 Q: FromOpaque,
1494 T: BuildComputer<Scalar, Q, N>,
1495{
1496}
1497
1498impl<A, B> Quantizer<B> for Impl<1, A>
1503where
1504 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1505 B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1506{
1507 fn nbits(&self) -> usize {
1508 1
1509 }
1510
1511 fn dim(&self) -> usize {
1512 self.quantizer.output_dim()
1513 }
1514
1515 fn full_dim(&self) -> usize {
1516 self.quantizer.input_dim()
1517 }
1518
1519 fn bytes(&self) -> usize {
1520 DataRef::<1>::canonical_bytes(self.quantizer.output_dim())
1521 }
1522
1523 fn distance_computer(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError> {
1524 self.query_computer::<AsData<1>, _>(allocator)
1525 }
1526
1527 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1528 &*self.distance
1529 }
1530
1531 fn query_computer(
1532 &self,
1533 layout: QueryLayout,
1534 allocator: B,
1535 ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1536 match layout {
1537 QueryLayout::SameAsData => Ok(self.query_computer::<AsData<1>, _>(allocator)?),
1538 QueryLayout::FourBitTransposed => {
1539 Ok(self.query_computer::<AsQuery<4, bits::BitTranspose>, _>(allocator)?)
1540 }
1541 QueryLayout::ScalarQuantized => {
1542 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1543 }
1544 QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1545 }
1546 }
1547
1548 fn query_buffer_description(
1549 &self,
1550 layout: QueryLayout,
1551 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout> {
1552 let dim = <Self as Quantizer<B>>::dim(self);
1553 match layout {
1554 QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1555 DataRef::<1>::canonical_bytes(dim),
1556 PowerOfTwo::alignment_of::<u8>(),
1557 )),
1558 QueryLayout::FourBitTransposed => Ok(QueryBufferDescription::new(
1559 QueryRef::<4, bits::BitTranspose>::canonical_bytes(dim),
1560 PowerOfTwo::alignment_of::<u8>(),
1561 )),
1562 QueryLayout::ScalarQuantized => {
1563 Err(UnsupportedQueryLayout::new(layout, "1-bit compression"))
1564 }
1565 QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1566 FullQueryRef::canonical_bytes(dim),
1567 FullQueryRef::canonical_align(),
1568 )),
1569 }
1570 }
1571
1572 fn compress_query(
1573 &self,
1574 x: &[f32],
1575 layout: QueryLayout,
1576 allow_rescale: bool,
1577 mut buffer: OpaqueMut<'_>,
1578 scratch: ScopedAllocator<'_>,
1579 ) -> Result<(), QueryCompressionError> {
1580 let dim = <Self as Quantizer<B>>::dim(self);
1581 let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1582 match layout {
1583 QueryLayout::SameAsData => self.compress_query(
1584 v,
1585 DataMut::<1>::from_canonical_back_mut(&mut buffer, dim)
1586 .map_err(NotCanonical::new)?,
1587 scratch,
1588 ),
1589 QueryLayout::FourBitTransposed => self.compress_query(
1590 v,
1591 QueryMut::<4, bits::BitTranspose>::from_canonical_back_mut(&mut buffer, dim)
1592 .map_err(NotCanonical::new)?,
1593 scratch,
1594 ),
1595 QueryLayout::ScalarQuantized => {
1596 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1597 }
1598 QueryLayout::FullPrecision => self.compress_query(
1599 v,
1600 FullQueryMut::from_canonical_mut(&mut buffer, dim)
1601 .map_err(NotCanonical::new)?,
1602 scratch,
1603 ),
1604 }
1605 };
1606
1607 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1608 let mut copy = x.to_owned();
1609 self.quantizer.rescale(&mut copy);
1610 finish(©)
1611 } else {
1612 finish(x)
1613 }
1614 }
1615
1616 fn fused_query_computer(
1617 &self,
1618 x: &[f32],
1619 layout: QueryLayout,
1620 allow_rescale: bool,
1621 allocator: B,
1622 scratch: ScopedAllocator<'_>,
1623 ) -> Result<QueryComputer<B>, QueryComputerError> {
1624 let dim = <Self as Quantizer<B>>::dim(self);
1625 let finish = |v: &[f32], allocator: B| -> Result<QueryComputer<B>, QueryComputerError> {
1626 match layout {
1627 QueryLayout::SameAsData => self.fused_query_computer::<AsData<1>, Data<1, _>, _>(
1628 v,
1629 Data::new_in(dim, allocator.clone())?,
1630 allocator,
1631 scratch,
1632 ),
1633 QueryLayout::FourBitTransposed => self
1634 .fused_query_computer::<AsQuery<4, bits::BitTranspose>, Query<4, bits::BitTranspose, _>, _>(
1635 v,
1636 Query::new_in(dim, allocator.clone())?,
1637 allocator,
1638 scratch,
1639 ),
1640 QueryLayout::ScalarQuantized => {
1641 Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1642 }
1643 QueryLayout::FullPrecision => self.fused_query_computer::<AsFull, FullQuery<_>, _>(
1644 v,
1645 FullQuery::empty(dim, allocator.clone())?,
1646 allocator,
1647 scratch,
1648 ),
1649 }
1650 };
1651
1652 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1653 let mut copy = x.to_owned();
1654 self.quantizer.rescale(&mut copy);
1655 finish(©, allocator)
1656 } else {
1657 finish(x, allocator)
1658 }
1659 }
1660
1661 fn is_supported(&self, layout: QueryLayout) -> bool {
1662 Self::supports(layout)
1663 }
1664
1665 fn compress(
1666 &self,
1667 x: &[f32],
1668 mut into: OpaqueMut<'_>,
1669 scratch: ScopedAllocator<'_>,
1670 ) -> Result<(), CompressionError> {
1671 let dim = <Self as Quantizer<B>>::dim(self);
1672 let into = DataMut::<1>::from_canonical_back_mut(into.inspect(), dim)
1673 .map_err(CompressionError::not_canonical)?;
1674 self.quantizer
1675 .compress_into_with(x, into, scratch)
1676 .map_err(CompressionError::CompressionError)
1677 }
1678
1679 fn metric(&self) -> SupportedMetric {
1680 self.quantizer.metric()
1681 }
1682
1683 fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1684 let clone = (*self).try_clone()?;
1685 poly!({ Quantizer<B> }, clone, allocator)
1686 }
1687
1688 #[cfg(feature = "flatbuffers")]
1689 fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1690 Impl::<1, A>::serialize(self, allocator)
1691 }
1692}
1693
1694macro_rules! plan {
1695 ($N:literal) => {
1696 impl<A, B> Quantizer<B> for Impl<$N, A>
1697 where
1698 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1699 B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1700 {
1701 fn nbits(&self) -> usize {
1702 $N
1703 }
1704
1705 fn dim(&self) -> usize {
1706 self.quantizer.output_dim()
1707 }
1708
1709 fn full_dim(&self) -> usize {
1710 self.quantizer.input_dim()
1711 }
1712
1713 fn bytes(&self) -> usize {
1714 DataRef::<$N>::canonical_bytes(<Self as Quantizer<B>>::dim(self))
1715 }
1716
1717 fn distance_computer(
1718 &self,
1719 allocator: B
1720 ) -> Result<DistanceComputer<B>, AllocatorError> {
1721 self.query_computer::<AsData<$N>, _>(allocator)
1722 }
1723
1724 fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1725 &*self.distance
1726 }
1727
1728 fn query_computer(
1729 &self,
1730 layout: QueryLayout,
1731 allocator: B,
1732 ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1733 match layout {
1734 QueryLayout::SameAsData => Ok(self.query_computer::<AsData<$N>, _>(allocator)?)
1735 ,
1736 QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout::new(
1737 layout,
1738 concat!($N, "-bit compression"),
1739 ).into()),
1740 QueryLayout::ScalarQuantized => {
1741 Ok(self.query_computer::<AsQuery<$N, bits::Dense>, _>(allocator)?)
1742 },
1743 QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1744
1745 }
1746 }
1747
1748 fn query_buffer_description(
1749 &self,
1750 layout: QueryLayout
1751 ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>
1752 {
1753 let dim = <Self as Quantizer<B>>::dim(self);
1754 match layout {
1755 QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1756 DataRef::<$N>::canonical_bytes(dim),
1757 PowerOfTwo::alignment_of::<u8>(),
1758 )),
1759 QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout {
1760 layout,
1761 desc: concat!($N, "-bit compression"),
1762 }),
1763 QueryLayout::ScalarQuantized => Ok(QueryBufferDescription::new(
1764 QueryRef::<$N, bits::Dense>::canonical_bytes(dim),
1765 PowerOfTwo::alignment_of::<u8>(),
1766 )),
1767 QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1768 FullQueryRef::canonical_bytes(dim),
1769 FullQueryRef::canonical_align(),
1770 )),
1771 }
1772 }
1773
1774 fn compress_query(
1775 &self,
1776 x: &[f32],
1777 layout: QueryLayout,
1778 allow_rescale: bool,
1779 mut buffer: OpaqueMut<'_>,
1780 scratch: ScopedAllocator<'_>,
1781 ) -> Result<(), QueryCompressionError> {
1782 let dim = <Self as Quantizer<B>>::dim(self);
1783 let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1784 match layout {
1785 QueryLayout::SameAsData => self.compress_query(
1786 v,
1787 DataMut::<$N>::from_canonical_back_mut(
1788 &mut buffer,
1789 dim,
1790 ).map_err(NotCanonical::new)?,
1791 scratch,
1792 ),
1793 QueryLayout::FourBitTransposed => {
1794 Err(UnsupportedQueryLayout::new(
1795 layout,
1796 concat!($N, "-bit compression"),
1797 ).into())
1798 },
1799 QueryLayout::ScalarQuantized => self.compress_query(
1800 v,
1801 QueryMut::<$N, bits::Dense>::from_canonical_back_mut(
1802 &mut buffer,
1803 dim,
1804 ).map_err(NotCanonical::new)?,
1805 scratch,
1806 ),
1807 QueryLayout::FullPrecision => self.compress_query(
1808 v,
1809 FullQueryMut::from_canonical_mut(
1810 &mut buffer,
1811 dim,
1812 ).map_err(NotCanonical::new)?,
1813 scratch,
1814 ),
1815 }
1816 };
1817
1818 if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1819 let mut copy = x.to_owned();
1820 self.quantizer.rescale(&mut copy);
1821 finish(©)
1822 } else {
1823 finish(x)
1824 }
1825 }
1826
1827 fn fused_query_computer(
1828 &self,
1829 x: &[f32],
1830 layout: QueryLayout,
1831 allow_rescale: bool,
1832 allocator: B,
1833 scratch: ScopedAllocator<'_>,
1834 ) -> Result<QueryComputer<B>, QueryComputerError>
1835 {
1836 let dim = <Self as Quantizer<B>>::dim(self);
1837 let finish = |v: &[f32]| -> Result<QueryComputer<B>, QueryComputerError> {
1838 match layout {
1839 QueryLayout::SameAsData => {
1840 self.fused_query_computer::<AsData<$N>, Data<$N, _>, B>(
1841 v,
1842 Data::new_in(dim, allocator.clone())?,
1843 allocator,
1844 scratch,
1845 )
1846 },
1847 QueryLayout::FourBitTransposed => {
1848 Err(UnsupportedQueryLayout::new(
1849 layout,
1850 concat!($N, "-bit compression"),
1851 ).into())
1852 },
1853 QueryLayout::ScalarQuantized => {
1854 self.fused_query_computer::<AsQuery<$N, bits::Dense>, Query<$N, bits::Dense, _>, B>(
1855 v,
1856 Query::new_in(dim, allocator.clone())?,
1857 allocator,
1858 scratch,
1859 )
1860 },
1861 QueryLayout::FullPrecision => {
1862 self.fused_query_computer::<AsFull, FullQuery<_>, B>(
1863 v,
1864 FullQuery::empty(dim, allocator.clone())?,
1865 allocator,
1866 scratch,
1867 )
1868 },
1869 }
1870 };
1871
1872 let metric = <Self as Quantizer<B>>::metric(self);
1873 if allow_rescale && metric == SupportedMetric::InnerProduct {
1874 let mut copy = x.to_owned();
1875 self.quantizer.rescale(&mut copy);
1876 finish(©)
1877 } else {
1878 finish(x)
1879 }
1880 }
1881
1882 fn is_supported(&self, layout: QueryLayout) -> bool {
1883 Self::supports(layout)
1884 }
1885
1886 fn compress(
1887 &self,
1888 x: &[f32],
1889 mut into: OpaqueMut<'_>,
1890 scratch: ScopedAllocator<'_>,
1891 ) -> Result<(), CompressionError> {
1892 let dim = <Self as Quantizer<B>>::dim(self);
1893 let into = DataMut::<$N>::from_canonical_back_mut(into.inspect(), dim)
1894 .map_err(CompressionError::not_canonical)?;
1895
1896 self.quantizer.compress_into_with(x, into, scratch)
1897 .map_err(CompressionError::CompressionError)
1898 }
1899
1900 fn metric(&self) -> SupportedMetric {
1901 self.quantizer.metric()
1902 }
1903
1904 fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1905 let clone = (&*self).try_clone()?;
1906 poly!({ Quantizer<B> }, clone, allocator)
1907 }
1908
1909 #[cfg(feature = "flatbuffers")]
1910 fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1911 Impl::<$N, A>::serialize(self, allocator)
1912 }
1913 }
1914 };
1915 ($N:literal, $($Ns:literal),*) => {
1916 plan!($N);
1917 $(plan!($Ns);)*
1918 }
1919}
1920
1921plan!(2, 4, 8);
1922
1923#[cfg(feature = "flatbuffers")]
1928#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1929#[derive(Debug, Clone, Error)]
1930#[non_exhaustive]
1931pub enum DeserializationError {
1932 #[error("unhandled file identifier in flatbuffer")]
1933 InvalidIdentifier,
1934
1935 #[error("unsupported number of bits ({0})")]
1936 UnsupportedBitWidth(u32),
1937
1938 #[error(transparent)]
1939 InvalidQuantizer(#[from] super::quantizer::DeserializationError),
1940
1941 #[error(transparent)]
1942 InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
1943
1944 #[error(transparent)]
1945 AllocatorError(#[from] AllocatorError),
1946}
1947
1948#[cfg(feature = "flatbuffers")]
1954#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1955pub fn try_deserialize<O, A>(
1956 data: &[u8],
1957 alloc: A,
1958) -> Result<Poly<dyn Quantizer<O>, A>, DeserializationError>
1959where
1960 O: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1961 A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1962{
1963 fn unpack_bits<'a, const NBITS: usize, O, A>(
1968 proto: fb::spherical::SphericalQuantizer<'_>,
1969 alloc: A,
1970 ) -> Result<Poly<dyn Quantizer<O> + 'a, A>, DeserializationError>
1971 where
1972 O: Allocator + Send + Sync + std::panic::UnwindSafe + 'static,
1973 A: Allocator + Send + Sync + 'a,
1974 Impl<NBITS, A>: Quantizer<O> + Constructible<A>,
1975 {
1976 let imp = match Poly::new_with(
1977 #[inline(never)]
1978 |alloc| -> Result<_, super::quantizer::DeserializationError> {
1979 let quantizer = SphericalQuantizer::try_unpack(alloc, proto)?;
1980 Ok(Impl::new(quantizer)?)
1981 },
1982 alloc,
1983 ) {
1984 Ok(imp) => imp,
1985 Err(CompoundError::Allocator(err)) => {
1986 return Err(err.into());
1987 }
1988 Err(CompoundError::Constructor(err)) => {
1989 return Err(err.into());
1990 }
1991 };
1992 Ok(poly!({ Quantizer<O> }, imp))
1993 }
1994
1995 if !fb::spherical::quantizer_buffer_has_identifier(data) {
1997 return Err(DeserializationError::InvalidIdentifier);
1998 }
1999
2000 let root = fb::spherical::root_as_quantizer(data)?;
2004 let nbits = root.nbits();
2005 let proto = root.quantizer();
2006
2007 match nbits {
2008 1 => unpack_bits::<1, _, _>(proto, alloc),
2009 2 => unpack_bits::<2, _, _>(proto, alloc),
2010 4 => unpack_bits::<4, _, _>(proto, alloc),
2011 8 => unpack_bits::<8, _, _>(proto, alloc),
2012 n => Err(DeserializationError::UnsupportedBitWidth(n)),
2013 }
2014}
2015
2016#[cfg(test)]
2021mod tests {
2022 use diskann_utils::views::{Matrix, MatrixView};
2023 use rand::{rngs::StdRng, SeedableRng};
2024
2025 use super::*;
2026 use crate::{
2027 algorithms::{transforms::TargetDim, TransformKind},
2028 alloc::{AlignedAllocator, GlobalAllocator, Poly},
2029 num::PowerOfTwo,
2030 spherical::PreScale,
2031 };
2032
2033 fn test_plan_1_bit(plan: &dyn Quantizer) {
2038 assert_eq!(
2039 plan.nbits(),
2040 1,
2041 "this test only applies to 1-bit quantization"
2042 );
2043
2044 for layout in QueryLayout::all() {
2046 match layout {
2047 QueryLayout::SameAsData
2048 | QueryLayout::FourBitTransposed
2049 | QueryLayout::FullPrecision => assert!(
2050 plan.is_supported(layout),
2051 "expected {} to be supported",
2052 layout
2053 ),
2054 QueryLayout::ScalarQuantized => assert!(
2055 !plan.is_supported(layout),
2056 "expected {} to not be supported",
2057 layout
2058 ),
2059 }
2060 }
2061 }
2062
2063 fn test_plan_n_bit(plan: &dyn Quantizer, nbits: usize) {
2064 assert_ne!(nbits, 1, "there is another test for 1-bit quantizers");
2065 assert_eq!(
2066 plan.nbits(),
2067 nbits,
2068 "this test only applies to 1-bit quantization"
2069 );
2070
2071 for layout in QueryLayout::all() {
2073 match layout {
2074 QueryLayout::SameAsData
2075 | QueryLayout::ScalarQuantized
2076 | QueryLayout::FullPrecision => assert!(
2077 plan.is_supported(layout),
2078 "expected {} to be supported",
2079 layout
2080 ),
2081 QueryLayout::FourBitTransposed => assert!(
2082 !plan.is_supported(layout),
2083 "expected {} to not be supported",
2084 layout
2085 ),
2086 }
2087 }
2088 }
2089
2090 #[inline(never)]
2091 fn test_plan(plan: &dyn Quantizer, nbits: usize, dataset: MatrixView<f32>) {
2092 if nbits == 1 {
2094 test_plan_1_bit(plan);
2095 } else {
2096 test_plan_n_bit(plan, nbits);
2097 }
2098
2099 assert_eq!(plan.full_dim(), dataset.ncols());
2101
2102 let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2104 let mut a = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2105 let mut b = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2106 let scoped_global = ScopedAllocator::global();
2107
2108 plan.compress(dataset.row(0), OpaqueMut::new(&mut a), scoped_global)
2109 .unwrap();
2110 plan.compress(dataset.row(1), OpaqueMut::new(&mut b), scoped_global)
2111 .unwrap();
2112
2113 let f = plan.distance_computer(GlobalAllocator).unwrap();
2114 let _: f32 = f
2115 .evaluate_similarity(Opaque::new(&a), Opaque::new(&b))
2116 .unwrap();
2117
2118 let test_errors = |f: &dyn DynDistanceComputer| {
2119 let err = f
2121 .evaluate(Opaque::new(&a[..a.len() - 1]), Opaque::new(&b))
2122 .unwrap_err();
2123 assert!(matches!(err, DistanceError::QueryReify(_)));
2124
2125 let err = f
2127 .evaluate(Opaque::new(&vec![0u8; a.len() + 1]), Opaque::new(&b))
2128 .unwrap_err();
2129 assert!(matches!(err, DistanceError::QueryReify(_)));
2130
2131 let err = f
2133 .evaluate(Opaque::new(&a), Opaque::new(&b[..b.len() - 1]))
2134 .unwrap_err();
2135 assert!(matches!(err, DistanceError::XReify(_)));
2136
2137 let err = f
2139 .evaluate(Opaque::new(&a), Opaque::new(&vec![0u8; b.len() + 1]))
2140 .unwrap_err();
2141 assert!(matches!(err, DistanceError::XReify(_)));
2142 };
2143
2144 test_errors(&*f.inner);
2145
2146 let f = plan.distance_computer_ref();
2147 let _: f32 = f.evaluate(Opaque::new(&a), Opaque::new(&b)).unwrap();
2148 test_errors(f);
2149
2150 for layout in QueryLayout::all() {
2152 if !plan.is_supported(layout) {
2153 let check_message = |msg: &str| {
2154 assert!(
2155 msg.contains(&(layout.to_string())),
2156 "error message ({}) should contain the layout \"{}\"",
2157 msg,
2158 layout
2159 );
2160 assert!(
2161 msg.contains(&format!("{}", nbits)),
2162 "error message ({}) should contain the number of bits \"{}\"",
2163 msg,
2164 nbits
2165 );
2166 };
2167
2168 {
2170 let err = plan
2171 .fused_query_computer(
2172 dataset.row(1),
2173 layout,
2174 false,
2175 GlobalAllocator,
2176 scoped_global,
2177 )
2178 .unwrap_err();
2179
2180 let msg = err.to_string();
2181 check_message(&msg);
2182 }
2183
2184 {
2186 let err = plan.query_buffer_description(layout).unwrap_err();
2187 let msg = err.to_string();
2188 check_message(&msg);
2189 }
2190
2191 {
2193 let buffer = &mut [];
2194 let err = plan
2195 .compress_query(
2196 dataset.row(1),
2197 layout,
2198 true,
2199 OpaqueMut::new(buffer),
2200 scoped_global,
2201 )
2202 .unwrap_err();
2203 let msg = err.to_string();
2204 check_message(&msg);
2205 }
2206
2207 {
2209 let err = plan.query_computer(layout, GlobalAllocator).unwrap_err();
2210 let msg = err.to_string();
2211 check_message(&msg);
2212 }
2213
2214 continue;
2215 }
2216
2217 let g = plan
2218 .fused_query_computer(
2219 dataset.row(1),
2220 layout,
2221 false,
2222 GlobalAllocator,
2223 scoped_global,
2224 )
2225 .unwrap();
2226 assert_eq!(
2227 g.layout(),
2228 layout,
2229 "the query computer should faithfully preserve the requested layout"
2230 );
2231
2232 let direct: f32 = g.evaluate_similarity(Opaque(&a)).unwrap();
2233
2234 {
2236 let err = g
2237 .evaluate_similarity(Opaque::new(&a[..a.len() - 1]))
2238 .unwrap_err();
2239 assert!(matches!(err, QueryDistanceError::XReify(_)));
2240
2241 let err = g
2242 .evaluate_similarity(Opaque::new(&vec![0u8; a.len() + 1]))
2243 .unwrap_err();
2244 assert!(matches!(err, QueryDistanceError::XReify(_)));
2245 }
2246
2247 let sizes = plan.query_buffer_description(layout).unwrap();
2248 let mut buf =
2249 Poly::broadcast(0u8, sizes.bytes(), AlignedAllocator::new(sizes.align())).unwrap();
2250
2251 plan.compress_query(
2252 dataset.row(1),
2253 layout,
2254 false,
2255 OpaqueMut::new(&mut buf),
2256 scoped_global,
2257 )
2258 .unwrap();
2259
2260 let standalone = plan.query_computer(layout, GlobalAllocator).unwrap();
2261
2262 assert_eq!(
2263 standalone.layout(),
2264 layout,
2265 "the standalone computer did not preserve the requested layout",
2266 );
2267
2268 let indirect: f32 = standalone
2269 .evaluate_similarity(Opaque(&buf), Opaque(&a))
2270 .unwrap();
2271
2272 assert_eq!(
2273 direct, indirect,
2274 "the two different query computation APIs did not return the same result"
2275 );
2276
2277 let too_small = &dataset.row(0)[..dataset.ncols() - 1];
2279 assert!(plan
2280 .fused_query_computer(too_small, layout, false, GlobalAllocator, scoped_global)
2281 .is_err());
2282 }
2283
2284 {
2286 let mut too_small = vec![u8::default(); plan.bytes() - 1];
2287 assert!(plan
2288 .compress(dataset.row(0), OpaqueMut(&mut too_small), scoped_global)
2289 .is_err());
2290
2291 let mut too_big = vec![u8::default(); plan.bytes() + 1];
2292 assert!(plan
2293 .compress(dataset.row(0), OpaqueMut(&mut too_big), scoped_global)
2294 .is_err());
2295
2296 let mut just_right = vec![u8::default(); plan.bytes()];
2297 assert!(plan
2298 .compress(
2299 &dataset.row(0)[..dataset.ncols() - 1],
2300 OpaqueMut(&mut just_right),
2301 scoped_global
2302 )
2303 .is_err());
2304 }
2305 }
2306
2307 fn make_impl<const NBITS: usize>(metric: SupportedMetric) -> (Impl<NBITS>, Matrix<f32>)
2308 where
2309 Impl<NBITS>: Constructible,
2310 {
2311 let data = test_dataset();
2312 let mut rng = StdRng::seed_from_u64(0x7d535118722ff197);
2313
2314 let quantizer = SphericalQuantizer::train(
2315 data.as_view(),
2316 TransformKind::PaddingHadamard {
2317 target_dim: TargetDim::Natural,
2318 },
2319 metric,
2320 PreScale::None,
2321 &mut rng,
2322 GlobalAllocator,
2323 )
2324 .unwrap();
2325
2326 (Impl::<NBITS>::new(quantizer).unwrap(), data)
2327 }
2328
2329 #[test]
2330 fn test_plan_1bit_l2() {
2331 let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2332 test_plan(&plan, 1, data.as_view());
2333 }
2334
2335 #[test]
2336 fn test_plan_1bit_ip() {
2337 let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2338 test_plan(&plan, 1, data.as_view());
2339 }
2340
2341 #[test]
2342 fn test_plan_1bit_cosine() {
2343 let (plan, data) = make_impl::<1>(SupportedMetric::Cosine);
2344 test_plan(&plan, 1, data.as_view());
2345 }
2346
2347 #[test]
2348 fn test_plan_2bit_l2() {
2349 let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2350 test_plan(&plan, 2, data.as_view());
2351 }
2352
2353 #[test]
2354 fn test_plan_2bit_ip() {
2355 let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2356 test_plan(&plan, 2, data.as_view());
2357 }
2358
2359 #[test]
2360 fn test_plan_2bit_cosine() {
2361 let (plan, data) = make_impl::<2>(SupportedMetric::Cosine);
2362 test_plan(&plan, 2, data.as_view());
2363 }
2364
2365 #[test]
2366 fn test_plan_4bit_l2() {
2367 let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2368 test_plan(&plan, 4, data.as_view());
2369 }
2370
2371 #[test]
2372 fn test_plan_4bit_ip() {
2373 let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2374 test_plan(&plan, 4, data.as_view());
2375 }
2376
2377 #[test]
2378 fn test_plan_4bit_cosine() {
2379 let (plan, data) = make_impl::<4>(SupportedMetric::Cosine);
2380 test_plan(&plan, 4, data.as_view());
2381 }
2382
2383 #[test]
2384 fn test_plan_8bit_l2() {
2385 let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2386 test_plan(&plan, 8, data.as_view());
2387 }
2388
2389 #[test]
2390 fn test_plan_8bit_ip() {
2391 let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2392 test_plan(&plan, 8, data.as_view());
2393 }
2394
2395 #[test]
2396 fn test_plan_8bit_cosine() {
2397 let (plan, data) = make_impl::<8>(SupportedMetric::Cosine);
2398 test_plan(&plan, 8, data.as_view());
2399 }
2400
2401 fn test_dataset() -> Matrix<f32> {
2402 let data = vec![
2403 0.28657,
2404 -0.0318168,
2405 0.0666847,
2406 0.0329265,
2407 -0.00829283,
2408 0.168735,
2409 -0.000846311,
2410 -0.360779, -0.0968938,
2412 0.161921,
2413 -0.0979579,
2414 0.102228,
2415 -0.259928,
2416 -0.139634,
2417 0.165384,
2418 -0.293443, 0.130205,
2420 0.265737,
2421 0.401816,
2422 -0.407552,
2423 0.13012,
2424 -0.0475244,
2425 0.511723,
2426 -0.4372, -0.0979126,
2428 0.135861,
2429 -0.0154144,
2430 -0.14047,
2431 -0.0250029,
2432 -0.190279,
2433 0.407283,
2434 -0.389184, -0.264153,
2436 0.0696822,
2437 -0.145585,
2438 0.370284,
2439 0.186825,
2440 -0.140736,
2441 0.274703,
2442 -0.334563, 0.247613,
2444 0.513165,
2445 -0.0845867,
2446 0.0532264,
2447 -0.00480601,
2448 -0.122408,
2449 0.47227,
2450 -0.268301, 0.103198,
2452 0.30756,
2453 -0.316293,
2454 -0.0686877,
2455 -0.330729,
2456 -0.461997,
2457 0.550857,
2458 -0.240851, 0.128258,
2460 0.786291,
2461 -0.0268103,
2462 0.111763,
2463 -0.308962,
2464 -0.17407,
2465 0.437154,
2466 -0.159879, 0.00374063,
2468 0.490301,
2469 0.0327826,
2470 -0.0340962,
2471 -0.118605,
2472 0.163879,
2473 0.2737,
2474 -0.299942, -0.284077,
2476 0.249377,
2477 -0.0307734,
2478 -0.0661631,
2479 0.233854,
2480 0.427987,
2481 0.614132,
2482 -0.288649, -0.109492,
2484 0.203939,
2485 -0.73956,
2486 -0.130748,
2487 0.22072,
2488 0.0647836,
2489 0.328726,
2490 -0.374602, -0.223114,
2492 0.0243489,
2493 0.109195,
2494 -0.416914,
2495 0.0201052,
2496 -0.0190542,
2497 0.947078,
2498 -0.333229, -0.165869,
2500 -0.00296729,
2501 -0.414378,
2502 0.231321,
2503 0.205365,
2504 0.161761,
2505 0.148608,
2506 -0.395063, -0.0498255,
2508 0.193279,
2509 -0.110946,
2510 -0.181174,
2511 -0.274578,
2512 -0.227511,
2513 0.190208,
2514 -0.256174, -0.188106,
2516 -0.0292958,
2517 0.0930939,
2518 0.0558456,
2519 0.257437,
2520 0.685481,
2521 0.307922,
2522 -0.320006, 0.250035,
2524 0.275942,
2525 -0.0856306,
2526 -0.352027,
2527 -0.103509,
2528 -0.00890859,
2529 0.276121,
2530 -0.324718, ];
2532
2533 Matrix::try_from(data.into(), 16, 8).unwrap()
2534 }
2535
2536 #[cfg(feature = "flatbuffers")]
2537 mod serialization {
2538 use std::sync::{
2539 atomic::{AtomicBool, Ordering},
2540 Arc,
2541 };
2542
2543 use super::*;
2544 use crate::alloc::{BumpAllocator, GlobalAllocator};
2545
2546 #[inline(never)]
2547 fn test_plan_serialization(
2548 quantizer: &dyn Quantizer,
2549 nbits: usize,
2550 dataset: MatrixView<f32>,
2551 ) {
2552 assert_eq!(quantizer.full_dim(), dataset.ncols());
2554 let scoped_global = ScopedAllocator::global();
2555
2556 let serialized = quantizer.serialize(GlobalAllocator).unwrap();
2557 let deserialized =
2558 try_deserialize::<GlobalAllocator, _>(&serialized, GlobalAllocator).unwrap();
2559
2560 assert_eq!(deserialized.nbits(), nbits);
2561 assert_eq!(deserialized.bytes(), quantizer.bytes());
2562 assert_eq!(deserialized.dim(), quantizer.dim());
2563 assert_eq!(deserialized.full_dim(), quantizer.full_dim());
2564 assert_eq!(deserialized.metric(), quantizer.metric());
2565
2566 for layout in QueryLayout::all() {
2567 assert_eq!(
2568 deserialized.is_supported(layout),
2569 quantizer.is_supported(layout)
2570 );
2571 }
2572
2573 let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2575 {
2576 let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2577 let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2578
2579 for row in dataset.row_iter() {
2580 quantizer
2581 .compress(row, OpaqueMut::new(&mut a), scoped_global)
2582 .unwrap();
2583 deserialized
2584 .compress(row, OpaqueMut::new(&mut b), scoped_global)
2585 .unwrap();
2586
2587 assert_eq!(a, b);
2589 }
2590 }
2591
2592 {
2594 let mut a0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2595 let mut a1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2596 let mut b0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2597 let mut b1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2598
2599 let q_computer = quantizer.distance_computer(GlobalAllocator).unwrap();
2600 let q_computer_ref = quantizer.distance_computer_ref();
2601 let d_computer = deserialized.distance_computer(GlobalAllocator).unwrap();
2602 let d_computer_ref = deserialized.distance_computer_ref();
2603
2604 for r0 in dataset.row_iter() {
2605 quantizer
2606 .compress(r0, OpaqueMut::new(&mut a0), scoped_global)
2607 .unwrap();
2608 deserialized
2609 .compress(r0, OpaqueMut::new(&mut b0), scoped_global)
2610 .unwrap();
2611 for r1 in dataset.row_iter() {
2612 quantizer
2613 .compress(r1, OpaqueMut::new(&mut a1), scoped_global)
2614 .unwrap();
2615 deserialized
2616 .compress(r1, OpaqueMut::new(&mut b1), scoped_global)
2617 .unwrap();
2618
2619 let a0 = Opaque::new(&a0);
2620 let a1 = Opaque::new(&a1);
2621
2622 let q_computer_dist = q_computer.evaluate_similarity(a0, a1).unwrap();
2623 let d_computer_dist = d_computer.evaluate_similarity(a0, a1).unwrap();
2624
2625 assert_eq!(q_computer_dist, d_computer_dist);
2626
2627 let q_computer_ref_dist = q_computer_ref.evaluate(a0, a1).unwrap();
2628
2629 assert_eq!(q_computer_dist, q_computer_ref_dist);
2630
2631 let d_computer_ref_dist = d_computer_ref.evaluate(a0, a1).unwrap();
2632 assert_eq!(d_computer_dist, d_computer_ref_dist);
2633 }
2634 }
2635 }
2636
2637 {
2639 let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2640 let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2641
2642 for layout in QueryLayout::all() {
2643 if !quantizer.is_supported(layout) {
2644 continue;
2645 }
2646
2647 for r in dataset.row_iter() {
2648 let q_computer = quantizer
2649 .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2650 .unwrap();
2651 let d_computer = deserialized
2652 .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2653 .unwrap();
2654
2655 for u in dataset.row_iter() {
2656 quantizer
2657 .compress(u, OpaqueMut::new(&mut a), scoped_global)
2658 .unwrap();
2659 deserialized
2660 .compress(u, OpaqueMut::new(&mut b), scoped_global)
2661 .unwrap();
2662
2663 assert_eq!(
2664 q_computer.evaluate_similarity(Opaque::new(&a)).unwrap(),
2665 d_computer.evaluate_similarity(Opaque::new(&b)).unwrap(),
2666 );
2667 }
2668 }
2669 }
2670 }
2671 }
2672
2673 #[derive(Debug, Clone)]
2675 struct FlakyAllocator {
2676 have_allocated: Arc<AtomicBool>,
2677 }
2678
2679 impl FlakyAllocator {
2680 fn new(have_allocated: Arc<AtomicBool>) -> Self {
2681 Self { have_allocated }
2682 }
2683 }
2684
2685 unsafe impl AllocatorCore for FlakyAllocator {
2687 fn allocate(
2688 &self,
2689 layout: std::alloc::Layout,
2690 ) -> Result<std::ptr::NonNull<[u8]>, AllocatorError> {
2691 if self.have_allocated.swap(true, Ordering::Relaxed) {
2692 Err(AllocatorError)
2693 } else {
2694 GlobalAllocator.allocate(layout)
2695 }
2696 }
2697
2698 unsafe fn deallocate(&self, ptr: std::ptr::NonNull<[u8]>, layout: std::alloc::Layout) {
2699 unsafe { GlobalAllocator.deallocate(ptr, layout) }
2701 }
2702 }
2703
2704 fn test_plan_panic_boundary<const NBITS: usize>(v: &Impl<NBITS>)
2705 where
2706 Impl<NBITS>: Quantizer,
2707 {
2708 let have_allocated = Arc::new(AtomicBool::new(false));
2710 let _: AllocatorError = v
2711 .serialize(FlakyAllocator::new(have_allocated.clone()))
2712 .unwrap_err();
2713 assert!(have_allocated.load(Ordering::Relaxed));
2714 }
2715
2716 #[test]
2717 fn test_plan_1bit_l2() {
2718 let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2719 test_plan_panic_boundary(&plan);
2720 test_plan_serialization(&plan, 1, data.as_view());
2721 }
2722
2723 #[test]
2724 fn test_plan_1bit_ip() {
2725 let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2726 test_plan_panic_boundary(&plan);
2727 test_plan_serialization(&plan, 1, data.as_view());
2728 }
2729
2730 #[test]
2731 fn test_plan_2bit_l2() {
2732 let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2733 test_plan_panic_boundary(&plan);
2734 test_plan_serialization(&plan, 2, data.as_view());
2735 }
2736
2737 #[test]
2738 fn test_plan_2bit_ip() {
2739 let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2740 test_plan_panic_boundary(&plan);
2741 test_plan_serialization(&plan, 2, data.as_view());
2742 }
2743
2744 #[test]
2745 fn test_plan_4bit_l2() {
2746 let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2747 test_plan_panic_boundary(&plan);
2748 test_plan_serialization(&plan, 4, data.as_view());
2749 }
2750
2751 #[test]
2752 fn test_plan_4bit_ip() {
2753 let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2754 test_plan_panic_boundary(&plan);
2755 test_plan_serialization(&plan, 4, data.as_view());
2756 }
2757
2758 #[test]
2759 fn test_plan_8bit_l2() {
2760 let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2761 test_plan_panic_boundary(&plan);
2762 test_plan_serialization(&plan, 8, data.as_view());
2763 }
2764
2765 #[test]
2766 fn test_plan_8bit_ip() {
2767 let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2768 test_plan_panic_boundary(&plan);
2769 test_plan_serialization(&plan, 8, data.as_view());
2770 }
2771
2772 #[test]
2773 fn test_allocation_order() {
2774 let (plan, _) = make_impl::<1>(SupportedMetric::SquaredL2);
2775 let buf = plan.serialize(GlobalAllocator).unwrap();
2776
2777 let allocator = BumpAllocator::new(8192, PowerOfTwo::new(64).unwrap()).unwrap();
2778 let deserialized =
2779 try_deserialize::<GlobalAllocator, _>(&buf, allocator.clone()).unwrap();
2780 assert_eq!(
2781 Poly::as_ptr(&deserialized).cast::<u8>(),
2782 allocator.as_ptr(),
2783 "expected the returned box to be allocated first",
2784 );
2785 }
2786 }
2787}