diskann_quantization/spherical/
iface.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! The main export of this module is the dyn compatible [`Quantizer`] trait, which provides
7//! a common interface for interacting with bit-width specific [`SphericalQuantizer`]s,
8//! compressing vectors, and computing distances between compressed vectors.
9//!
10//! This is offered as a convenience interface for interacting with the myriad of generics
11//! associated with the [`SphericalQuantizer`]. Better performance can be achieved by using
12//! the generic types directly if desired.
13//!
14//! The [`Quantizer`] uses the [`Opaque`] and [`OpaqueMut`] types for its compressed data
15//! representations. These are thin wrappers around raw byte slices.
16//!
17//! Distance computation is performed using [`DistanceComputer`] and [`QueryComputer`]
18//!
19//! Concrete implementations of [`Quantizer`] are available via the generic struct [`Impl`].
20//!
21//! ## Compatibility Table
22//!
23//! Multiple [`QueryLayout`]s are supported when constructing a [`QueryComputer`], but not
24//! all layouts are supported for each back end. This table lists the valid instantiations
25//! of [`Impl`] (parameterized by data vector bit-width) and their supported query layouts.
26//!
27//! | Bits | Same As Data | Full Precision | Four-Bit Transposed | Scalar Quantized |
28//! |------|--------------|----------------|---------------------|------------------|
29//! |    1 |     Yes      |      Yes       |         Yes         |       No         |
30//! |    2 |     Yes      |      Yes       |          No         |       Yes        |
31//! |    4 |     Yes      |      Yes       |          No         |       Yes        |
32//! |    8 |     Yes      |      Yes       |          No         |       Yes        |
33//!
34//! # Example
35//!
36//! ```
37//! use diskann_quantization::{
38//!     alloc::{Poly, ScopedAllocator, AlignedAllocator, GlobalAllocator},
39//!     algorithms::TransformKind,
40//!     spherical::{iface, SupportedMetric, SphericalQuantizer, PreScale},
41//!     num::PowerOfTwo,
42//! };
43//! use diskann_utils::views::Matrix;
44//!
45//! // For illustration purposes, the dataset consists of just a single vector.
46//! let mut data = Matrix::new(1.0, 1, 4);
47//! let quantizer = SphericalQuantizer::train(
48//!     data.as_view(),
49//!     TransformKind::Null,
50//!     SupportedMetric::SquaredL2,
51//!     PreScale::None,
52//!     &mut rand::rng(),
53//!     GlobalAllocator
54//! ).unwrap();
55//!
56//! let quantizer: Box<dyn iface::Quantizer> = Box::new(
57//!     iface::Impl::<1>::new(quantizer).unwrap()
58//! );
59//!
60//! let alloc = AlignedAllocator::new(PowerOfTwo::new(1).unwrap());
61//! let mut buf = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
62//!
63//! quantizer.compress(
64//!     data.row(0),
65//!     iface::OpaqueMut::new(&mut buf),
66//!     ScopedAllocator::new(&alloc),
67//! ).unwrap();
68//!
69//! assert!(quantizer.is_supported(iface::QueryLayout::FullPrecision));
70//! assert!(!quantizer.is_supported(iface::QueryLayout::ScalarQuantized));
71//! ```
72
73/// # DevDocs
74///
75/// This section provides developer documentation for the structure of the structs and
76/// traits inside this file. The main goal of the [`DistanceComputer`] and [`QueryComputer`]
77/// implementations are to introduce just a single level of indirection.
78///
79/// ## Distance Computer Philosophy
80///
81/// The goal of this code (and in large part the reason for the somewhat spaghetti nature)
82/// is to do all the the dispatches:
83///
84/// * Number of bits.
85/// * Query layout.
86/// * Distance Type (L2, Inner Product, Cosine).
87/// * Micro-architecture specific code-generation
88///
89/// Behind a **single** level of dynamic dispatch. This means we need to bake all of this
90/// information into a single type, facilitated through a combination of the `Reify` and
91/// `Curried` private structs.
92///
93/// ## Anatomy of the [`DistanceComputer`]
94///
95/// To provide a single level of indirection for all distance function implementations,
96/// the [`DistanceComputer`] is a thin wrapper around the [`DynDistanceComputer`] trait.
97///
98/// Concrete implementations of this trait consist of a base distance function like
99/// [`CompensatedIP`] or [`CompensatedSquaredL2`]. Because the data is passed through the
100/// [`Opaque`] type, these distance functions are embedded inside a [`Reify`] which first
101/// converts the [`Opaque`] to the appropriate fully-typed object before calling the inner
102/// distance function.
103///
104/// This full typing is supplied by the private [`FromOpaque`] helper trait.
105///
106/// When returned from the [`Quantizer::distance_computer`] or
107/// [`Quantizer::distance_computer_ref`] traits, the resulting computer will be specialized
108/// to work solely on data vectors compressed through [`Quantizer::compress`].
109///
110/// When returned from [`Quantizer::query_computer`], the expected type of the query will
111/// depend on the [`QueryLayout`] supplied to that method.
112///
113/// The method [`QueryComputer::layout`] is provided to inspect at run time the query layout
114/// the object is meant for.
115///
116/// If at all possible, the [`QueryComputer`] should be preferred as it removes the
117/// possibility of providing an incorrect query layout and will be slightly faster since
118/// it does not require argument reification.
119///
120/// ## Anatomy of the [`QueryComputer`]
121///
122/// This is similar to the [`DistanceComputer`] but has the extra duty of supporting
123/// multiple different [`iface::QueryLayouts`] (compressions for the query). To that end,
124/// the stack of types used to implement the underlying [`DynDistanceComputer`] trait is:
125///
126/// * Base [`DistanceFunction`] (e.g. [`CompensatedIP`]).
127///
128/// * Embedded inside [`Curried`] - which also contains a heap-allocated representation of
129///   the query using the selected layout. For example, this could be one of.
130///
131///   - [`diskann_quantization::spherical::Query`]
132///   - [`diskann_quantization::spherical::FullQuery`]
133///   - [`diskann_quantization::sphericasl::Data`]
134///
135/// * Embedded inside [`Reify`] to convert [`Opaque`] to the correct type.
136use std::marker::PhantomData;
137
138use diskann_utils::{Reborrow, ReborrowMut};
139use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction};
140use diskann_wide::{
141    arch::{Scalar, Target1, Target2},
142    Architecture,
143};
144#[cfg(feature = "flatbuffers")]
145use flatbuffers::FlatBufferBuilder;
146use thiserror::Error;
147
148#[cfg(target_arch = "x86_64")]
149use diskann_wide::arch::x86_64::{V3, V4};
150
151use super::{
152    quantizer, CompensatedCosine, CompensatedIP, CompensatedSquaredL2, Data, DataMut, DataRef,
153    FullQuery, FullQueryMut, FullQueryRef, Query, QueryMut, QueryRef, SphericalQuantizer,
154    SupportedMetric,
155};
156#[cfg(feature = "flatbuffers")]
157use crate::{alloc::CompoundError, flatbuffers as fb};
158use crate::{
159    alloc::{
160        Allocator, AllocatorCore, AllocatorError, GlobalAllocator, Poly, ScopedAllocator, TryClone,
161    },
162    bits::{self, Representation, Unsigned},
163    distances::{self, UnequalLengths},
164    error::InlineError,
165    meta,
166    num::PowerOfTwo,
167    poly, AsFunctor, CompressIntoWith,
168};
169
170// A convenience definition to shorten the extensive where-clauses present in this file.
171type Rf32 = distances::Result<f32>;
172
173///////////////
174// Quantizer //
175///////////////
176
177/// A description of the buffer size (in bytes) and alignment required for a compressed query.
178#[derive(Debug, Clone)]
179pub struct QueryBufferDescription {
180    size: usize,
181    align: PowerOfTwo,
182}
183
184impl QueryBufferDescription {
185    /// Construct a new [`QueryBufferDescription`]
186    pub fn new(size: usize, align: PowerOfTwo) -> Self {
187        Self { size, align }
188    }
189
190    /// Return the number of bytes needed in a buffer for a compressed query.
191    pub fn bytes(&self) -> usize {
192        self.size
193    }
194
195    /// Return the necessary alignment of the base pointer for a query buffer.
196    pub fn align(&self) -> PowerOfTwo {
197        self.align
198    }
199}
200
201/// A dyn-compatible trait providing a common interface for a bit-width specific
202/// [`SphericalQuantizer`].
203///
204/// This allows us to have a single [`dyn Quantizer`] type without generics while still
205/// supporting the range of bit-widths and query strategies we wish to support.
206///
207/// A level of indirection for each distance computation, unfortunately, is required to
208/// support this. But we try to structure the code so there is only a single level of
209/// indirection.
210///
211/// # Allocator
212///
213/// The quantizer is parameterized by the allocator provided used to acquire any necessary
214/// memory for returned data structures.  The contract is as follows:
215///
216/// 1. Any allocation made as part of a returned data structure from a function will be
217///    performed through the allocator given to that function.
218///
219/// 2. If dynamic memory allocation for scratch space is required, a separate `scratch`
220///    allocator will be required and all scratch space allocations will go through that
221///    allocator.
222pub trait Quantizer<A = GlobalAllocator>: Send + Sync
223where
224    A: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
225{
226    /// The effective number of bits in the encoding.
227    fn nbits(&self) -> usize;
228
229    /// The number of bytes occupied by each compressed vector.
230    fn bytes(&self) -> usize;
231
232    /// The effective dimensionality of each compressed vector.
233    fn dim(&self) -> usize;
234
235    /// The dimensionality of the full-precision input vectors.
236    fn full_dim(&self) -> usize;
237
238    /// Return a distance computer capable on operating on validly initialized [`Opaque`]
239    /// slices of length [`Self::bytes`].
240    ///
241    /// These slices should be initialized by [`Self::compress`].
242    ///
243    /// The query layout associated with this computer will always be
244    /// [`QueryLayout::SameAsData`].
245    fn distance_computer(&self, allocator: A) -> Result<DistanceComputer<A>, AllocatorError>;
246
247    /// Return a scoped distance computer capable on operating on validly initialized
248    /// [`Opaque`] slices of length [`Self::bytes`].
249    ///
250    /// These slices should be initialized by [`Self::compress`].
251    fn distance_computer_ref(&self) -> &dyn DynDistanceComputer;
252
253    /// A stand alone distance computer specialized for the specified query layout.
254    ///
255    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
256    ///
257    /// # Note
258    ///
259    /// The returned object will **only** be compatible with queries compressed using
260    /// [`Self::compress_query`] using the same layout. If possible, the API
261    /// [`Self::fused_query_computer`] should be used to avoid this ambiguity.
262    fn query_computer(
263        &self,
264        layout: QueryLayout,
265        allocator: A,
266    ) -> Result<DistanceComputer<A>, DistanceComputerError>;
267
268    /// Return the number of bytes and alignment of a buffer used to contain a compressed
269    /// query with the provided layout.
270    ///
271    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
272    fn query_buffer_description(
273        &self,
274        layout: QueryLayout,
275    ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>;
276
277    /// Compress the query using the specified layout into `buffer`.
278    ///
279    /// This requires that buffer have the exact size and alignment as that returned from
280    /// `query_buffer_description`.
281    ///
282    /// Only layouts for which [`Self::is_supported`] returns `true` are supported.
283    fn compress_query(
284        &self,
285        x: &[f32],
286        layout: QueryLayout,
287        allow_rescale: bool,
288        buffer: OpaqueMut<'_>,
289        scratch: ScopedAllocator<'_>,
290    ) -> Result<(), QueryCompressionError>;
291
292    /// Return a query for the argument `x` capable on operating on validly initialized
293    /// [`Opaque`] slices of length [`Self::bytes`].
294    ///
295    /// These slices should be initialized by [`Self::compress`].
296    ///
297    /// Note: Only layouts for which [`Self::is_supported`] returns `true` are supported.
298    fn fused_query_computer(
299        &self,
300        x: &[f32],
301        layout: QueryLayout,
302        allow_rescale: bool,
303        allocator: A,
304        scratch: ScopedAllocator<'_>,
305    ) -> Result<QueryComputer<A>, QueryComputerError>;
306
307    /// Return whether or not this plan supports the given [`QueryLayout`].
308    fn is_supported(&self, layout: QueryLayout) -> bool;
309
310    /// Compress the vector `x` into the opaque slice.
311    ///
312    /// # Note
313    ///
314    /// This requires the length of the slice to be exactly [`Self::bytes`]. There is no
315    /// alignment restriction on the base pointer.
316    fn compress(
317        &self,
318        x: &[f32],
319        into: OpaqueMut<'_>,
320        scratch: ScopedAllocator<'_>,
321    ) -> Result<(), CompressionError>;
322
323    /// Return the metric this plan was created with.
324    fn metric(&self) -> SupportedMetric;
325
326    /// Clone the backing object.
327    fn try_clone_into(&self, allocator: A) -> Result<Poly<dyn Quantizer<A>, A>, AllocatorError>;
328
329    crate::utils::features! {
330        #![feature = "flatbuffers"]
331        /// Serialize `self` into a flatbuffer, returning the flatbuffer. The function
332        /// [`try_deserialize`] should undo this operation.
333        fn serialize(&self, allocator: A) -> Result<Poly<[u8], A>, AllocatorError>;
334    }
335}
336
337#[derive(Debug, Error)]
338#[error("Layout {layout} is not supported for {desc}")]
339pub struct UnsupportedQueryLayout {
340    layout: QueryLayout,
341    desc: &'static str,
342}
343
344impl UnsupportedQueryLayout {
345    fn new(layout: QueryLayout, desc: &'static str) -> Self {
346        Self { layout, desc }
347    }
348}
349
350#[derive(Debug, Error)]
351#[non_exhaustive]
352pub enum DistanceComputerError {
353    #[error(transparent)]
354    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
355    #[error(transparent)]
356    AllocatorError(#[from] AllocatorError),
357}
358
359#[derive(Debug, Error)]
360#[non_exhaustive]
361pub enum QueryCompressionError {
362    #[error(transparent)]
363    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
364    #[error(transparent)]
365    CompressionError(#[from] CompressionError),
366    #[error(transparent)]
367    NotCanonical(#[from] NotCanonical),
368    #[error(transparent)]
369    AllocatorError(#[from] AllocatorError),
370}
371
372#[derive(Debug, Error)]
373#[non_exhaustive]
374pub enum QueryComputerError {
375    #[error(transparent)]
376    UnsupportedQueryLayout(#[from] UnsupportedQueryLayout),
377    #[error(transparent)]
378    CompressionError(#[from] CompressionError),
379    #[error(transparent)]
380    AllocatorError(#[from] AllocatorError),
381}
382
383/// Errors that can occur during data compression
384#[derive(Debug, Error)]
385#[error("Error occured during query compression")]
386pub enum CompressionError {
387    /// The input buffer did not have the expected layout. This is an input error.
388    NotCanonical(#[source] InlineError<16>),
389
390    /// Forward any error that occurs during the compression process.
391    ///
392    /// See [`quantizer::CompressionError`] for the complete list.
393    CompressionError(#[source] quantizer::CompressionError),
394}
395
396impl CompressionError {
397    fn not_canonical<E>(error: E) -> Self
398    where
399        E: std::error::Error + Send + Sync + 'static,
400    {
401        Self::NotCanonical(InlineError::new(error))
402    }
403}
404
405#[derive(Debug, Error)]
406#[error("An opaque argument did not have the required alignment or length")]
407pub struct NotCanonical {
408    source: Box<dyn std::error::Error + Send + Sync>,
409}
410
411impl NotCanonical {
412    fn new<E>(err: E) -> Self
413    where
414        E: std::error::Error + Send + Sync + 'static,
415    {
416        Self {
417            source: Box::new(err),
418        }
419    }
420}
421
422////////////
423// Opaque //
424////////////
425
426/// A type-erased slice wrapper used to hide the implementation of spherically quantized
427/// vectors. This allows multiple bit-width implementations to share the same type.
428#[derive(Debug, Clone, Copy)]
429#[repr(transparent)]
430pub struct Opaque<'a>(&'a [u8]);
431
432impl<'a> Opaque<'a> {
433    /// Construct a new `Opaque` referencing `slice`.
434    pub fn new(slice: &'a [u8]) -> Self {
435        Self(slice)
436    }
437
438    /// Consume `self`, returning the wrapped slice.
439    pub fn into_inner(self) -> &'a [u8] {
440        self.0
441    }
442}
443
444impl std::ops::Deref for Opaque<'_> {
445    type Target = [u8];
446    fn deref(&self) -> &[u8] {
447        self.0
448    }
449}
450impl<'short> Reborrow<'short> for Opaque<'_> {
451    type Target = Opaque<'short>;
452    fn reborrow(&'short self) -> Self::Target {
453        *self
454    }
455}
456
457/// A type-erased slice wrapper used to hide the implementation of spherically quantized
458/// vectors. This allows multiple bit-width implementations to share the same type.
459#[derive(Debug)]
460#[repr(transparent)]
461pub struct OpaqueMut<'a>(&'a mut [u8]);
462
463impl<'a> OpaqueMut<'a> {
464    /// Construct a new `OpaqueMut` referencing `slice`.
465    pub fn new(slice: &'a mut [u8]) -> Self {
466        Self(slice)
467    }
468
469    /// Inspect the referenced slice.
470    pub fn inspect(&mut self) -> &mut [u8] {
471        self.0
472    }
473}
474
475impl std::ops::Deref for OpaqueMut<'_> {
476    type Target = [u8];
477    fn deref(&self) -> &[u8] {
478        self.0
479    }
480}
481
482impl std::ops::DerefMut for OpaqueMut<'_> {
483    fn deref_mut(&mut self) -> &mut [u8] {
484        self.0
485    }
486}
487
488//////////////////
489// Query Layout //
490//////////////////
491
492/// The layout to use for the query in [`DistanceComputer`] and [`QueryComputer`].
493#[derive(Debug, Clone, Copy, PartialEq, Eq)]
494pub enum QueryLayout {
495    /// Use the same compression strategy as the data vectors.
496    ///
497    /// This may result in slow compression if high bit-widths are used.
498    SameAsData,
499
500    /// Use 4-bits for the query vector using a bitwise transpose layout.
501    FourBitTransposed,
502
503    /// Use scalar quantization for the query using the same number of bits per dimension
504    /// as the dataset.
505    ScalarQuantized,
506
507    /// Use `f32` to encode the query.
508    FullPrecision,
509}
510
511impl QueryLayout {
512    #[cfg(test)]
513    fn all() -> [Self; 4] {
514        [
515            Self::SameAsData,
516            Self::FourBitTransposed,
517            Self::ScalarQuantized,
518            Self::FullPrecision,
519        ]
520    }
521}
522
523impl std::fmt::Display for QueryLayout {
524    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525        <Self as std::fmt::Debug>::fmt(self, fmt)
526    }
527}
528
529//////////////////////
530// Layout Reporting //
531//////////////////////
532
533/// Because dynamic dispatch is used heavily in the implementation, it can be easy to lose
534/// track of the actual layout used for the [`DistanceComputer`] and [`QueryComputer`].
535///
536/// This trait provides a mechanism by which we ensure the correct runtime layout is always
537/// reported without requiring manual tracking.
538trait ReportQueryLayout {
539    fn report_query_layout(&self) -> QueryLayout;
540}
541
542impl<T, M, L, R> ReportQueryLayout for Reify<T, M, L, R>
543where
544    T: ReportQueryLayout,
545{
546    fn report_query_layout(&self) -> QueryLayout {
547        self.inner.report_query_layout()
548    }
549}
550
551impl<D, Q> ReportQueryLayout for Curried<D, Q>
552where
553    Q: ReportQueryLayout,
554{
555    fn report_query_layout(&self) -> QueryLayout {
556        self.query.report_query_layout()
557    }
558}
559
560impl<const NBITS: usize, A> ReportQueryLayout for Data<NBITS, A>
561where
562    Unsigned: Representation<NBITS>,
563    A: AllocatorCore,
564{
565    fn report_query_layout(&self) -> QueryLayout {
566        QueryLayout::SameAsData
567    }
568}
569
570impl<const NBITS: usize, A> ReportQueryLayout for Query<NBITS, bits::Dense, A>
571where
572    Unsigned: Representation<NBITS>,
573    A: AllocatorCore,
574{
575    fn report_query_layout(&self) -> QueryLayout {
576        QueryLayout::ScalarQuantized
577    }
578}
579
580impl<A> ReportQueryLayout for Query<4, bits::BitTranspose, A>
581where
582    A: AllocatorCore,
583{
584    fn report_query_layout(&self) -> QueryLayout {
585        QueryLayout::FourBitTransposed
586    }
587}
588
589impl<A> ReportQueryLayout for FullQuery<A>
590where
591    A: AllocatorCore,
592{
593    fn report_query_layout(&self) -> QueryLayout {
594        QueryLayout::FullPrecision
595    }
596}
597
598//-----------------------//
599// Reification Utilities //
600//-----------------------//
601
602/// An adaptor trait defining how to go from an `Opaque` slice to a fully reified type.
603///
604/// THis is the building block for building distance computers with the reificiation code
605/// inlined into the callsite.
606trait FromOpaque: 'static + Send + Sync {
607    type Target<'a>;
608    type Error: std::error::Error + Send + Sync + 'static;
609
610    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error>;
611}
612
613/// Reify as full-precision.
614#[derive(Debug, Default)]
615pub(super) struct AsFull;
616
617/// Reify as data.
618#[derive(Debug, Default)]
619pub(super) struct AsData<const NBITS: usize>;
620
621/// Reify as scalar quantized query.
622#[derive(Debug)]
623pub(super) struct AsQuery<const NBITS: usize, Perm = bits::Dense> {
624    _marker: PhantomData<Perm>,
625}
626
627// This impelmentation works around the `derive` impl requiring `Perm: Default`.
628impl<const NBITS: usize, Perm> Default for AsQuery<NBITS, Perm> {
629    fn default() -> Self {
630        Self {
631            _marker: PhantomData,
632        }
633    }
634}
635
636impl FromOpaque for AsFull {
637    type Target<'a> = FullQueryRef<'a>;
638    type Error = meta::slice::NotCanonical;
639
640    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
641        Self::Target::from_canonical(query.into_inner(), dim)
642    }
643}
644
645impl ReportQueryLayout for AsFull {
646    fn report_query_layout(&self) -> QueryLayout {
647        QueryLayout::FullPrecision
648    }
649}
650
651impl<const NBITS: usize> FromOpaque for AsData<NBITS>
652where
653    Unsigned: Representation<NBITS>,
654{
655    type Target<'a> = DataRef<'a, NBITS>;
656    type Error = meta::NotCanonical;
657
658    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
659        Self::Target::from_canonical_back(query.into_inner(), dim)
660    }
661}
662
663impl<const NBITS: usize> ReportQueryLayout for AsData<NBITS> {
664    fn report_query_layout(&self) -> QueryLayout {
665        QueryLayout::SameAsData
666    }
667}
668
669impl<const NBITS: usize, Perm> FromOpaque for AsQuery<NBITS, Perm>
670where
671    Unsigned: Representation<NBITS>,
672    Perm: bits::PermutationStrategy<NBITS> + Send + Sync + 'static,
673{
674    type Target<'a> = QueryRef<'a, NBITS, Perm>;
675    type Error = meta::NotCanonical;
676
677    fn from_opaque<'a>(query: Opaque<'a>, dim: usize) -> Result<Self::Target<'a>, Self::Error> {
678        Self::Target::from_canonical_back(query.into_inner(), dim)
679    }
680}
681
682impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::Dense> {
683    fn report_query_layout(&self) -> QueryLayout {
684        QueryLayout::ScalarQuantized
685    }
686}
687
688impl<const NBITS: usize> ReportQueryLayout for AsQuery<NBITS, bits::BitTranspose> {
689    fn report_query_layout(&self) -> QueryLayout {
690        QueryLayout::FourBitTransposed
691    }
692}
693
694//-------//
695// Reify //
696//-------//
697
698/// Helper struct to convert an [`Opaque`] to a fully-typed [`DataRef`].
699pub(super) struct Reify<T, M, L, R> {
700    inner: T,
701    dim: usize,
702    arch: M,
703    _markers: PhantomData<(L, R)>,
704}
705
706impl<T, M, L, R> Reify<T, M, L, R> {
707    pub(super) fn new(inner: T, dim: usize, arch: M) -> Self {
708        Self {
709            inner,
710            dim,
711            arch,
712            _markers: PhantomData,
713        }
714    }
715}
716
717impl<M, T, R> DynQueryComputer for Reify<T, M, (), R>
718where
719    M: Architecture,
720    R: FromOpaque,
721    T: ReportQueryLayout + Send + Sync,
722    for<'a> &'a T: Target1<M, Rf32, R::Target<'a>>,
723{
724    fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
725        self.arch.run2(
726            |this: &Self, x| {
727                let x = R::from_opaque(x, this.dim)
728                    .map_err(|err| QueryDistanceError::XReify(InlineError::new(err)))?;
729                this.arch
730                    .run1(&this.inner, x)
731                    .map_err(QueryDistanceError::UnequalLengths)
732            },
733            self,
734            x,
735        )
736    }
737
738    fn layout(&self) -> QueryLayout {
739        self.inner.report_query_layout()
740    }
741}
742
743impl<T, M, Q, R> DynDistanceComputer for Reify<T, M, Q, R>
744where
745    M: Architecture,
746    Q: FromOpaque + Default + ReportQueryLayout,
747    R: FromOpaque,
748    T: for<'a> Target2<M, Rf32, Q::Target<'a>, R::Target<'a>> + Copy + Send + Sync,
749{
750    fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
751        self.arch.run3(
752            |this: &Self, query, x| {
753                let query = Q::from_opaque(query, this.dim)
754                    .map_err(|err| DistanceError::QueryReify(InlineError::<24>::new(err)))?;
755
756                let x = R::from_opaque(x, this.dim)
757                    .map_err(|err| DistanceError::XReify(InlineError::<16>::new(err)))?;
758
759                this.arch
760                    .run2_inline(this.inner, query, x)
761                    .map_err(DistanceError::UnequalLengths)
762            },
763            self,
764            query,
765            x,
766        )
767    }
768
769    fn layout(&self) -> QueryLayout {
770        Q::default().report_query_layout()
771    }
772}
773
774///////////////////////
775// Query Computation //
776///////////////////////
777
778/// Errors that can occur while perfoming distance cacluations on opaque vectors.
779#[derive(Debug, Error)]
780pub enum QueryDistanceError {
781    /// The right-hand data argument appears to be malformed.
782    #[error("trouble trying to reify the argument")]
783    XReify(#[source] InlineError<16>),
784
785    /// Distance computation failed because the logical lengths of the two vectors differ.
786    #[error("encountered while trying to compute distances")]
787    UnequalLengths(#[source] UnequalLengths),
788}
789
790pub trait DynQueryComputer: Send + Sync {
791    fn evaluate(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError>;
792    fn layout(&self) -> QueryLayout;
793}
794
795/// An opaque [`PreprocessedDistanceFunction`] for the [`Quantizer`] trait object.
796///
797/// # Note
798///
799/// This is only valid to call on [`Opaque`] slices compressed by the same [`Quantizer`] that
800/// created the computer.
801///
802/// Otherwise, distance computations may return garbage values or panic.
803pub struct QueryComputer<A = GlobalAllocator>
804where
805    A: AllocatorCore,
806{
807    inner: Poly<dyn DynQueryComputer, A>,
808}
809
810impl<A> QueryComputer<A>
811where
812    A: AllocatorCore,
813{
814    fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
815    where
816        T: DynQueryComputer + 'static,
817    {
818        let inner = Poly::new(inner, allocator)?;
819        Ok(Self {
820            inner: poly!(DynQueryComputer, inner),
821        })
822    }
823
824    /// Report the layout used by the query computer.
825    pub fn layout(&self) -> QueryLayout {
826        self.inner.layout()
827    }
828
829    /// This is a temporary function until custom allocator support fully comes on line.
830    pub fn into_inner(self) -> Poly<dyn DynQueryComputer, A> {
831        self.inner
832    }
833}
834
835impl<A> std::fmt::Debug for QueryComputer<A>
836where
837    A: AllocatorCore,
838{
839    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
840        write!(
841            f,
842            "dynamic fused query computer with layout \"{}\"",
843            self.layout()
844        )
845    }
846}
847
848impl<A> PreprocessedDistanceFunction<Opaque<'_>, Result<f32, QueryDistanceError>>
849    for QueryComputer<A>
850where
851    A: AllocatorCore,
852{
853    fn evaluate_similarity(&self, x: Opaque<'_>) -> Result<f32, QueryDistanceError> {
854        self.inner.evaluate(x)
855    }
856}
857
858/// To handle multiple query bit-widths, we use type erasure on the actual distance
859/// function implementation.
860///
861/// This struct represents the partial application of the `inner` distance function with
862/// `query` in a generic way so we only have one level of dynamic dispatch when computing
863/// distances.
864pub(super) struct Curried<D, Q> {
865    inner: D,
866    query: Q,
867}
868
869impl<D, Q> Curried<D, Q> {
870    pub(super) fn new(inner: D, query: Q) -> Self {
871        Self { inner, query }
872    }
873}
874
875impl<A, D, Q, T, R> Target1<A, R, T> for &Curried<D, Q>
876where
877    A: Architecture,
878    Q: for<'a> Reborrow<'a>,
879    D: for<'a> Target2<A, R, <Q as Reborrow<'a>>::Target, T> + Copy,
880{
881    fn run(self, arch: A, x: T) -> R {
882        self.inner.run(arch, self.query.reborrow(), x)
883    }
884}
885
886///////////////////////
887// Distance Computer //
888///////////////////////
889
890/// Errors that can occur while perfoming distance cacluations on opaque vectors.
891#[derive(Debug, Error)]
892pub enum DistanceError {
893    /// The left-hand data argument appears to be malformed.
894    #[error("trouble trying to reify the left-hand argument")]
895    QueryReify(InlineError<24>),
896
897    /// The right-hand data argument appears to be malformed.
898    #[error("trouble trying to reify the right-hand argument")]
899    XReify(InlineError<16>),
900
901    /// Distance computation failed because the logical lengths of the two vectors differ.
902    ///
903    /// If vector reificiation occurs successfully, then this should not be returned.
904    #[error("encountered while trying to compute distances")]
905    UnequalLengths(UnequalLengths),
906}
907
908pub trait DynDistanceComputer: Send + Sync {
909    fn evaluate(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError>;
910    fn layout(&self) -> QueryLayout;
911}
912
913/// An opaque [`DistanceFunction`] for the [`Quantizer`] trait object.
914///
915/// # Note
916///
917/// Left-hand arguments must be [`Opaque`] slices compressed using
918/// [`Quantizer::compress_query`] using [`Self::layout`].
919///
920/// Right-hand arguments must be [`Opaque`] slices compressed using [`Quantizer::compress`].
921///
922/// Otherwise, distance computations may return garbage values or panic.
923pub struct DistanceComputer<A = GlobalAllocator>
924where
925    A: AllocatorCore,
926{
927    inner: Poly<dyn DynDistanceComputer, A>,
928}
929
930impl<A> DistanceComputer<A>
931where
932    A: AllocatorCore,
933{
934    pub(super) fn new<T>(inner: T, allocator: A) -> Result<Self, AllocatorError>
935    where
936        T: DynDistanceComputer + 'static,
937    {
938        let inner = Poly::new(inner, allocator)?;
939        Ok(Self {
940            inner: poly!(DynDistanceComputer, inner),
941        })
942    }
943
944    /// Report the layout used by the query computer.
945    pub fn layout(&self) -> QueryLayout {
946        self.inner.layout()
947    }
948
949    pub fn into_inner(self) -> Poly<dyn DynDistanceComputer, A> {
950        self.inner
951    }
952}
953
954impl<A> std::fmt::Debug for DistanceComputer<A>
955where
956    A: AllocatorCore,
957{
958    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
959        write!(
960            f,
961            "dynamic distance computer with layout \"{}\"",
962            self.layout()
963        )
964    }
965}
966
967impl<A> DistanceFunction<Opaque<'_>, Opaque<'_>, Result<f32, DistanceError>> for DistanceComputer<A>
968where
969    A: AllocatorCore,
970{
971    fn evaluate_similarity(&self, query: Opaque<'_>, x: Opaque<'_>) -> Result<f32, DistanceError> {
972        self.inner.evaluate(query, x)
973    }
974}
975
976//////////
977// Impl //
978//////////
979
980/// The base number of bytes to allocate when attempting to serialize a quantizer.
981#[cfg(all(not(test), feature = "flatbuffers"))]
982const DEFAULT_SERIALIZED_BYTES: usize = 1024;
983
984// When testing, use a small value so we trigger the reallocation logic.
985#[cfg(all(test, feature = "flatbuffers"))]
986const DEFAULT_SERIALIZED_BYTES: usize = 1;
987
988/// Implementation for [`Quantizer`] specializing on the number of bits used for data
989/// compression.
990pub struct Impl<const NBITS: usize, A = GlobalAllocator>
991where
992    A: Allocator,
993{
994    quantizer: SphericalQuantizer<A>,
995    distance: Poly<dyn DynDistanceComputer, A>,
996}
997
998/// Pre-dispatch distance functions between compressed data vectors from `quantizer`
999/// specialized for the current run-time mciro architecture.
1000pub trait Constructible<A = GlobalAllocator>
1001where
1002    A: Allocator,
1003{
1004    fn dispatch_distance(
1005        quantizer: &SphericalQuantizer<A>,
1006    ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError>;
1007}
1008
1009impl<const NBITS: usize, A: Allocator> Constructible<A> for Impl<NBITS, A>
1010where
1011    A: Allocator,
1012    AsData<NBITS>: FromOpaque,
1013    SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1014{
1015    fn dispatch_distance(
1016        quantizer: &SphericalQuantizer<A>,
1017    ) -> Result<Poly<dyn DynDistanceComputer, A>, AllocatorError> {
1018        diskann_wide::arch::dispatch2_no_features(
1019            ComputerDispatcher::<AsData<NBITS>, NBITS>::new(),
1020            quantizer,
1021            quantizer.allocator().clone(),
1022        )
1023        .map(|obj| obj.inner)
1024    }
1025}
1026
1027impl<const NBITS: usize, A> TryClone for Impl<NBITS, A>
1028where
1029    A: Allocator,
1030    AsData<NBITS>: FromOpaque,
1031    SphericalQuantizer<A>: Dispatchable<AsData<NBITS>, NBITS>,
1032{
1033    fn try_clone(&self) -> Result<Self, AllocatorError> {
1034        Self::new(self.quantizer.try_clone()?)
1035    }
1036}
1037
1038impl<const NBITS: usize, A: Allocator> Impl<NBITS, A> {
1039    /// Construct a new plan around `quantizer` providing distance computers for `metric`.
1040    pub fn new(quantizer: SphericalQuantizer<A>) -> Result<Self, AllocatorError>
1041    where
1042        Self: Constructible<A>,
1043    {
1044        let distance = Self::dispatch_distance(&quantizer)?;
1045        Ok(Self {
1046            quantizer,
1047            distance,
1048        })
1049    }
1050
1051    /// Return the underlying [`SphericalQuantizer`].
1052    pub fn quantizer(&self) -> &SphericalQuantizer<A> {
1053        &self.quantizer
1054    }
1055
1056    /// Return `true` if this plan supports `layout` for query computers.
1057    ///
1058    /// Otherwise, return `false`.
1059    pub fn supports(layout: QueryLayout) -> bool {
1060        if const { NBITS == 1 } {
1061            [
1062                QueryLayout::SameAsData,
1063                QueryLayout::FourBitTransposed,
1064                QueryLayout::FullPrecision,
1065            ]
1066            .contains(&layout)
1067        } else {
1068            [
1069                QueryLayout::SameAsData,
1070                QueryLayout::ScalarQuantized,
1071                QueryLayout::FullPrecision,
1072            ]
1073            .contains(&layout)
1074        }
1075    }
1076
1077    /// Return a [`DistanceComputer`] that is specialized for the most specific runtime
1078    /// architecture.
1079    fn query_computer<Q, B>(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError>
1080    where
1081        Q: FromOpaque,
1082        B: AllocatorCore,
1083        SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1084    {
1085        diskann_wide::arch::dispatch2_no_features(
1086            ComputerDispatcher::<Q, NBITS>::new(),
1087            &self.quantizer,
1088            allocator,
1089        )
1090    }
1091
1092    fn compress_query<'a, T>(
1093        &self,
1094        query: &'a [f32],
1095        storage: T,
1096        scratch: ScopedAllocator<'a>,
1097    ) -> Result<(), QueryCompressionError>
1098    where
1099        SphericalQuantizer<A>: CompressIntoWith<
1100            &'a [f32],
1101            T,
1102            ScopedAllocator<'a>,
1103            Error = quantizer::CompressionError,
1104        >,
1105    {
1106        self.quantizer
1107            .compress_into_with(query, storage, scratch)
1108            .map_err(|err| CompressionError::CompressionError(err).into())
1109    }
1110
1111    /// Return a [`QueryComputer`] that is specialized for the most specific runtime
1112    /// architecture.
1113    fn fused_query_computer<Q, T, B>(
1114        &self,
1115        query: &[f32],
1116        mut storage: T,
1117        allocator: B,
1118        scratch: ScopedAllocator<'_>,
1119    ) -> Result<QueryComputer<B>, QueryComputerError>
1120    where
1121        Q: FromOpaque,
1122        T: for<'a> ReborrowMut<'a>
1123            + for<'a> Reborrow<'a, Target = Q::Target<'a>>
1124            + ReportQueryLayout
1125            + Send
1126            + Sync
1127            + 'static,
1128        B: AllocatorCore,
1129        SphericalQuantizer<A>: for<'a> CompressIntoWith<
1130            &'a [f32],
1131            <T as ReborrowMut<'a>>::Target,
1132            ScopedAllocator<'a>,
1133            Error = quantizer::CompressionError,
1134        >,
1135        SphericalQuantizer<A>: Dispatchable<Q, NBITS>,
1136    {
1137        if let Err(err) = self
1138            .quantizer
1139            .compress_into_with(query, storage.reborrow_mut(), scratch)
1140        {
1141            return Err(CompressionError::CompressionError(err).into());
1142        }
1143
1144        diskann_wide::arch::dispatch3_no_features(
1145            ComputerDispatcher::<Q, NBITS>::new(),
1146            &self.quantizer,
1147            storage,
1148            allocator,
1149        )
1150        .map_err(|e| e.into())
1151    }
1152
1153    #[cfg(feature = "flatbuffers")]
1154    fn serialize<B>(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError>
1155    where
1156        B: Allocator + std::panic::UnwindSafe,
1157        A: std::panic::RefUnwindSafe,
1158    {
1159        let mut buf = FlatBufferBuilder::new_in(Poly::broadcast(
1160            0u8,
1161            DEFAULT_SERIALIZED_BYTES,
1162            allocator.clone(),
1163        )?);
1164
1165        let quantizer = &self.quantizer;
1166
1167        let (root, mut buf) = match std::panic::catch_unwind(move || {
1168            let offset = quantizer.pack(&mut buf);
1169
1170            let root = fb::spherical::Quantizer::create(
1171                &mut buf,
1172                &fb::spherical::QuantizerArgs {
1173                    quantizer: Some(offset),
1174                    nbits: NBITS as u32,
1175                },
1176            );
1177            (root, buf)
1178        }) {
1179            Ok(ret) => ret,
1180            Err(err) => match err.downcast_ref::<String>() {
1181                Some(msg) => {
1182                    if msg.contains("AllocatorError") {
1183                        return Err(AllocatorError);
1184                    } else {
1185                        std::panic::resume_unwind(err);
1186                    }
1187                }
1188                None => std::panic::resume_unwind(err),
1189            },
1190        };
1191
1192        // Finish serializing and then copy out the finished data into a newly allocated buffer.
1193        fb::spherical::finish_quantizer_buffer(&mut buf, root);
1194        Poly::from_iter(buf.finished_data().iter().copied(), allocator)
1195    }
1196}
1197
1198//----------------------//
1199// Distance Dispatching //
1200//----------------------//
1201
1202/// This trait and [`ComputerDispatcher`] are the glue for pre-dispatching
1203/// micro-architecture compatibility of distance computers.
1204///
1205/// This trait takes
1206///
1207/// * `M`: The target micro-architecture.
1208/// * `Q`: The target query type
1209///
1210/// And generates a specialized `DistanceComputer` and `QueryComputer`.
1211///
1212/// The [`ComputerDispatcher`] struct implements the [`diskann_wide::arch::Target2`] and
1213/// [`diskann_wide::arch::Target3`] traits to do the architecture-dispatching.
1214trait BuildComputer<M, Q, const N: usize>
1215where
1216    M: Architecture,
1217    Q: FromOpaque,
1218{
1219    /// Build a [`DistanceComputer`] targeting the micro-architecture `M`.
1220    ///
1221    /// The resulting object should implement distance calculations using just a single
1222    /// level of indirection.
1223    fn build_computer<A>(
1224        &self,
1225        arch: M,
1226        allocator: A,
1227    ) -> Result<DistanceComputer<A>, AllocatorError>
1228    where
1229        A: AllocatorCore;
1230
1231    /// Build a [`DistanceComputer`] with `query` targeting the micro-architecture `M`.
1232    ///
1233    /// The resulting object should implement distance calculations using just a single
1234    /// level of indirection.
1235    fn build_fused_computer<R, A>(
1236        &self,
1237        arch: M,
1238        query: R,
1239        allocator: A,
1240    ) -> Result<QueryComputer<A>, AllocatorError>
1241    where
1242        R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1243        A: AllocatorCore;
1244}
1245
1246fn identity<T>(x: T) -> T {
1247    x
1248}
1249
1250macro_rules! dispatch_map {
1251    ($N:literal, $Q:ty, $arch:ty) => {
1252        dispatch_map!($N, $Q, $arch, identity);
1253    };
1254    ($N:literal, $Q:ty, $arch:ty, $op:ident) => {
1255        impl<A> BuildComputer<$arch, $Q, $N> for SphericalQuantizer<A>
1256        where
1257            A: Allocator,
1258        {
1259            fn build_computer<B>(
1260                &self,
1261                input_arch: $arch,
1262                allocator: B,
1263            ) -> Result<DistanceComputer<B>, AllocatorError>
1264            where
1265                B: AllocatorCore,
1266            {
1267                type D = AsData<$N>;
1268
1269                // Perform any architecture down-casting.
1270                let arch = ($op)(input_arch);
1271                let dim = self.output_dim();
1272                match self.metric() {
1273                    SupportedMetric::SquaredL2 => {
1274                        let reify = Reify::<CompensatedSquaredL2, _, $Q, D>::new(
1275                            self.as_functor(),
1276                            dim,
1277                            arch,
1278                        );
1279                        DistanceComputer::new(reify, allocator)
1280                    }
1281                    SupportedMetric::InnerProduct => {
1282                        let reify =
1283                            Reify::<CompensatedIP, _, $Q, D>::new(self.as_functor(), dim, arch);
1284                        DistanceComputer::new(reify, allocator)
1285                    }
1286                    SupportedMetric::Cosine => {
1287                        let reify =
1288                            Reify::<CompensatedCosine, _, $Q, D>::new(self.as_functor(), dim, arch);
1289                        DistanceComputer::new(reify, allocator)
1290                    }
1291                }
1292            }
1293
1294            fn build_fused_computer<R, B>(
1295                &self,
1296                input_arch: $arch,
1297                query: R,
1298                allocator: B,
1299            ) -> Result<QueryComputer<B>, AllocatorError>
1300            where
1301                R: ReportQueryLayout
1302                    + for<'a> Reborrow<'a, Target = <$Q as FromOpaque>::Target<'a>>
1303                    + Send
1304                    + Sync
1305                    + 'static,
1306                B: AllocatorCore,
1307            {
1308                type D = AsData<$N>;
1309                let arch = ($op)(input_arch);
1310                let dim = self.output_dim();
1311                match self.metric() {
1312                    SupportedMetric::SquaredL2 => {
1313                        let computer: CompensatedSquaredL2 = self.as_functor();
1314                        let curried = Curried::new(computer, query);
1315                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1316                        Ok(QueryComputer::new(reify, allocator)?)
1317                    }
1318                    SupportedMetric::InnerProduct => {
1319                        let computer: CompensatedIP = self.as_functor();
1320                        let curried = Curried::new(computer, query);
1321                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1322                        Ok(QueryComputer::new(reify, allocator)?)
1323                    }
1324                    SupportedMetric::Cosine => {
1325                        let computer: CompensatedCosine = self.as_functor();
1326                        let curried = Curried::new(computer, query);
1327                        let reify = Reify::<_, _, (), D>::new(curried, dim, arch);
1328                        Ok(QueryComputer::new(reify, allocator)?)
1329                    }
1330                }
1331            }
1332        }
1333    };
1334}
1335
1336dispatch_map!(1, AsFull, Scalar);
1337dispatch_map!(2, AsFull, Scalar);
1338dispatch_map!(4, AsFull, Scalar);
1339dispatch_map!(8, AsFull, Scalar);
1340
1341dispatch_map!(1, AsData<1>, Scalar);
1342dispatch_map!(2, AsData<2>, Scalar);
1343dispatch_map!(4, AsData<4>, Scalar);
1344dispatch_map!(8, AsData<8>, Scalar);
1345
1346// Special Cases
1347dispatch_map!(1, AsQuery<4, bits::BitTranspose>, Scalar);
1348dispatch_map!(2, AsQuery<2>, Scalar);
1349dispatch_map!(4, AsQuery<4>, Scalar);
1350dispatch_map!(8, AsQuery<8>, Scalar);
1351
1352cfg_if::cfg_if! {
1353    if #[cfg(target_arch = "x86_64")] {
1354        fn downcast_to_v3(arch: V4) -> V3 {
1355            arch.into()
1356        }
1357
1358        // V3
1359        dispatch_map!(1, AsFull, V3);
1360        dispatch_map!(2, AsFull, V3);
1361        dispatch_map!(4, AsFull, V3);
1362        dispatch_map!(8, AsFull, V3);
1363
1364        dispatch_map!(1, AsData<1>, V3);
1365        dispatch_map!(2, AsData<2>, V3);
1366        dispatch_map!(4, AsData<4>, V3);
1367        dispatch_map!(8, AsData<8>, V3);
1368
1369        dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V3);
1370        dispatch_map!(2, AsQuery<2>, V3);
1371        dispatch_map!(4, AsQuery<4>, V3);
1372        dispatch_map!(8, AsQuery<8>, V3);
1373
1374        // V4
1375        dispatch_map!(1, AsFull, V4, downcast_to_v3);
1376        dispatch_map!(2, AsFull, V4, downcast_to_v3);
1377        dispatch_map!(4, AsFull, V4, downcast_to_v3);
1378        dispatch_map!(8, AsFull, V4, downcast_to_v3);
1379
1380        dispatch_map!(1, AsData<1>, V4, downcast_to_v3);
1381        dispatch_map!(2, AsData<2>, V4); // specialized
1382        dispatch_map!(4, AsData<4>, V4, downcast_to_v3);
1383        dispatch_map!(8, AsData<8>, V4, downcast_to_v3);
1384
1385        dispatch_map!(1, AsQuery<4, bits::BitTranspose>, V4, downcast_to_v3);
1386        dispatch_map!(2, AsQuery<2>, V4); // specialized
1387        dispatch_map!(4, AsQuery<4>, V4, downcast_to_v3);
1388        dispatch_map!(8, AsQuery<8>, V4, downcast_to_v3);
1389    }
1390}
1391
1392/// This struct and the [`BuildComputer`] trait are the glue for pre-dispatching
1393/// micro-architecture compatibility of distance computers.
1394///
1395/// This trait takes
1396///
1397/// * `Q`: The target query type.
1398/// * `N`: The nubmer of data bits to target.
1399///
1400/// This struct implements [`diskann_wide::arch::Target2`] and
1401/// [`diskann_wide::arch::Target3`] traits to do the architecture-dispatching, relying on
1402/// `Impl<N, A> as BuildQueryComputer` for the implementation.
1403#[derive(Debug, Clone, Copy)]
1404struct ComputerDispatcher<Q, const N: usize> {
1405    _query_type: std::marker::PhantomData<Q>,
1406}
1407
1408impl<Q, const N: usize> ComputerDispatcher<Q, N> {
1409    fn new() -> Self {
1410        Self {
1411            _query_type: std::marker::PhantomData,
1412        }
1413    }
1414}
1415
1416impl<M, const N: usize, A, B, Q>
1417    diskann_wide::arch::Target2<
1418        M,
1419        Result<DistanceComputer<B>, AllocatorError>,
1420        &SphericalQuantizer<A>,
1421        B,
1422    > for ComputerDispatcher<Q, N>
1423where
1424    M: Architecture,
1425    A: Allocator,
1426    B: AllocatorCore,
1427    Q: FromOpaque,
1428    SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1429{
1430    fn run(
1431        self,
1432        arch: M,
1433        quantizer: &SphericalQuantizer<A>,
1434        allocator: B,
1435    ) -> Result<DistanceComputer<B>, AllocatorError> {
1436        quantizer.build_computer(arch, allocator)
1437    }
1438}
1439
1440impl<M, const N: usize, A, R, B, Q>
1441    diskann_wide::arch::Target3<
1442        M,
1443        Result<QueryComputer<B>, AllocatorError>,
1444        &SphericalQuantizer<A>,
1445        R,
1446        B,
1447    > for ComputerDispatcher<Q, N>
1448where
1449    M: Architecture,
1450    A: Allocator,
1451    B: AllocatorCore,
1452    Q: FromOpaque,
1453    R: ReportQueryLayout + for<'a> Reborrow<'a, Target = Q::Target<'a>> + Send + Sync + 'static,
1454    SphericalQuantizer<A>: BuildComputer<M, Q, N>,
1455{
1456    fn run(
1457        self,
1458        arch: M,
1459        quantizer: &SphericalQuantizer<A>,
1460        query: R,
1461        allocator: B,
1462    ) -> Result<QueryComputer<B>, AllocatorError> {
1463        quantizer.build_fused_computer(arch, query, allocator)
1464    }
1465}
1466
1467#[cfg(target_arch = "x86_64")]
1468trait Dispatchable<Q, const N: usize>:
1469    BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>
1470where
1471    Q: FromOpaque,
1472{
1473}
1474
1475#[cfg(target_arch = "x86_64")]
1476impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1477where
1478    Q: FromOpaque,
1479    T: BuildComputer<Scalar, Q, N> + BuildComputer<V3, Q, N> + BuildComputer<V4, Q, N>,
1480{
1481}
1482
1483#[cfg(not(target_arch = "x86_64"))]
1484trait Dispatchable<Q, const N: usize>: BuildComputer<Scalar, Q, N>
1485where
1486    Q: FromOpaque,
1487{
1488}
1489
1490#[cfg(not(target_arch = "x86_64"))]
1491impl<Q, const N: usize, T> Dispatchable<Q, N> for T
1492where
1493    Q: FromOpaque,
1494    T: BuildComputer<Scalar, Q, N>,
1495{
1496}
1497
1498//---------------------------//
1499// Quantizer Implementations //
1500//---------------------------//
1501
1502impl<A, B> Quantizer<B> for Impl<1, A>
1503where
1504    A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1505    B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1506{
1507    fn nbits(&self) -> usize {
1508        1
1509    }
1510
1511    fn dim(&self) -> usize {
1512        self.quantizer.output_dim()
1513    }
1514
1515    fn full_dim(&self) -> usize {
1516        self.quantizer.input_dim()
1517    }
1518
1519    fn bytes(&self) -> usize {
1520        DataRef::<1>::canonical_bytes(self.quantizer.output_dim())
1521    }
1522
1523    fn distance_computer(&self, allocator: B) -> Result<DistanceComputer<B>, AllocatorError> {
1524        self.query_computer::<AsData<1>, _>(allocator)
1525    }
1526
1527    fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1528        &*self.distance
1529    }
1530
1531    fn query_computer(
1532        &self,
1533        layout: QueryLayout,
1534        allocator: B,
1535    ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1536        match layout {
1537            QueryLayout::SameAsData => Ok(self.query_computer::<AsData<1>, _>(allocator)?),
1538            QueryLayout::FourBitTransposed => {
1539                Ok(self.query_computer::<AsQuery<4, bits::BitTranspose>, _>(allocator)?)
1540            }
1541            QueryLayout::ScalarQuantized => {
1542                Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1543            }
1544            QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1545        }
1546    }
1547
1548    fn query_buffer_description(
1549        &self,
1550        layout: QueryLayout,
1551    ) -> Result<QueryBufferDescription, UnsupportedQueryLayout> {
1552        let dim = <Self as Quantizer<B>>::dim(self);
1553        match layout {
1554            QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1555                DataRef::<1>::canonical_bytes(dim),
1556                PowerOfTwo::alignment_of::<u8>(),
1557            )),
1558            QueryLayout::FourBitTransposed => Ok(QueryBufferDescription::new(
1559                QueryRef::<4, bits::BitTranspose>::canonical_bytes(dim),
1560                PowerOfTwo::alignment_of::<u8>(),
1561            )),
1562            QueryLayout::ScalarQuantized => {
1563                Err(UnsupportedQueryLayout::new(layout, "1-bit compression"))
1564            }
1565            QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1566                FullQueryRef::canonical_bytes(dim),
1567                FullQueryRef::canonical_align(),
1568            )),
1569        }
1570    }
1571
1572    fn compress_query(
1573        &self,
1574        x: &[f32],
1575        layout: QueryLayout,
1576        allow_rescale: bool,
1577        mut buffer: OpaqueMut<'_>,
1578        scratch: ScopedAllocator<'_>,
1579    ) -> Result<(), QueryCompressionError> {
1580        let dim = <Self as Quantizer<B>>::dim(self);
1581        let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1582            match layout {
1583                QueryLayout::SameAsData => self.compress_query(
1584                    v,
1585                    DataMut::<1>::from_canonical_back_mut(&mut buffer, dim)
1586                        .map_err(NotCanonical::new)?,
1587                    scratch,
1588                ),
1589                QueryLayout::FourBitTransposed => self.compress_query(
1590                    v,
1591                    QueryMut::<4, bits::BitTranspose>::from_canonical_back_mut(&mut buffer, dim)
1592                        .map_err(NotCanonical::new)?,
1593                    scratch,
1594                ),
1595                QueryLayout::ScalarQuantized => {
1596                    Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1597                }
1598                QueryLayout::FullPrecision => self.compress_query(
1599                    v,
1600                    FullQueryMut::from_canonical_mut(&mut buffer, dim)
1601                        .map_err(NotCanonical::new)?,
1602                    scratch,
1603                ),
1604            }
1605        };
1606
1607        if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1608            let mut copy = x.to_owned();
1609            self.quantizer.rescale(&mut copy);
1610            finish(&copy)
1611        } else {
1612            finish(x)
1613        }
1614    }
1615
1616    fn fused_query_computer(
1617        &self,
1618        x: &[f32],
1619        layout: QueryLayout,
1620        allow_rescale: bool,
1621        allocator: B,
1622        scratch: ScopedAllocator<'_>,
1623    ) -> Result<QueryComputer<B>, QueryComputerError> {
1624        let dim = <Self as Quantizer<B>>::dim(self);
1625        let finish = |v: &[f32], allocator: B| -> Result<QueryComputer<B>, QueryComputerError> {
1626            match layout {
1627                    QueryLayout::SameAsData => self.fused_query_computer::<AsData<1>, Data<1, _>, _>(
1628                        v,
1629                        Data::new_in(dim, allocator.clone())?,
1630                        allocator,
1631                        scratch,
1632                    ),
1633                    QueryLayout::FourBitTransposed => self
1634                        .fused_query_computer::<AsQuery<4, bits::BitTranspose>, Query<4, bits::BitTranspose, _>, _>(
1635                            v,
1636                            Query::new_in(dim, allocator.clone())?,
1637                            allocator,
1638                            scratch,
1639                        ),
1640                    QueryLayout::ScalarQuantized => {
1641                        Err(UnsupportedQueryLayout::new(layout, "1-bit compression").into())
1642                    }
1643                    QueryLayout::FullPrecision => self.fused_query_computer::<AsFull, FullQuery<_>, _>(
1644                        v,
1645                        FullQuery::empty(dim, allocator.clone())?,
1646                        allocator,
1647                        scratch,
1648                    ),
1649                }
1650        };
1651
1652        if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1653            let mut copy = x.to_owned();
1654            self.quantizer.rescale(&mut copy);
1655            finish(&copy, allocator)
1656        } else {
1657            finish(x, allocator)
1658        }
1659    }
1660
1661    fn is_supported(&self, layout: QueryLayout) -> bool {
1662        Self::supports(layout)
1663    }
1664
1665    fn compress(
1666        &self,
1667        x: &[f32],
1668        mut into: OpaqueMut<'_>,
1669        scratch: ScopedAllocator<'_>,
1670    ) -> Result<(), CompressionError> {
1671        let dim = <Self as Quantizer<B>>::dim(self);
1672        let into = DataMut::<1>::from_canonical_back_mut(into.inspect(), dim)
1673            .map_err(CompressionError::not_canonical)?;
1674        self.quantizer
1675            .compress_into_with(x, into, scratch)
1676            .map_err(CompressionError::CompressionError)
1677    }
1678
1679    fn metric(&self) -> SupportedMetric {
1680        self.quantizer.metric()
1681    }
1682
1683    fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1684        let clone = (*self).try_clone()?;
1685        poly!({ Quantizer<B> }, clone, allocator)
1686    }
1687
1688    #[cfg(feature = "flatbuffers")]
1689    fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1690        Impl::<1, A>::serialize(self, allocator)
1691    }
1692}
1693
1694macro_rules! plan {
1695    ($N:literal) => {
1696        impl<A, B> Quantizer<B> for Impl<$N, A>
1697        where
1698            A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1699            B: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1700        {
1701            fn nbits(&self) -> usize {
1702                $N
1703            }
1704
1705            fn dim(&self) -> usize {
1706                self.quantizer.output_dim()
1707            }
1708
1709            fn full_dim(&self) -> usize {
1710                self.quantizer.input_dim()
1711            }
1712
1713            fn bytes(&self) -> usize {
1714                DataRef::<$N>::canonical_bytes(<Self as Quantizer<B>>::dim(self))
1715            }
1716
1717            fn distance_computer(
1718                &self,
1719                allocator: B
1720            ) -> Result<DistanceComputer<B>, AllocatorError> {
1721                self.query_computer::<AsData<$N>, _>(allocator)
1722            }
1723
1724            fn distance_computer_ref(&self) -> &dyn DynDistanceComputer {
1725                &*self.distance
1726            }
1727
1728            fn query_computer(
1729                &self,
1730                layout: QueryLayout,
1731                allocator: B,
1732            ) -> Result<DistanceComputer<B>, DistanceComputerError> {
1733                match layout {
1734                    QueryLayout::SameAsData => Ok(self.query_computer::<AsData<$N>, _>(allocator)?)
1735                    ,
1736                    QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout::new(
1737                        layout,
1738                        concat!($N, "-bit compression"),
1739                    ).into()),
1740                    QueryLayout::ScalarQuantized => {
1741                        Ok(self.query_computer::<AsQuery<$N, bits::Dense>, _>(allocator)?)
1742                    },
1743                    QueryLayout::FullPrecision => Ok(self.query_computer::<AsFull, _>(allocator)?),
1744
1745                }
1746            }
1747
1748            fn query_buffer_description(
1749                &self,
1750                layout: QueryLayout
1751            ) -> Result<QueryBufferDescription, UnsupportedQueryLayout>
1752            {
1753                let dim = <Self as Quantizer<B>>::dim(self);
1754                match layout {
1755                    QueryLayout::SameAsData => Ok(QueryBufferDescription::new(
1756                        DataRef::<$N>::canonical_bytes(dim),
1757                        PowerOfTwo::alignment_of::<u8>(),
1758                    )),
1759                    QueryLayout::FourBitTransposed => Err(UnsupportedQueryLayout {
1760                        layout,
1761                        desc: concat!($N, "-bit compression"),
1762                    }),
1763                    QueryLayout::ScalarQuantized => Ok(QueryBufferDescription::new(
1764                        QueryRef::<$N, bits::Dense>::canonical_bytes(dim),
1765                        PowerOfTwo::alignment_of::<u8>(),
1766                    )),
1767                    QueryLayout::FullPrecision => Ok(QueryBufferDescription::new(
1768                        FullQueryRef::canonical_bytes(dim),
1769                        FullQueryRef::canonical_align(),
1770                    )),
1771                }
1772            }
1773
1774            fn compress_query(
1775                &self,
1776                x: &[f32],
1777                layout: QueryLayout,
1778                allow_rescale: bool,
1779                mut buffer: OpaqueMut<'_>,
1780                scratch: ScopedAllocator<'_>,
1781            ) -> Result<(), QueryCompressionError> {
1782                let dim = <Self as Quantizer<B>>::dim(self);
1783                let mut finish = |v: &[f32]| -> Result<(), QueryCompressionError> {
1784                    match layout {
1785                        QueryLayout::SameAsData => self.compress_query(
1786                            v,
1787                            DataMut::<$N>::from_canonical_back_mut(
1788                                &mut buffer,
1789                                dim,
1790                            ).map_err(NotCanonical::new)?,
1791                            scratch,
1792                        ),
1793                        QueryLayout::FourBitTransposed => {
1794                            Err(UnsupportedQueryLayout::new(
1795                                layout,
1796                                concat!($N, "-bit compression"),
1797                            ).into())
1798                        },
1799                        QueryLayout::ScalarQuantized => self.compress_query(
1800                            v,
1801                            QueryMut::<$N, bits::Dense>::from_canonical_back_mut(
1802                                &mut buffer,
1803                                dim,
1804                            ).map_err(NotCanonical::new)?,
1805                            scratch,
1806                        ),
1807                        QueryLayout::FullPrecision => self.compress_query(
1808                            v,
1809                            FullQueryMut::from_canonical_mut(
1810                                &mut buffer,
1811                                dim,
1812                            ).map_err(NotCanonical::new)?,
1813                            scratch,
1814                        ),
1815                    }
1816                };
1817
1818                if allow_rescale && self.quantizer.metric() == SupportedMetric::InnerProduct {
1819                    let mut copy = x.to_owned();
1820                    self.quantizer.rescale(&mut copy);
1821                    finish(&copy)
1822                } else {
1823                    finish(x)
1824                }
1825            }
1826
1827            fn fused_query_computer(
1828                &self,
1829                x: &[f32],
1830                layout: QueryLayout,
1831                allow_rescale: bool,
1832                allocator: B,
1833                scratch: ScopedAllocator<'_>,
1834            ) -> Result<QueryComputer<B>, QueryComputerError>
1835            {
1836                let dim = <Self as Quantizer<B>>::dim(self);
1837                let finish = |v: &[f32]| -> Result<QueryComputer<B>, QueryComputerError> {
1838                    match layout {
1839                        QueryLayout::SameAsData => {
1840                            self.fused_query_computer::<AsData<$N>, Data<$N, _>, B>(
1841                                v,
1842                                Data::new_in(dim, allocator.clone())?,
1843                                allocator,
1844                                scratch,
1845                            )
1846                        },
1847                        QueryLayout::FourBitTransposed => {
1848                            Err(UnsupportedQueryLayout::new(
1849                                layout,
1850                                concat!($N, "-bit compression"),
1851                            ).into())
1852                        },
1853                        QueryLayout::ScalarQuantized => {
1854                            self.fused_query_computer::<AsQuery<$N, bits::Dense>, Query<$N, bits::Dense, _>, B>(
1855                                v,
1856                                Query::new_in(dim, allocator.clone())?,
1857                                allocator,
1858                                scratch,
1859                            )
1860                        },
1861                        QueryLayout::FullPrecision => {
1862                            self.fused_query_computer::<AsFull, FullQuery<_>, B>(
1863                                v,
1864                                FullQuery::empty(dim, allocator.clone())?,
1865                                allocator,
1866                                scratch,
1867                            )
1868                        },
1869                    }
1870                };
1871
1872                let metric = <Self as Quantizer<B>>::metric(self);
1873                if allow_rescale && metric == SupportedMetric::InnerProduct {
1874                    let mut copy = x.to_owned();
1875                    self.quantizer.rescale(&mut copy);
1876                    finish(&copy)
1877                } else {
1878                    finish(x)
1879                }
1880            }
1881
1882            fn is_supported(&self, layout: QueryLayout) -> bool {
1883                Self::supports(layout)
1884            }
1885
1886            fn compress(
1887                &self,
1888                x: &[f32],
1889                mut into: OpaqueMut<'_>,
1890                scratch: ScopedAllocator<'_>,
1891            ) -> Result<(), CompressionError> {
1892                let dim = <Self as Quantizer<B>>::dim(self);
1893                let into = DataMut::<$N>::from_canonical_back_mut(into.inspect(), dim)
1894                    .map_err(CompressionError::not_canonical)?;
1895
1896                self.quantizer.compress_into_with(x, into, scratch)
1897                    .map_err(CompressionError::CompressionError)
1898            }
1899
1900            fn metric(&self) -> SupportedMetric {
1901                self.quantizer.metric()
1902            }
1903
1904            fn try_clone_into(&self, allocator: B) -> Result<Poly<dyn Quantizer<B>, B>, AllocatorError> {
1905                let clone = (&*self).try_clone()?;
1906                poly!({ Quantizer<B> }, clone, allocator)
1907            }
1908
1909            #[cfg(feature = "flatbuffers")]
1910            fn serialize(&self, allocator: B) -> Result<Poly<[u8], B>, AllocatorError> {
1911                Impl::<$N, A>::serialize(self, allocator)
1912            }
1913        }
1914    };
1915    ($N:literal, $($Ns:literal),*) => {
1916        plan!($N);
1917        $(plan!($Ns);)*
1918    }
1919}
1920
1921plan!(2, 4, 8);
1922
1923////////////////
1924// Flatbuffer //
1925////////////////
1926
1927#[cfg(feature = "flatbuffers")]
1928#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1929#[derive(Debug, Clone, Error)]
1930#[non_exhaustive]
1931pub enum DeserializationError {
1932    #[error("unhandled file identifier in flatbuffer")]
1933    InvalidIdentifier,
1934
1935    #[error("unsupported number of bits ({0})")]
1936    UnsupportedBitWidth(u32),
1937
1938    #[error(transparent)]
1939    InvalidQuantizer(#[from] super::quantizer::DeserializationError),
1940
1941    #[error(transparent)]
1942    InvalidFlatBuffer(#[from] flatbuffers::InvalidFlatbuffer),
1943
1944    #[error(transparent)]
1945    AllocatorError(#[from] AllocatorError),
1946}
1947
1948/// Attempt to deserialize a `spherical::Quantizer` flatbuffer into one of the concrete
1949/// implementations of `Quantizer`.
1950///
1951/// This function guarantees that the returned `Poly` is the first object allocated through
1952/// `alloc`.
1953#[cfg(feature = "flatbuffers")]
1954#[cfg_attr(docsrs, doc(cfg(feature = "flatbuffers")))]
1955pub fn try_deserialize<O, A>(
1956    data: &[u8],
1957    alloc: A,
1958) -> Result<Poly<dyn Quantizer<O>, A>, DeserializationError>
1959where
1960    O: Allocator + std::panic::UnwindSafe + Send + Sync + 'static,
1961    A: Allocator + std::panic::RefUnwindSafe + Send + Sync + 'static,
1962{
1963    // An inner impl is used to ensure that the returned `Poly` is allocated before any of
1964    // the allocations needed by the members.
1965    //
1966    // This ensures that if a bump allocator is used, the root object appears first.
1967    fn unpack_bits<'a, const NBITS: usize, O, A>(
1968        proto: fb::spherical::SphericalQuantizer<'_>,
1969        alloc: A,
1970    ) -> Result<Poly<dyn Quantizer<O> + 'a, A>, DeserializationError>
1971    where
1972        O: Allocator + Send + Sync + std::panic::UnwindSafe + 'static,
1973        A: Allocator + Send + Sync + 'a,
1974        Impl<NBITS, A>: Quantizer<O> + Constructible<A>,
1975    {
1976        let imp = match Poly::new_with(
1977            #[inline(never)]
1978            |alloc| -> Result<_, super::quantizer::DeserializationError> {
1979                let quantizer = SphericalQuantizer::try_unpack(alloc, proto)?;
1980                Ok(Impl::new(quantizer)?)
1981            },
1982            alloc,
1983        ) {
1984            Ok(imp) => imp,
1985            Err(CompoundError::Allocator(err)) => {
1986                return Err(err.into());
1987            }
1988            Err(CompoundError::Constructor(err)) => {
1989                return Err(err.into());
1990            }
1991        };
1992        Ok(poly!({ Quantizer<O> }, imp))
1993    }
1994
1995    // Check that this is one of the known identifiers.
1996    if !fb::spherical::quantizer_buffer_has_identifier(data) {
1997        return Err(DeserializationError::InvalidIdentifier);
1998    }
1999
2000    // Match as much as we can without allocating.
2001    //
2002    // Then, we branch on the number of bits.
2003    let root = fb::spherical::root_as_quantizer(data)?;
2004    let nbits = root.nbits();
2005    let proto = root.quantizer();
2006
2007    match nbits {
2008        1 => unpack_bits::<1, _, _>(proto, alloc),
2009        2 => unpack_bits::<2, _, _>(proto, alloc),
2010        4 => unpack_bits::<4, _, _>(proto, alloc),
2011        8 => unpack_bits::<8, _, _>(proto, alloc),
2012        n => Err(DeserializationError::UnsupportedBitWidth(n)),
2013    }
2014}
2015
2016///////////
2017// Tests //
2018///////////
2019
2020#[cfg(test)]
2021mod tests {
2022    use diskann_utils::views::{Matrix, MatrixView};
2023    use rand::{rngs::StdRng, SeedableRng};
2024
2025    use super::*;
2026    use crate::{
2027        algorithms::{transforms::TargetDim, TransformKind},
2028        alloc::{AlignedAllocator, GlobalAllocator, Poly},
2029        num::PowerOfTwo,
2030        spherical::PreScale,
2031    };
2032
2033    ////////////////////
2034    // Test Quantizer //
2035    ////////////////////
2036
2037    fn test_plan_1_bit(plan: &dyn Quantizer) {
2038        assert_eq!(
2039            plan.nbits(),
2040            1,
2041            "this test only applies to 1-bit quantization"
2042        );
2043
2044        // Check Layouts.
2045        for layout in QueryLayout::all() {
2046            match layout {
2047                QueryLayout::SameAsData
2048                | QueryLayout::FourBitTransposed
2049                | QueryLayout::FullPrecision => assert!(
2050                    plan.is_supported(layout),
2051                    "expected {} to be supported",
2052                    layout
2053                ),
2054                QueryLayout::ScalarQuantized => assert!(
2055                    !plan.is_supported(layout),
2056                    "expected {} to not be supported",
2057                    layout
2058                ),
2059            }
2060        }
2061    }
2062
2063    fn test_plan_n_bit(plan: &dyn Quantizer, nbits: usize) {
2064        assert_ne!(nbits, 1, "there is another test for 1-bit quantizers");
2065        assert_eq!(
2066            plan.nbits(),
2067            nbits,
2068            "this test only applies to 1-bit quantization"
2069        );
2070
2071        // Check Layouts.
2072        for layout in QueryLayout::all() {
2073            match layout {
2074                QueryLayout::SameAsData
2075                | QueryLayout::ScalarQuantized
2076                | QueryLayout::FullPrecision => assert!(
2077                    plan.is_supported(layout),
2078                    "expected {} to be supported",
2079                    layout
2080                ),
2081                QueryLayout::FourBitTransposed => assert!(
2082                    !plan.is_supported(layout),
2083                    "expected {} to not be supported",
2084                    layout
2085                ),
2086            }
2087        }
2088    }
2089
2090    #[inline(never)]
2091    fn test_plan(plan: &dyn Quantizer, nbits: usize, dataset: MatrixView<f32>) {
2092        // Perform the bit-specific test.
2093        if nbits == 1 {
2094            test_plan_1_bit(plan);
2095        } else {
2096            test_plan_n_bit(plan, nbits);
2097        }
2098
2099        // Run bit-width agnostic tests.
2100        assert_eq!(plan.full_dim(), dataset.ncols());
2101
2102        // Use the correct alignment for the base pointers.
2103        let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2104        let mut a = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2105        let mut b = Poly::broadcast(u8::default(), plan.bytes(), alloc).unwrap();
2106        let scoped_global = ScopedAllocator::global();
2107
2108        plan.compress(dataset.row(0), OpaqueMut::new(&mut a), scoped_global)
2109            .unwrap();
2110        plan.compress(dataset.row(1), OpaqueMut::new(&mut b), scoped_global)
2111            .unwrap();
2112
2113        let f = plan.distance_computer(GlobalAllocator).unwrap();
2114        let _: f32 = f
2115            .evaluate_similarity(Opaque::new(&a), Opaque::new(&b))
2116            .unwrap();
2117
2118        let test_errors = |f: &dyn DynDistanceComputer| {
2119            // `a` too short
2120            let err = f
2121                .evaluate(Opaque::new(&a[..a.len() - 1]), Opaque::new(&b))
2122                .unwrap_err();
2123            assert!(matches!(err, DistanceError::QueryReify(_)));
2124
2125            // `a` too long
2126            let err = f
2127                .evaluate(Opaque::new(&vec![0u8; a.len() + 1]), Opaque::new(&b))
2128                .unwrap_err();
2129            assert!(matches!(err, DistanceError::QueryReify(_)));
2130
2131            // `b` too short
2132            let err = f
2133                .evaluate(Opaque::new(&a), Opaque::new(&b[..b.len() - 1]))
2134                .unwrap_err();
2135            assert!(matches!(err, DistanceError::XReify(_)));
2136
2137            // `a` too long
2138            let err = f
2139                .evaluate(Opaque::new(&a), Opaque::new(&vec![0u8; b.len() + 1]))
2140                .unwrap_err();
2141            assert!(matches!(err, DistanceError::XReify(_)));
2142        };
2143
2144        test_errors(&*f.inner);
2145
2146        let f = plan.distance_computer_ref();
2147        let _: f32 = f.evaluate(Opaque::new(&a), Opaque::new(&b)).unwrap();
2148        test_errors(f);
2149
2150        // Test all supported flavors of `QueryComputer`.
2151        for layout in QueryLayout::all() {
2152            if !plan.is_supported(layout) {
2153                let check_message = |msg: &str| {
2154                    assert!(
2155                        msg.contains(&(layout.to_string())),
2156                        "error message ({}) should contain the layout \"{}\"",
2157                        msg,
2158                        layout
2159                    );
2160                    assert!(
2161                        msg.contains(&format!("{}", nbits)),
2162                        "error message ({}) should contain the number of bits \"{}\"",
2163                        msg,
2164                        nbits
2165                    );
2166                };
2167
2168                // Error for query computer
2169                {
2170                    let err = plan
2171                        .fused_query_computer(
2172                            dataset.row(1),
2173                            layout,
2174                            false,
2175                            GlobalAllocator,
2176                            scoped_global,
2177                        )
2178                        .unwrap_err();
2179
2180                    let msg = err.to_string();
2181                    check_message(&msg);
2182                }
2183
2184                // Query buffer
2185                {
2186                    let err = plan.query_buffer_description(layout).unwrap_err();
2187                    let msg = err.to_string();
2188                    check_message(&msg);
2189                }
2190
2191                // Compresss Query Into
2192                {
2193                    let buffer = &mut [];
2194                    let err = plan
2195                        .compress_query(
2196                            dataset.row(1),
2197                            layout,
2198                            true,
2199                            OpaqueMut::new(buffer),
2200                            scoped_global,
2201                        )
2202                        .unwrap_err();
2203                    let msg = err.to_string();
2204                    check_message(&msg);
2205                }
2206
2207                // Standalone Query Computer
2208                {
2209                    let err = plan.query_computer(layout, GlobalAllocator).unwrap_err();
2210                    let msg = err.to_string();
2211                    check_message(&msg);
2212                }
2213
2214                continue;
2215            }
2216
2217            let g = plan
2218                .fused_query_computer(
2219                    dataset.row(1),
2220                    layout,
2221                    false,
2222                    GlobalAllocator,
2223                    scoped_global,
2224                )
2225                .unwrap();
2226            assert_eq!(
2227                g.layout(),
2228                layout,
2229                "the query computer should faithfully preserve the requested layout"
2230            );
2231
2232            let direct: f32 = g.evaluate_similarity(Opaque(&a)).unwrap();
2233
2234            // Check that the fused computer correctly returns errors for invalid inputs.
2235            {
2236                let err = g
2237                    .evaluate_similarity(Opaque::new(&a[..a.len() - 1]))
2238                    .unwrap_err();
2239                assert!(matches!(err, QueryDistanceError::XReify(_)));
2240
2241                let err = g
2242                    .evaluate_similarity(Opaque::new(&vec![0u8; a.len() + 1]))
2243                    .unwrap_err();
2244                assert!(matches!(err, QueryDistanceError::XReify(_)));
2245            }
2246
2247            let sizes = plan.query_buffer_description(layout).unwrap();
2248            let mut buf =
2249                Poly::broadcast(0u8, sizes.bytes(), AlignedAllocator::new(sizes.align())).unwrap();
2250
2251            plan.compress_query(
2252                dataset.row(1),
2253                layout,
2254                false,
2255                OpaqueMut::new(&mut buf),
2256                scoped_global,
2257            )
2258            .unwrap();
2259
2260            let standalone = plan.query_computer(layout, GlobalAllocator).unwrap();
2261
2262            assert_eq!(
2263                standalone.layout(),
2264                layout,
2265                "the standalone computer did not preserve the requested layout",
2266            );
2267
2268            let indirect: f32 = standalone
2269                .evaluate_similarity(Opaque(&buf), Opaque(&a))
2270                .unwrap();
2271
2272            assert_eq!(
2273                direct, indirect,
2274                "the two different query computation APIs did not return the same result"
2275            );
2276
2277            // Errors
2278            let too_small = &dataset.row(0)[..dataset.ncols() - 1];
2279            assert!(plan
2280                .fused_query_computer(too_small, layout, false, GlobalAllocator, scoped_global)
2281                .is_err());
2282        }
2283
2284        // Errors
2285        {
2286            let mut too_small = vec![u8::default(); plan.bytes() - 1];
2287            assert!(plan
2288                .compress(dataset.row(0), OpaqueMut(&mut too_small), scoped_global)
2289                .is_err());
2290
2291            let mut too_big = vec![u8::default(); plan.bytes() + 1];
2292            assert!(plan
2293                .compress(dataset.row(0), OpaqueMut(&mut too_big), scoped_global)
2294                .is_err());
2295
2296            let mut just_right = vec![u8::default(); plan.bytes()];
2297            assert!(plan
2298                .compress(
2299                    &dataset.row(0)[..dataset.ncols() - 1],
2300                    OpaqueMut(&mut just_right),
2301                    scoped_global
2302                )
2303                .is_err());
2304        }
2305    }
2306
2307    fn make_impl<const NBITS: usize>(metric: SupportedMetric) -> (Impl<NBITS>, Matrix<f32>)
2308    where
2309        Impl<NBITS>: Constructible,
2310    {
2311        let data = test_dataset();
2312        let mut rng = StdRng::seed_from_u64(0x7d535118722ff197);
2313
2314        let quantizer = SphericalQuantizer::train(
2315            data.as_view(),
2316            TransformKind::PaddingHadamard {
2317                target_dim: TargetDim::Natural,
2318            },
2319            metric,
2320            PreScale::None,
2321            &mut rng,
2322            GlobalAllocator,
2323        )
2324        .unwrap();
2325
2326        (Impl::<NBITS>::new(quantizer).unwrap(), data)
2327    }
2328
2329    #[test]
2330    fn test_plan_1bit_l2() {
2331        let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2332        test_plan(&plan, 1, data.as_view());
2333    }
2334
2335    #[test]
2336    fn test_plan_1bit_ip() {
2337        let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2338        test_plan(&plan, 1, data.as_view());
2339    }
2340
2341    #[test]
2342    fn test_plan_1bit_cosine() {
2343        let (plan, data) = make_impl::<1>(SupportedMetric::Cosine);
2344        test_plan(&plan, 1, data.as_view());
2345    }
2346
2347    #[test]
2348    fn test_plan_2bit_l2() {
2349        let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2350        test_plan(&plan, 2, data.as_view());
2351    }
2352
2353    #[test]
2354    fn test_plan_2bit_ip() {
2355        let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2356        test_plan(&plan, 2, data.as_view());
2357    }
2358
2359    #[test]
2360    fn test_plan_2bit_cosine() {
2361        let (plan, data) = make_impl::<2>(SupportedMetric::Cosine);
2362        test_plan(&plan, 2, data.as_view());
2363    }
2364
2365    #[test]
2366    fn test_plan_4bit_l2() {
2367        let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2368        test_plan(&plan, 4, data.as_view());
2369    }
2370
2371    #[test]
2372    fn test_plan_4bit_ip() {
2373        let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2374        test_plan(&plan, 4, data.as_view());
2375    }
2376
2377    #[test]
2378    fn test_plan_4bit_cosine() {
2379        let (plan, data) = make_impl::<4>(SupportedMetric::Cosine);
2380        test_plan(&plan, 4, data.as_view());
2381    }
2382
2383    #[test]
2384    fn test_plan_8bit_l2() {
2385        let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2386        test_plan(&plan, 8, data.as_view());
2387    }
2388
2389    #[test]
2390    fn test_plan_8bit_ip() {
2391        let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2392        test_plan(&plan, 8, data.as_view());
2393    }
2394
2395    #[test]
2396    fn test_plan_8bit_cosine() {
2397        let (plan, data) = make_impl::<8>(SupportedMetric::Cosine);
2398        test_plan(&plan, 8, data.as_view());
2399    }
2400
2401    fn test_dataset() -> Matrix<f32> {
2402        let data = vec![
2403            0.28657,
2404            -0.0318168,
2405            0.0666847,
2406            0.0329265,
2407            -0.00829283,
2408            0.168735,
2409            -0.000846311,
2410            -0.360779, // row 0
2411            -0.0968938,
2412            0.161921,
2413            -0.0979579,
2414            0.102228,
2415            -0.259928,
2416            -0.139634,
2417            0.165384,
2418            -0.293443, // row 1
2419            0.130205,
2420            0.265737,
2421            0.401816,
2422            -0.407552,
2423            0.13012,
2424            -0.0475244,
2425            0.511723,
2426            -0.4372, // row 2
2427            -0.0979126,
2428            0.135861,
2429            -0.0154144,
2430            -0.14047,
2431            -0.0250029,
2432            -0.190279,
2433            0.407283,
2434            -0.389184, // row 3
2435            -0.264153,
2436            0.0696822,
2437            -0.145585,
2438            0.370284,
2439            0.186825,
2440            -0.140736,
2441            0.274703,
2442            -0.334563, // row 4
2443            0.247613,
2444            0.513165,
2445            -0.0845867,
2446            0.0532264,
2447            -0.00480601,
2448            -0.122408,
2449            0.47227,
2450            -0.268301, // row 5
2451            0.103198,
2452            0.30756,
2453            -0.316293,
2454            -0.0686877,
2455            -0.330729,
2456            -0.461997,
2457            0.550857,
2458            -0.240851, // row 6
2459            0.128258,
2460            0.786291,
2461            -0.0268103,
2462            0.111763,
2463            -0.308962,
2464            -0.17407,
2465            0.437154,
2466            -0.159879, // row 7
2467            0.00374063,
2468            0.490301,
2469            0.0327826,
2470            -0.0340962,
2471            -0.118605,
2472            0.163879,
2473            0.2737,
2474            -0.299942, // row 8
2475            -0.284077,
2476            0.249377,
2477            -0.0307734,
2478            -0.0661631,
2479            0.233854,
2480            0.427987,
2481            0.614132,
2482            -0.288649, // row 9
2483            -0.109492,
2484            0.203939,
2485            -0.73956,
2486            -0.130748,
2487            0.22072,
2488            0.0647836,
2489            0.328726,
2490            -0.374602, // row 10
2491            -0.223114,
2492            0.0243489,
2493            0.109195,
2494            -0.416914,
2495            0.0201052,
2496            -0.0190542,
2497            0.947078,
2498            -0.333229, // row 11
2499            -0.165869,
2500            -0.00296729,
2501            -0.414378,
2502            0.231321,
2503            0.205365,
2504            0.161761,
2505            0.148608,
2506            -0.395063, // row 12
2507            -0.0498255,
2508            0.193279,
2509            -0.110946,
2510            -0.181174,
2511            -0.274578,
2512            -0.227511,
2513            0.190208,
2514            -0.256174, // row 13
2515            -0.188106,
2516            -0.0292958,
2517            0.0930939,
2518            0.0558456,
2519            0.257437,
2520            0.685481,
2521            0.307922,
2522            -0.320006, // row 14
2523            0.250035,
2524            0.275942,
2525            -0.0856306,
2526            -0.352027,
2527            -0.103509,
2528            -0.00890859,
2529            0.276121,
2530            -0.324718, // row 15
2531        ];
2532
2533        Matrix::try_from(data.into(), 16, 8).unwrap()
2534    }
2535
2536    #[cfg(feature = "flatbuffers")]
2537    mod serialization {
2538        use std::sync::{
2539            atomic::{AtomicBool, Ordering},
2540            Arc,
2541        };
2542
2543        use super::*;
2544        use crate::alloc::{BumpAllocator, GlobalAllocator};
2545
2546        #[inline(never)]
2547        fn test_plan_serialization(
2548            quantizer: &dyn Quantizer,
2549            nbits: usize,
2550            dataset: MatrixView<f32>,
2551        ) {
2552            // Run bit-width agnostic tests.
2553            assert_eq!(quantizer.full_dim(), dataset.ncols());
2554            let scoped_global = ScopedAllocator::global();
2555
2556            let serialized = quantizer.serialize(GlobalAllocator).unwrap();
2557            let deserialized =
2558                try_deserialize::<GlobalAllocator, _>(&serialized, GlobalAllocator).unwrap();
2559
2560            assert_eq!(deserialized.nbits(), nbits);
2561            assert_eq!(deserialized.bytes(), quantizer.bytes());
2562            assert_eq!(deserialized.dim(), quantizer.dim());
2563            assert_eq!(deserialized.full_dim(), quantizer.full_dim());
2564            assert_eq!(deserialized.metric(), quantizer.metric());
2565
2566            for layout in QueryLayout::all() {
2567                assert_eq!(
2568                    deserialized.is_supported(layout),
2569                    quantizer.is_supported(layout)
2570                );
2571            }
2572
2573            // Use the correct alignment for the base pointers.
2574            let alloc = AlignedAllocator::new(PowerOfTwo::new(4).unwrap());
2575            {
2576                let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2577                let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2578
2579                for row in dataset.row_iter() {
2580                    quantizer
2581                        .compress(row, OpaqueMut::new(&mut a), scoped_global)
2582                        .unwrap();
2583                    deserialized
2584                        .compress(row, OpaqueMut::new(&mut b), scoped_global)
2585                        .unwrap();
2586
2587                    // Compressed representation should be identical.
2588                    assert_eq!(a, b);
2589                }
2590            }
2591
2592            // Distance Computer
2593            {
2594                let mut a0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2595                let mut a1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2596                let mut b0 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2597                let mut b1 = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2598
2599                let q_computer = quantizer.distance_computer(GlobalAllocator).unwrap();
2600                let q_computer_ref = quantizer.distance_computer_ref();
2601                let d_computer = deserialized.distance_computer(GlobalAllocator).unwrap();
2602                let d_computer_ref = deserialized.distance_computer_ref();
2603
2604                for r0 in dataset.row_iter() {
2605                    quantizer
2606                        .compress(r0, OpaqueMut::new(&mut a0), scoped_global)
2607                        .unwrap();
2608                    deserialized
2609                        .compress(r0, OpaqueMut::new(&mut b0), scoped_global)
2610                        .unwrap();
2611                    for r1 in dataset.row_iter() {
2612                        quantizer
2613                            .compress(r1, OpaqueMut::new(&mut a1), scoped_global)
2614                            .unwrap();
2615                        deserialized
2616                            .compress(r1, OpaqueMut::new(&mut b1), scoped_global)
2617                            .unwrap();
2618
2619                        let a0 = Opaque::new(&a0);
2620                        let a1 = Opaque::new(&a1);
2621
2622                        let q_computer_dist = q_computer.evaluate_similarity(a0, a1).unwrap();
2623                        let d_computer_dist = d_computer.evaluate_similarity(a0, a1).unwrap();
2624
2625                        assert_eq!(q_computer_dist, d_computer_dist);
2626
2627                        let q_computer_ref_dist = q_computer_ref.evaluate(a0, a1).unwrap();
2628
2629                        assert_eq!(q_computer_dist, q_computer_ref_dist);
2630
2631                        let d_computer_ref_dist = d_computer_ref.evaluate(a0, a1).unwrap();
2632                        assert_eq!(d_computer_dist, d_computer_ref_dist);
2633                    }
2634                }
2635            }
2636
2637            // Query Computer
2638            {
2639                let mut a = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2640                let mut b = Poly::broadcast(u8::default(), quantizer.bytes(), alloc).unwrap();
2641
2642                for layout in QueryLayout::all() {
2643                    if !quantizer.is_supported(layout) {
2644                        continue;
2645                    }
2646
2647                    for r in dataset.row_iter() {
2648                        let q_computer = quantizer
2649                            .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2650                            .unwrap();
2651                        let d_computer = deserialized
2652                            .fused_query_computer(r, layout, false, GlobalAllocator, scoped_global)
2653                            .unwrap();
2654
2655                        for u in dataset.row_iter() {
2656                            quantizer
2657                                .compress(u, OpaqueMut::new(&mut a), scoped_global)
2658                                .unwrap();
2659                            deserialized
2660                                .compress(u, OpaqueMut::new(&mut b), scoped_global)
2661                                .unwrap();
2662
2663                            assert_eq!(
2664                                q_computer.evaluate_similarity(Opaque::new(&a)).unwrap(),
2665                                d_computer.evaluate_similarity(Opaque::new(&b)).unwrap(),
2666                            );
2667                        }
2668                    }
2669                }
2670            }
2671        }
2672
2673        // An allocator that succeeds on its first allocation but fails on its second.
2674        #[derive(Debug, Clone)]
2675        struct FlakyAllocator {
2676            have_allocated: Arc<AtomicBool>,
2677        }
2678
2679        impl FlakyAllocator {
2680            fn new(have_allocated: Arc<AtomicBool>) -> Self {
2681                Self { have_allocated }
2682            }
2683        }
2684
2685        // SAFETY: This is a wrapper around GlobalAllocator that only succeeds once.
2686        unsafe impl AllocatorCore for FlakyAllocator {
2687            fn allocate(
2688                &self,
2689                layout: std::alloc::Layout,
2690            ) -> Result<std::ptr::NonNull<[u8]>, AllocatorError> {
2691                if self.have_allocated.swap(true, Ordering::Relaxed) {
2692                    Err(AllocatorError)
2693                } else {
2694                    GlobalAllocator.allocate(layout)
2695                }
2696            }
2697
2698            unsafe fn deallocate(&self, ptr: std::ptr::NonNull<[u8]>, layout: std::alloc::Layout) {
2699                // SAFETY: Inherited from caller.
2700                unsafe { GlobalAllocator.deallocate(ptr, layout) }
2701            }
2702        }
2703
2704        fn test_plan_panic_boundary<const NBITS: usize>(v: &Impl<NBITS>)
2705        where
2706            Impl<NBITS>: Quantizer,
2707        {
2708            // Ensure that we do not panic if reallocation returns an error.
2709            let have_allocated = Arc::new(AtomicBool::new(false));
2710            let _: AllocatorError = v
2711                .serialize(FlakyAllocator::new(have_allocated.clone()))
2712                .unwrap_err();
2713            assert!(have_allocated.load(Ordering::Relaxed));
2714        }
2715
2716        #[test]
2717        fn test_plan_1bit_l2() {
2718            let (plan, data) = make_impl::<1>(SupportedMetric::SquaredL2);
2719            test_plan_panic_boundary(&plan);
2720            test_plan_serialization(&plan, 1, data.as_view());
2721        }
2722
2723        #[test]
2724        fn test_plan_1bit_ip() {
2725            let (plan, data) = make_impl::<1>(SupportedMetric::InnerProduct);
2726            test_plan_panic_boundary(&plan);
2727            test_plan_serialization(&plan, 1, data.as_view());
2728        }
2729
2730        #[test]
2731        fn test_plan_2bit_l2() {
2732            let (plan, data) = make_impl::<2>(SupportedMetric::SquaredL2);
2733            test_plan_panic_boundary(&plan);
2734            test_plan_serialization(&plan, 2, data.as_view());
2735        }
2736
2737        #[test]
2738        fn test_plan_2bit_ip() {
2739            let (plan, data) = make_impl::<2>(SupportedMetric::InnerProduct);
2740            test_plan_panic_boundary(&plan);
2741            test_plan_serialization(&plan, 2, data.as_view());
2742        }
2743
2744        #[test]
2745        fn test_plan_4bit_l2() {
2746            let (plan, data) = make_impl::<4>(SupportedMetric::SquaredL2);
2747            test_plan_panic_boundary(&plan);
2748            test_plan_serialization(&plan, 4, data.as_view());
2749        }
2750
2751        #[test]
2752        fn test_plan_4bit_ip() {
2753            let (plan, data) = make_impl::<4>(SupportedMetric::InnerProduct);
2754            test_plan_panic_boundary(&plan);
2755            test_plan_serialization(&plan, 4, data.as_view());
2756        }
2757
2758        #[test]
2759        fn test_plan_8bit_l2() {
2760            let (plan, data) = make_impl::<8>(SupportedMetric::SquaredL2);
2761            test_plan_panic_boundary(&plan);
2762            test_plan_serialization(&plan, 8, data.as_view());
2763        }
2764
2765        #[test]
2766        fn test_plan_8bit_ip() {
2767            let (plan, data) = make_impl::<8>(SupportedMetric::InnerProduct);
2768            test_plan_panic_boundary(&plan);
2769            test_plan_serialization(&plan, 8, data.as_view());
2770        }
2771
2772        #[test]
2773        fn test_allocation_order() {
2774            let (plan, _) = make_impl::<1>(SupportedMetric::SquaredL2);
2775            let buf = plan.serialize(GlobalAllocator).unwrap();
2776
2777            let allocator = BumpAllocator::new(8192, PowerOfTwo::new(64).unwrap()).unwrap();
2778            let deserialized =
2779                try_deserialize::<GlobalAllocator, _>(&buf, allocator.clone()).unwrap();
2780            assert_eq!(
2781                Poly::as_ptr(&deserialized).cast::<u8>(),
2782                allocator.as_ptr(),
2783                "expected the returned box to be allocated first",
2784            );
2785        }
2786    }
2787}