Skip to main content

diskann_quantization/spherical/
iface.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! The main export of this module is the dyn compatible [`Quantizer`] trait, which provides
7//! a common interface for interacting with bit-width specific [`SphericalQuantizer`]s,
8//! compressing vectors, and computing distances between compressed vectors.
9//!
10//! This is offered as a convenience interface for interacting with the myriad of generics
11//! associated with the [`SphericalQuantizer`]. Better performance can be achieved by using
12//! the generic types directly if desired.
13//!
14//! The [`Quantizer`] uses the [`Opaque`] and [`OpaqueMut`] types for its compressed data
15//! representations. These are thin wrappers around raw byte slices.
16//!
17//! Distance computation is performed using [`DistanceComputer`] and [`QueryComputer`]
18//!
19//! Concrete implementations of [`Quantizer`] are available via the generic struct [`Impl`].
20//!
21//! ## Compatibility Table
22//!
23//! Multiple [`QueryLayout`]s are supported when constructing a [`QueryComputer`], but not
24//! all layouts are supported for each back end. This table lists the valid instantiations
25//! of [`Impl`] (parameterized by data vector bit-width) and their supported query layouts.
26//!
27//! | Bits | Same As Data | Full Precision | Four-Bit Transposed | Scalar Quantized |
28//! |------|--------------|----------------|---------------------|------------------|
29//! |    1 |     Yes      |      Yes       |         Yes         |       No         |
30//! |    2 |     Yes      |      Yes       |          No         |       Yes        |
31//! |    4 |     Yes      |      Yes       |          No         |       Yes        |
32//! |    8 |     Yes      |      Yes       |          No         |       Yes        |
33//!
34//! # Example
35//!
36//! ```
37//! use diskann_quantization::{
38//!     alloc::{Poly, ScopedAllocator, AlignedAllocator, GlobalAllocator},
39//!     algorithms::TransformKind,
40//!     spherical::{iface, SupportedMetric, SphericalQuantizer, PreScale},
41//!     num::PowerOfTwo,
42//! };
43//! use diskann_utils::views::Matrix;
44//!
45//! // For illustration purposes, the dataset consists of just a single vector.
46//! let mut data = Matrix::new(1.0, 1, 4);
47//! let quantizer = SphericalQuantizer::train(
48//!     data.as_view(),
49//!     TransformKind::Null,
50//!     SupportedMetric::SquaredL2,
51//!     PreScale::None,
52//!     &mut rand::rng(),
53//!     GlobalAllocator
54//! ).unwrap();
55//!
56//! let quantizer: Box<dyn iface::Quantizer> = Box::new(
57//!     iface::Impl::<1>::new(quantizer).unwrap()
58//! );
59//!
60//! let alloc = AlignedAllocator::new(PowerOfTwo::new(1).unwrap());
61//! let mut buf = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
62//!
63//! quantizer.compress(
64//!     data.row(0),
65//!     iface::OpaqueMut::new(&mut buf),
66//!     ScopedAllocator::new(&alloc),
67//! ).unwrap();
68//!
69//! assert!(quantizer.is_supported(iface::QueryLayout::FullPrecision));
70//! assert!(!quantizer.is_supported(iface::QueryLayout::ScalarQuantized));
71//! ```
72
73/// # DevDocs
74///
75/// This section provides developer documentation for the structure of the structs and
76/// traits inside this file. The main goal of the [`DistanceComputer`] and [`QueryComputer`]
77/// implementations are to introduce just a single level of indirection.
78///
79/// ## Distance Computer Philosophy
80///
81/// The goal of this code (and in large part the reason for the somewhat spaghetti nature)
82/// is to do all the the dispatches:
83///
84/// * Number of bits.
85/// * Query layout.
86/// * Distance Type (L2, Inner Product, Cosine).
87/// * Micro-architecture specific code-generation
88///
89/// Behind a **single** level of dynamic dispatch. This means we need to bake all of this
90/// information into a single type, facilitated through a combination of the `Reify` and
91/// `Curried` private structs.
92///
93/// ## Anatomy of the [`DistanceComputer`]
94///
95/// To provide a single level of indirection for all distance function implementations,
96/// the [`DistanceComputer`] is a thin wrapper around the [`DynDistanceComputer`] trait.
97///
98/// Concrete implementations of this trait consist of a base distance function like
99/// [`CompensatedIP`] or [`CompensatedSquaredL2`]. Because the data is passed through the
100/// [`Opaque`] type, these distance functions are embedded inside a [`Reify`] which first
101/// converts the [`Opaque`] to the appropriate fully-typed object before calling the inner
102/// distance function.
103///
104/// This full typing is supplied by the private [`FromOpaque`] helper trait.
105///
106/// When returned from the [`Quantizer::distance_computer`] or
107/// [`Quantizer::distance_computer_ref`] traits, the resulting computer will be specialized
108/// to work solely on data vectors compressed through [`Quantizer::compress`].
109///
110/// When returned from [`Quantizer::query_computer`], the expected type of the query will
111/// depend on the [`QueryLayout`] supplied to that method.
112///
113/// The method [`QueryComputer::layout`] is provided to inspect at run time the query layout
114/// the object is meant for.
115///
116/// If at all possible, the [`QueryComputer`] should be preferred as it removes the
117/// possibility of providing an incorrect query layout and will be slightly faster since
118/// it does not require argument reification.
119///
120/// ## Anatomy of the [`QueryComputer`]
121///
122/// This is similar to the [`DistanceComputer`] but has the extra duty of supporting
123/// multiple different [`iface::QueryLayouts`] (compressions for the query). To that end,
124/// the stack of types used to implement the underlying [`DynDistanceComputer`] trait is:
125///
126/// * Base [`DistanceFunction`] (e.g. [`CompensatedIP`]).
127///
128/// * Embedded inside [`Curried`] - which also contains a heap-allocated representation of
129///   the query using the selected layout. For example, this could be one of.
130///
131///   - [`diskann_quantization::spherical::Query`]
132///   - [`diskann_quantization::spherical::FullQuery`]
133///   - [`diskann_quantization::sphericasl::Data`]
134///
135/// * Embedded inside [`Reify`] to convert [`Opaque`] to the correct type.
136use std::marker::PhantomData;
137
138use diskann_utils::{Reborrow, ReborrowMut};
139use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
140use diskann_wide::{
141    Architecture,
142    arch::{Scalar, Target1, Target2},
143};
144#[cfg(feature = "flatbuffers")]
145use flatbuffers::FlatBufferBuilder;
146use thiserror::Error;
147
148#[cfg(target_arch = "x86_64")]
149use diskann_wide::arch::x86_64::{V3, V4};
150
151#[cfg(target_arch = "aarch64")]
152use diskann_wide::arch::aarch64::Neon;
153
154use super::{
155    CompensatedCosine, CompensatedIP, CompensatedSquaredL2, Data, DataMut, DataRef, FullQuery,
156    FullQueryMut, FullQueryRef, Query, QueryMut, QueryRef, SphericalQuantizer, SupportedMetric,
157    quantizer,
158};
159use crate::{
160    AsFunctor, CompressIntoWith,
161    alloc::{
162        Allocator, AllocatorCore, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone,
163    },
164    bits::{self, Representation, Unsigned},
165    distances::{self, UnequalLengths},
166    error::InlineError,
167    meta,
168    num::PowerOfTwo,
169    poly,
170};
171#[cfg(feature = "flatbuffers")]
172use crate::{alloc::CompoundError, flatbuffers as fb};
173
174// A convenience definition to shorten the extensive where-clauses present in this file.
175type Rf32 = distances::Result<f32>;
176
177///////////////
178// Quantizer //
179///////////////
180
181/// A description of the buffer size (in bytes) and alignment required for a compressed query.
182#[derive(Debug, Clone)]
183pub struct QueryBufferDescription {
184    size: usize,
185    align: PowerOfTwo,
186}
187
188impl QueryBufferDescription {
189    /// Construct a new [`QueryBufferDescription`]
190    pub fn new(size: usize, align: PowerOfTwo) -> Self {
191        Self { size, align }
192    }
193
194    /// Return the number of bytes needed in a buffer for a compressed query.
195    pub fn bytes(&self) -> usize {
196        self.size
197    }
198
199    /// Return the necessary alignment of the base pointer for a query buffer.
200    pub fn align(&self) -> PowerOfTwo {
201        self.align
202    }
203}
204
205/// A dyn-compatible trait providing a common interface for a bit-width specific
206/// [`SphericalQuantizer`].
207///
208/// This allows us to have a single [`dyn Quantizer`] type without generics while still
209/// supporting the range of bit-widths and query strategies we wish to support.
210///
211/// A level of indirection for each distance computation, unfortunately, is required to
212/// support this. But we try to structure the code so there is only a single level of
213/// indirection.
214///
215/// # Allocator
216///
217/// The quantizer is parameterized by the allocator provided used to acquire any necessary
218/// memory for returned data structures.  The contract is as follows:
219///
220/// 1. Any allocation made as part of a returned data structure from a function will be
221///    performed through the allocator given to that function.
222///
223/// 2. If dynamic memory allocation for scratch space is required, a separate `scratch`
224///    allocator will be required and all scratch space allocations will go through that
225///    allocator.
226pub trait Quantizer<A = GlobalAllocator>: Send + Sync
227where
228    A: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
229{
230    /// The effective number of bits in the encoding.
231    fn nbits(&self) -> usize;
232
233    /// The number of bytes occupied by each compressed vector.
234    fn bytes(&self) -> usize;
235
236    /// The effective dimensionality of each compressed vector.
237    fn dim(&self) -> usize;
238
239    /// The dimensionality of the full-precision input vectors.
240    fn full_dim(&self) -> usize;
241
242    /// Return a distance computer capable on operating on validly initialized [`Opaque`]
243    /// slices of length [`Self::bytes`].
244    ///
245    /// These slices should be initialized by [`Self::compress`].
246    ///
247    /// The query layout associated with this computer will always be
248    /// [`QueryLayout::SameAsData`].
249    fn distance_computer(&self, allocator: A) -> Result<DistanceComputer<A>, AllocatorError>;
250
251    /// Return a scoped distance computer capable on operating on validly initialized
252    /// [`Opaque`] slices of length [`Self::bytes`].
253    ///
254    /// These slices should be initialized by [`Self::compress`].
255    fn distance_computer_ref(&self) -> &dyn DynDistanceComputer;
256
257    /// A stand alone distance computer specialized for the specified query layout.
258    ///
259    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
260    ///
261    /// # Note
262    ///
263    /// The returned object will **only** be compatible with queries compressed using
264    /// [`Self::compress_query`] using the same layout. If possible, the API
265    /// [`Self::fused_query_computer`] should be used to avoid this ambiguity.
266    fn query_computer(
267        &self,
268        layout: QueryLayout,
269        allocator: A,
270    ) -> Result<DistanceComputer<A>, DistanceComputerError>;
271
272    /// Return the number of bytes and alignment of a buffer used to contain a compressed
273    /// query with the provided layout.
274    ///
275    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
276    fn query_buffer_description(
277        &self,
278        layout: QueryLayout,
279    ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>;
280
281    /// Compress the query using the specified layout into `buffer`.
282    ///
283    /// This requires that buffer have the exact size and alignment as that returned from
284    /// `query_buffer_description`.
285    ///
286    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
287    fn compress_query(
288        &self,
289        x: &[f32],
290        layout: QueryLayout,
291        allow_rescale: bool,
292        buffer: OpaqueMut<'_>,
293        scratch: ScopedAllocator<'_>,
294    ) -> Result<(), QueryCompressionError>;
295
296    /// Return a query for the argument `x` capable on operating on validly initialized
297    /// [`Opaque`] slices of length [`Self::bytes`].
298    ///
299    /// These slices should be initialized by [`Self::compress`].
300    ///
301    /// Note: Only layouts for which [`Self::is_supported`] returns `true` are supported.
302    fn fused_query_computer(
303        &self,
304        x: &[f32],
305        layout: QueryLayout,
306        allow_rescale: bool,
307        allocator: A,
308        scratch: ScopedAllocator<'_>,
309    ) -> Result<QueryComputer<A>, QueryComputerError>;
310
311    /// Return whether or not this plan supports the given [`QueryLayout`].
312    fn is_supported(&self, layout: QueryLayout) -> bool;
313
314    /// Compress the vector `x` into the opaque slice.
315    ///
316    /// # Note
317    ///
318    /// This requires the length of the slice to be exactly [`Self::bytes`]. There is no
319    /// alignment restriction on the base pointer.
320    fn compress(
321        &self,
322        x: &[f32],
323        into: OpaqueMut<'_>,
324        scratch: ScopedAllocator<'_>,
325    ) -> Result<(), CompressionError>;
326
327    /// Return the metric this plan was created with.
328    fn metric(&self) -> SupportedMetric;
329
330    /// Clone the backing object.
331    fn try_clone_into(&self, allocator: A) -> Result<Poly<dyn Quantizer<A>, A>, AllocatorError>;
332
333    crate::utils::features! {
334        #![feature = "flatbuffers"]
335        /// Serialize `self` into a flatbuffer, returning the flatbuffer. The function
336        /// [`try_deserialize`] should undo this operation.
337        fn serialize(&self, allocator: A) -> Result<Poly<[u8], A>, AllocatorError>;
338    }
339}
340
341#[derive(Debug, Error)]
342#[error("Layout {layout} is not supported for {desc}")]
343pub struct UnsupportedQueryLayout {
344    layout: QueryLayout,
345    desc: &'static str,
346}
347
348impl UnsupportedQueryLayout {
349    fn new(layout: QueryLayout, desc: &'static str) -> Self {
350        Self { layout, desc }
351    }
352}
353
354#[derive(Debug, Error)]
355#[non_exhaustive]
356pub enum DistanceComputerError {
357    #[error(transparent)]
358    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
359    #[error(transparent)]
360    AllocatorError(#[from] AllocatorError),
361}
362
363#[derive(Debug, Error)]
364#[non_exhaustive]
365pub enum QueryCompressionError {
366    #[error(transparent)]
367    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
368    #[error(transparent)]
369    CompressionError(#[from] CompressionError),
370    #[error(transparent)]
371    NotCanonical(#[from] NotCanonical),
372    #[error(transparent)]
373    AllocatorError(#[from] AllocatorError),
374}
375
376#[derive(Debug, Error)]
377#[non_exhaustive]
378pub enum QueryComputerError {
379    #[error(transparent)]
380    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
381    #[error(transparent)]
382    CompressionError(#[from] CompressionError),
383    #[error(transparent)]
384    AllocatorError(#[from] AllocatorError),
385}
386
387/// Errors that can occur during data compression
388#[derive(Debug, Error)]
389#[error("Error occured during query compression")]
390pub enum CompressionError {
391    /// The input buffer did not have the expected layout. This is an input error.
392    NotCanonical(#[source] InlineError<16>),
393
394    /// Forward any error that occurs during the compression process.
395    ///
396    /// See [`quantizer::CompressionError`] for the complete list.
397    CompressionError(#[source] quantizer::CompressionError),
398}
399
400impl CompressionError {
401    fn not_canonical<E>(error: E) -> Self
402    where
403        E: std::error::Error + Send + Sync + 'static,
404    {
405        Self::NotCanonical(InlineError::new(error))
406    }
407}
408
409#[derive(Debug, Error)]
410#[error("An opaque argument did not have the required alignment or length")]
411pub struct NotCanonical {
412    source: Box<dyn std::error::Error + Send + Sync>,
413}
414
415impl NotCanonical {
416    fn new<E>(err: E) -> Self
417    where
418        E: std::error::Error + Send + Sync + 'static,
419    {
420        Self {
421            source: Box::new(err),
422        }
423    }
424}
425
426////////////
427// Opaque //
428////////////
429
430/// A type-erased slice wrapper used to hide the implementation of spherically quantized
431/// vectors. This allows multiple bit-width implementations to share the same type.
432#[derive(Debug, Clone, Copy)]
433#[repr(transparent)]
434pub struct Opaque<'a>(&'a [u8]);
435
436impl<'a> Opaque<'a> {
437    /// Construct a new `Opaque` referencing `slice`.
438    pub fn new(slice: &'a [u8]) -> Self {
439        Self(slice)
440    }
441
442    /// Consume `self`, returning the wrapped slice.
443    pub fn into_inner(self) -> &'a [u8] {
444        self.0
445    }
446}
447
448impl std::ops::Deref for Opaque<'_> {
449    type Target = [u8];
450    fn deref(&self) -> &[u8] {
451        self.0
452    }
453}
454impl<'short> Reborrow<'short> for Opaque<'_> {
455    type Target = Opaque<'short>;
456    fn reborrow(&'short self) -> Self::Target {
457        *self
458    }
459}
460
461/// A type-erased slice wrapper used to hide the implementation of spherically quantized
462/// vectors. This allows multiple bit-width implementations to share the same type.
463#[derive(Debug)]
464#[repr(transparent)]
465pub struct OpaqueMut<'a>(&'a mut [u8]);
466
467impl<'a> OpaqueMut<'a> {
468    /// Construct a new `OpaqueMut` referencing `slice`.
469    pub fn new(slice: &'a mut [u8]) -> Self {
470        Self(slice)
471    }
472
473    /// Inspect the referenced slice.
474    pub fn inspect(&mut self) -> &mut [u8] {
475        self.0
476    }
477}
478
479impl std::ops::Deref for OpaqueMut<'_> {
480    type Target = [u8];
481    fn deref(&self) -> &[u8] {
482        self.0
483    }
484}
485
486impl std::ops::DerefMut for OpaqueMut<'_> {
487    fn deref_mut(&mut self) -> &mut [u8] {
488        self.0
489    }
490}
491
492//////////////////
493// Query Layout //
494//////////////////
495
496/// The layout to use for the query in [`DistanceComputer`] and [`QueryComputer`].
497#[derive(Debug, Clone, Copy, PartialEq, Eq)]
498pub enum QueryLayout {
499    /// Use the same compression strategy as the data vectors.
500    ///
501    /// This may result in slow compression if high bit-widths are used.
502    SameAsData,
503
504    /// Use 4-bits for the query vector using a bitwise transpose layout.
505    FourBitTransposed,
506
507    /// Use scalar quantization for the query using the same number of bits per dimension
508    /// as the dataset.
509    ScalarQuantized,
510
511    /// Use `f32` to encode the query.
512    FullPrecision,
513}
514
515impl QueryLayout {
516    #[cfg(test)]
517    fn all() -> [Self; 4] {
518        [
519            Self::SameAsData,
520            Self::FourBitTransposed,
521            Self::ScalarQuantized,
522            Self::FullPrecision,
523        ]
524    }
525}
526
527impl std::fmt::Display for QueryLayout {
528    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529        <Self as std::fmt::Debug>::fmt(self, fmt)
530    }
531}
532
533//////////////////////
534// Layout Reporting //
535//////////////////////
536
537/// Because dynamic dispatch is used heavily in the implementation, it can be easy to lose
538/// track of the actual layout used for the [`DistanceComputer`] and [`QueryComputer`].
539///
540/// This trait provides a mechanism by which we ensure the correct runtime layout is always
541/// reported without requiring manual tracking.
542trait ReportQueryLayout {
543    fn report_query_layout(&self) -> QueryLayout;
544}
545
546impl<T, M, L, R> ReportQueryLayout for Reify<T, M, L, R>
547where
548    T: ReportQueryLayout,
549{
550    fn report_query_layout(&self) -> QueryLayout {
551        self.inner.report_query_layout()
552    }
553}
554
555impl<D, Q> ReportQueryLayout for Curried<D, Q>
556where
557    Q: ReportQueryLayout,
558{
559    fn report_query_layout(&self) -> QueryLayout {
560        self.query.report_query_layout()
561    }
562}
563
564impl<const NBITS: usize, A> ReportQueryLayout for Data<NBITS, A>
565where
566    Unsigned: Representation<NBITS>,
567    A: AllocatorCore,
568{
569    fn report_query_layout(&self) -> QueryLayout {
570        QueryLayout::SameAsData
571    }
572}
573
574impl<const NBITS: usize, A> ReportQueryLayout for Query<NBITS, bits::Dense, A>
575where
576    Unsigned: Representation<NBITS>,
577    A: AllocatorCore,
578{
579    fn report_query_layout(&self) -> QueryLayout {
580        QueryLayout::ScalarQuantized
581    }
582}
583
584impl<A> ReportQueryLayout for Query<4, bits::BitTranspose, A>
585where
586    A: AllocatorCore,
587{
588    fn report_query_layout(&self) -> QueryLayout {
589        QueryLayout::FourBitTransposed
590    }
591}
592
593impl<A> ReportQueryLayout for FullQuery<A>
594where
595    A: AllocatorCore,
596{
597    fn report_query_layout(&self) -> QueryLayout {
598        QueryLayout::FullPrecision
599    }
600}
601
602//-----------------------//
603// Reification Utilities //
604//-----------------------//
605
606/// An adaptor trait defining how to go from an `Opaque` slice to a fully reified type.
607///
608/// THis is the building block for building distance computers with the reificiation code
609/// inlined into the callsite.
610trait FromOpaque: 'static + Send + Sync {
611    type Target<'a>;
612    type Error: std::error::Error + Send + Sync + 'static;
613
614    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error>;
615}
616
617/// Reify as full-precision.
618#[derive(Debug, Default)]
619pub(super) struct AsFull;
620
621/// Reify as data.
622#[derive(Debug, Default)]
623pub(super) struct AsData<const NBITS: usize>;
624
625/// Reify as scalar quantized query.
626#[derive(Debug)]
627pub(super) struct AsQuery<const NBITS: usize, Perm = bits::Dense> {
628    _marker: PhantomData<Perm>,
629}
630
631// This impelmentation works around the `derive` impl requiring `Perm: Default`.
632impl<const NBITS: usize, Perm> Default for AsQuery<NBITS, Perm> {
633    fn default() -> Self {
634        Self {
635            _marker: PhantomData,
636        }
637    }
638}
639
640impl FromOpaque for AsFull {
641    type Target<'a> = FullQueryRef<'a>;
642    type Error = meta::slice::NotCanonical;
643
644    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
645        Self::Target::from_canonical(query.into_inner(), dim)
646    }
647}
648
649impl ReportQueryLayout for AsFull {
650    fn report_query_layout(&self) -> QueryLayout {
651        QueryLayout::FullPrecision
652    }
653}
654
655impl<const NBITS: usize> FromOpaque for AsData<NBITS>
656where
657    Unsigned: Representation<NBITS>,
658{
659    type Target<'a> = DataRef<'a, NBITS>;
660    type Error = meta::NotCanonical;
661
662    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
663        Self::Target::from_canonical_back(query.into_inner(), dim)
664    }
665}
666
667impl<const NBITS: usize> ReportQueryLayout for AsData<NBITS> {
668    fn report_query_layout(&self) -> QueryLayout {
669        QueryLayout::SameAsData
670    }
671}
672
673impl<const NBITS: usize, Perm> FromOpaque for AsQuery<NBITS, Perm>
674where
675    Unsigned: Representation<NBITS>,
676    Perm: bits::PermutationStrategy<NBITS> + Send + Sync + 'static,
677{
678    type Target<'a> = QueryRef<'a, NBITS, Perm>;
679    type Error = meta::NotCanonical;
680
681    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
682        Self::Target::from_canonical_back(query.into_inner(), dim)
683    }
684}
685
686impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::Dense> {
687    fn report_query_layout(&self) -> QueryLayout {
688        QueryLayout::ScalarQuantized
689    }
690}
691
692impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::BitTranspose> {
693    fn report_query_layout(&self) -> QueryLayout {
694        QueryLayout::FourBitTransposed
695    }
696}
697
698//-------//
699// Reify //
700//-------//
701
702/// Helper struct to convert an [`Opaque`] to a fully-typed [`DataRef`].
703pub(super) struct Reify<T, M, L, R> {
704    inner: T,
705    dim: usize,
706    arch: M,
707    _markers: PhantomData<(L, R)>,
708}
709
710impl<T, M, L, R> Reify<T, M, L, R> {
711    pub(super) fn new(inner: T, dim: usize, arch: M) -> Self {
712        Self {
713            inner,
714            dim,
715            arch,
716            _markers: PhantomData,
717        }
718    }
719}
720
721impl<M, T, R> DynQueryComputer for Reify<T, M, (), R>
722where
723    M: Architecture,
724    R: FromOpaque,
725    T: ReportQueryLayout + Send + Sync,
726    for<'a> &'a T: Target1<M, Rf32, R::Target<'a>>,
727{
728    fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
729        self.arch.run2(
730            |this: &Self, x| {
731                let x = R::from_opaque(x, this.dim)
732                    .map_err(|err| QueryDistanceError::XReify(InlineError::new(err)))?;
733                this.arch
734                    .run1(&this.inner, x)
735                    .map_err(QueryDistanceError::UnequalLengths)
736            },
737            self,
738            x,
739        )
740    }
741
742    fn layout(&self) -> QueryLayout {
743        self.inner.report_query_layout()
744    }
745}
746
747impl<T, M, Q, R> DynDistanceComputer for Reify<T, M, Q, R>
748where
749    M: Architecture,
750    Q: FromOpaque + Default + ReportQueryLayout,
751    R: FromOpaque,
752    T: for<'a> Target2<M, Rf32, Q::Target<'a>, R::Target<'a>> + Copy + Send + Sync,
753{
754    fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
755        self.arch.run3(
756            |this: &Self, query, x| {
757                let query = Q::from_opaque(query, this.dim)
758                    .map_err(|err| DistanceError::QueryReify(InlineError::<24>::new(err)))?;
759
760                let x = R::from_opaque(x, this.dim)
761                    .map_err(|err| DistanceError::XReify(InlineError::<16>::new(err)))?;
762
763                this.arch
764                    .run2_inline(this.inner, query, x)
765                    .map_err(DistanceError::UnequalLengths)
766            },
767            self,
768            query,
769            x,
770        )
771    }
772
773    fn layout(&self) -> QueryLayout {
774        Q::default().report_query_layout()
775    }
776}
777
778///////////////////////
779// Query Computation //
780///////////////////////
781
782/// Errors that can occur while perfoming distance cacluations on opaque vectors.
783#[derive(Debug, Error)]
784pub enum QueryDistanceError {
785    /// The right-hand data argument appears to be malformed.
786    #[error("trouble trying to reify the argument")]
787    XReify(#[source] InlineError<16>),
788
789    /// Distance computation failed because the logical lengths of the two vectors differ.
790    #[error("encountered while trying to compute distances")]
791    UnequalLengths(#[source] UnequalLengths),
792}
793
794pub trait DynQueryComputer: Send + Sync {
795    fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError>;
796    fn layout(&self) -> QueryLayout;
797}
798
799/// An opaque [`PreprocessedDistanceFunction`] for the [`Quantizer`] trait object.
800///
801/// # Note
802///
803/// This is only valid to call on [`Opaque`] slices compressed by the same [`Quantizer`] that
804/// created the computer.
805///
806/// Otherwise, distance computations may return garbage values or panic.
807pub struct QueryComputer<A = GlobalAllocator>
808where
809    A: AllocatorCore,
810{
811    inner: Poly<dyn DynQueryComputer, A>,
812}
813
814impl<A> QueryComputer<A>
815where
816    A: AllocatorCore,
817{
818    fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
819    where
820        T: DynQueryComputer + 'static,
821    {
822        let inner = Poly::new(inner, allocator)?;
823        Ok(Self {
824            inner: poly!(DynQueryComputer, inner),
825        })
826    }
827
828    /// Report the layout used by the query computer.
829    pub fn layout(&self) -> QueryLayout {
830        self.inner.layout()
831    }
832
833    /// This is a temporary function until custom allocator support fully comes on line.
834    pub fn into_inner(self) -> Poly<dyn DynQueryComputer, A> {
835        self.inner
836    }
837}
838
839impl<A> std::fmt::Debug for QueryComputer<A>
840where
841    A: AllocatorCore,
842{
843    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
844        write!(
845            f,
846            "dynamic fused query computer with layout \"{}\"",
847            self.layout()
848        )
849    }
850}
851
852impl<A> PreprocessedDistanceFunction<Opaque<'_>, Result<f32, QueryDistanceError>>
853    for QueryComputer<A>
854where
855    A: AllocatorCore,
856{
857    fn evaluate_similarity(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
858        self.inner.evaluate(x)
859    }
860}
861
862/// To handle multiple query bit-widths, we use type erasure on the actual distance
863/// function implementation.
864///
865/// This struct represents the partial application of the `inner` distance function with
866/// `query` in a generic way so we only have one level of dynamic dispatch when computing
867/// distances.
868pub(super) struct Curried<D, Q> {
869    inner: D,
870    query: Q,
871}
872
873impl<D, Q> Curried<D, Q> {
874    pub(super) fn new(inner: D, query: Q) -> Self {
875        Self { inner, query }
876    }
877}
878
879impl<A, D, Q, T, R> Target1<A, R, T> for &Curried<D, Q>
880where
881    A: Architecture,
882    Q: for<'a> Reborrow<'a>,
883    D: for<'a> Target2<A, R, <Q as Reborrow<'a>>::Target, T> + Copy,
884{
885    fn run(self, arch: A, x: T) -> R {
886        self.inner.run(arch, self.query.reborrow(), x)
887    }
888}
889
890///////////////////////
891// Distance Computer //
892///////////////////////
893
894/// Errors that can occur while perfoming distance cacluations on opaque vectors.
895#[derive(Debug, Error)]
896pub enum DistanceError {
897    /// The left-hand data argument appears to be malformed.
898    #[error("trouble trying to reify the left-hand argument")]
899    QueryReify(InlineError<24>),
900
901    /// The right-hand data argument appears to be malformed.
902    #[error("trouble trying to reify the right-hand argument")]
903    XReify(InlineError<16>),
904
905    /// Distance computation failed because the logical lengths of the two vectors differ.
906    ///
907    /// If vector reificiation occurs successfully, then this should not be returned.
908    #[error("encountered while trying to compute distances")]
909    UnequalLengths(UnequalLengths),
910}
911
912pub trait DynDistanceComputer: Send + Sync {
913    fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError>;
914    fn layout(&self) -> QueryLayout;
915}
916
917/// An opaque [`DistanceFunction`] for the [`Quantizer`] trait object.
918///
919/// # Note
920///
921/// Left-hand arguments must be [`Opaque`] slices compressed using
922/// [`Quantizer::compress_query`] using [`Self::layout`].
923///
924/// Right-hand arguments must be [`Opaque`] slices compressed using [`Quantizer::compress`].
925///
926/// Otherwise, distance computations may return garbage values or panic.
927pub struct DistanceComputer<A = GlobalAllocator>
928where
929    A: AllocatorCore,
930{
931    inner: Poly<dyn DynDistanceComputer, A>,
932}
933
934impl<A> DistanceComputer<A>
935where
936    A: AllocatorCore,
937{
938    pub(super) fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
939    where
940        T: DynDistanceComputer + 'static,
941    {
942        let inner = Poly::new(inner, allocator)?;
943        Ok(Self {
944            inner: poly!(DynDistanceComputer, inner),
945        })
946    }
947
948    /// Report the layout used by the query computer.
949    pub fn layout(&self) -> QueryLayout {
950        self.inner.layout()
951    }
952
953    pub fn into_inner(self) -> Poly<dyn DynDistanceComputer, A> {
954        self.inner
955    }
956}
957
958impl<A> std::fmt::Debug for DistanceComputer<A>
959where
960    A: AllocatorCore,
961{
962    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
963        write!(
964            f,
965            "dynamic distance computer with layout \"{}\"",
966            self.layout()
967        )
968    }
969}
970
971impl<A> DistanceFunction<Opaque<'_>, Opaque<'_>, Result<f32, DistanceError>> for DistanceComputer<A>
972where
973    A: AllocatorCore,
974{
975    fn evaluate_similarity(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
976        self.inner.evaluate(query, x)
977    }
978}
979
980//////////
981// Impl //
982//////////
983
984/// The base number of bytes to allocate when attempting to serialize a quantizer.
985#[cfg(all(not(test), feature = "flatbuffers"))]
986const DEFAULT_SERIALIZED_BYTES: usize = 1024;
987
988// When testing, use a small value so we trigger the reallocation logic.
989#[cfg(all(test, feature = "flatbuffers"))]
990const DEFAULT_SERIALIZED_BYTES: usize = 1;
991
992/// Implementation for [`Quantizer`] specializing on the number of bits used for data
993/// compression.
994pub struct Impl<const NBITS: usize, A = GlobalAllocator>
995where
996    A: Allocator,
997{
998    quantizer: SphericalQuantizer<A>,
999    distance: Poly<dyn DynDistanceComputer, A>,
1000}
1001
1002/// Pre-dispatch distance functions between compressed data vectors from `quantizer`
1003/// specialized for the current run-time mciro architecture.
1004pub trait Constructible<A = GlobalAllocator>
1005where
1006    A: Allocator,
1007{
1008    fn dispatch_distance(
1009        quantizer: &SphericalQuantizer<A>,
1010    ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError>;
1011}
1012
1013impl<const NBITS: usize, A: Allocator> Constructible<A> for Impl<NBITS, A>
1014where
1015    A: Allocator,
1016    AsData<NBITS>: FromOpaque,
1017    SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1018{
1019    fn dispatch_distance(
1020        quantizer: &SphericalQuantizer<A>,
1021    ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError> {
1022        diskann_wide::arch::dispatch2_no_features(
1023            ComputerDispatcher::<AsData<NBITS>, NBITS>::new(),
1024            quantizer,
1025            quantizer.allocator().clone(),
1026        )
1027        .map(|obj| obj.inner)
1028    }
1029}
1030
1031impl<const NBITS: usize, A> TryClone for Impl<NBITS, A>
1032where
1033    A: Allocator,
1034    AsData<NBITS>: FromOpaque,
1035    SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1036{
1037    fn try_clone(&self) -> Result<Self, AllocatorError> {
1038        Self::new(self.quantizer.try_clone()?)
1039    }
1040}
1041
1042impl<const NBITS: usize, A: Allocator> Impl<NBITS, A> {
1043    /// Construct a new plan around `quantizer` providing distance computers for `metric`.
1044    pub fn new(quantizer: SphericalQuantizer<A>) -> Result<Self, AllocatorError>
1045    where
1046        Self: Constructible<A>,
1047    {
1048        let distance = Self::dispatch_distance(&quantizer)?;
1049        Ok(Self {
1050            quantizer,
1051            distance,
1052        })
1053    }
1054
1055    /// Return the underlying [`SphericalQuantizer`].
1056    pub fn quantizer(&self) -> &SphericalQuantizer<A> {
1057        &self.quantizer
1058    }
1059
1060    /// Return `true` if this plan supports `layout` for query computers.
1061    ///
1062    /// Otherwise, return `false`.
1063    pub fn supports(layout: QueryLayout) -> bool {
1064        if const { NBITS == 1 } {
1065            [
1066                QueryLayout::SameAsData,
1067                QueryLayout::FourBitTransposed,
1068                QueryLayout::FullPrecision,
1069            ]
1070            .contains(&layout)
1071        } else {
1072            [
1073                QueryLayout::SameAsData,
1074                QueryLayout::ScalarQuantized,
1075                QueryLayout::FullPrecision,
1076            ]
1077            .contains(&layout)
1078        }
1079    }
1080
1081    /// Return a [`DistanceComputer`] that is specialized for the most specific runtime
1082    /// architecture.
1083    fn query_computer<Q, B>(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError>
1084    where
1085        Q: FromOpaque,
1086        B: AllocatorCore,
1087        SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1088    {
1089        diskann_wide::arch::dispatch2_no_features(
1090            ComputerDispatcher::<Q, NBITS>::new(),
1091            &self.quantizer,
1092            allocator,
1093        )
1094    }
1095
1096    fn compress_query<'a, T>(
1097        &self,
1098        query: &'a [f32],
1099        storage: T,
1100        scratch: ScopedAllocator<'a>,
1101    ) -> Result<(), QueryCompressionError>
1102    where
1103        SphericalQuantizer<A>: CompressIntoWith<&'a [f32], T, ScopedAllocator<'a>, Error = quantizer::CompressionError>,
1104    {
1105        self.quantizer
1106            .compress_into_with(query, storage, scratch)
1107            .map_err(|err| CompressionError::CompressionError(err).into())
1108    }
1109
1110    /// Return a [`QueryComputer`] that is specialized for the most specific runtime
1111    /// architecture.
1112    fn fused_query_computer<Q, T, B>(
1113        &self,
1114        query: &[f32],
1115        mut storage: T,
1116        allocator: B,
1117        scratch: ScopedAllocator<'_>,
1118    ) -> Result<QueryComputer<B>, QueryComputerError>
1119    where
1120        Q: FromOpaque,
1121        T: for<'a> ReborrowMut<'a>
1122            + for<'a> Reborrow<'a, Target = Q::Target<'a>>
1123            + ReportQueryLayout
1124            + Send
1125            + Sync
1126            + 'static,
1127        B: AllocatorCore,
1128        SphericalQuantizer<A>: for<'a> CompressIntoWith<
1129                &'a [f32],
1130                <T as ReborrowMut<'a>>::Target,
1131                ScopedAllocator<'a>,
1132                Error = quantizer::CompressionError,
1133            >,
1134        SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1135    {
1136        if let Err(err) = self
1137            .quantizer
1138            .compress_into_with(query, storage.reborrow_mut(), scratch)
1139        {
1140            return Err(CompressionError::CompressionError(err).into());
1141        }
1142
1143        diskann_wide::arch::dispatch3_no_features(
1144            ComputerDispatcher::<Q, NBITS>::new(),
1145            &self.quantizer,
1146            storage,
1147            allocator,
1148        )
1149        .map_err(|e| e.into())
1150    }
1151
1152    #[cfg(feature = "flatbuffers")]
1153    fn serialize<B>(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError>
1154    where
1155        B: Allocator + std::panic::UnwindSafe,
1156        A: std::panic::RefUnwindSafe,
1157    {
1158        let mut buf = FlatBufferBuilder::new_in(Poly::broadcast(
1159            0u8,
1160            DEFAULT_SERIALIZED_BYTES,
1161            allocator.clone(),
1162        )?);
1163
1164        let quantizer = &self.quantizer;
1165
1166        let (root, mut buf) = match std::panic::catch_unwind(move || {
1167            let offset = quantizer.pack(&mut buf);
1168
1169            let root = fb::spherical::Quantizer::create(
1170                &mut buf,
1171                &fb::spherical::QuantizerArgs {
1172                    quantizer: Some(offset),
1173                    nbits: NBITS as u32,
1174                },
1175            );
1176            (root, buf)
1177        }) {
1178            Ok(ret) => ret,
1179            Err(err) => match err.downcast_ref::<String>() {
1180                Some(msg) => {
1181                    if msg.contains("AllocatorError") {
1182                        return Err(AllocatorError);
1183                    } else {
1184                        std::panic::resume_unwind(err);
1185                    }
1186                }
1187                None => std::panic::resume_unwind(err),
1188            },
1189        };
1190
1191        // Finish serializing and then copy out the finished data into a newly allocated buffer.
1192        fb::spherical::finish_quantizer_buffer(&mut buf, root);
1193        Poly::from_iter(buf.finished_data().iter().copied(), allocator)
1194    }
1195}
1196
1197//----------------------//
1198// Distance Dispatching //
1199//----------------------//
1200
1201/// This trait and [`ComputerDispatcher`] are the glue for pre-dispatching
1202/// micro-architecture compatibility of distance computers.
1203///
1204/// This trait takes
1205///
1206/// * `M`: The target micro-architecture.
1207/// * `Q`: The target query type
1208///
1209/// And generates a specialized `DistanceComputer` and `QueryComputer`.
1210///
1211/// The [`ComputerDispatcher`] struct implements the [`diskann_wide::arch::Target2`] and
1212/// [`diskann_wide::arch::Target3`] traits to do the architecture-dispatching.
1213trait BuildComputer<M, Q, const N: usize>
1214where
1215    M: Architecture,
1216    Q: FromOpaque,
1217{
1218    /// Build a [`DistanceComputer`] targeting the micro-architecture `M`.
1219    ///
1220    /// The resulting object should implement distance calculations using just a single
1221    /// level of indirection.
1222    fn build_computer<A>(
1223        &self,
1224        arch: M,
1225        allocator: A,
1226    ) -> Result<DistanceComputer<A>, AllocatorError>
1227    where
1228        A: AllocatorCore;
1229
1230    /// Build a [`DistanceComputer`] with `query` targeting the micro-architecture `M`.
1231    ///
1232    /// The resulting object should implement distance calculations using just a single
1233    /// level of indirection.
1234    fn build_fused_computer<R, A>(
1235        &self,
1236        arch: M,
1237        query: R,
1238        allocator: A,
1239    ) -> Result<QueryComputer<A>, AllocatorError>
1240    where
1241        R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1242        A: AllocatorCore;
1243}
1244
1245fn identity<T>(x: T) -> T {
1246    x
1247}
1248
1249macro_rules! dispatch_map {
1250    ($N:literal, $Q:ty, $arch:ty) => {
1251        dispatch_map!($N, $Q, $arch, identity);
1252    };
1253    ($N:literal, $Q:ty, $arch:ty, $op:ident) => {
1254        impl<A> BuildComputer<$arch, $Q, $N> for SphericalQuantizer<A>
1255        where
1256            A: Allocator,
1257        {
1258            fn build_computer<B>(
1259                &self,
1260                input_arch: $arch,
1261                allocator: B,
1262            ) -> Result<DistanceComputer<B>, AllocatorError>
1263            where
1264                B: AllocatorCore,
1265            {
1266                type D = AsData<$N>;
1267
1268                // Perform any architecture down-casting.
1269                let arch = ($op)(input_arch);
1270                let dim = self.output_dim();
1271                match self.metric() {
1272                    SupportedMetric::SquaredL2 => {
1273                        let reify = Reify::<CompensatedSquaredL2, _, $Q, D>::new(
1274                            self.as_functor(),
1275                            dim,
1276                            arch,
1277                        );
1278                        DistanceComputer::new(reify, allocator)
1279                    }
1280                    SupportedMetric::InnerProduct => {
1281                        let reify =
1282                            Reify::<CompensatedIP, _, $Q, D>::new(self.as_functor(), dim, arch);
1283                        DistanceComputer::new(reify, allocator)
1284                    }
1285                    SupportedMetric::Cosine => {
1286                        let reify =
1287                            Reify::<CompensatedCosine, _, $Q, D>::new(self.as_functor(), dim, arch);
1288                        DistanceComputer::new(reify, allocator)
1289                    }
1290                }
1291            }
1292
1293            fn build_fused_computer<R, B>(
1294                &self,
1295                input_arch: $arch,
1296                query: R,
1297                allocator: B,
1298            ) -> Result<QueryComputer<B>, AllocatorError>
1299            where
1300                R: ReportQueryLayout
1301                    + for<'a> Reborrow<'a, Target = <$Q as FromOpaque>::Target<'a>>
1302                    + Send
1303                    + Sync
1304                    + 'static,
1305                B: AllocatorCore,
1306            {
1307                type D = AsData<$N>;
1308                let arch = ($op)(input_arch);
1309                let dim = self.output_dim();
1310                match self.metric() {
1311                    SupportedMetric::SquaredL2 => {
1312                        let computer: CompensatedSquaredL2 = self.as_functor();
1313                        let curried = Curried::new(computer, query);
1314                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1315                        Ok(QueryComputer::new(reify, allocator)?)
1316                    }
1317                    SupportedMetric::InnerProduct => {
1318                        let computer: CompensatedIP = self.as_functor();
1319                        let curried = Curried::new(computer, query);
1320                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1321                        Ok(QueryComputer::new(reify, allocator)?)
1322                    }
1323                    SupportedMetric::Cosine => {
1324                        let computer: CompensatedCosine = self.as_functor();
1325                        let curried = Curried::new(computer, query);
1326                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1327                        Ok(QueryComputer::new(reify, allocator)?)
1328                    }
1329                }
1330            }
1331        }
1332    };
1333}
1334
1335dispatch_map!(1, AsFull, Scalar);
1336dispatch_map!(2, AsFull, Scalar);
1337dispatch_map!(4, AsFull, Scalar);
1338dispatch_map!(8, AsFull, Scalar);
1339
1340dispatch_map!(1, AsData<1>, Scalar);
1341dispatch_map!(2, AsData<2>, Scalar);
1342dispatch_map!(4, AsData<4>, Scalar);
1343dispatch_map!(8, AsData<8>, Scalar);
1344
1345// Special Cases
1346dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Scalar);
1347dispatch_map!(2, AsQuery<2>, Scalar);
1348dispatch_map!(4, AsQuery<4>, Scalar);
1349dispatch_map!(8, AsQuery<8>, Scalar);
1350
1351cfg_if::cfg_if! {
1352    if #[cfg(target_arch = "x86_64")] {
1353        fn downcast_to_v3(arch: V4) -> V3 {
1354            arch.into()
1355        }
1356
1357        // V3
1358        dispatch_map!(1, AsFull, V3);
1359        dispatch_map!(2, AsFull, V3);
1360        dispatch_map!(4, AsFull, V3);
1361        dispatch_map!(8, AsFull, V3);
1362
1363        dispatch_map!(1, AsData<1>, V3);
1364        dispatch_map!(2, AsData<2>, V3);
1365        dispatch_map!(4, AsData<4>, V3);
1366        dispatch_map!(8, AsData<8>, V3);
1367
1368        dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V3);
1369        dispatch_map!(2, AsQuery<2>, V3);
1370        dispatch_map!(4, AsQuery<4>, V3);
1371        dispatch_map!(8, AsQuery<8>, V3);
1372
1373        // V4
1374        dispatch_map!(1, AsFull, V4, downcast_to_v3);
1375        dispatch_map!(2, AsFull, V4, downcast_to_v3);
1376        dispatch_map!(4, AsFull, V4, downcast_to_v3);
1377        dispatch_map!(8, AsFull, V4, downcast_to_v3);
1378
1379        dispatch_map!(1, AsData<1>, V4, downcast_to_v3);
1380        dispatch_map!(2, AsData<2>, V4); // specialized
1381        dispatch_map!(4, AsData<4>, V4, downcast_to_v3);
1382        dispatch_map!(8, AsData<8>, V4, downcast_to_v3);
1383
1384        dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3);
1385        dispatch_map!(2, AsQuery<2>, V4); // specialized
1386        dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3);
1387        dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3);
1388    } else if #[cfg(target_arch = "aarch64")] {
1389        fn downcast(arch: Neon) -> Scalar {
1390            arch.retarget()
1391        }
1392
1393        dispatch_map!(1, AsFull, Neon, downcast);
1394        dispatch_map!(2, AsFull, Neon, downcast);
1395        dispatch_map!(4, AsFull, Neon, downcast);
1396        dispatch_map!(8, AsFull, Neon, downcast);
1397
1398        dispatch_map!(1, AsData<1>, Neon, downcast);
1399        dispatch_map!(2, AsData<2>, Neon, downcast);
1400        dispatch_map!(4, AsData<4>, Neon, downcast);
1401        dispatch_map!(8, AsData<8>, Neon, downcast);
1402
1403        dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Neon, downcast);
1404        dispatch_map!(2, AsQuery<2>, Neon, downcast);
1405        dispatch_map!(4, AsQuery<4>, Neon, downcast);
1406        dispatch_map!(8, AsQuery<8>, Neon, downcast);
1407    }
1408}
1409
1410/// This struct and the [`BuildComputer`] trait are the glue for pre-dispatching
1411/// micro-architecture compatibility of distance computers.
1412///
1413/// This trait takes
1414///
1415/// * `Q`: The target query type.
1416/// * `N`: The nubmer of data bits to target.
1417///
1418/// This struct implements [`diskann_wide::arch::Target2`] and
1419/// [`diskann_wide::arch::Target3`] traits to do the architecture-dispatching, relying on
1420/// `Impl<N, A> as BuildQueryComputer` for the implementation.
1421#[derive(Debug, Clone, Copy)]
1422struct ComputerDispatcher<Q, const N: usize> {
1423    _query_type: std::marker::PhantomData<Q>,
1424}
1425
1426impl<Q, const N: usize> ComputerDispatcher<Q, N> {
1427    fn new() -> Self {
1428        Self {
1429            _query_type: std::marker::PhantomData,
1430        }
1431    }
1432}
1433
1434impl<M, const N: usize, A, B, Q>
1435    diskann_wide::arch::Target2<
1436        M,
1437        Result<DistanceComputer<B>, AllocatorError>,
1438        &SphericalQuantizer<A>,
1439        B,
1440    > for ComputerDispatcher<Q, N>
1441where
1442    M: Architecture,
1443    A: Allocator,
1444    B: AllocatorCore,
1445    Q: FromOpaque,
1446    SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1447{
1448    fn run(
1449        self,
1450        arch: M,
1451        quantizer: &SphericalQuantizer<A>,
1452        allocator: B,
1453    ) -> Result<DistanceComputer<B>, AllocatorError> {
1454        quantizer.build_computer(arch, allocator)
1455    }
1456}
1457
1458impl<M, const N: usize, A, R, B, Q>
1459    diskann_wide::arch::Target3<
1460        M,
1461        Result<QueryComputer<B>, AllocatorError>,
1462        &SphericalQuantizer<A>,
1463        R,
1464        B,
1465    > for ComputerDispatcher<Q, N>
1466where
1467    M: Architecture,
1468    A: Allocator,
1469    B: AllocatorCore,
1470    Q: FromOpaque,
1471    R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1472    SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1473{
1474    fn run(
1475        self,
1476        arch: M,
1477        quantizer: &SphericalQuantizer<A>,
1478        query: R,
1479        allocator: B,
1480    ) -> Result<QueryComputer<B>, AllocatorError> {
1481        quantizer.build_fused_computer(arch, query, allocator)
1482    }
1483}
1484
1485#[cfg(target_arch = "x86_64")]
1486trait Dispatchable<Q, const N: usize>:
1487    BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>
1488where
1489    Q: FromOpaque,
1490{
1491}
1492
1493#[cfg(target_arch = "x86_64")]
1494impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1495where
1496    Q: FromOpaque,
1497    T: BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>,
1498{
1499}
1500
1501#[cfg(target_arch = "aarch64")]
1502trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N> + BuildComputer<Neon, Q, N>
1503where
1504    Q: FromOpaque,
1505{
1506}
1507
1508#[cfg(target_arch = "aarch64")]
1509impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1510where
1511    Q: FromOpaque,
1512    T: BuildComputer<Scalar, Q, N> + BuildComputer<Neon, Q, N>,
1513{
1514}
1515
1516#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1517trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N>
1518where
1519    Q: FromOpaque,
1520{
1521}
1522
1523#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
1524impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1525where
1526    Q: FromOpaque,
1527    T: BuildComputer<Scalar, Q, N>,
1528{
1529}
1530
1531//---------------------------//
1532// Quantizer Implementations //
1533//---------------------------//
1534
1535impl<A, B> Quantizer<B> for Impl<1, A>
1536where
1537    A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1538    B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1539{
1540    fn nbits(&self) -> usize {
1541        1
1542    }
1543
1544    fn dim(&self) -> usize {
1545        self.quantizer.output_dim()
1546    }
1547
1548    fn full_dim(&self) -> usize {
1549        self.quantizer.input_dim()
1550    }
1551
1552    fn bytes(&self) -> usize {
1553        DataRef::<1>::canonical_bytes(self.quantizer.output_dim())
1554    }
1555
1556    fn distance_computer(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError> {
1557        self.query_computer::<AsData<1>, _>(allocator)
1558    }
1559
1560    fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1561        &*self.distance
1562    }
1563
1564    fn query_computer(
1565        &self,
1566        layout: QueryLayout,
1567        allocator: B,
1568    ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1569        match layout {
1570            QueryLayout::SameAsData => Ok(self.query_computer::<AsData<1>, _>(allocator)?),
1571            QueryLayout::FourBitTransposed => {
1572                Ok(self.query_computer::<AsQuery<4, bits::BitTranspose>, _>(allocator)?)
1573            }
1574            QueryLayout::ScalarQuantized => {
1575                Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1576            }
1577            QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1578        }
1579    }
1580
1581    fn query_buffer_description(
1582        &self,
1583        layout: QueryLayout,
1584    ) -> Result<QueryBufferDescription, UnsupportedQueryLayout> {
1585        let dim = <Self as Quantizer<B>>::dim(self);
1586        match layout {
1587            QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1588                DataRef::<1>::canonical_bytes(dim),
1589                PowerOfTwo::alignment_of::<u8>(),
1590            )),
1591            QueryLayout::FourBitTransposed => Ok(QueryBufferDescription::new(
1592                QueryRef::<4, bits::BitTranspose>::canonical_bytes(dim),
1593                PowerOfTwo::alignment_of::<u8>(),
1594            )),
1595            QueryLayout::ScalarQuantized => {
1596                Err(UnsupportedQueryLayout::new(layout, "1-bit compression"))
1597            }
1598            QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1599                FullQueryRef::canonical_bytes(dim),
1600                FullQueryRef::canonical_align(),
1601            )),
1602        }
1603    }
1604
1605    fn compress_query(
1606        &self,
1607        x: &[f32],
1608        layout: QueryLayout,
1609        allow_rescale: bool,
1610        mut buffer: OpaqueMut<'_>,
1611        scratch: ScopedAllocator<'_>,
1612    ) -> Result<(), QueryCompressionError> {
1613        let dim = <Self as Quantizer<B>>::dim(self);
1614        let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1615            match layout {
1616                QueryLayout::SameAsData => self.compress_query(
1617                    v,
1618                    DataMut::<1>::from_canonical_back_mut(&mut buffer, dim)
1619                        .map_err(NotCanonical::new)?,
1620                    scratch,
1621                ),
1622                QueryLayout::FourBitTransposed => self.compress_query(
1623                    v,
1624                    QueryMut::<4, bits::BitTranspose>::from_canonical_back_mut(&mut buffer, dim)
1625                        .map_err(NotCanonical::new)?,
1626                    scratch,
1627                ),
1628                QueryLayout::ScalarQuantized => {
1629                    Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1630                }
1631                QueryLayout::FullPrecision => self.compress_query(
1632                    v,
1633                    FullQueryMut::from_canonical_mut(&mut buffer, dim)
1634                        .map_err(NotCanonical::new)?,
1635                    scratch,
1636                ),
1637            }
1638        };
1639
1640        if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1641            let mut copy = x.to_owned();
1642            self.quantizer.rescale(&mut copy);
1643            finish(&copy)
1644        } else {
1645            finish(x)
1646        }
1647    }
1648
1649    fn fused_query_computer(
1650        &self,
1651        x: &[f32],
1652        layout: QueryLayout,
1653        allow_rescale: bool,
1654        allocator: B,
1655        scratch: ScopedAllocator<'_>,
1656    ) -> Result<QueryComputer<B>, QueryComputerError> {
1657        let dim = <Self as Quantizer<B>>::dim(self);
1658        let finish = |v: &[f32], allocator: B| -> Result<QueryComputer<B>, QueryComputerError> {
1659            match layout {
1660                    QueryLayout::SameAsData => self.fused_query_computer::<AsData<1>, Data<1, _>, _>(
1661                        v,
1662                        Data::new_in(dim, allocator.clone())?,
1663                        allocator,
1664                        scratch,
1665                    ),
1666                    QueryLayout::FourBitTransposed => self
1667                        .fused_query_computer::<AsQuery<4, bits::BitTranspose>, Query<4, bits::BitTranspose, _>, _>(
1668                            v,
1669                            Query::new_in(dim, allocator.clone())?,
1670                            allocator,
1671                            scratch,
1672                        ),
1673                    QueryLayout::ScalarQuantized => {
1674                        Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1675                    }
1676                    QueryLayout::FullPrecision => self.fused_query_computer::<AsFull, FullQuery<_>, _>(
1677                        v,
1678                        FullQuery::empty(dim, allocator.clone())?,
1679                        allocator,
1680                        scratch,
1681                    ),
1682                }
1683        };
1684
1685        if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1686            let mut copy = x.to_owned();
1687            self.quantizer.rescale(&mut copy);
1688            finish(&copy, allocator)
1689        } else {
1690            finish(x, allocator)
1691        }
1692    }
1693
1694    fn is_supported(&self, layout: QueryLayout) -> bool {
1695        Self::supports(layout)
1696    }
1697
1698    fn compress(
1699        &self,
1700        x: &[f32],
1701        mut into: OpaqueMut<'_>,
1702        scratch: ScopedAllocator<'_>,
1703    ) -> Result<(), CompressionError> {
1704        let dim = <Self as Quantizer<B>>::dim(self);
1705        let into = DataMut::<1>::from_canonical_back_mut(into.inspect(), dim)
1706            .map_err(CompressionError::not_canonical)?;
1707        self.quantizer
1708            .compress_into_with(x, into, scratch)
1709            .map_err(CompressionError::CompressionError)
1710    }
1711
1712    fn metric(&self) -> SupportedMetric {
1713        self.quantizer.metric()
1714    }
1715
1716    fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1717        let clone = (*self).try_clone()?;
1718        poly!({ Quantizer<B> }, clone, allocator)
1719    }
1720
1721    #[cfg(feature = "flatbuffers")]
1722    fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1723        Impl::<1, A>::serialize(self, allocator)
1724    }
1725}
1726
1727macro_rules! plan {
1728    ($N:literal) => {
1729        impl<A, B> Quantizer<B> for Impl<$N, A>
1730        where
1731            A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1732            B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1733        {
1734            fn nbits(&self) -> usize {
1735                $N
1736            }
1737
1738            fn dim(&self) -> usize {
1739                self.quantizer.output_dim()
1740            }
1741
1742            fn full_dim(&self) -> usize {
1743                self.quantizer.input_dim()
1744            }
1745
1746            fn bytes(&self) -> usize {
1747                DataRef::<$N>::canonical_bytes(<Self as Quantizer<B>>::dim(self))
1748            }
1749
1750            fn distance_computer(
1751                &self,
1752                allocator: B
1753            ) -> Result<DistanceComputer<B>, AllocatorError> {
1754                self.query_computer::<AsData<$N>, _>(allocator)
1755            }
1756
1757            fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1758                &*self.distance
1759            }
1760
1761            fn query_computer(
1762                &self,
1763                layout: QueryLayout,
1764                allocator: B,
1765            ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1766                match layout {
1767                    QueryLayout::SameAsData => Ok(self.query_computer::<AsData<$N>, _>(allocator)?)
1768                    ,
1769                    QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout::new(
1770                        layout,
1771                        concat!($N, "-bit compression"),
1772                    ).into()),
1773                    QueryLayout::ScalarQuantized => {
1774                        Ok(self.query_computer::<AsQuery<$N, bits::Dense>, _>(allocator)?)
1775                    },
1776                    QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1777
1778                }
1779            }
1780
1781            fn query_buffer_description(
1782                &self,
1783                layout: QueryLayout
1784            ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>
1785            {
1786                let dim = <Self as Quantizer<B>>::dim(self);
1787                match layout {
1788                    QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1789                        DataRef::<$N>::canonical_bytes(dim),
1790                        PowerOfTwo::alignment_of::<u8>(),
1791                    )),
1792                    QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout {
1793                        layout,
1794                        desc: concat!($N, "-bit compression"),
1795                    }),
1796                    QueryLayout::ScalarQuantized => Ok(QueryBufferDescription::new(
1797                        QueryRef::<$N, bits::Dense>::canonical_bytes(dim),
1798                        PowerOfTwo::alignment_of::<u8>(),
1799                    )),
1800                    QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1801                        FullQueryRef::canonical_bytes(dim),
1802                        FullQueryRef::canonical_align(),
1803                    )),
1804                }
1805            }
1806
1807            fn compress_query(
1808                &self,
1809                x: &[f32],
1810                layout: QueryLayout,
1811                allow_rescale: bool,
1812                mut buffer: OpaqueMut<'_>,
1813                scratch: ScopedAllocator<'_>,
1814            ) -> Result<(), QueryCompressionError> {
1815                let dim = <Self as Quantizer<B>>::dim(self);
1816                let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1817                    match layout {
1818                        QueryLayout::SameAsData => self.compress_query(
1819                            v,
1820                            DataMut::<$N>::from_canonical_back_mut(
1821                                &mut buffer,
1822                                dim,
1823                            ).map_err(NotCanonical::new)?,
1824                            scratch,
1825                        ),
1826                        QueryLayout::FourBitTransposed => {
1827                            Err(UnsupportedQueryLayout::new(
1828                                layout,
1829                                concat!($N, "-bit compression"),
1830                            ).into())
1831                        },
1832                        QueryLayout::ScalarQuantized => self.compress_query(
1833                            v,
1834                            QueryMut::<$N, bits::Dense>::from_canonical_back_mut(
1835                                &mut buffer,
1836                                dim,
1837                            ).map_err(NotCanonical::new)?,
1838                            scratch,
1839                        ),
1840                        QueryLayout::FullPrecision => self.compress_query(
1841                            v,
1842                            FullQueryMut::from_canonical_mut(
1843                                &mut buffer,
1844                                dim,
1845                            ).map_err(NotCanonical::new)?,
1846                            scratch,
1847                        ),
1848                    }
1849                };
1850
1851                if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1852                    let mut copy = x.to_owned();
1853                    self.quantizer.rescale(&mut copy);
1854                    finish(&copy)
1855                } else {
1856                    finish(x)
1857                }
1858            }
1859
1860            fn fused_query_computer(
1861                &self,
1862                x: &[f32],
1863                layout: QueryLayout,
1864                allow_rescale: bool,
1865                allocator: B,
1866                scratch: ScopedAllocator<'_>,
1867            ) -> Result<QueryComputer<B>, QueryComputerError>
1868            {
1869                let dim = <Self as Quantizer<B>>::dim(self);
1870                let finish = |v: &[f32]| -> Result<QueryComputer<B>, QueryComputerError> {
1871                    match layout {
1872                        QueryLayout::SameAsData => {
1873                            self.fused_query_computer::<AsData<$N>, Data<$N, _>, B>(
1874                                v,
1875                                Data::new_in(dim, allocator.clone())?,
1876                                allocator,
1877                                scratch,
1878                            )
1879                        },
1880                        QueryLayout::FourBitTransposed => {
1881                            Err(UnsupportedQueryLayout::new(
1882                                layout,
1883                                concat!($N, "-bit compression"),
1884                            ).into())
1885                        },
1886                        QueryLayout::ScalarQuantized => {
1887                            self.fused_query_computer::<AsQuery<$N, bits::Dense>, Query<$N, bits::Dense, _>, B>(
1888                                v,
1889                                Query::new_in(dim, allocator.clone())?,
1890                                allocator,
1891                                scratch,
1892                            )
1893                        },
1894                        QueryLayout::FullPrecision => {
1895                            self.fused_query_computer::<AsFull, FullQuery<_>, B>(
1896                                v,
1897                                FullQuery::empty(dim, allocator.clone())?,
1898                                allocator,
1899                                scratch,
1900                            )
1901                        },
1902                    }
1903                };
1904
1905                let metric = <Self as Quantizer<B>>::metric(self);
1906                if allow_rescale && metric == SupportedMetric::InnerProduct {
1907                    let mut copy = x.to_owned();
1908                    self.quantizer.rescale(&mut copy);
1909                    finish(&copy)
1910                } else {
1911                    finish(x)
1912                }
1913            }
1914
1915            fn is_supported(&self, layout: QueryLayout) -> bool {
1916                Self::supports(layout)
1917            }
1918
1919            fn compress(
1920                &self,
1921                x: &[f32],
1922                mut into: OpaqueMut<'_>,
1923                scratch: ScopedAllocator<'_>,
1924            ) -> Result<(), CompressionError> {
1925                let dim = <Self as Quantizer<B>>::dim(self);
1926                let into = DataMut::<$N>::from_canonical_back_mut(into.inspect(), dim)
1927                    .map_err(CompressionError::not_canonical)?;
1928
1929                self.quantizer.compress_into_with(x, into, scratch)
1930                    .map_err(CompressionError::CompressionError)
1931            }
1932
1933            fn metric(&self) -> SupportedMetric {
1934                self.quantizer.metric()
1935            }
1936
1937            fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1938                let clone = (&*self).try_clone()?;
1939                poly!({ Quantizer<B> }, clone, allocator)
1940            }
1941
1942            #[cfg(feature = "flatbuffers")]
1943            fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1944                Impl::<$N, A>::serialize(self, allocator)
1945            }
1946        }
1947    };
1948    ($N:literal, $($Ns:literal),*) => {
1949        plan!($N);
1950        $(plan!($Ns);)*
1951    }
1952}
1953
1954plan!(2, 4, 8);
1955
1956////////////////
1957// Flatbuffer //
1958////////////////
1959
1960#[cfg(feature = "flatbuffers")]
1961#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1962#[derive(Debug, Clone, Error)]
1963#[non_exhaustive]
1964pub enum DeserializationError {
1965    #[error("unhandled file identifier in flatbuffer")]
1966    InvalidIdentifier,
1967
1968    #[error("unsupported number of bits ({0})")]
1969    UnsupportedBitWidth(u32),
1970
1971    #[error(transparent)]
1972    InvalidQuantizer(#[from] super::quantizer::DeserializationError),
1973
1974    #[error(transparent)]
1975    InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
1976
1977    #[error(transparent)]
1978    AllocatorError(#[from] AllocatorError),
1979}
1980
1981/// Attempt to deserialize a `spherical::Quantizer` flatbuffer into one of the concrete
1982/// implementations of `Quantizer`.
1983///
1984/// This function guarantees that the returned `Poly` is the first object allocated through
1985/// `alloc`.
1986#[cfg(feature = "flatbuffers")]
1987#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1988pub fn try_deserialize<O, A>(
1989    data: &[u8],
1990    alloc: A,
1991) -> Result<Poly<dyn Quantizer<O>, A>, DeserializationError>
1992where
1993    O: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1994    A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1995{
1996    // An inner impl is used to ensure that the returned `Poly` is allocated before any of
1997    // the allocations needed by the members.
1998    //
1999    // This ensures that if a bump allocator is used, the root object appears first.
2000    fn unpack_bits<'a, const NBITS: usize, O, A>(
2001        proto: fb::spherical::SphericalQuantizer<'_>,
2002        alloc: A,
2003    ) -> Result<Poly<dyn Quantizer<O> + 'a, A>, DeserializationError>
2004    where
2005        O: Allocator + Send + Sync + std::panic::UnwindSafe + 'static,
2006        A: Allocator + Send + Sync + 'a,
2007        Impl<NBITS, A>: Quantizer<O> + Constructible<A>,
2008    {
2009        let imp = match Poly::new_with(
2010            #[inline(never)]
2011            |alloc| -> Result<_, super::quantizer::DeserializationError> {
2012                let quantizer = SphericalQuantizer::try_unpack(alloc, proto)?;
2013                Ok(Impl::new(quantizer)?)
2014            },
2015            alloc,
2016        ) {
2017            Ok(imp) => imp,
2018            Err(CompoundError::Allocator(err)) => {
2019                return Err(err.into());
2020            }
2021            Err(CompoundError::Constructor(err)) => {
2022                return Err(err.into());
2023            }
2024        };
2025        Ok(poly!({ Quantizer<O> }, imp))
2026    }
2027
2028    // Check that this is one of the known identifiers.
2029    if !fb::spherical::quantizer_buffer_has_identifier(data) {
2030        return Err(DeserializationError::InvalidIdentifier);
2031    }
2032
2033    // Match as much as we can without allocating.
2034    //
2035    // Then, we branch on the number of bits.
2036    let root = fb::spherical::root_as_quantizer(data)?;
2037    let nbits = root.nbits();
2038    let proto = root.quantizer();
2039
2040    match nbits {
2041        1 => unpack_bits::<1, _, _>(proto, alloc),
2042        2 => unpack_bits::<2, _, _>(proto, alloc),
2043        4 => unpack_bits::<4, _, _>(proto, alloc),
2044        8 => unpack_bits::<8, _, _>(proto, alloc),
2045        n => Err(DeserializationError::UnsupportedBitWidth(n)),
2046    }
2047}
2048
2049///////////
2050// Tests //
2051///////////
2052
2053#[cfg(test)]
2054mod tests {
2055    use diskann_utils::views::{Matrix, MatrixView};
2056    use rand::{SeedableRng, rngs::StdRng};
2057
2058    use super::*;
2059    use crate::{
2060        algorithms::{TransformKind, transforms::TargetDim},
2061        alloc::{AlignedAllocator, GlobalAllocator, Poly},
2062        num::PowerOfTwo,
2063        spherical::PreScale,
2064    };
2065
2066    ////////////////////
2067    // Test Quantizer //
2068    ////////////////////
2069
2070    fn test_plan_1_bit(plan: &dyn Quantizer) {
2071        assert_eq!(
2072            plan.nbits(),
2073            1,
2074            "this test only applies to 1-bit quantization"
2075        );
2076
2077        // Check Layouts.
2078        for layout in QueryLayout::all() {
2079            match layout {
2080                QueryLayout::SameAsData
2081                | QueryLayout::FourBitTransposed
2082                | QueryLayout::FullPrecision => assert!(
2083                    plan.is_supported(layout),
2084                    "expected {} to be supported",
2085                    layout
2086                ),
2087                QueryLayout::ScalarQuantized => assert!(
2088                    !plan.is_supported(layout),
2089                    "expected {} to not be supported",
2090                    layout
2091                ),
2092            }
2093        }
2094    }
2095
2096    fn test_plan_n_bit(plan: &dyn Quantizer, nbits: usize) {
2097        assert_ne!(nbits, 1, "there is another test for 1-bit quantizers");
2098        assert_eq!(
2099            plan.nbits(),
2100            nbits,
2101            "this test only applies to 1-bit quantization"
2102        );
2103
2104        // Check Layouts.
2105        for layout in QueryLayout::all() {
2106            match layout {
2107                QueryLayout::SameAsData
2108                | QueryLayout::ScalarQuantized
2109                | QueryLayout::FullPrecision => assert!(
2110                    plan.is_supported(layout),
2111                    "expected {} to be supported",
2112                    layout
2113                ),
2114                QueryLayout::FourBitTransposed => assert!(
2115                    !plan.is_supported(layout),
2116                    "expected {} to not be supported",
2117                    layout
2118                ),
2119            }
2120        }
2121    }
2122
2123    #[inline(never)]
2124    fn test_plan(plan: &dyn Quantizer, nbits: usize, dataset: MatrixView<f32>) {
2125        // Perform the bit-specific test.
2126        if nbits == 1 {
2127            test_plan_1_bit(plan);
2128        } else {
2129            test_plan_n_bit(plan, nbits);
2130        }
2131
2132        // Run bit-width agnostic tests.
2133        assert_eq!(plan.full_dim(), dataset.ncols());
2134
2135        // Use the correct alignment for the base pointers.
2136        let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2137        let mut a = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2138        let mut b = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2139        let scoped_global = ScopedAllocator::global();
2140
2141        plan.compress(dataset.row(0), OpaqueMut::new(&mut a), scoped_global)
2142            .unwrap();
2143        plan.compress(dataset.row(1), OpaqueMut::new(&mut b), scoped_global)
2144            .unwrap();
2145
2146        let f = plan.distance_computer(GlobalAllocator).unwrap();
2147        let _: f32 = f
2148            .evaluate_similarity(Opaque::new(&a), Opaque::new(&b))
2149            .unwrap();
2150
2151        let test_errors = |f: &dyn DynDistanceComputer| {
2152            // `a` too short
2153            let err = f
2154                .evaluate(Opaque::new(&a[..a.len() - 1]), Opaque::new(&b))
2155                .unwrap_err();
2156            assert!(matches!(err, DistanceError::QueryReify(_)));
2157
2158            // `a` too long
2159            let err = f
2160                .evaluate(Opaque::new(&vec![0u8; a.len() + 1]), Opaque::new(&b))
2161                .unwrap_err();
2162            assert!(matches!(err, DistanceError::QueryReify(_)));
2163
2164            // `b` too short
2165            let err = f
2166                .evaluate(Opaque::new(&a), Opaque::new(&b[..b.len() - 1]))
2167                .unwrap_err();
2168            assert!(matches!(err, DistanceError::XReify(_)));
2169
2170            // `a` too long
2171            let err = f
2172                .evaluate(Opaque::new(&a), Opaque::new(&vec![0u8; b.len() + 1]))
2173                .unwrap_err();
2174            assert!(matches!(err, DistanceError::XReify(_)));
2175        };
2176
2177        test_errors(&*f.inner);
2178
2179        let f = plan.distance_computer_ref();
2180        let _: f32 = f.evaluate(Opaque::new(&a), Opaque::new(&b)).unwrap();
2181        test_errors(f);
2182
2183        // Test all supported flavors of `QueryComputer`.
2184        for layout in QueryLayout::all() {
2185            if !plan.is_supported(layout) {
2186                let check_message = |msg: &str| {
2187                    assert!(
2188                        msg.contains(&(layout.to_string())),
2189                        "error message ({}) should contain the layout \"{}\"",
2190                        msg,
2191                        layout
2192                    );
2193                    assert!(
2194                        msg.contains(&format!("{}", nbits)),
2195                        "error message ({}) should contain the number of bits \"{}\"",
2196                        msg,
2197                        nbits
2198                    );
2199                };
2200
2201                // Error for query computer
2202                {
2203                    let err = plan
2204                        .fused_query_computer(
2205                            dataset.row(1),
2206                            layout,
2207                            false,
2208                            GlobalAllocator,
2209                            scoped_global,
2210                        )
2211                        .unwrap_err();
2212
2213                    let msg = err.to_string();
2214                    check_message(&msg);
2215                }
2216
2217                // Query buffer
2218                {
2219                    let err = plan.query_buffer_description(layout).unwrap_err();
2220                    let msg = err.to_string();
2221                    check_message(&msg);
2222                }
2223
2224                // Compresss Query Into
2225                {
2226                    let buffer = &mut [];
2227                    let err = plan
2228                        .compress_query(
2229                            dataset.row(1),
2230                            layout,
2231                            true,
2232                            OpaqueMut::new(buffer),
2233                            scoped_global,
2234                        )
2235                        .unwrap_err();
2236                    let msg = err.to_string();
2237                    check_message(&msg);
2238                }
2239
2240                // Standalone Query Computer
2241                {
2242                    let err = plan.query_computer(layout, GlobalAllocator).unwrap_err();
2243                    let msg = err.to_string();
2244                    check_message(&msg);
2245                }
2246
2247                continue;
2248            }
2249
2250            let g = plan
2251                .fused_query_computer(
2252                    dataset.row(1),
2253                    layout,
2254                    false,
2255                    GlobalAllocator,
2256                    scoped_global,
2257                )
2258                .unwrap();
2259            assert_eq!(
2260                g.layout(),
2261                layout,
2262                "the query computer should faithfully preserve the requested layout"
2263            );
2264
2265            let direct: f32 = g.evaluate_similarity(Opaque(&a)).unwrap();
2266
2267            // Check that the fused computer correctly returns errors for invalid inputs.
2268            {
2269                let err = g
2270                    .evaluate_similarity(Opaque::new(&a[..a.len() - 1]))
2271                    .unwrap_err();
2272                assert!(matches!(err, QueryDistanceError::XReify(_)));
2273
2274                let err = g
2275                    .evaluate_similarity(Opaque::new(&vec![0u8; a.len() + 1]))
2276                    .unwrap_err();
2277                assert!(matches!(err, QueryDistanceError::XReify(_)));
2278            }
2279
2280            let sizes = plan.query_buffer_description(layout).unwrap();
2281            let mut buf =
2282                Poly::broadcast(0u8, sizes.bytes(), AlignedAllocator::new(sizes.align())).unwrap();
2283
2284            plan.compress_query(
2285                dataset.row(1),
2286                layout,
2287                false,
2288                OpaqueMut::new(&mut buf),
2289                scoped_global,
2290            )
2291            .unwrap();
2292
2293            let standalone = plan.query_computer(layout, GlobalAllocator).unwrap();
2294
2295            assert_eq!(
2296                standalone.layout(),
2297                layout,
2298                "the standalone computer did not preserve the requested layout",
2299            );
2300
2301            let indirect: f32 = standalone
2302                .evaluate_similarity(Opaque(&buf), Opaque(&a))
2303                .unwrap();
2304
2305            assert_eq!(
2306                direct, indirect,
2307                "the two different query computation APIs did not return the same result"
2308            );
2309
2310            // Errors
2311            let too_small = &dataset.row(0)[..dataset.ncols() - 1];
2312            assert!(
2313                plan.fused_query_computer(too_small, layout, false, GlobalAllocator, scoped_global)
2314                    .is_err()
2315            );
2316        }
2317
2318        // Errors
2319        {
2320            let mut too_small = vec![u8::default(); plan.bytes() - 1];
2321            assert!(
2322                plan.compress(dataset.row(0), OpaqueMut(&mut too_small), scoped_global)
2323                    .is_err()
2324            );
2325
2326            let mut too_big = vec![u8::default(); plan.bytes() + 1];
2327            assert!(
2328                plan.compress(dataset.row(0), OpaqueMut(&mut too_big), scoped_global)
2329                    .is_err()
2330            );
2331
2332            let mut just_right = vec![u8::default(); plan.bytes()];
2333            assert!(
2334                plan.compress(
2335                    &dataset.row(0)[..dataset.ncols() - 1],
2336                    OpaqueMut(&mut just_right),
2337                    scoped_global
2338                )
2339                .is_err()
2340            );
2341        }
2342    }
2343
2344    fn make_impl<const NBITS: usize>(metric: SupportedMetric) -> (Impl<NBITS>, Matrix<f32>)
2345    where
2346        Impl<NBITS>: Constructible,
2347    {
2348        let data = test_dataset();
2349        let mut rng = StdRng::seed_from_u64(0x7d535118722ff197);
2350
2351        let quantizer = SphericalQuantizer::train(
2352            data.as_view(),
2353            TransformKind::PaddingHadamard {
2354                target_dim: TargetDim::Natural,
2355            },
2356            metric,
2357            PreScale::None,
2358            &mut rng,
2359            GlobalAllocator,
2360        )
2361        .unwrap();
2362
2363        (Impl::<NBITS>::new(quantizer).unwrap(), data)
2364    }
2365
2366    #[test]
2367    fn test_plan_1bit_l2() {
2368        let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2369        test_plan(&plan, 1, data.as_view());
2370    }
2371
2372    #[test]
2373    fn test_plan_1bit_ip() {
2374        let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2375        test_plan(&plan, 1, data.as_view());
2376    }
2377
2378    #[test]
2379    fn test_plan_1bit_cosine() {
2380        let (plan, data) = make_impl::<1>(SupportedMetric::Cosine);
2381        test_plan(&plan, 1, data.as_view());
2382    }
2383
2384    #[test]
2385    fn test_plan_2bit_l2() {
2386        let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2387        test_plan(&plan, 2, data.as_view());
2388    }
2389
2390    #[test]
2391    fn test_plan_2bit_ip() {
2392        let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2393        test_plan(&plan, 2, data.as_view());
2394    }
2395
2396    #[test]
2397    fn test_plan_2bit_cosine() {
2398        let (plan, data) = make_impl::<2>(SupportedMetric::Cosine);
2399        test_plan(&plan, 2, data.as_view());
2400    }
2401
2402    #[test]
2403    fn test_plan_4bit_l2() {
2404        let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2405        test_plan(&plan, 4, data.as_view());
2406    }
2407
2408    #[test]
2409    fn test_plan_4bit_ip() {
2410        let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2411        test_plan(&plan, 4, data.as_view());
2412    }
2413
2414    #[test]
2415    fn test_plan_4bit_cosine() {
2416        let (plan, data) = make_impl::<4>(SupportedMetric::Cosine);
2417        test_plan(&plan, 4, data.as_view());
2418    }
2419
2420    #[test]
2421    fn test_plan_8bit_l2() {
2422        let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2423        test_plan(&plan, 8, data.as_view());
2424    }
2425
2426    #[test]
2427    fn test_plan_8bit_ip() {
2428        let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2429        test_plan(&plan, 8, data.as_view());
2430    }
2431
2432    #[test]
2433    fn test_plan_8bit_cosine() {
2434        let (plan, data) = make_impl::<8>(SupportedMetric::Cosine);
2435        test_plan(&plan, 8, data.as_view());
2436    }
2437
2438    fn test_dataset() -> Matrix<f32> {
2439        let data = vec![
2440            0.28657,
2441            -0.0318168,
2442            0.0666847,
2443            0.0329265,
2444            -0.00829283,
2445            0.168735,
2446            -0.000846311,
2447            -0.360779, // row 0
2448            -0.0968938,
2449            0.161921,
2450            -0.0979579,
2451            0.102228,
2452            -0.259928,
2453            -0.139634,
2454            0.165384,
2455            -0.293443, // row 1
2456            0.130205,
2457            0.265737,
2458            0.401816,
2459            -0.407552,
2460            0.13012,
2461            -0.0475244,
2462            0.511723,
2463            -0.4372, // row 2
2464            -0.0979126,
2465            0.135861,
2466            -0.0154144,
2467            -0.14047,
2468            -0.0250029,
2469            -0.190279,
2470            0.407283,
2471            -0.389184, // row 3
2472            -0.264153,
2473            0.0696822,
2474            -0.145585,
2475            0.370284,
2476            0.186825,
2477            -0.140736,
2478            0.274703,
2479            -0.334563, // row 4
2480            0.247613,
2481            0.513165,
2482            -0.0845867,
2483            0.0532264,
2484            -0.00480601,
2485            -0.122408,
2486            0.47227,
2487            -0.268301, // row 5
2488            0.103198,
2489            0.30756,
2490            -0.316293,
2491            -0.0686877,
2492            -0.330729,
2493            -0.461997,
2494            0.550857,
2495            -0.240851, // row 6
2496            0.128258,
2497            0.786291,
2498            -0.0268103,
2499            0.111763,
2500            -0.308962,
2501            -0.17407,
2502            0.437154,
2503            -0.159879, // row 7
2504            0.00374063,
2505            0.490301,
2506            0.0327826,
2507            -0.0340962,
2508            -0.118605,
2509            0.163879,
2510            0.2737,
2511            -0.299942, // row 8
2512            -0.284077,
2513            0.249377,
2514            -0.0307734,
2515            -0.0661631,
2516            0.233854,
2517            0.427987,
2518            0.614132,
2519            -0.288649, // row 9
2520            -0.109492,
2521            0.203939,
2522            -0.73956,
2523            -0.130748,
2524            0.22072,
2525            0.0647836,
2526            0.328726,
2527            -0.374602, // row 10
2528            -0.223114,
2529            0.0243489,
2530            0.109195,
2531            -0.416914,
2532            0.0201052,
2533            -0.0190542,
2534            0.947078,
2535            -0.333229, // row 11
2536            -0.165869,
2537            -0.00296729,
2538            -0.414378,
2539            0.231321,
2540            0.205365,
2541            0.161761,
2542            0.148608,
2543            -0.395063, // row 12
2544            -0.0498255,
2545            0.193279,
2546            -0.110946,
2547            -0.181174,
2548            -0.274578,
2549            -0.227511,
2550            0.190208,
2551            -0.256174, // row 13
2552            -0.188106,
2553            -0.0292958,
2554            0.0930939,
2555            0.0558456,
2556            0.257437,
2557            0.685481,
2558            0.307922,
2559            -0.320006, // row 14
2560            0.250035,
2561            0.275942,
2562            -0.0856306,
2563            -0.352027,
2564            -0.103509,
2565            -0.00890859,
2566            0.276121,
2567            -0.324718, // row 15
2568        ];
2569
2570        Matrix::try_from(data.into(), 16, 8).unwrap()
2571    }
2572
2573    #[cfg(feature = "flatbuffers")]
2574    mod serialization {
2575        use std::sync::{
2576            Arc,
2577            atomic::{AtomicBool, Ordering},
2578        };
2579
2580        use super::*;
2581        use crate::alloc::{BumpAllocator, GlobalAllocator};
2582
2583        #[inline(never)]
2584        fn test_plan_serialization(
2585            quantizer: &dyn Quantizer,
2586            nbits: usize,
2587            dataset: MatrixView<f32>,
2588        ) {
2589            // Run bit-width agnostic tests.
2590            assert_eq!(quantizer.full_dim(), dataset.ncols());
2591            let scoped_global = ScopedAllocator::global();
2592
2593            let serialized = quantizer.serialize(GlobalAllocator).unwrap();
2594            let deserialized =
2595                try_deserialize::<GlobalAllocator, _>(&serialized, GlobalAllocator).unwrap();
2596
2597            assert_eq!(deserialized.nbits(), nbits);
2598            assert_eq!(deserialized.bytes(), quantizer.bytes());
2599            assert_eq!(deserialized.dim(), quantizer.dim());
2600            assert_eq!(deserialized.full_dim(), quantizer.full_dim());
2601            assert_eq!(deserialized.metric(), quantizer.metric());
2602
2603            for layout in QueryLayout::all() {
2604                assert_eq!(
2605                    deserialized.is_supported(layout),
2606                    quantizer.is_supported(layout)
2607                );
2608            }
2609
2610            // Use the correct alignment for the base pointers.
2611            let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2612            {
2613                let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2614                let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2615
2616                for row in dataset.row_iter() {
2617                    quantizer
2618                        .compress(row, OpaqueMut::new(&mut a), scoped_global)
2619                        .unwrap();
2620                    deserialized
2621                        .compress(row, OpaqueMut::new(&mut b), scoped_global)
2622                        .unwrap();
2623
2624                    // Compressed representation should be identical.
2625                    assert_eq!(a, b);
2626                }
2627            }
2628
2629            // Distance Computer
2630            {
2631                let mut a0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2632                let mut a1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2633                let mut b0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2634                let mut b1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2635
2636                let q_computer = quantizer.distance_computer(GlobalAllocator).unwrap();
2637                let q_computer_ref = quantizer.distance_computer_ref();
2638                let d_computer = deserialized.distance_computer(GlobalAllocator).unwrap();
2639                let d_computer_ref = deserialized.distance_computer_ref();
2640
2641                for r0 in dataset.row_iter() {
2642                    quantizer
2643                        .compress(r0, OpaqueMut::new(&mut a0), scoped_global)
2644                        .unwrap();
2645                    deserialized
2646                        .compress(r0, OpaqueMut::new(&mut b0), scoped_global)
2647                        .unwrap();
2648                    for r1 in dataset.row_iter() {
2649                        quantizer
2650                            .compress(r1, OpaqueMut::new(&mut a1), scoped_global)
2651                            .unwrap();
2652                        deserialized
2653                            .compress(r1, OpaqueMut::new(&mut b1), scoped_global)
2654                            .unwrap();
2655
2656                        let a0 = Opaque::new(&a0);
2657                        let a1 = Opaque::new(&a1);
2658
2659                        let q_computer_dist = q_computer.evaluate_similarity(a0, a1).unwrap();
2660                        let d_computer_dist = d_computer.evaluate_similarity(a0, a1).unwrap();
2661
2662                        assert_eq!(q_computer_dist, d_computer_dist);
2663
2664                        let q_computer_ref_dist = q_computer_ref.evaluate(a0, a1).unwrap();
2665
2666                        assert_eq!(q_computer_dist, q_computer_ref_dist);
2667
2668                        let d_computer_ref_dist = d_computer_ref.evaluate(a0, a1).unwrap();
2669                        assert_eq!(d_computer_dist, d_computer_ref_dist);
2670                    }
2671                }
2672            }
2673
2674            // Query Computer
2675            {
2676                let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2677                let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2678
2679                for layout in QueryLayout::all() {
2680                    if !quantizer.is_supported(layout) {
2681                        continue;
2682                    }
2683
2684                    for r in dataset.row_iter() {
2685                        let q_computer = quantizer
2686                            .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2687                            .unwrap();
2688                        let d_computer = deserialized
2689                            .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2690                            .unwrap();
2691
2692                        for u in dataset.row_iter() {
2693                            quantizer
2694                                .compress(u, OpaqueMut::new(&mut a), scoped_global)
2695                                .unwrap();
2696                            deserialized
2697                                .compress(u, OpaqueMut::new(&mut b), scoped_global)
2698                                .unwrap();
2699
2700                            assert_eq!(
2701                                q_computer.evaluate_similarity(Opaque::new(&a)).unwrap(),
2702                                d_computer.evaluate_similarity(Opaque::new(&b)).unwrap(),
2703                            );
2704                        }
2705                    }
2706                }
2707            }
2708        }
2709
2710        // An allocator that succeeds on its first allocation but fails on its second.
2711        #[derive(Debug, Clone)]
2712        struct FlakyAllocator {
2713            have_allocated: Arc<AtomicBool>,
2714        }
2715
2716        impl FlakyAllocator {
2717            fn new(have_allocated: Arc<AtomicBool>) -> Self {
2718                Self { have_allocated }
2719            }
2720        }
2721
2722        // SAFETY: This is a wrapper around GlobalAllocator that only succeeds once.
2723        unsafe impl AllocatorCore for FlakyAllocator {
2724            fn allocate(
2725                &self,
2726                layout: std::alloc::Layout,
2727            ) -> Result<std::ptr::NonNull<[u8]>, AllocatorError> {
2728                if self.have_allocated.swap(true, Ordering::Relaxed) {
2729                    Err(AllocatorError)
2730                } else {
2731                    GlobalAllocator.allocate(layout)
2732                }
2733            }
2734
2735            unsafe fn deallocate(&self, ptr: std::ptr::NonNull<[u8]>, layout: std::alloc::Layout) {
2736                // SAFETY: Inherited from caller.
2737                unsafe { GlobalAllocator.deallocate(ptr, layout) }
2738            }
2739        }
2740
2741        fn test_plan_panic_boundary<const NBITS: usize>(v: &Impl<NBITS>)
2742        where
2743            Impl<NBITS>: Quantizer,
2744        {
2745            // Ensure that we do not panic if reallocation returns an error.
2746            let have_allocated = Arc::new(AtomicBool::new(false));
2747            let _: AllocatorError = v
2748                .serialize(FlakyAllocator::new(have_allocated.clone()))
2749                .unwrap_err();
2750            assert!(have_allocated.load(Ordering::Relaxed));
2751        }
2752
2753        #[test]
2754        fn test_plan_1bit_l2() {
2755            let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2756            test_plan_panic_boundary(&plan);
2757            test_plan_serialization(&plan, 1, data.as_view());
2758        }
2759
2760        #[test]
2761        fn test_plan_1bit_ip() {
2762            let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2763            test_plan_panic_boundary(&plan);
2764            test_plan_serialization(&plan, 1, data.as_view());
2765        }
2766
2767        #[test]
2768        fn test_plan_2bit_l2() {
2769            let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2770            test_plan_panic_boundary(&plan);
2771            test_plan_serialization(&plan, 2, data.as_view());
2772        }
2773
2774        #[test]
2775        fn test_plan_2bit_ip() {
2776            let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2777            test_plan_panic_boundary(&plan);
2778            test_plan_serialization(&plan, 2, data.as_view());
2779        }
2780
2781        #[test]
2782        fn test_plan_4bit_l2() {
2783            let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2784            test_plan_panic_boundary(&plan);
2785            test_plan_serialization(&plan, 4, data.as_view());
2786        }
2787
2788        #[test]
2789        fn test_plan_4bit_ip() {
2790            let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2791            test_plan_panic_boundary(&plan);
2792            test_plan_serialization(&plan, 4, data.as_view());
2793        }
2794
2795        #[test]
2796        fn test_plan_8bit_l2() {
2797            let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2798            test_plan_panic_boundary(&plan);
2799            test_plan_serialization(&plan, 8, data.as_view());
2800        }
2801
2802        #[test]
2803        fn test_plan_8bit_ip() {
2804            let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2805            test_plan_panic_boundary(&plan);
2806            test_plan_serialization(&plan, 8, data.as_view());
2807        }
2808
2809        #[test]
2810        fn test_allocation_order() {
2811            let (plan, _) = make_impl::<1>(SupportedMetric::SquaredL2);
2812            let buf = plan.serialize(GlobalAllocator).unwrap();
2813
2814            let allocator = BumpAllocator::new(8192, PowerOfTwo::new(64).unwrap()).unwrap();
2815            let deserialized =
2816                try_deserialize::<GlobalAllocator, _>(&buf, allocator.clone()).unwrap();
2817            assert_eq!(
2818                Poly::as_ptr(&deserialized).cast::<u8>(),
2819                allocator.as_ptr(),
2820                "expected the returned box to be allocated first",
2821            );
2822        }
2823    }
2824}