Skip to main content

diskann_quantization/meta/
vector.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::ptr::NonNull;
7
8use diskann_utils::{Reborrow, ReborrowMut};
9use thiserror::Error;
10
11use crate::{
12    alloc::{AllocatorCore, AllocatorError, GlobalAllocator, Poly},
13    bits::{
14        AsMutPtr, AsPtr, BitSlice, BitSliceBase, Dense, MutBitSlice, MutSlicePtr,
15        PermutationStrategy, Representation, SlicePtr,
16    },
17    ownership::{CopyMut, CopyRef, Mut, Owned, Ref},
18};
19
20/// A wrapper for [`BitSliceBase`] that provides the addition of arbitrary metadata.
21///
22/// # Examples
23///
24/// The `VectorBase` has several named variants that are commonly used:
25/// * [`Vector`]: An owning, independently allocated `VectorBase`.
26/// * [`VectorMut`]: A mutable, reference-like type to a `VectorBase`.
27/// * [`VectorRef`]: A const, reference-like type to a `VectorBase`.
28///
29/// ```
30/// use diskann_quantization::{
31///     meta::{Vector, VectorMut, VectorRef},
32///     bits::Unsigned,
33/// };
34///
35/// use diskann_utils::{Reborrow, ReborrowMut};
36///
37/// #[derive(Debug, Default, Clone, Copy, PartialEq)]
38/// struct Metadata {
39///     value: f32,
40/// }
41///
42/// // Create a new heap-allocated Vector for 4-bit compressions capable of
43/// // holding 3 elements.
44/// //
45/// // In this case, the associated m
46/// let mut v = Vector::<4, Unsigned, Metadata>::new_boxed(3);
47///
48/// // We can inspect the underlying bitslice.
49/// let bitslice = v.vector();
50/// assert_eq!(bitslice.get(0).unwrap(), 0);
51/// assert_eq!(bitslice.get(1).unwrap(), 0);
52/// assert_eq!(v.meta(), Metadata::default(), "expected default metadata value");
53///
54/// // If we want, we can mutably borrow the bitslice and mutate its components.
55/// let mut bitslice = v.vector_mut();
56/// bitslice.set(0, 1).unwrap();
57/// bitslice.set(1, 2).unwrap();
58/// bitslice.set(2, 3).unwrap();
59///
60/// assert!(bitslice.set(3, 4).is_err(), "out-of-bounds access");
61///
62/// // Get the underlying pointer for comparison.
63/// let ptr = bitslice.as_ptr();
64///
65/// // Vectors can be converted to a generalized reference.
66/// let mut v_ref = v.reborrow_mut();
67///
68/// // The generalized reference preserves the underlying pointer.
69/// assert_eq!(v_ref.vector().as_ptr(), ptr);
70/// let mut bitslice = v_ref.vector_mut();
71/// bitslice.set(0, 10).unwrap();
72///
73/// // Setting the underlying compensation will be visible in the original allocation.
74/// v_ref.set_meta(Metadata { value: 10.5 });
75///
76/// // Check that the changes are visible.
77/// assert_eq!(v.meta().value, 10.5);
78/// assert_eq!(v.vector().get(0).unwrap(), 10);
79///
80/// // Finally, the immutable ref also maintains pointer compatibility.
81/// let v_ref = v.reborrow();
82/// assert_eq!(v_ref.vector().as_ptr(), ptr);
83/// ```
84///
85/// ## Constructing a `VectorMut` From Components
86///
87/// The following example shows how to assemble a `VectorMut` from raw memory.
88/// ```
89/// use diskann_quantization::{bits::{Unsigned, MutBitSlice}, meta::VectorMut};
90///
91/// // Start with 2 bytes of memory. We will impose a 4-bit scalar quantization on top of
92/// // these 2 bytes.
93/// let mut data = vec![0u8; 2];
94/// let mut metadata: f32 = 0.0;
95/// {
96///     // First, we need to construct a bit-slice over the data.
97///     // This will check that it is sized properly for 4, 4-bit values.
98///     let mut slice = MutBitSlice::<4, Unsigned>::new(data.as_mut_slice(), 4).unwrap();
99///
100///     // Next, we construct the `VectorMut`.
101///     let mut v = VectorMut::new(slice, &mut metadata);
102///
103///     // Through `v`, we can set all the components in `slice` and the compensation.
104///     v.set_meta(123.4);
105///     let mut from_v = v.vector_mut();
106///     from_v.set(0, 1).unwrap();
107///     from_v.set(1, 2).unwrap();
108///     from_v.set(2, 3).unwrap();
109///     from_v.set(3, 4).unwrap();
110/// }
111///
112/// // Now we can check that the changes made internally are visible.
113/// assert_eq!(&data, &[0x21, 0x43]);
114/// assert_eq!(metadata, 123.4);
115/// ```
116///
117/// ## Canonical Layout
118///
119/// When the metadata type `T` is
120/// [`bytemuck::Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html), [`VectorRef`]
121/// and [`VectorMut`] support layout canonicalization, where a raw slice can be used as the
122/// backing store for such vectors, enabling inline storage.
123///
124/// There are two supported schems for the canonical layout, depending on whether the
125/// metadata is located at the beginning of the slice or at the end of the slice.
126///
127/// If the metadata is at the front, then the layout consists of a slice `&[u8]` where the
128/// first `std::mem::size_of::<T>()` bytes are the metadata and the remainder compose the
129/// [`BitSlice`] codes.
130///
131/// If the metadata is at the back, , then the layout consists of a slice `&[u8]` where the
132/// last `std::mem::size_of::<T>()` bytes are the metadata and the prefix is the
133/// [`BitSlice`] codes.
134///
135/// The canonical layout needs the following properties:
136///
137/// * `T: bytemuck::Pod`: For safely storing and retrieving.
138/// * The length for a vector with `N` dimensions must be equal to the value returne from
139///   [`Vector::canonical_bytes`].
140///
141/// The following functions can be used to construct [`VectorBase`]s from raw slices:
142///
143/// * [`VectorRef::from_canonical_front`]
144/// * [`VectorRef::from_canonical_back`]
145/// * [`VectorMut::from_canonical_front_mut`]
146/// * [`VectorMut::from_canonical_back_mut`]
147///
148/// An example is shown below.
149/// ```rust
150/// use diskann_quantization::{bits, meta::{Vector, VectorRef, VectorMut}};
151///
152/// type CVRef<'a, const NBITS: usize> = VectorRef<'a, NBITS, bits::Unsigned, f32>;
153/// type MutCV<'a, const NBITS: usize> = VectorMut<'a, NBITS, bits::Unsigned, f32>;
154///
155/// let dim = 3;
156///
157/// // Since we don't control the alignment of the returned pointer, we need to oversize it.
158/// let bytes = CVRef::<4>::canonical_bytes(dim);
159/// let mut data: Box<[u8]> = (0..bytes).map(|_| u8::default()).collect();
160///
161/// // Construct a mutable compensated vector over the slice.
162/// let mut mut_cv = MutCV::<4>::from_canonical_front_mut(&mut data, dim).unwrap();
163/// mut_cv.set_meta(1.0);
164/// let mut v = mut_cv.vector_mut();
165/// v.set(0, 1).unwrap();
166/// v.set(1, 2).unwrap();
167/// v.set(2, 3).unwrap();
168///
169/// // Reconstruct a constant CompensatedVector.
170/// let cv = CVRef::<4>::from_canonical_front(&data, dim).unwrap();
171/// assert_eq!(cv.meta(), 1.0);
172/// let v = cv.vector();
173/// assert_eq!(v.get(0).unwrap(), 1);
174/// assert_eq!(v.get(1).unwrap(), 2);
175/// assert_eq!(v.get(2).unwrap(), 3);
176/// ```
177#[derive(Debug, Clone, Copy)]
178pub struct VectorBase<const NBITS: usize, Repr, Ptr, T, Perm = Dense>
179where
180    Ptr: AsPtr<Type = u8>,
181    Repr: Representation<NBITS>,
182    Perm: PermutationStrategy<NBITS>,
183{
184    bits: BitSliceBase<NBITS, Repr, Ptr, Perm>,
185    meta: T,
186}
187
188impl<const NBITS: usize, Repr, Ptr, T, Perm> VectorBase<NBITS, Repr, Ptr, T, Perm>
189where
190    Ptr: AsPtr<Type = u8>,
191    Repr: Representation<NBITS>,
192    Perm: PermutationStrategy<NBITS>,
193{
194    /// Return the number of bytes required for the underlying `BitSlice`.
195    pub fn slice_bytes(count: usize) -> usize {
196        BitSliceBase::<NBITS, Repr, Ptr, Perm>::bytes_for(count)
197    }
198
199    /// Return the number of bytes required for the canonical representation of a
200    /// `Vector`.
201    ///
202    /// See: [`VectorRef::from_canonical_back`], [`VectorMut::from_canonical_back_mut`].
203    pub fn canonical_bytes(count: usize) -> usize
204    where
205        T: CopyRef,
206        T::Target: bytemuck::Pod,
207    {
208        Self::slice_bytes(count) + std::mem::size_of::<T::Target>()
209    }
210
211    /// Construct a new `VectorBase` over the bit-slice.
212    pub fn new<M>(bits: BitSliceBase<NBITS, Repr, Ptr, Perm>, meta: M) -> Self
213    where
214        M: Into<T>,
215    {
216        Self {
217            bits,
218            meta: meta.into(),
219        }
220    }
221
222    /// Return the number of dimensions of in the vector.
223    pub fn len(&self) -> usize {
224        self.bits.len()
225    }
226
227    /// Return whether or not the vector is empty.
228    pub fn is_empty(&self) -> bool {
229        self.bits.is_empty()
230    }
231
232    /// Return the metadata value for this vector.
233    pub fn meta(&self) -> T::Target
234    where
235        T: CopyRef,
236    {
237        self.meta.copy_ref()
238    }
239
240    /// Borrow the integer compressed vector.
241    pub fn vector(&self) -> BitSlice<'_, NBITS, Repr, Perm> {
242        self.bits.reborrow()
243    }
244
245    /// Mutably borrow the integer compressed vector.
246    pub fn vector_mut(&mut self) -> MutBitSlice<'_, NBITS, Repr, Perm>
247    where
248        Ptr: AsMutPtr,
249    {
250        self.bits.reborrow_mut()
251    }
252
253    /// Get a mutable reference to the metadata component.
254    ///
255    /// In addition to a mutable reference, this also requires `Ptr: AsMutPtr` to prevent
256    /// accidental misuse where the `VectorBase` is mutable but the underlying
257    /// `BitSlice` is not.
258    pub fn set_meta(&mut self, value: T::Target)
259    where
260        Ptr: AsMutPtr,
261        T: CopyMut,
262    {
263        self.meta.copy_mut(value)
264    }
265}
266
267impl<const NBITS: usize, Repr, Perm, T>
268    VectorBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Owned<T>, Perm>
269where
270    Repr: Representation<NBITS>,
271    Perm: PermutationStrategy<NBITS>,
272    T: Default,
273{
274    /// Create a new owned `VectorBase` with its metadata default initialized.
275    pub fn new_boxed(len: usize) -> Self {
276        Self {
277            bits: BitSliceBase::new_boxed(len),
278            meta: Owned::default(),
279        }
280    }
281}
282
283impl<const NBITS: usize, Repr, Perm, T, A> VectorBase<NBITS, Repr, Poly<[u8], A>, Owned<T>, Perm>
284where
285    Repr: Representation<NBITS>,
286    Perm: PermutationStrategy<NBITS>,
287    T: Default,
288    A: AllocatorCore,
289{
290    /// Create a new owned `VectorBase` with its metadata default initialized.
291    pub fn new_in(len: usize, allocator: A) -> Result<Self, AllocatorError> {
292        Ok(Self {
293            bits: BitSliceBase::new_in(len, allocator)?,
294            meta: Owned::default(),
295        })
296    }
297}
298
299/// A borrowed `Vector`.
300///
301/// See: [`VectorBase`].
302pub type VectorRef<'a, const NBITS: usize, Repr, T, Perm = Dense> =
303    VectorBase<NBITS, Repr, SlicePtr<'a, u8>, Ref<'a, T>, Perm>;
304
305/// A mutably borrowed `Vector`.
306///
307/// See: [`VectorBase`].
308pub type VectorMut<'a, const NBITS: usize, Repr, T, Perm = Dense> =
309    VectorBase<NBITS, Repr, MutSlicePtr<'a, u8>, Mut<'a, T>, Perm>;
310
311/// An owning `VectorBase`.
312///
313/// See: [`VectorBase`].
314pub type Vector<const NBITS: usize, Repr, T, Perm = Dense> =
315    VectorBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Owned<T>, Perm>;
316
317/// An owning `VectorBase`.
318///
319/// See: [`VectorBase`].
320pub type PolyVector<const NBITS: usize, Repr, T, Perm, A> =
321    VectorBase<NBITS, Repr, Poly<[u8], A>, Owned<T>, Perm>;
322
323// Reborrow
324impl<'this, const NBITS: usize, Repr, Ptr, T, Perm> Reborrow<'this>
325    for VectorBase<NBITS, Repr, Ptr, T, Perm>
326where
327    Ptr: AsPtr<Type = u8>,
328    Repr: Representation<NBITS>,
329    Perm: PermutationStrategy<NBITS>,
330    T: CopyRef + Reborrow<'this, Target = Ref<'this, <T as CopyRef>::Target>>,
331{
332    type Target = VectorRef<'this, NBITS, Repr, <T as CopyRef>::Target, Perm>;
333
334    fn reborrow(&'this self) -> Self::Target {
335        Self::Target {
336            bits: self.bits.reborrow(),
337            meta: self.meta.reborrow(),
338        }
339    }
340}
341
342// ReborrowMut
343impl<'this, const NBITS: usize, Repr, Ptr, T, Perm> ReborrowMut<'this>
344    for VectorBase<NBITS, Repr, Ptr, T, Perm>
345where
346    Ptr: AsMutPtr<Type = u8>,
347    Repr: Representation<NBITS>,
348    Perm: PermutationStrategy<NBITS>,
349    T: CopyMut + ReborrowMut<'this, Target = Mut<'this, <T as CopyRef>::Target>>,
350{
351    type Target = VectorMut<'this, NBITS, Repr, <T as CopyRef>::Target, Perm>;
352
353    fn reborrow_mut(&'this mut self) -> Self::Target {
354        Self::Target {
355            bits: self.bits.reborrow_mut(),
356            meta: self.meta.reborrow_mut(),
357        }
358    }
359}
360
361//////////////////////
362// Canonical Layout //
363//////////////////////
364
365#[derive(Debug, Error, PartialEq, Clone, Copy)]
366pub enum NotCanonical {
367    #[error("expected a slice length of {0} bytes but instead got {1} bytes")]
368    WrongLength(usize, usize),
369}
370
371impl<'a, const NBITS: usize, Repr, T, Perm> VectorRef<'a, NBITS, Repr, T, Perm>
372where
373    Repr: Representation<NBITS>,
374    Perm: PermutationStrategy<NBITS>,
375    T: bytemuck::Pod,
376{
377    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
378    /// The canonical layout is as follows:
379    ///
380    /// * `std::mem::size_of::<T>()` for the metadata coefficient.
381    /// * `Self::slice_bytes(dim)` for the underlying bit-slice.
382    ///
383    /// Returns an error if `data.len() != `Self::canonical_bytes`.
384    pub fn from_canonical_front(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
385        let expected = Self::canonical_bytes(dim);
386        if data.len() != expected {
387            Err(NotCanonical::WrongLength(expected, data.len()))
388        } else {
389            // SAFETY: We have checked both the length and alignment of `data`.
390            Ok(unsafe { Self::from_canonical_unchecked(data, dim) })
391        }
392    }
393
394    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
395    /// The back canonical layout is as follows:
396    ///
397    /// * `Self::slice_bytes(dim)` for the underlying bit-slice.
398    /// * `std::mem::size_of::<T>()` for the metadata coefficient.
399    ///
400    /// Returns an error if `data.len() != `Self::canonical_bytes`.
401    pub fn from_canonical_back(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
402        let expected = Self::canonical_bytes(dim);
403        if data.len() != expected {
404            Err(NotCanonical::WrongLength(expected, data.len()))
405        } else {
406            // SAFETY: We have checked both the length and alignment of `data`.
407            Ok(unsafe { Self::from_canonical_back_unchecked(data, dim) })
408        }
409    }
410
411    /// Construct a `VectorRef` from the raw data.
412    ///
413    /// # Safety
414    ///
415    /// * `data.len()` must be equal to `Self::canonical_bytes(dim)`.
416    ///
417    /// This invariant is checked in debug builds and will panic if not satisfied.
418    pub unsafe fn from_canonical_unchecked(data: &'a [u8], dim: usize) -> Self {
419        debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
420
421        // SAFETY: `BitSlice` has no alignment requirements, but the length precondition
422        // for this function (i.e., `data.len() == Self::canonical_bytes(dim)`) implies
423        // that `Self::slice_bytes(dim)` is valid beginning at an offset of
424        // `std::mem::size_of::<T>()`.
425        let bits =
426            unsafe { BitSlice::new_unchecked(data.get_unchecked(std::mem::size_of::<T>()..), dim) };
427
428        // SAFETY: The pointer is valid and non-null because `data` is a slice, its length
429        // must be at least `std::mem::size_of::<T>()` (from the length precondition for
430        // this function).
431        let meta =
432            unsafe { Ref::new(NonNull::new_unchecked(data.as_ptr().cast_mut()).cast::<T>()) };
433        Self { bits, meta }
434    }
435
436    /// Construct a `VectorRef` from the raw data.
437    ///
438    /// # Safety
439    ///
440    /// * `data.len()` must be equal to `Self::canonical_bytes(dim)`.
441    ///
442    /// This invariant is checked in debug builds and will panic if not satisfied.
443    pub unsafe fn from_canonical_back_unchecked(data: &'a [u8], dim: usize) -> Self {
444        debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
445        // SAFETY: The caller asserts that
446        // `data.len() == Self::canonical_bytes(dim) >= std::mem::size_of::<T>()`.
447        let (data, meta) =
448            unsafe { data.split_at_unchecked(data.len() - std::mem::size_of::<T>()) };
449
450        // SAFETY: `BitSlice` has no alignment requirements, but the length precondition
451        // for this function (i.e., `data.len() == Self::canonical_bytes(dim)`) implies
452        // that `Self::slice_bytes(dim)` is valid beginning at an offset of
453        // `std::mem::size_of::<T>()`.
454        let bits = unsafe { BitSlice::new_unchecked(data, dim) };
455
456        // SAFETY: The pointer is valid and non-null because `data` is a slice, its length
457        // must be at least `std::mem::size_of::<T>()` (from the length precondition for
458        // this function).
459        let meta =
460            unsafe { Ref::new(NonNull::new_unchecked(meta.as_ptr().cast_mut()).cast::<T>()) };
461        Self { bits, meta }
462    }
463}
464
465impl<'a, const NBITS: usize, Repr, T, Perm> VectorMut<'a, NBITS, Repr, T, Perm>
466where
467    Repr: Representation<NBITS>,
468    Perm: PermutationStrategy<NBITS>,
469    T: bytemuck::Pod,
470{
471    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
472    /// The canonical layout is as follows:
473    ///
474    /// * `std::mem::size_of::<T>()` for the metadata coefficient.
475    /// * `Self::slice_bytes(dim)` for the underlying bit-slice.
476    ///
477    /// Returns an error if `data.len() != `Self::canonical_bytes`.
478    pub fn from_canonical_front_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
479        let expected = Self::canonical_bytes(dim);
480        if data.len() != expected {
481            Err(NotCanonical::WrongLength(expected, data.len()))
482        } else {
483            // SAFETY: We have checked the length of `data`.
484            Ok(unsafe { Self::from_canonical_front_mut_unchecked(data, dim) })
485        }
486    }
487
488    /// Construct a `VectorMut` from the raw data.
489    ///
490    /// # Safety
491    ///
492    /// * `data.len()` must be equal to `Self::canonical_bytes(dim)`.
493    ///
494    /// This invariant is checked in debug builds and will panic if not satisfied.
495    pub unsafe fn from_canonical_front_mut_unchecked(data: &'a mut [u8], dim: usize) -> Self {
496        debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
497
498        // SAFETY: The length precondition for this function guarantees the split is valid.
499        let (front, back) = unsafe { data.split_at_mut_unchecked(std::mem::size_of::<T>()) };
500
501        // SAFETY: The length precondition guarantees the bit slice is valid.
502        let bits = unsafe { MutBitSlice::new_unchecked(back, dim) };
503
504        // SAFETY: `front` points to a valid slice of `std::mem::size_of::<T>()` bytes.
505        let meta = unsafe { Mut::new(NonNull::new_unchecked(front.as_mut_ptr()).cast::<T>()) };
506        Self { bits, meta }
507    }
508
509    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
510    /// The back canonical layout is as follows:
511    ///
512    /// * `Self::slice_bytes(dim)` for the underlying bit-slice.
513    /// * `std::mem::size_of::<T>()` for the metadata coefficient.
514    ///
515    /// Returns an error if `data.len() != `Self::canonical_bytes`.
516    pub fn from_canonical_back_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
517        let len = data.len();
518        let expected = || Self::canonical_bytes(dim);
519        let (front, back) = match data.split_at_mut_checked(Self::slice_bytes(dim)) {
520            Some(v) => v,
521            None => {
522                return Err(NotCanonical::WrongLength(expected(), len));
523            }
524        };
525
526        if back.len() != std::mem::size_of::<T>() {
527            return Err(NotCanonical::WrongLength(expected(), len));
528        }
529
530        // SAFETY: Since `split_at_mut_checked` was successful, we know that the underlying
531        // slice is the correct size.
532        let bits = unsafe { MutBitSlice::new_unchecked(front, dim) };
533
534        // SAFETY: `split_at_mut_checked` was successful and `back` was checked for lenght,
535        // so `back` points to a valid slice of `std::mem::size_of::<T>()` bytes.
536        let meta = unsafe { Mut::new(NonNull::new_unchecked(back.as_mut_ptr()).cast::<T>()) };
537        Ok(Self { bits, meta })
538    }
539}
540
541///////////
542// Tests //
543///////////
544
545#[cfg(test)]
546mod tests {
547    use diskann_utils::{Reborrow, ReborrowMut};
548    use rand::{
549        Rng, SeedableRng,
550        distr::{Distribution, StandardUniform, Uniform},
551        rngs::StdRng,
552    };
553
554    use super::*;
555    use crate::bits::{BoxedBitSlice, Representation, Unsigned};
556
557    ////////////////////////
558    // Compensated Vector //
559    ////////////////////////
560
561    #[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
562    #[repr(C)]
563    struct Metadata {
564        a: u32,
565        b: u32,
566    }
567
568    impl Metadata {
569        fn new(a: u32, b: u32) -> Metadata {
570            Self { a, b }
571        }
572    }
573
574    #[test]
575    fn test_vector() {
576        let len = 20;
577        let mut base = Vector::<7, Unsigned, Metadata>::new_boxed(len);
578        assert_eq!(base.len(), len);
579        assert_eq!(base.meta(), Metadata::default());
580        assert!(!base.is_empty());
581        // Ensure that if we reborrow mutably that changes are visible.
582        {
583            let mut rb = base.reborrow_mut();
584            assert_eq!(rb.len(), len);
585            rb.set_meta(Metadata::new(1, 2));
586            let mut v = rb.vector_mut();
587
588            assert_eq!(v.len(), len);
589            for i in 0..v.len() {
590                v.set(i, i as i64).unwrap();
591            }
592        }
593
594        // Are the changes visible?
595        let expected_metadata = Metadata::new(1, 2);
596        assert_eq!(base.meta(), expected_metadata);
597        assert_eq!(base.len(), len);
598        let v = base.vector();
599        for i in 0..v.len() {
600            assert_eq!(v.get(i).unwrap(), i as i64);
601        }
602
603        // Are the changes still visible if we reborrow?
604        {
605            let rb = base.reborrow();
606            assert_eq!(rb.len(), len);
607            assert_eq!(rb.meta(), expected_metadata);
608            let v = rb.vector();
609            for i in 0..v.len() {
610                assert_eq!(v.get(i).unwrap(), i as i64);
611            }
612        }
613    }
614
615    #[test]
616    fn test_compensated_mut() {
617        let len = 30;
618        let mut v = BoxedBitSlice::<7, Unsigned>::new_boxed(len);
619        let mut m = Metadata::default();
620
621        // borrowed duration
622        let mut vector = VectorMut::new(v.reborrow_mut(), &mut m);
623        assert_eq!(vector.len(), len);
624        vector.set_meta(Metadata::new(200, 5));
625        for i in 0..vector.len() {
626            vector.vector_mut().set(i, i as i64).unwrap();
627        }
628
629        // ensure changes are visible
630        assert_eq!(m.a, 200);
631        assert_eq!(m.b, 5);
632        for i in 0..len {
633            assert_eq!(v.get(i).unwrap(), i as i64);
634        }
635    }
636
637    //////////////////////
638    // Canonicalization //
639    //////////////////////
640
641    type TestVectorRef<'a, const NBITS: usize> = VectorRef<'a, NBITS, Unsigned, Metadata>;
642    type TestVectorMut<'a, const NBITS: usize> = VectorMut<'a, NBITS, Unsigned, Metadata>;
643
644    fn check_canonicalization<const NBITS: usize, R>(dim: usize, ntrials: usize, rng: &mut R)
645    where
646        Unsigned: Representation<NBITS>,
647        R: Rng,
648    {
649        let bytes = TestVectorRef::<NBITS>::canonical_bytes(dim);
650        assert_eq!(
651            bytes,
652            std::mem::size_of::<Metadata>() + BitSlice::<NBITS, Unsigned>::bytes_for(dim)
653        );
654
655        let mut buffer_front = vec![u8::default(); bytes + std::mem::size_of::<Metadata>() + 1];
656        let mut buffer_back = vec![u8::default(); bytes + std::mem::size_of::<Metadata>() + 1];
657
658        // Expected metadata and vector encoding.
659        let mut expected = vec![i64::default(); dim];
660
661        let uniform = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
662
663        for _ in 0..ntrials {
664            let offset = Uniform::new(0, std::mem::size_of::<Metadata>())
665                .unwrap()
666                .sample(rng);
667            let a: u32 = StandardUniform.sample(rng);
668            let b: u32 = StandardUniform.sample(rng);
669
670            expected.iter_mut().for_each(|i| *i = uniform.sample(rng));
671            {
672                let set = |mut cv: TestVectorMut<NBITS>| {
673                    cv.set_meta(Metadata::new(a, b));
674                    let mut vector = cv.vector_mut();
675                    for (i, e) in expected.iter().enumerate() {
676                        vector.set(i, *e).unwrap();
677                    }
678                };
679
680                // Front
681                let cv = TestVectorMut::<NBITS>::from_canonical_front_mut(
682                    &mut buffer_front[offset..offset + bytes],
683                    dim,
684                )
685                .unwrap();
686                set(cv);
687
688                // Back
689                let cv = TestVectorMut::<NBITS>::from_canonical_back_mut(
690                    &mut buffer_back[offset..offset + bytes],
691                    dim,
692                )
693                .unwrap();
694                set(cv);
695            }
696
697            // Make sure the reconstruction is valid.
698            {
699                let check = |cv: TestVectorRef<NBITS>| {
700                    assert_eq!(cv.meta(), Metadata::new(a, b));
701                    let vector = cv.vector();
702                    for (i, e) in expected.iter().enumerate() {
703                        assert_eq!(vector.get(i).unwrap(), *e);
704                    }
705                };
706
707                let cv = TestVectorRef::<NBITS>::from_canonical_front(
708                    &buffer_front[offset..offset + bytes],
709                    dim,
710                )
711                .unwrap();
712                check(cv);
713
714                let cv = TestVectorRef::<NBITS>::from_canonical_back(
715                    &buffer_back[offset..offset + bytes],
716                    dim,
717                )
718                .unwrap();
719                check(cv);
720            }
721        }
722
723        // Check Errors - Mut
724        {
725            // Too short
726            let err = TestVectorMut::<NBITS>::from_canonical_front_mut(
727                &mut buffer_front[..bytes - 1],
728                dim,
729            )
730            .unwrap_err();
731
732            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
733
734            let err =
735                TestVectorMut::<NBITS>::from_canonical_back_mut(&mut buffer_back[..bytes - 1], dim)
736                    .unwrap_err();
737
738            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
739
740            // Empty
741            let err = TestVectorMut::<NBITS>::from_canonical_front_mut(&mut [], dim).unwrap_err();
742
743            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
744
745            let err = TestVectorMut::<NBITS>::from_canonical_back_mut(&mut [], dim).unwrap_err();
746
747            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
748
749            // Too long
750            let err = TestVectorMut::<NBITS>::from_canonical_front_mut(
751                &mut buffer_front[..bytes + 1],
752                dim,
753            )
754            .unwrap_err();
755
756            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
757
758            let err =
759                TestVectorMut::<NBITS>::from_canonical_back_mut(&mut buffer_back[..bytes + 1], dim)
760                    .unwrap_err();
761
762            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
763        }
764
765        // Check Errors - Const
766        {
767            // Too short
768            let err = TestVectorRef::<NBITS>::from_canonical_front(&buffer_front[..bytes - 1], dim)
769                .unwrap_err();
770            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
771
772            let err = TestVectorRef::<NBITS>::from_canonical_back(&buffer_back[..bytes - 1], dim)
773                .unwrap_err();
774            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
775
776            // Empty
777            let err = TestVectorRef::<NBITS>::from_canonical_front(&[], dim).unwrap_err();
778            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
779
780            let err = TestVectorRef::<NBITS>::from_canonical_back(&[], dim).unwrap_err();
781            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
782
783            // Too long
784            let err = TestVectorRef::<NBITS>::from_canonical_front(&buffer_front[..bytes + 1], dim)
785                .unwrap_err();
786            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
787
788            let err = TestVectorRef::<NBITS>::from_canonical_back(&buffer_back[..bytes + 1], dim)
789                .unwrap_err();
790            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
791        }
792    }
793
794    fn check_canonicalization_zst<const NBITS: usize, R>(dim: usize, ntrials: usize, rng: &mut R)
795    where
796        Unsigned: Representation<NBITS>,
797        R: Rng,
798    {
799        let bytes = VectorRef::<NBITS, Unsigned, ()>::canonical_bytes(dim);
800        assert_eq!(bytes, BitSlice::<NBITS, Unsigned>::bytes_for(dim));
801
802        let max_offset = 10;
803        let mut buffer_front = vec![u8::default(); bytes + max_offset];
804        let mut buffer_back = vec![u8::default(); bytes + max_offset];
805
806        // Expected metadata and vector encoding.
807        let mut expected = vec![i64::default(); dim];
808
809        let uniform = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
810
811        for _ in 0..ntrials {
812            let offset = Uniform::new(0, max_offset).unwrap().sample(rng);
813            expected.iter_mut().for_each(|i| *i = uniform.sample(rng));
814            {
815                let set = |mut cv: VectorMut<NBITS, Unsigned, ()>| {
816                    cv.set_meta(());
817                    let mut vector = cv.vector_mut();
818                    for (i, e) in expected.iter().enumerate() {
819                        vector.set(i, *e).unwrap();
820                    }
821                };
822
823                let cv = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
824                    &mut buffer_front[offset..offset + bytes],
825                    dim,
826                )
827                .unwrap();
828                set(cv);
829
830                let cv = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
831                    &mut buffer_back[offset..offset + bytes],
832                    dim,
833                )
834                .unwrap();
835                set(cv);
836            }
837
838            // Make sure the reconstruction is valid.
839            {
840                let check = |cv: VectorRef<NBITS, Unsigned, ()>| {
841                    let vector = cv.vector();
842                    for (i, e) in expected.iter().enumerate() {
843                        assert_eq!(vector.get(i).unwrap(), *e);
844                    }
845                };
846
847                let cv = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
848                    &buffer_front[offset..offset + bytes],
849                    dim,
850                )
851                .unwrap();
852                check(cv);
853
854                let cv = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
855                    &buffer_back[offset..offset + bytes],
856                    dim,
857                )
858                .unwrap();
859                check(cv);
860            }
861        }
862
863        // Check Errors - Mut
864        {
865            // Too short
866            if dim >= 1 {
867                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
868                    &mut buffer_front[..bytes - 1],
869                    dim,
870                )
871                .unwrap_err();
872                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
873
874                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
875                    &mut buffer_back[..bytes - 1],
876                    dim,
877                )
878                .unwrap_err();
879                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
880            }
881
882            // Empty
883            if dim >= 1 {
884                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(&mut [], dim)
885                    .unwrap_err();
886                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
887
888                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(&mut [], dim)
889                    .unwrap_err();
890                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
891            }
892
893            // Too long
894            {
895                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
896                    &mut buffer_front[..bytes + 1],
897                    dim,
898                )
899                .unwrap_err();
900
901                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
902
903                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
904                    &mut buffer_back[..bytes + 1],
905                    dim,
906                )
907                .unwrap_err();
908
909                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
910            }
911        }
912
913        // Check Errors - Const
914        {
915            // Too short
916            if dim >= 1 {
917                let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
918                    &buffer_front[..bytes - 1],
919                    dim,
920                )
921                .unwrap_err();
922
923                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
924
925                let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
926                    &buffer_back[..bytes - 1],
927                    dim,
928                )
929                .unwrap_err();
930
931                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
932            }
933
934            // Too long
935            let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
936                &mut buffer_front[..bytes + 1],
937                dim,
938            )
939            .unwrap_err();
940
941            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
942
943            let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
944                &mut buffer_back[..bytes + 1],
945                dim,
946            )
947            .unwrap_err();
948
949            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
950        }
951
952        // Check Errors - Const
953        {
954            // Too short
955            if dim >= 1 {
956                let err =
957                    VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(&[], dim).unwrap_err();
958
959                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
960
961                let err =
962                    VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(&[], dim).unwrap_err();
963
964                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
965            }
966
967            // Too long
968            {
969                let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
970                    &buffer_front[..bytes + 1],
971                    dim,
972                )
973                .unwrap_err();
974
975                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
976
977                let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
978                    &buffer_back[..bytes + 1],
979                    dim,
980                )
981                .unwrap_err();
982
983                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
984            }
985        }
986    }
987
988    cfg_if::cfg_if! {
989        if #[cfg(miri)] {
990            // The max dim does not need to be as high for `CompensatedVectors` because they
991            // defer their distance function implementation to `BitSlice`, which is more
992            // heavily tested.
993            const MAX_DIM: usize = 37;
994            const TRIALS_PER_DIM: usize = 1;
995        } else {
996            const MAX_DIM: usize = 256;
997            const TRIALS_PER_DIM: usize = 20;
998        }
999    }
1000
1001    macro_rules! test_canonical {
1002        ($name:ident, $nbits:literal, $seed:literal) => {
1003            #[test]
1004            fn $name() {
1005                let mut rng = StdRng::seed_from_u64($seed);
1006                for dim in 0..MAX_DIM {
1007                    check_canonicalization::<$nbits, _>(dim, TRIALS_PER_DIM, &mut rng);
1008                    check_canonicalization_zst::<$nbits, _>(dim, TRIALS_PER_DIM, &mut rng);
1009                }
1010            }
1011        };
1012    }
1013
1014    test_canonical!(canonical_8bit, 8, 0xe64518a00ee99e2f);
1015    test_canonical!(canonical_7bit, 7, 0x3907123f8c38def2);
1016    test_canonical!(canonical_6bit, 6, 0xeccaeb83965ff6a1);
1017    test_canonical!(canonical_5bit, 5, 0x9691fe59e49bfb96);
1018    test_canonical!(canonical_4bit, 4, 0xc4d3e9bc699a7e6f);
1019    test_canonical!(canonical_3bit, 3, 0x8a01b2ccdca8fb2b);
1020    test_canonical!(canonical_2bit, 2, 0x3a07429e8184b67f);
1021    test_canonical!(canonical_1bit, 1, 0x93fddb26059c115c);
1022}