Skip to main content

diskann_quantization/spherical/
iface.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! The main export of this module is the dyn compatible [`Quantizer`] trait, which provides
7//! a common interface for interacting with bit-width specific [`SphericalQuantizer`]s,
8//! compressing vectors, and computing distances between compressed vectors.
9//!
10//! This is offered as a convenience interface for interacting with the myriad of generics
11//! associated with the [`SphericalQuantizer`]. Better performance can be achieved by using
12//! the generic types directly if desired.
13//!
14//! The [`Quantizer`] uses the [`Opaque`] and [`OpaqueMut`] types for its compressed data
15//! representations. These are thin wrappers around raw byte slices.
16//!
17//! Distance computation is performed using [`DistanceComputer`] and [`QueryComputer`]
18//!
19//! Concrete implementations of [`Quantizer`] are available via the generic struct [`Impl`].
20//!
21//! ## Compatibility Table
22//!
23//! Multiple [`QueryLayout`]s are supported when constructing a [`QueryComputer`], but not
24//! all layouts are supported for each back end. This table lists the valid instantiations
25//! of [`Impl`] (parameterized by data vector bit-width) and their supported query layouts.
26//!
27//! | Bits | Same As Data | Full Precision | Four-Bit Transposed | Scalar Quantized |
28//! |------|--------------|----------------|---------------------|------------------|
29//! |    1 |     Yes      |      Yes       |         Yes         |       No         |
30//! |    2 |     Yes      |      Yes       |          No         |       Yes        |
31//! |    4 |     Yes      |      Yes       |          No         |       Yes        |
32//! |    8 |     Yes      |      Yes       |          No         |       Yes        |
33//!
34//! # Example
35//!
36//! ```
37//! use diskann_quantization::{
38//!     alloc::{Poly, ScopedAllocator, AlignedAllocator, GlobalAllocator},
39//!     algorithms::TransformKind,
40//!     spherical::{iface, SupportedMetric, SphericalQuantizer, PreScale},
41//!     num::PowerOfTwo,
42//! };
43//! use diskann_utils::views::Matrix;
44//!
45//! // For illustration purposes, the dataset consists of just a single vector.
46//! let mut data = Matrix::new(1.0, 1, 4);
47//! let quantizer = SphericalQuantizer::train(
48//!     data.as_view(),
49//!     TransformKind::Null,
50//!     SupportedMetric::SquaredL2,
51//!     PreScale::None,
52//!     &mut rand::rng(),
53//!     GlobalAllocator
54//! ).unwrap();
55//!
56//! let quantizer: Box<dyn iface::Quantizer> = Box::new(
57//!     iface::Impl::<1>::new(quantizer).unwrap()
58//! );
59//!
60//! let alloc = AlignedAllocator::new(PowerOfTwo::new(1).unwrap());
61//! let mut buf = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
62//!
63//! quantizer.compress(
64//!     data.row(0),
65//!     iface::OpaqueMut::new(&mut buf),
66//!     ScopedAllocator::new(&alloc),
67//! ).unwrap();
68//!
69//! assert!(quantizer.is_supported(iface::QueryLayout::FullPrecision));
70//! assert!(!quantizer.is_supported(iface::QueryLayout::ScalarQuantized));
71//! ```
72
73/// # DevDocs
74///
75/// This section provides developer documentation for the structure of the structs and
76/// traits inside this file. The main goal of the [`DistanceComputer`] and [`QueryComputer`]
77/// implementations are to introduce just a single level of indirection.
78///
79/// ## Distance Computer Philosophy
80///
81/// The goal of this code (and in large part the reason for the somewhat spaghetti nature)
82/// is to do all the the dispatches:
83///
84/// * Number of bits.
85/// * Query layout.
86/// * Distance Type (L2, Inner Product, Cosine).
87/// * Micro-architecture specific code-generation
88///
89/// Behind a **single** level of dynamic dispatch. This means we need to bake all of this
90/// information into a single type, facilitated through a combination of the `Reify` and
91/// `Curried` private structs.
92///
93/// ## Anatomy of the [`DistanceComputer`]
94///
95/// To provide a single level of indirection for all distance function implementations,
96/// the [`DistanceComputer`] is a thin wrapper around the [`DynDistanceComputer`] trait.
97///
98/// Concrete implementations of this trait consist of a base distance function like
99/// [`CompensatedIP`] or [`CompensatedSquaredL2`]. Because the data is passed through the
100/// [`Opaque`] type, these distance functions are embedded inside a [`Reify`] which first
101/// converts the [`Opaque`] to the appropriate fully-typed object before calling the inner
102/// distance function.
103///
104/// This full typing is supplied by the private [`FromOpaque`] helper trait.
105///
106/// When returned from the [`Quantizer::distance_computer`] or
107/// [`Quantizer::distance_computer_ref`] traits, the resulting computer will be specialized
108/// to work solely on data vectors compressed through [`Quantizer::compress`].
109///
110/// When returned from [`Quantizer::query_computer`], the expected type of the query will
111/// depend on the [`QueryLayout`] supplied to that method.
112///
113/// The method [`QueryComputer::layout`] is provided to inspect at run time the query layout
114/// the object is meant for.
115///
116/// If at all possible, the [`QueryComputer`] should be preferred as it removes the
117/// possibility of providing an incorrect query layout and will be slightly faster since
118/// it does not require argument reification.
119///
120/// ## Anatomy of the [`QueryComputer`]
121///
122/// This is similar to the [`DistanceComputer`] but has the extra duty of supporting
123/// multiple different [`iface::QueryLayouts`] (compressions for the query). To that end,
124/// the stack of types used to implement the underlying [`DynDistanceComputer`] trait is:
125///
126/// * Base [`DistanceFunction`] (e.g. [`CompensatedIP`]).
127///
128/// * Embedded inside [`Curried`] - which also contains a heap-allocated representation of
129///   the query using the selected layout. For example, this could be one of.
130///
131///   - [`diskann_quantization::spherical::Query`]
132///   - [`diskann_quantization::spherical::FullQuery`]
133///   - [`diskann_quantization::sphericasl::Data`]
134///
135/// * Embedded inside [`Reify`] to convert [`Opaque`] to the correct type.
136use std::marker::PhantomData;
137
138use diskann_utils::{Reborrow, ReborrowMut};
139use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
140use diskann_wide::{
141    Architecture,
142    arch::{Scalar, Target1, Target2},
143};
144#[cfg(feature = "flatbuffers")]
145use flatbuffers::FlatBufferBuilder;
146use thiserror::Error;
147
148#[cfg(target_arch = "x86_64")]
149use diskann_wide::arch::x86_64::{V3, V4};
150
151use super::{
152    CompensatedCosine, CompensatedIP, CompensatedSquaredL2, Data, DataMut, DataRef, FullQuery,
153    FullQueryMut, FullQueryRef, Query, QueryMut, QueryRef, SphericalQuantizer, SupportedMetric,
154    quantizer,
155};
156use crate::{
157    AsFunctor, CompressIntoWith,
158    alloc::{
159        Allocator, AllocatorCore, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone,
160    },
161    bits::{self, Representation, Unsigned},
162    distances::{self, UnequalLengths},
163    error::InlineError,
164    meta,
165    num::PowerOfTwo,
166    poly,
167};
168#[cfg(feature = "flatbuffers")]
169use crate::{alloc::CompoundError, flatbuffers as fb};
170
171// A convenience definition to shorten the extensive where-clauses present in this file.
172type Rf32 = distances::Result<f32>;
173
174///////////////
175// Quantizer //
176///////////////
177
178/// A description of the buffer size (in bytes) and alignment required for a compressed query.
179#[derive(Debug, Clone)]
180pub struct QueryBufferDescription {
181    size: usize,
182    align: PowerOfTwo,
183}
184
185impl QueryBufferDescription {
186    /// Construct a new [`QueryBufferDescription`]
187    pub fn new(size: usize, align: PowerOfTwo) -> Self {
188        Self { size, align }
189    }
190
191    /// Return the number of bytes needed in a buffer for a compressed query.
192    pub fn bytes(&self) -> usize {
193        self.size
194    }
195
196    /// Return the necessary alignment of the base pointer for a query buffer.
197    pub fn align(&self) -> PowerOfTwo {
198        self.align
199    }
200}
201
202/// A dyn-compatible trait providing a common interface for a bit-width specific
203/// [`SphericalQuantizer`].
204///
205/// This allows us to have a single [`dyn Quantizer`] type without generics while still
206/// supporting the range of bit-widths and query strategies we wish to support.
207///
208/// A level of indirection for each distance computation, unfortunately, is required to
209/// support this. But we try to structure the code so there is only a single level of
210/// indirection.
211///
212/// # Allocator
213///
214/// The quantizer is parameterized by the allocator provided used to acquire any necessary
215/// memory for returned data structures.  The contract is as follows:
216///
217/// 1. Any allocation made as part of a returned data structure from a function will be
218///    performed through the allocator given to that function.
219///
220/// 2. If dynamic memory allocation for scratch space is required, a separate `scratch`
221///    allocator will be required and all scratch space allocations will go through that
222///    allocator.
223pub trait Quantizer<A = GlobalAllocator>: Send + Sync
224where
225    A: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
226{
227    /// The effective number of bits in the encoding.
228    fn nbits(&self) -> usize;
229
230    /// The number of bytes occupied by each compressed vector.
231    fn bytes(&self) -> usize;
232
233    /// The effective dimensionality of each compressed vector.
234    fn dim(&self) -> usize;
235
236    /// The dimensionality of the full-precision input vectors.
237    fn full_dim(&self) -> usize;
238
239    /// Return a distance computer capable on operating on validly initialized [`Opaque`]
240    /// slices of length [`Self::bytes`].
241    ///
242    /// These slices should be initialized by [`Self::compress`].
243    ///
244    /// The query layout associated with this computer will always be
245    /// [`QueryLayout::SameAsData`].
246    fn distance_computer(&self, allocator: A) -> Result<DistanceComputer<A>, AllocatorError>;
247
248    /// Return a scoped distance computer capable on operating on validly initialized
249    /// [`Opaque`] slices of length [`Self::bytes`].
250    ///
251    /// These slices should be initialized by [`Self::compress`].
252    fn distance_computer_ref(&self) -> &dyn DynDistanceComputer;
253
254    /// A stand alone distance computer specialized for the specified query layout.
255    ///
256    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
257    ///
258    /// # Note
259    ///
260    /// The returned object will **only** be compatible with queries compressed using
261    /// [`Self::compress_query`] using the same layout. If possible, the API
262    /// [`Self::fused_query_computer`] should be used to avoid this ambiguity.
263    fn query_computer(
264        &self,
265        layout: QueryLayout,
266        allocator: A,
267    ) -> Result<DistanceComputer<A>, DistanceComputerError>;
268
269    /// Return the number of bytes and alignment of a buffer used to contain a compressed
270    /// query with the provided layout.
271    ///
272    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
273    fn query_buffer_description(
274        &self,
275        layout: QueryLayout,
276    ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>;
277
278    /// Compress the query using the specified layout into `buffer`.
279    ///
280    /// This requires that buffer have the exact size and alignment as that returned from
281    /// `query_buffer_description`.
282    ///
283    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
284    fn compress_query(
285        &self,
286        x: &[f32],
287        layout: QueryLayout,
288        allow_rescale: bool,
289        buffer: OpaqueMut<'_>,
290        scratch: ScopedAllocator<'_>,
291    ) -> Result<(), QueryCompressionError>;
292
293    /// Return a query for the argument `x` capable on operating on validly initialized
294    /// [`Opaque`] slices of length [`Self::bytes`].
295    ///
296    /// These slices should be initialized by [`Self::compress`].
297    ///
298    /// Note: Only layouts for which [`Self::is_supported`] returns `true` are supported.
299    fn fused_query_computer(
300        &self,
301        x: &[f32],
302        layout: QueryLayout,
303        allow_rescale: bool,
304        allocator: A,
305        scratch: ScopedAllocator<'_>,
306    ) -> Result<QueryComputer<A>, QueryComputerError>;
307
308    /// Return whether or not this plan supports the given [`QueryLayout`].
309    fn is_supported(&self, layout: QueryLayout) -> bool;
310
311    /// Compress the vector `x` into the opaque slice.
312    ///
313    /// # Note
314    ///
315    /// This requires the length of the slice to be exactly [`Self::bytes`]. There is no
316    /// alignment restriction on the base pointer.
317    fn compress(
318        &self,
319        x: &[f32],
320        into: OpaqueMut<'_>,
321        scratch: ScopedAllocator<'_>,
322    ) -> Result<(), CompressionError>;
323
324    /// Return the metric this plan was created with.
325    fn metric(&self) -> SupportedMetric;
326
327    /// Clone the backing object.
328    fn try_clone_into(&self, allocator: A) -> Result<Poly<dyn Quantizer<A>, A>, AllocatorError>;
329
330    crate::utils::features! {
331        #![feature = "flatbuffers"]
332        /// Serialize `self` into a flatbuffer, returning the flatbuffer. The function
333        /// [`try_deserialize`] should undo this operation.
334        fn serialize(&self, allocator: A) -> Result<Poly<[u8], A>, AllocatorError>;
335    }
336}
337
338#[derive(Debug, Error)]
339#[error("Layout {layout} is not supported for {desc}")]
340pub struct UnsupportedQueryLayout {
341    layout: QueryLayout,
342    desc: &'static str,
343}
344
345impl UnsupportedQueryLayout {
346    fn new(layout: QueryLayout, desc: &'static str) -> Self {
347        Self { layout, desc }
348    }
349}
350
351#[derive(Debug, Error)]
352#[non_exhaustive]
353pub enum DistanceComputerError {
354    #[error(transparent)]
355    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
356    #[error(transparent)]
357    AllocatorError(#[from] AllocatorError),
358}
359
360#[derive(Debug, Error)]
361#[non_exhaustive]
362pub enum QueryCompressionError {
363    #[error(transparent)]
364    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
365    #[error(transparent)]
366    CompressionError(#[from] CompressionError),
367    #[error(transparent)]
368    NotCanonical(#[from] NotCanonical),
369    #[error(transparent)]
370    AllocatorError(#[from] AllocatorError),
371}
372
373#[derive(Debug, Error)]
374#[non_exhaustive]
375pub enum QueryComputerError {
376    #[error(transparent)]
377    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
378    #[error(transparent)]
379    CompressionError(#[from] CompressionError),
380    #[error(transparent)]
381    AllocatorError(#[from] AllocatorError),
382}
383
384/// Errors that can occur during data compression
385#[derive(Debug, Error)]
386#[error("Error occured during query compression")]
387pub enum CompressionError {
388    /// The input buffer did not have the expected layout. This is an input error.
389    NotCanonical(#[source] InlineError<16>),
390
391    /// Forward any error that occurs during the compression process.
392    ///
393    /// See [`quantizer::CompressionError`] for the complete list.
394    CompressionError(#[source] quantizer::CompressionError),
395}
396
397impl CompressionError {
398    fn not_canonical<E>(error: E) -> Self
399    where
400        E: std::error::Error + Send + Sync + 'static,
401    {
402        Self::NotCanonical(InlineError::new(error))
403    }
404}
405
406#[derive(Debug, Error)]
407#[error("An opaque argument did not have the required alignment or length")]
408pub struct NotCanonical {
409    source: Box<dyn std::error::Error + Send + Sync>,
410}
411
412impl NotCanonical {
413    fn new<E>(err: E) -> Self
414    where
415        E: std::error::Error + Send + Sync + 'static,
416    {
417        Self {
418            source: Box::new(err),
419        }
420    }
421}
422
423////////////
424// Opaque //
425////////////
426
427/// A type-erased slice wrapper used to hide the implementation of spherically quantized
428/// vectors. This allows multiple bit-width implementations to share the same type.
429#[derive(Debug, Clone, Copy)]
430#[repr(transparent)]
431pub struct Opaque<'a>(&'a [u8]);
432
433impl<'a> Opaque<'a> {
434    /// Construct a new `Opaque` referencing `slice`.
435    pub fn new(slice: &'a [u8]) -> Self {
436        Self(slice)
437    }
438
439    /// Consume `self`, returning the wrapped slice.
440    pub fn into_inner(self) -> &'a [u8] {
441        self.0
442    }
443}
444
445impl std::ops::Deref for Opaque<'_> {
446    type Target = [u8];
447    fn deref(&self) -> &[u8] {
448        self.0
449    }
450}
451impl<'short> Reborrow<'short> for Opaque<'_> {
452    type Target = Opaque<'short>;
453    fn reborrow(&'short self) -> Self::Target {
454        *self
455    }
456}
457
458/// A type-erased slice wrapper used to hide the implementation of spherically quantized
459/// vectors. This allows multiple bit-width implementations to share the same type.
460#[derive(Debug)]
461#[repr(transparent)]
462pub struct OpaqueMut<'a>(&'a mut [u8]);
463
464impl<'a> OpaqueMut<'a> {
465    /// Construct a new `OpaqueMut` referencing `slice`.
466    pub fn new(slice: &'a mut [u8]) -> Self {
467        Self(slice)
468    }
469
470    /// Inspect the referenced slice.
471    pub fn inspect(&mut self) -> &mut [u8] {
472        self.0
473    }
474}
475
476impl std::ops::Deref for OpaqueMut<'_> {
477    type Target = [u8];
478    fn deref(&self) -> &[u8] {
479        self.0
480    }
481}
482
483impl std::ops::DerefMut for OpaqueMut<'_> {
484    fn deref_mut(&mut self) -> &mut [u8] {
485        self.0
486    }
487}
488
489//////////////////
490// Query Layout //
491//////////////////
492
493/// The layout to use for the query in [`DistanceComputer`] and [`QueryComputer`].
494#[derive(Debug, Clone, Copy, PartialEq, Eq)]
495pub enum QueryLayout {
496    /// Use the same compression strategy as the data vectors.
497    ///
498    /// This may result in slow compression if high bit-widths are used.
499    SameAsData,
500
501    /// Use 4-bits for the query vector using a bitwise transpose layout.
502    FourBitTransposed,
503
504    /// Use scalar quantization for the query using the same number of bits per dimension
505    /// as the dataset.
506    ScalarQuantized,
507
508    /// Use `f32` to encode the query.
509    FullPrecision,
510}
511
512impl QueryLayout {
513    #[cfg(test)]
514    fn all() -> [Self; 4] {
515        [
516            Self::SameAsData,
517            Self::FourBitTransposed,
518            Self::ScalarQuantized,
519            Self::FullPrecision,
520        ]
521    }
522}
523
524impl std::fmt::Display for QueryLayout {
525    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        <Self as std::fmt::Debug>::fmt(self, fmt)
527    }
528}
529
530//////////////////////
531// Layout Reporting //
532//////////////////////
533
534/// Because dynamic dispatch is used heavily in the implementation, it can be easy to lose
535/// track of the actual layout used for the [`DistanceComputer`] and [`QueryComputer`].
536///
537/// This trait provides a mechanism by which we ensure the correct runtime layout is always
538/// reported without requiring manual tracking.
539trait ReportQueryLayout {
540    fn report_query_layout(&self) -> QueryLayout;
541}
542
543impl<T, M, L, R> ReportQueryLayout for Reify<T, M, L, R>
544where
545    T: ReportQueryLayout,
546{
547    fn report_query_layout(&self) -> QueryLayout {
548        self.inner.report_query_layout()
549    }
550}
551
552impl<D, Q> ReportQueryLayout for Curried<D, Q>
553where
554    Q: ReportQueryLayout,
555{
556    fn report_query_layout(&self) -> QueryLayout {
557        self.query.report_query_layout()
558    }
559}
560
561impl<const NBITS: usize, A> ReportQueryLayout for Data<NBITS, A>
562where
563    Unsigned: Representation<NBITS>,
564    A: AllocatorCore,
565{
566    fn report_query_layout(&self) -> QueryLayout {
567        QueryLayout::SameAsData
568    }
569}
570
571impl<const NBITS: usize, A> ReportQueryLayout for Query<NBITS, bits::Dense, A>
572where
573    Unsigned: Representation<NBITS>,
574    A: AllocatorCore,
575{
576    fn report_query_layout(&self) -> QueryLayout {
577        QueryLayout::ScalarQuantized
578    }
579}
580
581impl<A> ReportQueryLayout for Query<4, bits::BitTranspose, A>
582where
583    A: AllocatorCore,
584{
585    fn report_query_layout(&self) -> QueryLayout {
586        QueryLayout::FourBitTransposed
587    }
588}
589
590impl<A> ReportQueryLayout for FullQuery<A>
591where
592    A: AllocatorCore,
593{
594    fn report_query_layout(&self) -> QueryLayout {
595        QueryLayout::FullPrecision
596    }
597}
598
599//-----------------------//
600// Reification Utilities //
601//-----------------------//
602
603/// An adaptor trait defining how to go from an `Opaque` slice to a fully reified type.
604///
605/// THis is the building block for building distance computers with the reificiation code
606/// inlined into the callsite.
607trait FromOpaque: 'static + Send + Sync {
608    type Target<'a>;
609    type Error: std::error::Error + Send + Sync + 'static;
610
611    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error>;
612}
613
614/// Reify as full-precision.
615#[derive(Debug, Default)]
616pub(super) struct AsFull;
617
618/// Reify as data.
619#[derive(Debug, Default)]
620pub(super) struct AsData<const NBITS: usize>;
621
622/// Reify as scalar quantized query.
623#[derive(Debug)]
624pub(super) struct AsQuery<const NBITS: usize, Perm = bits::Dense> {
625    _marker: PhantomData<Perm>,
626}
627
628// This impelmentation works around the `derive` impl requiring `Perm: Default`.
629impl<const NBITS: usize, Perm> Default for AsQuery<NBITS, Perm> {
630    fn default() -> Self {
631        Self {
632            _marker: PhantomData,
633        }
634    }
635}
636
637impl FromOpaque for AsFull {
638    type Target<'a> = FullQueryRef<'a>;
639    type Error = meta::slice::NotCanonical;
640
641    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
642        Self::Target::from_canonical(query.into_inner(), dim)
643    }
644}
645
646impl ReportQueryLayout for AsFull {
647    fn report_query_layout(&self) -> QueryLayout {
648        QueryLayout::FullPrecision
649    }
650}
651
652impl<const NBITS: usize> FromOpaque for AsData<NBITS>
653where
654    Unsigned: Representation<NBITS>,
655{
656    type Target<'a> = DataRef<'a, NBITS>;
657    type Error = meta::NotCanonical;
658
659    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
660        Self::Target::from_canonical_back(query.into_inner(), dim)
661    }
662}
663
664impl<const NBITS: usize> ReportQueryLayout for AsData<NBITS> {
665    fn report_query_layout(&self) -> QueryLayout {
666        QueryLayout::SameAsData
667    }
668}
669
670impl<const NBITS: usize, Perm> FromOpaque for AsQuery<NBITS, Perm>
671where
672    Unsigned: Representation<NBITS>,
673    Perm: bits::PermutationStrategy<NBITS> + Send + Sync + 'static,
674{
675    type Target<'a> = QueryRef<'a, NBITS, Perm>;
676    type Error = meta::NotCanonical;
677
678    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
679        Self::Target::from_canonical_back(query.into_inner(), dim)
680    }
681}
682
683impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::Dense> {
684    fn report_query_layout(&self) -> QueryLayout {
685        QueryLayout::ScalarQuantized
686    }
687}
688
689impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::BitTranspose> {
690    fn report_query_layout(&self) -> QueryLayout {
691        QueryLayout::FourBitTransposed
692    }
693}
694
695//-------//
696// Reify //
697//-------//
698
699/// Helper struct to convert an [`Opaque`] to a fully-typed [`DataRef`].
700pub(super) struct Reify<T, M, L, R> {
701    inner: T,
702    dim: usize,
703    arch: M,
704    _markers: PhantomData<(L, R)>,
705}
706
707impl<T, M, L, R> Reify<T, M, L, R> {
708    pub(super) fn new(inner: T, dim: usize, arch: M) -> Self {
709        Self {
710            inner,
711            dim,
712            arch,
713            _markers: PhantomData,
714        }
715    }
716}
717
718impl<M, T, R> DynQueryComputer for Reify<T, M, (), R>
719where
720    M: Architecture,
721    R: FromOpaque,
722    T: ReportQueryLayout + Send + Sync,
723    for<'a> &'a T: Target1<M, Rf32, R::Target<'a>>,
724{
725    fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
726        self.arch.run2(
727            |this: &Self, x| {
728                let x = R::from_opaque(x, this.dim)
729                    .map_err(|err| QueryDistanceError::XReify(InlineError::new(err)))?;
730                this.arch
731                    .run1(&this.inner, x)
732                    .map_err(QueryDistanceError::UnequalLengths)
733            },
734            self,
735            x,
736        )
737    }
738
739    fn layout(&self) -> QueryLayout {
740        self.inner.report_query_layout()
741    }
742}
743
744impl<T, M, Q, R> DynDistanceComputer for Reify<T, M, Q, R>
745where
746    M: Architecture,
747    Q: FromOpaque + Default + ReportQueryLayout,
748    R: FromOpaque,
749    T: for<'a> Target2<M, Rf32, Q::Target<'a>, R::Target<'a>> + Copy + Send + Sync,
750{
751    fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
752        self.arch.run3(
753            |this: &Self, query, x| {
754                let query = Q::from_opaque(query, this.dim)
755                    .map_err(|err| DistanceError::QueryReify(InlineError::<24>::new(err)))?;
756
757                let x = R::from_opaque(x, this.dim)
758                    .map_err(|err| DistanceError::XReify(InlineError::<16>::new(err)))?;
759
760                this.arch
761                    .run2_inline(this.inner, query, x)
762                    .map_err(DistanceError::UnequalLengths)
763            },
764            self,
765            query,
766            x,
767        )
768    }
769
770    fn layout(&self) -> QueryLayout {
771        Q::default().report_query_layout()
772    }
773}
774
775///////////////////////
776// Query Computation //
777///////////////////////
778
779/// Errors that can occur while perfoming distance cacluations on opaque vectors.
780#[derive(Debug, Error)]
781pub enum QueryDistanceError {
782    /// The right-hand data argument appears to be malformed.
783    #[error("trouble trying to reify the argument")]
784    XReify(#[source] InlineError<16>),
785
786    /// Distance computation failed because the logical lengths of the two vectors differ.
787    #[error("encountered while trying to compute distances")]
788    UnequalLengths(#[source] UnequalLengths),
789}
790
791pub trait DynQueryComputer: Send + Sync {
792    fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError>;
793    fn layout(&self) -> QueryLayout;
794}
795
796/// An opaque [`PreprocessedDistanceFunction`] for the [`Quantizer`] trait object.
797///
798/// # Note
799///
800/// This is only valid to call on [`Opaque`] slices compressed by the same [`Quantizer`] that
801/// created the computer.
802///
803/// Otherwise, distance computations may return garbage values or panic.
804pub struct QueryComputer<A = GlobalAllocator>
805where
806    A: AllocatorCore,
807{
808    inner: Poly<dyn DynQueryComputer, A>,
809}
810
811impl<A> QueryComputer<A>
812where
813    A: AllocatorCore,
814{
815    fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
816    where
817        T: DynQueryComputer + 'static,
818    {
819        let inner = Poly::new(inner, allocator)?;
820        Ok(Self {
821            inner: poly!(DynQueryComputer, inner),
822        })
823    }
824
825    /// Report the layout used by the query computer.
826    pub fn layout(&self) -> QueryLayout {
827        self.inner.layout()
828    }
829
830    /// This is a temporary function until custom allocator support fully comes on line.
831    pub fn into_inner(self) -> Poly<dyn DynQueryComputer, A> {
832        self.inner
833    }
834}
835
836impl<A> std::fmt::Debug for QueryComputer<A>
837where
838    A: AllocatorCore,
839{
840    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
841        write!(
842            f,
843            "dynamic fused query computer with layout \"{}\"",
844            self.layout()
845        )
846    }
847}
848
849impl<A> PreprocessedDistanceFunction<Opaque<'_>, Result<f32, QueryDistanceError>>
850    for QueryComputer<A>
851where
852    A: AllocatorCore,
853{
854    fn evaluate_similarity(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
855        self.inner.evaluate(x)
856    }
857}
858
859/// To handle multiple query bit-widths, we use type erasure on the actual distance
860/// function implementation.
861///
862/// This struct represents the partial application of the `inner` distance function with
863/// `query` in a generic way so we only have one level of dynamic dispatch when computing
864/// distances.
865pub(super) struct Curried<D, Q> {
866    inner: D,
867    query: Q,
868}
869
870impl<D, Q> Curried<D, Q> {
871    pub(super) fn new(inner: D, query: Q) -> Self {
872        Self { inner, query }
873    }
874}
875
876impl<A, D, Q, T, R> Target1<A, R, T> for &Curried<D, Q>
877where
878    A: Architecture,
879    Q: for<'a> Reborrow<'a>,
880    D: for<'a> Target2<A, R, <Q as Reborrow<'a>>::Target, T> + Copy,
881{
882    fn run(self, arch: A, x: T) -> R {
883        self.inner.run(arch, self.query.reborrow(), x)
884    }
885}
886
887///////////////////////
888// Distance Computer //
889///////////////////////
890
891/// Errors that can occur while perfoming distance cacluations on opaque vectors.
892#[derive(Debug, Error)]
893pub enum DistanceError {
894    /// The left-hand data argument appears to be malformed.
895    #[error("trouble trying to reify the left-hand argument")]
896    QueryReify(InlineError<24>),
897
898    /// The right-hand data argument appears to be malformed.
899    #[error("trouble trying to reify the right-hand argument")]
900    XReify(InlineError<16>),
901
902    /// Distance computation failed because the logical lengths of the two vectors differ.
903    ///
904    /// If vector reificiation occurs successfully, then this should not be returned.
905    #[error("encountered while trying to compute distances")]
906    UnequalLengths(UnequalLengths),
907}
908
909pub trait DynDistanceComputer: Send + Sync {
910    fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError>;
911    fn layout(&self) -> QueryLayout;
912}
913
914/// An opaque [`DistanceFunction`] for the [`Quantizer`] trait object.
915///
916/// # Note
917///
918/// Left-hand arguments must be [`Opaque`] slices compressed using
919/// [`Quantizer::compress_query`] using [`Self::layout`].
920///
921/// Right-hand arguments must be [`Opaque`] slices compressed using [`Quantizer::compress`].
922///
923/// Otherwise, distance computations may return garbage values or panic.
924pub struct DistanceComputer<A = GlobalAllocator>
925where
926    A: AllocatorCore,
927{
928    inner: Poly<dyn DynDistanceComputer, A>,
929}
930
931impl<A> DistanceComputer<A>
932where
933    A: AllocatorCore,
934{
935    pub(super) fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
936    where
937        T: DynDistanceComputer + 'static,
938    {
939        let inner = Poly::new(inner, allocator)?;
940        Ok(Self {
941            inner: poly!(DynDistanceComputer, inner),
942        })
943    }
944
945    /// Report the layout used by the query computer.
946    pub fn layout(&self) -> QueryLayout {
947        self.inner.layout()
948    }
949
950    pub fn into_inner(self) -> Poly<dyn DynDistanceComputer, A> {
951        self.inner
952    }
953}
954
955impl<A> std::fmt::Debug for DistanceComputer<A>
956where
957    A: AllocatorCore,
958{
959    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
960        write!(
961            f,
962            "dynamic distance computer with layout \"{}\"",
963            self.layout()
964        )
965    }
966}
967
968impl<A> DistanceFunction<Opaque<'_>, Opaque<'_>, Result<f32, DistanceError>> for DistanceComputer<A>
969where
970    A: AllocatorCore,
971{
972    fn evaluate_similarity(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
973        self.inner.evaluate(query, x)
974    }
975}
976
977//////////
978// Impl //
979//////////
980
981/// The base number of bytes to allocate when attempting to serialize a quantizer.
982#[cfg(all(not(test), feature = "flatbuffers"))]
983const DEFAULT_SERIALIZED_BYTES: usize = 1024;
984
985// When testing, use a small value so we trigger the reallocation logic.
986#[cfg(all(test, feature = "flatbuffers"))]
987const DEFAULT_SERIALIZED_BYTES: usize = 1;
988
989/// Implementation for [`Quantizer`] specializing on the number of bits used for data
990/// compression.
991pub struct Impl<const NBITS: usize, A = GlobalAllocator>
992where
993    A: Allocator,
994{
995    quantizer: SphericalQuantizer<A>,
996    distance: Poly<dyn DynDistanceComputer, A>,
997}
998
999/// Pre-dispatch distance functions between compressed data vectors from `quantizer`
1000/// specialized for the current run-time mciro architecture.
1001pub trait Constructible<A = GlobalAllocator>
1002where
1003    A: Allocator,
1004{
1005    fn dispatch_distance(
1006        quantizer: &SphericalQuantizer<A>,
1007    ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError>;
1008}
1009
1010impl<const NBITS: usize, A: Allocator> Constructible<A> for Impl<NBITS, A>
1011where
1012    A: Allocator,
1013    AsData<NBITS>: FromOpaque,
1014    SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1015{
1016    fn dispatch_distance(
1017        quantizer: &SphericalQuantizer<A>,
1018    ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError> {
1019        diskann_wide::arch::dispatch2_no_features(
1020            ComputerDispatcher::<AsData<NBITS>, NBITS>::new(),
1021            quantizer,
1022            quantizer.allocator().clone(),
1023        )
1024        .map(|obj| obj.inner)
1025    }
1026}
1027
1028impl<const NBITS: usize, A> TryClone for Impl<NBITS, A>
1029where
1030    A: Allocator,
1031    AsData<NBITS>: FromOpaque,
1032    SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1033{
1034    fn try_clone(&self) -> Result<Self, AllocatorError> {
1035        Self::new(self.quantizer.try_clone()?)
1036    }
1037}
1038
1039impl<const NBITS: usize, A: Allocator> Impl<NBITS, A> {
1040    /// Construct a new plan around `quantizer` providing distance computers for `metric`.
1041    pub fn new(quantizer: SphericalQuantizer<A>) -> Result<Self, AllocatorError>
1042    where
1043        Self: Constructible<A>,
1044    {
1045        let distance = Self::dispatch_distance(&quantizer)?;
1046        Ok(Self {
1047            quantizer,
1048            distance,
1049        })
1050    }
1051
1052    /// Return the underlying [`SphericalQuantizer`].
1053    pub fn quantizer(&self) -> &SphericalQuantizer<A> {
1054        &self.quantizer
1055    }
1056
1057    /// Return `true` if this plan supports `layout` for query computers.
1058    ///
1059    /// Otherwise, return `false`.
1060    pub fn supports(layout: QueryLayout) -> bool {
1061        if const { NBITS == 1 } {
1062            [
1063                QueryLayout::SameAsData,
1064                QueryLayout::FourBitTransposed,
1065                QueryLayout::FullPrecision,
1066            ]
1067            .contains(&layout)
1068        } else {
1069            [
1070                QueryLayout::SameAsData,
1071                QueryLayout::ScalarQuantized,
1072                QueryLayout::FullPrecision,
1073            ]
1074            .contains(&layout)
1075        }
1076    }
1077
1078    /// Return a [`DistanceComputer`] that is specialized for the most specific runtime
1079    /// architecture.
1080    fn query_computer<Q, B>(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError>
1081    where
1082        Q: FromOpaque,
1083        B: AllocatorCore,
1084        SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1085    {
1086        diskann_wide::arch::dispatch2_no_features(
1087            ComputerDispatcher::<Q, NBITS>::new(),
1088            &self.quantizer,
1089            allocator,
1090        )
1091    }
1092
1093    fn compress_query<'a, T>(
1094        &self,
1095        query: &'a [f32],
1096        storage: T,
1097        scratch: ScopedAllocator<'a>,
1098    ) -> Result<(), QueryCompressionError>
1099    where
1100        SphericalQuantizer<A>: CompressIntoWith<&'a [f32], T, ScopedAllocator<'a>, Error = quantizer::CompressionError>,
1101    {
1102        self.quantizer
1103            .compress_into_with(query, storage, scratch)
1104            .map_err(|err| CompressionError::CompressionError(err).into())
1105    }
1106
1107    /// Return a [`QueryComputer`] that is specialized for the most specific runtime
1108    /// architecture.
1109    fn fused_query_computer<Q, T, B>(
1110        &self,
1111        query: &[f32],
1112        mut storage: T,
1113        allocator: B,
1114        scratch: ScopedAllocator<'_>,
1115    ) -> Result<QueryComputer<B>, QueryComputerError>
1116    where
1117        Q: FromOpaque,
1118        T: for<'a> ReborrowMut<'a>
1119            + for<'a> Reborrow<'a, Target = Q::Target<'a>>
1120            + ReportQueryLayout
1121            + Send
1122            + Sync
1123            + 'static,
1124        B: AllocatorCore,
1125        SphericalQuantizer<A>: for<'a> CompressIntoWith<
1126                &'a [f32],
1127                <T as ReborrowMut<'a>>::Target,
1128                ScopedAllocator<'a>,
1129                Error = quantizer::CompressionError,
1130            >,
1131        SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1132    {
1133        if let Err(err) = self
1134            .quantizer
1135            .compress_into_with(query, storage.reborrow_mut(), scratch)
1136        {
1137            return Err(CompressionError::CompressionError(err).into());
1138        }
1139
1140        diskann_wide::arch::dispatch3_no_features(
1141            ComputerDispatcher::<Q, NBITS>::new(),
1142            &self.quantizer,
1143            storage,
1144            allocator,
1145        )
1146        .map_err(|e| e.into())
1147    }
1148
1149    #[cfg(feature = "flatbuffers")]
1150    fn serialize<B>(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError>
1151    where
1152        B: Allocator + std::panic::UnwindSafe,
1153        A: std::panic::RefUnwindSafe,
1154    {
1155        let mut buf = FlatBufferBuilder::new_in(Poly::broadcast(
1156            0u8,
1157            DEFAULT_SERIALIZED_BYTES,
1158            allocator.clone(),
1159        )?);
1160
1161        let quantizer = &self.quantizer;
1162
1163        let (root, mut buf) = match std::panic::catch_unwind(move || {
1164            let offset = quantizer.pack(&mut buf);
1165
1166            let root = fb::spherical::Quantizer::create(
1167                &mut buf,
1168                &fb::spherical::QuantizerArgs {
1169                    quantizer: Some(offset),
1170                    nbits: NBITS as u32,
1171                },
1172            );
1173            (root, buf)
1174        }) {
1175            Ok(ret) => ret,
1176            Err(err) => match err.downcast_ref::<String>() {
1177                Some(msg) => {
1178                    if msg.contains("AllocatorError") {
1179                        return Err(AllocatorError);
1180                    } else {
1181                        std::panic::resume_unwind(err);
1182                    }
1183                }
1184                None => std::panic::resume_unwind(err),
1185            },
1186        };
1187
1188        // Finish serializing and then copy out the finished data into a newly allocated buffer.
1189        fb::spherical::finish_quantizer_buffer(&mut buf, root);
1190        Poly::from_iter(buf.finished_data().iter().copied(), allocator)
1191    }
1192}
1193
1194//----------------------//
1195// Distance Dispatching //
1196//----------------------//
1197
1198/// This trait and [`ComputerDispatcher`] are the glue for pre-dispatching
1199/// micro-architecture compatibility of distance computers.
1200///
1201/// This trait takes
1202///
1203/// * `M`: The target micro-architecture.
1204/// * `Q`: The target query type
1205///
1206/// And generates a specialized `DistanceComputer` and `QueryComputer`.
1207///
1208/// The [`ComputerDispatcher`] struct implements the [`diskann_wide::arch::Target2`] and
1209/// [`diskann_wide::arch::Target3`] traits to do the architecture-dispatching.
1210trait BuildComputer<M, Q, const N: usize>
1211where
1212    M: Architecture,
1213    Q: FromOpaque,
1214{
1215    /// Build a [`DistanceComputer`] targeting the micro-architecture `M`.
1216    ///
1217    /// The resulting object should implement distance calculations using just a single
1218    /// level of indirection.
1219    fn build_computer<A>(
1220        &self,
1221        arch: M,
1222        allocator: A,
1223    ) -> Result<DistanceComputer<A>, AllocatorError>
1224    where
1225        A: AllocatorCore;
1226
1227    /// Build a [`DistanceComputer`] with `query` targeting the micro-architecture `M`.
1228    ///
1229    /// The resulting object should implement distance calculations using just a single
1230    /// level of indirection.
1231    fn build_fused_computer<R, A>(
1232        &self,
1233        arch: M,
1234        query: R,
1235        allocator: A,
1236    ) -> Result<QueryComputer<A>, AllocatorError>
1237    where
1238        R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1239        A: AllocatorCore;
1240}
1241
1242fn identity<T>(x: T) -> T {
1243    x
1244}
1245
1246macro_rules! dispatch_map {
1247    ($N:literal, $Q:ty, $arch:ty) => {
1248        dispatch_map!($N, $Q, $arch, identity);
1249    };
1250    ($N:literal, $Q:ty, $arch:ty, $op:ident) => {
1251        impl<A> BuildComputer<$arch, $Q, $N> for SphericalQuantizer<A>
1252        where
1253            A: Allocator,
1254        {
1255            fn build_computer<B>(
1256                &self,
1257                input_arch: $arch,
1258                allocator: B,
1259            ) -> Result<DistanceComputer<B>, AllocatorError>
1260            where
1261                B: AllocatorCore,
1262            {
1263                type D = AsData<$N>;
1264
1265                // Perform any architecture down-casting.
1266                let arch = ($op)(input_arch);
1267                let dim = self.output_dim();
1268                match self.metric() {
1269                    SupportedMetric::SquaredL2 => {
1270                        let reify = Reify::<CompensatedSquaredL2, _, $Q, D>::new(
1271                            self.as_functor(),
1272                            dim,
1273                            arch,
1274                        );
1275                        DistanceComputer::new(reify, allocator)
1276                    }
1277                    SupportedMetric::InnerProduct => {
1278                        let reify =
1279                            Reify::<CompensatedIP, _, $Q, D>::new(self.as_functor(), dim, arch);
1280                        DistanceComputer::new(reify, allocator)
1281                    }
1282                    SupportedMetric::Cosine => {
1283                        let reify =
1284                            Reify::<CompensatedCosine, _, $Q, D>::new(self.as_functor(), dim, arch);
1285                        DistanceComputer::new(reify, allocator)
1286                    }
1287                }
1288            }
1289
1290            fn build_fused_computer<R, B>(
1291                &self,
1292                input_arch: $arch,
1293                query: R,
1294                allocator: B,
1295            ) -> Result<QueryComputer<B>, AllocatorError>
1296            where
1297                R: ReportQueryLayout
1298                    + for<'a> Reborrow<'a, Target = <$Q as FromOpaque>::Target<'a>>
1299                    + Send
1300                    + Sync
1301                    + 'static,
1302                B: AllocatorCore,
1303            {
1304                type D = AsData<$N>;
1305                let arch = ($op)(input_arch);
1306                let dim = self.output_dim();
1307                match self.metric() {
1308                    SupportedMetric::SquaredL2 => {
1309                        let computer: CompensatedSquaredL2 = self.as_functor();
1310                        let curried = Curried::new(computer, query);
1311                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1312                        Ok(QueryComputer::new(reify, allocator)?)
1313                    }
1314                    SupportedMetric::InnerProduct => {
1315                        let computer: CompensatedIP = self.as_functor();
1316                        let curried = Curried::new(computer, query);
1317                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1318                        Ok(QueryComputer::new(reify, allocator)?)
1319                    }
1320                    SupportedMetric::Cosine => {
1321                        let computer: CompensatedCosine = self.as_functor();
1322                        let curried = Curried::new(computer, query);
1323                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1324                        Ok(QueryComputer::new(reify, allocator)?)
1325                    }
1326                }
1327            }
1328        }
1329    };
1330}
1331
1332dispatch_map!(1, AsFull, Scalar);
1333dispatch_map!(2, AsFull, Scalar);
1334dispatch_map!(4, AsFull, Scalar);
1335dispatch_map!(8, AsFull, Scalar);
1336
1337dispatch_map!(1, AsData<1>, Scalar);
1338dispatch_map!(2, AsData<2>, Scalar);
1339dispatch_map!(4, AsData<4>, Scalar);
1340dispatch_map!(8, AsData<8>, Scalar);
1341
1342// Special Cases
1343dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Scalar);
1344dispatch_map!(2, AsQuery<2>, Scalar);
1345dispatch_map!(4, AsQuery<4>, Scalar);
1346dispatch_map!(8, AsQuery<8>, Scalar);
1347
1348cfg_if::cfg_if! {
1349    if #[cfg(target_arch = "x86_64")] {
1350        fn downcast_to_v3(arch: V4) -> V3 {
1351            arch.into()
1352        }
1353
1354        // V3
1355        dispatch_map!(1, AsFull, V3);
1356        dispatch_map!(2, AsFull, V3);
1357        dispatch_map!(4, AsFull, V3);
1358        dispatch_map!(8, AsFull, V3);
1359
1360        dispatch_map!(1, AsData<1>, V3);
1361        dispatch_map!(2, AsData<2>, V3);
1362        dispatch_map!(4, AsData<4>, V3);
1363        dispatch_map!(8, AsData<8>, V3);
1364
1365        dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V3);
1366        dispatch_map!(2, AsQuery<2>, V3);
1367        dispatch_map!(4, AsQuery<4>, V3);
1368        dispatch_map!(8, AsQuery<8>, V3);
1369
1370        // V4
1371        dispatch_map!(1, AsFull, V4, downcast_to_v3);
1372        dispatch_map!(2, AsFull, V4, downcast_to_v3);
1373        dispatch_map!(4, AsFull, V4, downcast_to_v3);
1374        dispatch_map!(8, AsFull, V4, downcast_to_v3);
1375
1376        dispatch_map!(1, AsData<1>, V4, downcast_to_v3);
1377        dispatch_map!(2, AsData<2>, V4); // specialized
1378        dispatch_map!(4, AsData<4>, V4, downcast_to_v3);
1379        dispatch_map!(8, AsData<8>, V4, downcast_to_v3);
1380
1381        dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3);
1382        dispatch_map!(2, AsQuery<2>, V4); // specialized
1383        dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3);
1384        dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3);
1385    }
1386}
1387
1388/// This struct and the [`BuildComputer`] trait are the glue for pre-dispatching
1389/// micro-architecture compatibility of distance computers.
1390///
1391/// This trait takes
1392///
1393/// * `Q`: The target query type.
1394/// * `N`: The nubmer of data bits to target.
1395///
1396/// This struct implements [`diskann_wide::arch::Target2`] and
1397/// [`diskann_wide::arch::Target3`] traits to do the architecture-dispatching, relying on
1398/// `Impl<N, A> as BuildQueryComputer` for the implementation.
1399#[derive(Debug, Clone, Copy)]
1400struct ComputerDispatcher<Q, const N: usize> {
1401    _query_type: std::marker::PhantomData<Q>,
1402}
1403
1404impl<Q, const N: usize> ComputerDispatcher<Q, N> {
1405    fn new() -> Self {
1406        Self {
1407            _query_type: std::marker::PhantomData,
1408        }
1409    }
1410}
1411
1412impl<M, const N: usize, A, B, Q>
1413    diskann_wide::arch::Target2<
1414        M,
1415        Result<DistanceComputer<B>, AllocatorError>,
1416        &SphericalQuantizer<A>,
1417        B,
1418    > for ComputerDispatcher<Q, N>
1419where
1420    M: Architecture,
1421    A: Allocator,
1422    B: AllocatorCore,
1423    Q: FromOpaque,
1424    SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1425{
1426    fn run(
1427        self,
1428        arch: M,
1429        quantizer: &SphericalQuantizer<A>,
1430        allocator: B,
1431    ) -> Result<DistanceComputer<B>, AllocatorError> {
1432        quantizer.build_computer(arch, allocator)
1433    }
1434}
1435
1436impl<M, const N: usize, A, R, B, Q>
1437    diskann_wide::arch::Target3<
1438        M,
1439        Result<QueryComputer<B>, AllocatorError>,
1440        &SphericalQuantizer<A>,
1441        R,
1442        B,
1443    > for ComputerDispatcher<Q, N>
1444where
1445    M: Architecture,
1446    A: Allocator,
1447    B: AllocatorCore,
1448    Q: FromOpaque,
1449    R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1450    SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1451{
1452    fn run(
1453        self,
1454        arch: M,
1455        quantizer: &SphericalQuantizer<A>,
1456        query: R,
1457        allocator: B,
1458    ) -> Result<QueryComputer<B>, AllocatorError> {
1459        quantizer.build_fused_computer(arch, query, allocator)
1460    }
1461}
1462
1463#[cfg(target_arch = "x86_64")]
1464trait Dispatchable<Q, const N: usize>:
1465    BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>
1466where
1467    Q: FromOpaque,
1468{
1469}
1470
1471#[cfg(target_arch = "x86_64")]
1472impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1473where
1474    Q: FromOpaque,
1475    T: BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>,
1476{
1477}
1478
1479#[cfg(not(target_arch = "x86_64"))]
1480trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N>
1481where
1482    Q: FromOpaque,
1483{
1484}
1485
1486#[cfg(not(target_arch = "x86_64"))]
1487impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1488where
1489    Q: FromOpaque,
1490    T: BuildComputer<Scalar, Q, N>,
1491{
1492}
1493
1494//---------------------------//
1495// Quantizer Implementations //
1496//---------------------------//
1497
1498impl<A, B> Quantizer<B> for Impl<1, A>
1499where
1500    A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1501    B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1502{
1503    fn nbits(&self) -> usize {
1504        1
1505    }
1506
1507    fn dim(&self) -> usize {
1508        self.quantizer.output_dim()
1509    }
1510
1511    fn full_dim(&self) -> usize {
1512        self.quantizer.input_dim()
1513    }
1514
1515    fn bytes(&self) -> usize {
1516        DataRef::<1>::canonical_bytes(self.quantizer.output_dim())
1517    }
1518
1519    fn distance_computer(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError> {
1520        self.query_computer::<AsData<1>, _>(allocator)
1521    }
1522
1523    fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1524        &*self.distance
1525    }
1526
1527    fn query_computer(
1528        &self,
1529        layout: QueryLayout,
1530        allocator: B,
1531    ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1532        match layout {
1533            QueryLayout::SameAsData => Ok(self.query_computer::<AsData<1>, _>(allocator)?),
1534            QueryLayout::FourBitTransposed => {
1535                Ok(self.query_computer::<AsQuery<4, bits::BitTranspose>, _>(allocator)?)
1536            }
1537            QueryLayout::ScalarQuantized => {
1538                Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1539            }
1540            QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1541        }
1542    }
1543
1544    fn query_buffer_description(
1545        &self,
1546        layout: QueryLayout,
1547    ) -> Result<QueryBufferDescription, UnsupportedQueryLayout> {
1548        let dim = <Self as Quantizer<B>>::dim(self);
1549        match layout {
1550            QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1551                DataRef::<1>::canonical_bytes(dim),
1552                PowerOfTwo::alignment_of::<u8>(),
1553            )),
1554            QueryLayout::FourBitTransposed => Ok(QueryBufferDescription::new(
1555                QueryRef::<4, bits::BitTranspose>::canonical_bytes(dim),
1556                PowerOfTwo::alignment_of::<u8>(),
1557            )),
1558            QueryLayout::ScalarQuantized => {
1559                Err(UnsupportedQueryLayout::new(layout, "1-bit compression"))
1560            }
1561            QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1562                FullQueryRef::canonical_bytes(dim),
1563                FullQueryRef::canonical_align(),
1564            )),
1565        }
1566    }
1567
1568    fn compress_query(
1569        &self,
1570        x: &[f32],
1571        layout: QueryLayout,
1572        allow_rescale: bool,
1573        mut buffer: OpaqueMut<'_>,
1574        scratch: ScopedAllocator<'_>,
1575    ) -> Result<(), QueryCompressionError> {
1576        let dim = <Self as Quantizer<B>>::dim(self);
1577        let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1578            match layout {
1579                QueryLayout::SameAsData => self.compress_query(
1580                    v,
1581                    DataMut::<1>::from_canonical_back_mut(&mut buffer, dim)
1582                        .map_err(NotCanonical::new)?,
1583                    scratch,
1584                ),
1585                QueryLayout::FourBitTransposed => self.compress_query(
1586                    v,
1587                    QueryMut::<4, bits::BitTranspose>::from_canonical_back_mut(&mut buffer, dim)
1588                        .map_err(NotCanonical::new)?,
1589                    scratch,
1590                ),
1591                QueryLayout::ScalarQuantized => {
1592                    Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1593                }
1594                QueryLayout::FullPrecision => self.compress_query(
1595                    v,
1596                    FullQueryMut::from_canonical_mut(&mut buffer, dim)
1597                        .map_err(NotCanonical::new)?,
1598                    scratch,
1599                ),
1600            }
1601        };
1602
1603        if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1604            let mut copy = x.to_owned();
1605            self.quantizer.rescale(&mut copy);
1606            finish(&copy)
1607        } else {
1608            finish(x)
1609        }
1610    }
1611
1612    fn fused_query_computer(
1613        &self,
1614        x: &[f32],
1615        layout: QueryLayout,
1616        allow_rescale: bool,
1617        allocator: B,
1618        scratch: ScopedAllocator<'_>,
1619    ) -> Result<QueryComputer<B>, QueryComputerError> {
1620        let dim = <Self as Quantizer<B>>::dim(self);
1621        let finish = |v: &[f32], allocator: B| -> Result<QueryComputer<B>, QueryComputerError> {
1622            match layout {
1623                    QueryLayout::SameAsData => self.fused_query_computer::<AsData<1>, Data<1, _>, _>(
1624                        v,
1625                        Data::new_in(dim, allocator.clone())?,
1626                        allocator,
1627                        scratch,
1628                    ),
1629                    QueryLayout::FourBitTransposed => self
1630                        .fused_query_computer::<AsQuery<4, bits::BitTranspose>, Query<4, bits::BitTranspose, _>, _>(
1631                            v,
1632                            Query::new_in(dim, allocator.clone())?,
1633                            allocator,
1634                            scratch,
1635                        ),
1636                    QueryLayout::ScalarQuantized => {
1637                        Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1638                    }
1639                    QueryLayout::FullPrecision => self.fused_query_computer::<AsFull, FullQuery<_>, _>(
1640                        v,
1641                        FullQuery::empty(dim, allocator.clone())?,
1642                        allocator,
1643                        scratch,
1644                    ),
1645                }
1646        };
1647
1648        if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1649            let mut copy = x.to_owned();
1650            self.quantizer.rescale(&mut copy);
1651            finish(&copy, allocator)
1652        } else {
1653            finish(x, allocator)
1654        }
1655    }
1656
1657    fn is_supported(&self, layout: QueryLayout) -> bool {
1658        Self::supports(layout)
1659    }
1660
1661    fn compress(
1662        &self,
1663        x: &[f32],
1664        mut into: OpaqueMut<'_>,
1665        scratch: ScopedAllocator<'_>,
1666    ) -> Result<(), CompressionError> {
1667        let dim = <Self as Quantizer<B>>::dim(self);
1668        let into = DataMut::<1>::from_canonical_back_mut(into.inspect(), dim)
1669            .map_err(CompressionError::not_canonical)?;
1670        self.quantizer
1671            .compress_into_with(x, into, scratch)
1672            .map_err(CompressionError::CompressionError)
1673    }
1674
1675    fn metric(&self) -> SupportedMetric {
1676        self.quantizer.metric()
1677    }
1678
1679    fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1680        let clone = (*self).try_clone()?;
1681        poly!({ Quantizer<B> }, clone, allocator)
1682    }
1683
1684    #[cfg(feature = "flatbuffers")]
1685    fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1686        Impl::<1, A>::serialize(self, allocator)
1687    }
1688}
1689
1690macro_rules! plan {
1691    ($N:literal) => {
1692        impl<A, B> Quantizer<B> for Impl<$N, A>
1693        where
1694            A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1695            B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1696        {
1697            fn nbits(&self) -> usize {
1698                $N
1699            }
1700
1701            fn dim(&self) -> usize {
1702                self.quantizer.output_dim()
1703            }
1704
1705            fn full_dim(&self) -> usize {
1706                self.quantizer.input_dim()
1707            }
1708
1709            fn bytes(&self) -> usize {
1710                DataRef::<$N>::canonical_bytes(<Self as Quantizer<B>>::dim(self))
1711            }
1712
1713            fn distance_computer(
1714                &self,
1715                allocator: B
1716            ) -> Result<DistanceComputer<B>, AllocatorError> {
1717                self.query_computer::<AsData<$N>, _>(allocator)
1718            }
1719
1720            fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1721                &*self.distance
1722            }
1723
1724            fn query_computer(
1725                &self,
1726                layout: QueryLayout,
1727                allocator: B,
1728            ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1729                match layout {
1730                    QueryLayout::SameAsData => Ok(self.query_computer::<AsData<$N>, _>(allocator)?)
1731                    ,
1732                    QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout::new(
1733                        layout,
1734                        concat!($N, "-bit compression"),
1735                    ).into()),
1736                    QueryLayout::ScalarQuantized => {
1737                        Ok(self.query_computer::<AsQuery<$N, bits::Dense>, _>(allocator)?)
1738                    },
1739                    QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1740
1741                }
1742            }
1743
1744            fn query_buffer_description(
1745                &self,
1746                layout: QueryLayout
1747            ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>
1748            {
1749                let dim = <Self as Quantizer<B>>::dim(self);
1750                match layout {
1751                    QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1752                        DataRef::<$N>::canonical_bytes(dim),
1753                        PowerOfTwo::alignment_of::<u8>(),
1754                    )),
1755                    QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout {
1756                        layout,
1757                        desc: concat!($N, "-bit compression"),
1758                    }),
1759                    QueryLayout::ScalarQuantized => Ok(QueryBufferDescription::new(
1760                        QueryRef::<$N, bits::Dense>::canonical_bytes(dim),
1761                        PowerOfTwo::alignment_of::<u8>(),
1762                    )),
1763                    QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1764                        FullQueryRef::canonical_bytes(dim),
1765                        FullQueryRef::canonical_align(),
1766                    )),
1767                }
1768            }
1769
1770            fn compress_query(
1771                &self,
1772                x: &[f32],
1773                layout: QueryLayout,
1774                allow_rescale: bool,
1775                mut buffer: OpaqueMut<'_>,
1776                scratch: ScopedAllocator<'_>,
1777            ) -> Result<(), QueryCompressionError> {
1778                let dim = <Self as Quantizer<B>>::dim(self);
1779                let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1780                    match layout {
1781                        QueryLayout::SameAsData => self.compress_query(
1782                            v,
1783                            DataMut::<$N>::from_canonical_back_mut(
1784                                &mut buffer,
1785                                dim,
1786                            ).map_err(NotCanonical::new)?,
1787                            scratch,
1788                        ),
1789                        QueryLayout::FourBitTransposed => {
1790                            Err(UnsupportedQueryLayout::new(
1791                                layout,
1792                                concat!($N, "-bit compression"),
1793                            ).into())
1794                        },
1795                        QueryLayout::ScalarQuantized => self.compress_query(
1796                            v,
1797                            QueryMut::<$N, bits::Dense>::from_canonical_back_mut(
1798                                &mut buffer,
1799                                dim,
1800                            ).map_err(NotCanonical::new)?,
1801                            scratch,
1802                        ),
1803                        QueryLayout::FullPrecision => self.compress_query(
1804                            v,
1805                            FullQueryMut::from_canonical_mut(
1806                                &mut buffer,
1807                                dim,
1808                            ).map_err(NotCanonical::new)?,
1809                            scratch,
1810                        ),
1811                    }
1812                };
1813
1814                if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1815                    let mut copy = x.to_owned();
1816                    self.quantizer.rescale(&mut copy);
1817                    finish(&copy)
1818                } else {
1819                    finish(x)
1820                }
1821            }
1822
1823            fn fused_query_computer(
1824                &self,
1825                x: &[f32],
1826                layout: QueryLayout,
1827                allow_rescale: bool,
1828                allocator: B,
1829                scratch: ScopedAllocator<'_>,
1830            ) -> Result<QueryComputer<B>, QueryComputerError>
1831            {
1832                let dim = <Self as Quantizer<B>>::dim(self);
1833                let finish = |v: &[f32]| -> Result<QueryComputer<B>, QueryComputerError> {
1834                    match layout {
1835                        QueryLayout::SameAsData => {
1836                            self.fused_query_computer::<AsData<$N>, Data<$N, _>, B>(
1837                                v,
1838                                Data::new_in(dim, allocator.clone())?,
1839                                allocator,
1840                                scratch,
1841                            )
1842                        },
1843                        QueryLayout::FourBitTransposed => {
1844                            Err(UnsupportedQueryLayout::new(
1845                                layout,
1846                                concat!($N, "-bit compression"),
1847                            ).into())
1848                        },
1849                        QueryLayout::ScalarQuantized => {
1850                            self.fused_query_computer::<AsQuery<$N, bits::Dense>, Query<$N, bits::Dense, _>, B>(
1851                                v,
1852                                Query::new_in(dim, allocator.clone())?,
1853                                allocator,
1854                                scratch,
1855                            )
1856                        },
1857                        QueryLayout::FullPrecision => {
1858                            self.fused_query_computer::<AsFull, FullQuery<_>, B>(
1859                                v,
1860                                FullQuery::empty(dim, allocator.clone())?,
1861                                allocator,
1862                                scratch,
1863                            )
1864                        },
1865                    }
1866                };
1867
1868                let metric = <Self as Quantizer<B>>::metric(self);
1869                if allow_rescale && metric == SupportedMetric::InnerProduct {
1870                    let mut copy = x.to_owned();
1871                    self.quantizer.rescale(&mut copy);
1872                    finish(&copy)
1873                } else {
1874                    finish(x)
1875                }
1876            }
1877
1878            fn is_supported(&self, layout: QueryLayout) -> bool {
1879                Self::supports(layout)
1880            }
1881
1882            fn compress(
1883                &self,
1884                x: &[f32],
1885                mut into: OpaqueMut<'_>,
1886                scratch: ScopedAllocator<'_>,
1887            ) -> Result<(), CompressionError> {
1888                let dim = <Self as Quantizer<B>>::dim(self);
1889                let into = DataMut::<$N>::from_canonical_back_mut(into.inspect(), dim)
1890                    .map_err(CompressionError::not_canonical)?;
1891
1892                self.quantizer.compress_into_with(x, into, scratch)
1893                    .map_err(CompressionError::CompressionError)
1894            }
1895
1896            fn metric(&self) -> SupportedMetric {
1897                self.quantizer.metric()
1898            }
1899
1900            fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1901                let clone = (&*self).try_clone()?;
1902                poly!({ Quantizer<B> }, clone, allocator)
1903            }
1904
1905            #[cfg(feature = "flatbuffers")]
1906            fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1907                Impl::<$N, A>::serialize(self, allocator)
1908            }
1909        }
1910    };
1911    ($N:literal, $($Ns:literal),*) => {
1912        plan!($N);
1913        $(plan!($Ns);)*
1914    }
1915}
1916
1917plan!(2, 4, 8);
1918
1919////////////////
1920// Flatbuffer //
1921////////////////
1922
1923#[cfg(feature = "flatbuffers")]
1924#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1925#[derive(Debug, Clone, Error)]
1926#[non_exhaustive]
1927pub enum DeserializationError {
1928    #[error("unhandled file identifier in flatbuffer")]
1929    InvalidIdentifier,
1930
1931    #[error("unsupported number of bits ({0})")]
1932    UnsupportedBitWidth(u32),
1933
1934    #[error(transparent)]
1935    InvalidQuantizer(#[from] super::quantizer::DeserializationError),
1936
1937    #[error(transparent)]
1938    InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
1939
1940    #[error(transparent)]
1941    AllocatorError(#[from] AllocatorError),
1942}
1943
1944/// Attempt to deserialize a `spherical::Quantizer` flatbuffer into one of the concrete
1945/// implementations of `Quantizer`.
1946///
1947/// This function guarantees that the returned `Poly` is the first object allocated through
1948/// `alloc`.
1949#[cfg(feature = "flatbuffers")]
1950#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1951pub fn try_deserialize<O, A>(
1952    data: &[u8],
1953    alloc: A,
1954) -> Result<Poly<dyn Quantizer<O>, A>, DeserializationError>
1955where
1956    O: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1957    A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1958{
1959    // An inner impl is used to ensure that the returned `Poly` is allocated before any of
1960    // the allocations needed by the members.
1961    //
1962    // This ensures that if a bump allocator is used, the root object appears first.
1963    fn unpack_bits<'a, const NBITS: usize, O, A>(
1964        proto: fb::spherical::SphericalQuantizer<'_>,
1965        alloc: A,
1966    ) -> Result<Poly<dyn Quantizer<O> + 'a, A>, DeserializationError>
1967    where
1968        O: Allocator + Send + Sync + std::panic::UnwindSafe + 'static,
1969        A: Allocator + Send + Sync + 'a,
1970        Impl<NBITS, A>: Quantizer<O> + Constructible<A>,
1971    {
1972        let imp = match Poly::new_with(
1973            #[inline(never)]
1974            |alloc| -> Result<_, super::quantizer::DeserializationError> {
1975                let quantizer = SphericalQuantizer::try_unpack(alloc, proto)?;
1976                Ok(Impl::new(quantizer)?)
1977            },
1978            alloc,
1979        ) {
1980            Ok(imp) => imp,
1981            Err(CompoundError::Allocator(err)) => {
1982                return Err(err.into());
1983            }
1984            Err(CompoundError::Constructor(err)) => {
1985                return Err(err.into());
1986            }
1987        };
1988        Ok(poly!({ Quantizer<O> }, imp))
1989    }
1990
1991    // Check that this is one of the known identifiers.
1992    if !fb::spherical::quantizer_buffer_has_identifier(data) {
1993        return Err(DeserializationError::InvalidIdentifier);
1994    }
1995
1996    // Match as much as we can without allocating.
1997    //
1998    // Then, we branch on the number of bits.
1999    let root = fb::spherical::root_as_quantizer(data)?;
2000    let nbits = root.nbits();
2001    let proto = root.quantizer();
2002
2003    match nbits {
2004        1 => unpack_bits::<1, _, _>(proto, alloc),
2005        2 => unpack_bits::<2, _, _>(proto, alloc),
2006        4 => unpack_bits::<4, _, _>(proto, alloc),
2007        8 => unpack_bits::<8, _, _>(proto, alloc),
2008        n => Err(DeserializationError::UnsupportedBitWidth(n)),
2009    }
2010}
2011
2012///////////
2013// Tests //
2014///////////
2015
2016#[cfg(test)]
2017mod tests {
2018    use diskann_utils::views::{Matrix, MatrixView};
2019    use rand::{SeedableRng, rngs::StdRng};
2020
2021    use super::*;
2022    use crate::{
2023        algorithms::{TransformKind, transforms::TargetDim},
2024        alloc::{AlignedAllocator, GlobalAllocator, Poly},
2025        num::PowerOfTwo,
2026        spherical::PreScale,
2027    };
2028
2029    ////////////////////
2030    // Test Quantizer //
2031    ////////////////////
2032
2033    fn test_plan_1_bit(plan: &dyn Quantizer) {
2034        assert_eq!(
2035            plan.nbits(),
2036            1,
2037            "this test only applies to 1-bit quantization"
2038        );
2039
2040        // Check Layouts.
2041        for layout in QueryLayout::all() {
2042            match layout {
2043                QueryLayout::SameAsData
2044                | QueryLayout::FourBitTransposed
2045                | QueryLayout::FullPrecision => assert!(
2046                    plan.is_supported(layout),
2047                    "expected {} to be supported",
2048                    layout
2049                ),
2050                QueryLayout::ScalarQuantized => assert!(
2051                    !plan.is_supported(layout),
2052                    "expected {} to not be supported",
2053                    layout
2054                ),
2055            }
2056        }
2057    }
2058
2059    fn test_plan_n_bit(plan: &dyn Quantizer, nbits: usize) {
2060        assert_ne!(nbits, 1, "there is another test for 1-bit quantizers");
2061        assert_eq!(
2062            plan.nbits(),
2063            nbits,
2064            "this test only applies to 1-bit quantization"
2065        );
2066
2067        // Check Layouts.
2068        for layout in QueryLayout::all() {
2069            match layout {
2070                QueryLayout::SameAsData
2071                | QueryLayout::ScalarQuantized
2072                | QueryLayout::FullPrecision => assert!(
2073                    plan.is_supported(layout),
2074                    "expected {} to be supported",
2075                    layout
2076                ),
2077                QueryLayout::FourBitTransposed => assert!(
2078                    !plan.is_supported(layout),
2079                    "expected {} to not be supported",
2080                    layout
2081                ),
2082            }
2083        }
2084    }
2085
2086    #[inline(never)]
2087    fn test_plan(plan: &dyn Quantizer, nbits: usize, dataset: MatrixView<f32>) {
2088        // Perform the bit-specific test.
2089        if nbits == 1 {
2090            test_plan_1_bit(plan);
2091        } else {
2092            test_plan_n_bit(plan, nbits);
2093        }
2094
2095        // Run bit-width agnostic tests.
2096        assert_eq!(plan.full_dim(), dataset.ncols());
2097
2098        // Use the correct alignment for the base pointers.
2099        let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2100        let mut a = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2101        let mut b = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2102        let scoped_global = ScopedAllocator::global();
2103
2104        plan.compress(dataset.row(0), OpaqueMut::new(&mut a), scoped_global)
2105            .unwrap();
2106        plan.compress(dataset.row(1), OpaqueMut::new(&mut b), scoped_global)
2107            .unwrap();
2108
2109        let f = plan.distance_computer(GlobalAllocator).unwrap();
2110        let _: f32 = f
2111            .evaluate_similarity(Opaque::new(&a), Opaque::new(&b))
2112            .unwrap();
2113
2114        let test_errors = |f: &dyn DynDistanceComputer| {
2115            // `a` too short
2116            let err = f
2117                .evaluate(Opaque::new(&a[..a.len() - 1]), Opaque::new(&b))
2118                .unwrap_err();
2119            assert!(matches!(err, DistanceError::QueryReify(_)));
2120
2121            // `a` too long
2122            let err = f
2123                .evaluate(Opaque::new(&vec![0u8; a.len() + 1]), Opaque::new(&b))
2124                .unwrap_err();
2125            assert!(matches!(err, DistanceError::QueryReify(_)));
2126
2127            // `b` too short
2128            let err = f
2129                .evaluate(Opaque::new(&a), Opaque::new(&b[..b.len() - 1]))
2130                .unwrap_err();
2131            assert!(matches!(err, DistanceError::XReify(_)));
2132
2133            // `a` too long
2134            let err = f
2135                .evaluate(Opaque::new(&a), Opaque::new(&vec![0u8; b.len() + 1]))
2136                .unwrap_err();
2137            assert!(matches!(err, DistanceError::XReify(_)));
2138        };
2139
2140        test_errors(&*f.inner);
2141
2142        let f = plan.distance_computer_ref();
2143        let _: f32 = f.evaluate(Opaque::new(&a), Opaque::new(&b)).unwrap();
2144        test_errors(f);
2145
2146        // Test all supported flavors of `QueryComputer`.
2147        for layout in QueryLayout::all() {
2148            if !plan.is_supported(layout) {
2149                let check_message = |msg: &str| {
2150                    assert!(
2151                        msg.contains(&(layout.to_string())),
2152                        "error message ({}) should contain the layout \"{}\"",
2153                        msg,
2154                        layout
2155                    );
2156                    assert!(
2157                        msg.contains(&format!("{}", nbits)),
2158                        "error message ({}) should contain the number of bits \"{}\"",
2159                        msg,
2160                        nbits
2161                    );
2162                };
2163
2164                // Error for query computer
2165                {
2166                    let err = plan
2167                        .fused_query_computer(
2168                            dataset.row(1),
2169                            layout,
2170                            false,
2171                            GlobalAllocator,
2172                            scoped_global,
2173                        )
2174                        .unwrap_err();
2175
2176                    let msg = err.to_string();
2177                    check_message(&msg);
2178                }
2179
2180                // Query buffer
2181                {
2182                    let err = plan.query_buffer_description(layout).unwrap_err();
2183                    let msg = err.to_string();
2184                    check_message(&msg);
2185                }
2186
2187                // Compresss Query Into
2188                {
2189                    let buffer = &mut [];
2190                    let err = plan
2191                        .compress_query(
2192                            dataset.row(1),
2193                            layout,
2194                            true,
2195                            OpaqueMut::new(buffer),
2196                            scoped_global,
2197                        )
2198                        .unwrap_err();
2199                    let msg = err.to_string();
2200                    check_message(&msg);
2201                }
2202
2203                // Standalone Query Computer
2204                {
2205                    let err = plan.query_computer(layout, GlobalAllocator).unwrap_err();
2206                    let msg = err.to_string();
2207                    check_message(&msg);
2208                }
2209
2210                continue;
2211            }
2212
2213            let g = plan
2214                .fused_query_computer(
2215                    dataset.row(1),
2216                    layout,
2217                    false,
2218                    GlobalAllocator,
2219                    scoped_global,
2220                )
2221                .unwrap();
2222            assert_eq!(
2223                g.layout(),
2224                layout,
2225                "the query computer should faithfully preserve the requested layout"
2226            );
2227
2228            let direct: f32 = g.evaluate_similarity(Opaque(&a)).unwrap();
2229
2230            // Check that the fused computer correctly returns errors for invalid inputs.
2231            {
2232                let err = g
2233                    .evaluate_similarity(Opaque::new(&a[..a.len() - 1]))
2234                    .unwrap_err();
2235                assert!(matches!(err, QueryDistanceError::XReify(_)));
2236
2237                let err = g
2238                    .evaluate_similarity(Opaque::new(&vec![0u8; a.len() + 1]))
2239                    .unwrap_err();
2240                assert!(matches!(err, QueryDistanceError::XReify(_)));
2241            }
2242
2243            let sizes = plan.query_buffer_description(layout).unwrap();
2244            let mut buf =
2245                Poly::broadcast(0u8, sizes.bytes(), AlignedAllocator::new(sizes.align())).unwrap();
2246
2247            plan.compress_query(
2248                dataset.row(1),
2249                layout,
2250                false,
2251                OpaqueMut::new(&mut buf),
2252                scoped_global,
2253            )
2254            .unwrap();
2255
2256            let standalone = plan.query_computer(layout, GlobalAllocator).unwrap();
2257
2258            assert_eq!(
2259                standalone.layout(),
2260                layout,
2261                "the standalone computer did not preserve the requested layout",
2262            );
2263
2264            let indirect: f32 = standalone
2265                .evaluate_similarity(Opaque(&buf), Opaque(&a))
2266                .unwrap();
2267
2268            assert_eq!(
2269                direct, indirect,
2270                "the two different query computation APIs did not return the same result"
2271            );
2272
2273            // Errors
2274            let too_small = &dataset.row(0)[..dataset.ncols() - 1];
2275            assert!(
2276                plan.fused_query_computer(too_small, layout, false, GlobalAllocator, scoped_global)
2277                    .is_err()
2278            );
2279        }
2280
2281        // Errors
2282        {
2283            let mut too_small = vec![u8::default(); plan.bytes() - 1];
2284            assert!(
2285                plan.compress(dataset.row(0), OpaqueMut(&mut too_small), scoped_global)
2286                    .is_err()
2287            );
2288
2289            let mut too_big = vec![u8::default(); plan.bytes() + 1];
2290            assert!(
2291                plan.compress(dataset.row(0), OpaqueMut(&mut too_big), scoped_global)
2292                    .is_err()
2293            );
2294
2295            let mut just_right = vec![u8::default(); plan.bytes()];
2296            assert!(
2297                plan.compress(
2298                    &dataset.row(0)[..dataset.ncols() - 1],
2299                    OpaqueMut(&mut just_right),
2300                    scoped_global
2301                )
2302                .is_err()
2303            );
2304        }
2305    }
2306
2307    fn make_impl<const NBITS: usize>(metric: SupportedMetric) -> (Impl<NBITS>, Matrix<f32>)
2308    where
2309        Impl<NBITS>: Constructible,
2310    {
2311        let data = test_dataset();
2312        let mut rng = StdRng::seed_from_u64(0x7d535118722ff197);
2313
2314        let quantizer = SphericalQuantizer::train(
2315            data.as_view(),
2316            TransformKind::PaddingHadamard {
2317                target_dim: TargetDim::Natural,
2318            },
2319            metric,
2320            PreScale::None,
2321            &mut rng,
2322            GlobalAllocator,
2323        )
2324        .unwrap();
2325
2326        (Impl::<NBITS>::new(quantizer).unwrap(), data)
2327    }
2328
2329    #[test]
2330    fn test_plan_1bit_l2() {
2331        let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2332        test_plan(&plan, 1, data.as_view());
2333    }
2334
2335    #[test]
2336    fn test_plan_1bit_ip() {
2337        let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2338        test_plan(&plan, 1, data.as_view());
2339    }
2340
2341    #[test]
2342    fn test_plan_1bit_cosine() {
2343        let (plan, data) = make_impl::<1>(SupportedMetric::Cosine);
2344        test_plan(&plan, 1, data.as_view());
2345    }
2346
2347    #[test]
2348    fn test_plan_2bit_l2() {
2349        let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2350        test_plan(&plan, 2, data.as_view());
2351    }
2352
2353    #[test]
2354    fn test_plan_2bit_ip() {
2355        let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2356        test_plan(&plan, 2, data.as_view());
2357    }
2358
2359    #[test]
2360    fn test_plan_2bit_cosine() {
2361        let (plan, data) = make_impl::<2>(SupportedMetric::Cosine);
2362        test_plan(&plan, 2, data.as_view());
2363    }
2364
2365    #[test]
2366    fn test_plan_4bit_l2() {
2367        let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2368        test_plan(&plan, 4, data.as_view());
2369    }
2370
2371    #[test]
2372    fn test_plan_4bit_ip() {
2373        let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2374        test_plan(&plan, 4, data.as_view());
2375    }
2376
2377    #[test]
2378    fn test_plan_4bit_cosine() {
2379        let (plan, data) = make_impl::<4>(SupportedMetric::Cosine);
2380        test_plan(&plan, 4, data.as_view());
2381    }
2382
2383    #[test]
2384    fn test_plan_8bit_l2() {
2385        let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2386        test_plan(&plan, 8, data.as_view());
2387    }
2388
2389    #[test]
2390    fn test_plan_8bit_ip() {
2391        let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2392        test_plan(&plan, 8, data.as_view());
2393    }
2394
2395    #[test]
2396    fn test_plan_8bit_cosine() {
2397        let (plan, data) = make_impl::<8>(SupportedMetric::Cosine);
2398        test_plan(&plan, 8, data.as_view());
2399    }
2400
2401    fn test_dataset() -> Matrix<f32> {
2402        let data = vec![
2403            0.28657,
2404            -0.0318168,
2405            0.0666847,
2406            0.0329265,
2407            -0.00829283,
2408            0.168735,
2409            -0.000846311,
2410            -0.360779, // row 0
2411            -0.0968938,
2412            0.161921,
2413            -0.0979579,
2414            0.102228,
2415            -0.259928,
2416            -0.139634,
2417            0.165384,
2418            -0.293443, // row 1
2419            0.130205,
2420            0.265737,
2421            0.401816,
2422            -0.407552,
2423            0.13012,
2424            -0.0475244,
2425            0.511723,
2426            -0.4372, // row 2
2427            -0.0979126,
2428            0.135861,
2429            -0.0154144,
2430            -0.14047,
2431            -0.0250029,
2432            -0.190279,
2433            0.407283,
2434            -0.389184, // row 3
2435            -0.264153,
2436            0.0696822,
2437            -0.145585,
2438            0.370284,
2439            0.186825,
2440            -0.140736,
2441            0.274703,
2442            -0.334563, // row 4
2443            0.247613,
2444            0.513165,
2445            -0.0845867,
2446            0.0532264,
2447            -0.00480601,
2448            -0.122408,
2449            0.47227,
2450            -0.268301, // row 5
2451            0.103198,
2452            0.30756,
2453            -0.316293,
2454            -0.0686877,
2455            -0.330729,
2456            -0.461997,
2457            0.550857,
2458            -0.240851, // row 6
2459            0.128258,
2460            0.786291,
2461            -0.0268103,
2462            0.111763,
2463            -0.308962,
2464            -0.17407,
2465            0.437154,
2466            -0.159879, // row 7
2467            0.00374063,
2468            0.490301,
2469            0.0327826,
2470            -0.0340962,
2471            -0.118605,
2472            0.163879,
2473            0.2737,
2474            -0.299942, // row 8
2475            -0.284077,
2476            0.249377,
2477            -0.0307734,
2478            -0.0661631,
2479            0.233854,
2480            0.427987,
2481            0.614132,
2482            -0.288649, // row 9
2483            -0.109492,
2484            0.203939,
2485            -0.73956,
2486            -0.130748,
2487            0.22072,
2488            0.0647836,
2489            0.328726,
2490            -0.374602, // row 10
2491            -0.223114,
2492            0.0243489,
2493            0.109195,
2494            -0.416914,
2495            0.0201052,
2496            -0.0190542,
2497            0.947078,
2498            -0.333229, // row 11
2499            -0.165869,
2500            -0.00296729,
2501            -0.414378,
2502            0.231321,
2503            0.205365,
2504            0.161761,
2505            0.148608,
2506            -0.395063, // row 12
2507            -0.0498255,
2508            0.193279,
2509            -0.110946,
2510            -0.181174,
2511            -0.274578,
2512            -0.227511,
2513            0.190208,
2514            -0.256174, // row 13
2515            -0.188106,
2516            -0.0292958,
2517            0.0930939,
2518            0.0558456,
2519            0.257437,
2520            0.685481,
2521            0.307922,
2522            -0.320006, // row 14
2523            0.250035,
2524            0.275942,
2525            -0.0856306,
2526            -0.352027,
2527            -0.103509,
2528            -0.00890859,
2529            0.276121,
2530            -0.324718, // row 15
2531        ];
2532
2533        Matrix::try_from(data.into(), 16, 8).unwrap()
2534    }
2535
2536    #[cfg(feature = "flatbuffers")]
2537    mod serialization {
2538        use std::sync::{
2539            Arc,
2540            atomic::{AtomicBool, Ordering},
2541        };
2542
2543        use super::*;
2544        use crate::alloc::{BumpAllocator, GlobalAllocator};
2545
2546        #[inline(never)]
2547        fn test_plan_serialization(
2548            quantizer: &dyn Quantizer,
2549            nbits: usize,
2550            dataset: MatrixView<f32>,
2551        ) {
2552            // Run bit-width agnostic tests.
2553            assert_eq!(quantizer.full_dim(), dataset.ncols());
2554            let scoped_global = ScopedAllocator::global();
2555
2556            let serialized = quantizer.serialize(GlobalAllocator).unwrap();
2557            let deserialized =
2558                try_deserialize::<GlobalAllocator, _>(&serialized, GlobalAllocator).unwrap();
2559
2560            assert_eq!(deserialized.nbits(), nbits);
2561            assert_eq!(deserialized.bytes(), quantizer.bytes());
2562            assert_eq!(deserialized.dim(), quantizer.dim());
2563            assert_eq!(deserialized.full_dim(), quantizer.full_dim());
2564            assert_eq!(deserialized.metric(), quantizer.metric());
2565
2566            for layout in QueryLayout::all() {
2567                assert_eq!(
2568                    deserialized.is_supported(layout),
2569                    quantizer.is_supported(layout)
2570                );
2571            }
2572
2573            // Use the correct alignment for the base pointers.
2574            let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2575            {
2576                let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2577                let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2578
2579                for row in dataset.row_iter() {
2580                    quantizer
2581                        .compress(row, OpaqueMut::new(&mut a), scoped_global)
2582                        .unwrap();
2583                    deserialized
2584                        .compress(row, OpaqueMut::new(&mut b), scoped_global)
2585                        .unwrap();
2586
2587                    // Compressed representation should be identical.
2588                    assert_eq!(a, b);
2589                }
2590            }
2591
2592            // Distance Computer
2593            {
2594                let mut a0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2595                let mut a1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2596                let mut b0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2597                let mut b1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2598
2599                let q_computer = quantizer.distance_computer(GlobalAllocator).unwrap();
2600                let q_computer_ref = quantizer.distance_computer_ref();
2601                let d_computer = deserialized.distance_computer(GlobalAllocator).unwrap();
2602                let d_computer_ref = deserialized.distance_computer_ref();
2603
2604                for r0 in dataset.row_iter() {
2605                    quantizer
2606                        .compress(r0, OpaqueMut::new(&mut a0), scoped_global)
2607                        .unwrap();
2608                    deserialized
2609                        .compress(r0, OpaqueMut::new(&mut b0), scoped_global)
2610                        .unwrap();
2611                    for r1 in dataset.row_iter() {
2612                        quantizer
2613                            .compress(r1, OpaqueMut::new(&mut a1), scoped_global)
2614                            .unwrap();
2615                        deserialized
2616                            .compress(r1, OpaqueMut::new(&mut b1), scoped_global)
2617                            .unwrap();
2618
2619                        let a0 = Opaque::new(&a0);
2620                        let a1 = Opaque::new(&a1);
2621
2622                        let q_computer_dist = q_computer.evaluate_similarity(a0, a1).unwrap();
2623                        let d_computer_dist = d_computer.evaluate_similarity(a0, a1).unwrap();
2624
2625                        assert_eq!(q_computer_dist, d_computer_dist);
2626
2627                        let q_computer_ref_dist = q_computer_ref.evaluate(a0, a1).unwrap();
2628
2629                        assert_eq!(q_computer_dist, q_computer_ref_dist);
2630
2631                        let d_computer_ref_dist = d_computer_ref.evaluate(a0, a1).unwrap();
2632                        assert_eq!(d_computer_dist, d_computer_ref_dist);
2633                    }
2634                }
2635            }
2636
2637            // Query Computer
2638            {
2639                let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2640                let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2641
2642                for layout in QueryLayout::all() {
2643                    if !quantizer.is_supported(layout) {
2644                        continue;
2645                    }
2646
2647                    for r in dataset.row_iter() {
2648                        let q_computer = quantizer
2649                            .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2650                            .unwrap();
2651                        let d_computer = deserialized
2652                            .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2653                            .unwrap();
2654
2655                        for u in dataset.row_iter() {
2656                            quantizer
2657                                .compress(u, OpaqueMut::new(&mut a), scoped_global)
2658                                .unwrap();
2659                            deserialized
2660                                .compress(u, OpaqueMut::new(&mut b), scoped_global)
2661                                .unwrap();
2662
2663                            assert_eq!(
2664                                q_computer.evaluate_similarity(Opaque::new(&a)).unwrap(),
2665                                d_computer.evaluate_similarity(Opaque::new(&b)).unwrap(),
2666                            );
2667                        }
2668                    }
2669                }
2670            }
2671        }
2672
2673        // An allocator that succeeds on its first allocation but fails on its second.
2674        #[derive(Debug, Clone)]
2675        struct FlakyAllocator {
2676            have_allocated: Arc<AtomicBool>,
2677        }
2678
2679        impl FlakyAllocator {
2680            fn new(have_allocated: Arc<AtomicBool>) -> Self {
2681                Self { have_allocated }
2682            }
2683        }
2684
2685        // SAFETY: This is a wrapper around GlobalAllocator that only succeeds once.
2686        unsafe impl AllocatorCore for FlakyAllocator {
2687            fn allocate(
2688                &self,
2689                layout: std::alloc::Layout,
2690            ) -> Result<std::ptr::NonNull<[u8]>, AllocatorError> {
2691                if self.have_allocated.swap(true, Ordering::Relaxed) {
2692                    Err(AllocatorError)
2693                } else {
2694                    GlobalAllocator.allocate(layout)
2695                }
2696            }
2697
2698            unsafe fn deallocate(&self, ptr: std::ptr::NonNull<[u8]>, layout: std::alloc::Layout) {
2699                // SAFETY: Inherited from caller.
2700                unsafe { GlobalAllocator.deallocate(ptr, layout) }
2701            }
2702        }
2703
2704        fn test_plan_panic_boundary<const NBITS: usize>(v: &Impl<NBITS>)
2705        where
2706            Impl<NBITS>: Quantizer,
2707        {
2708            // Ensure that we do not panic if reallocation returns an error.
2709            let have_allocated = Arc::new(AtomicBool::new(false));
2710            let _: AllocatorError = v
2711                .serialize(FlakyAllocator::new(have_allocated.clone()))
2712                .unwrap_err();
2713            assert!(have_allocated.load(Ordering::Relaxed));
2714        }
2715
2716        #[test]
2717        fn test_plan_1bit_l2() {
2718            let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2719            test_plan_panic_boundary(&plan);
2720            test_plan_serialization(&plan, 1, data.as_view());
2721        }
2722
2723        #[test]
2724        fn test_plan_1bit_ip() {
2725            let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2726            test_plan_panic_boundary(&plan);
2727            test_plan_serialization(&plan, 1, data.as_view());
2728        }
2729
2730        #[test]
2731        fn test_plan_2bit_l2() {
2732            let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2733            test_plan_panic_boundary(&plan);
2734            test_plan_serialization(&plan, 2, data.as_view());
2735        }
2736
2737        #[test]
2738        fn test_plan_2bit_ip() {
2739            let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2740            test_plan_panic_boundary(&plan);
2741            test_plan_serialization(&plan, 2, data.as_view());
2742        }
2743
2744        #[test]
2745        fn test_plan_4bit_l2() {
2746            let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2747            test_plan_panic_boundary(&plan);
2748            test_plan_serialization(&plan, 4, data.as_view());
2749        }
2750
2751        #[test]
2752        fn test_plan_4bit_ip() {
2753            let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2754            test_plan_panic_boundary(&plan);
2755            test_plan_serialization(&plan, 4, data.as_view());
2756        }
2757
2758        #[test]
2759        fn test_plan_8bit_l2() {
2760            let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2761            test_plan_panic_boundary(&plan);
2762            test_plan_serialization(&plan, 8, data.as_view());
2763        }
2764
2765        #[test]
2766        fn test_plan_8bit_ip() {
2767            let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2768            test_plan_panic_boundary(&plan);
2769            test_plan_serialization(&plan, 8, data.as_view());
2770        }
2771
2772        #[test]
2773        fn test_allocation_order() {
2774            let (plan, _) = make_impl::<1>(SupportedMetric::SquaredL2);
2775            let buf = plan.serialize(GlobalAllocator).unwrap();
2776
2777            let allocator = BumpAllocator::new(8192, PowerOfTwo::new(64).unwrap()).unwrap();
2778            let deserialized =
2779                try_deserialize::<GlobalAllocator, _>(&buf, allocator.clone()).unwrap();
2780            assert_eq!(
2781                Poly::as_ptr(&deserialized).cast::<u8>(),
2782                allocator.as_ptr(),
2783                "expected the returned box to be allocated first",
2784            );
2785        }
2786    }
2787}