diskann_quantization/meta/
vector.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::ptr::NonNull;
7
8use diskann_utils::{Reborrow, ReborrowMut};
9use thiserror::Error;
10
11use crate::{
12    alloc::{AllocatorCore, AllocatorError, GlobalAllocator, Poly},
13    bits::{
14        AsMutPtr, AsPtr, BitSlice, BitSliceBase, Dense, MutBitSlice, MutSlicePtr,
15        PermutationStrategy, Representation, SlicePtr,
16    },
17    ownership::{CopyMut, CopyRef, Mut, Owned, Ref},
18};
19
20/// A wrapper for [`BitSliceBase`] that provides the addition of arbitrary metadata.
21///
22/// # Examples
23///
24/// The `VectorBase` has several named variants that are commonly used:
25/// * [`Vector`]: An owning, independently allocated `VectorBase`.
26/// * [`VectorMut`]: A mutable, reference-like type to a `VectorBase`.
27/// * [`VectorRef`]: A const, reference-like type to a `VectorBase`.
28///
29/// ```
30/// use diskann_quantization::{
31///     meta::{Vector, VectorMut, VectorRef},
32///     bits::Unsigned,
33/// };
34///
35/// use diskann_utils::{Reborrow, ReborrowMut};
36///
37/// #[derive(Debug, Default, Clone, Copy, PartialEq)]
38/// struct Metadata {
39///     value: f32,
40/// }
41///
42/// // Create a new heap-allocated Vector for 4-bit compressions capable of
43/// // holding 3 elements.
44/// //
45/// // In this case, the associated m
46/// let mut v = Vector::<4, Unsigned, Metadata>::new_boxed(3);
47///
48/// // We can inspect the underlying bitslice.
49/// let bitslice = v.vector();
50/// assert_eq!(bitslice.get(0).unwrap(), 0);
51/// assert_eq!(bitslice.get(1).unwrap(), 0);
52/// assert_eq!(v.meta(), Metadata::default(), "expected default metadata value");
53///
54/// // If we want, we can mutably borrow the bitslice and mutate its components.
55/// let mut bitslice = v.vector_mut();
56/// bitslice.set(0, 1).unwrap();
57/// bitslice.set(1, 2).unwrap();
58/// bitslice.set(2, 3).unwrap();
59///
60/// assert!(bitslice.set(3, 4).is_err(), "out-of-bounds access");
61///
62/// // Get the underlying pointer for comparison.
63/// let ptr = bitslice.as_ptr();
64///
65/// // Vectors can be converted to a generalized reference.
66/// let mut v_ref = v.reborrow_mut();
67///
68/// // The generalized reference preserves the underlying pointer.
69/// assert_eq!(v_ref.vector().as_ptr(), ptr);
70/// let mut bitslice = v_ref.vector_mut();
71/// bitslice.set(0, 10).unwrap();
72///
73/// // Setting the underlying compensation will be visible in the original allocation.
74/// v_ref.set_meta(Metadata { value: 10.5 });
75///
76/// // Check that the changes are visible.
77/// assert_eq!(v.meta().value, 10.5);
78/// assert_eq!(v.vector().get(0).unwrap(), 10);
79///
80/// // Finally, the immutable ref also maintains pointer compatibility.
81/// let v_ref = v.reborrow();
82/// assert_eq!(v_ref.vector().as_ptr(), ptr);
83/// ```
84///
85/// ## Constructing a `VectorMut` From Components
86///
87/// The following example shows how to assemble a `VectorMut` from raw memory.
88/// ```
89/// use diskann_quantization::{bits::{Unsigned, MutBitSlice}, meta::VectorMut};
90///
91/// // Start with 2 bytes of memory. We will impose a 4-bit scalar quantization on top of
92/// // these 2 bytes.
93/// let mut data = vec![0u8; 2];
94/// let mut metadata: f32 = 0.0;
95/// {
96///     // First, we need to construct a bit-slice over the data.
97///     // This will check that it is sized properly for 4, 4-bit values.
98///     let mut slice = MutBitSlice::<4, Unsigned>::new(data.as_mut_slice(), 4).unwrap();
99///
100///     // Next, we construct the `VectorMut`.
101///     let mut v = VectorMut::new(slice, &mut metadata);
102///
103///     // Through `v`, we can set all the components in `slice` and the compensation.
104///     v.set_meta(123.4);
105///     let mut from_v = v.vector_mut();
106///     from_v.set(0, 1).unwrap();
107///     from_v.set(1, 2).unwrap();
108///     from_v.set(2, 3).unwrap();
109///     from_v.set(3, 4).unwrap();
110/// }
111///
112/// // Now we can check that the changes made internally are visible.
113/// assert_eq!(&data, &[0x21, 0x43]);
114/// assert_eq!(metadata, 123.4);
115/// ```
116///
117/// ## Canonical Layout
118///
119/// When the metadata type `T` is
120/// [`bytemuck::Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html), [`VectorRef`]
121/// and [`VectorMut`] support layout canonicalization, where a raw slice can be used as the
122/// backing store for such vectors, enabling inline storage.
123///
124/// There are two supported schems for the canonical layout, depending on whether the
125/// metadata is located at the beginning of the slice or at the end of the slice.
126///
127/// If the metadata is at the front, then the layout consists of a slice `&[u8]` where the
128/// first `std::mem::size_of::<T>()` bytes are the metadata and the remainder compose the
129/// [`BitSlice`] codes.
130///
131/// If the metadata is at the back, , then the layout consists of a slice `&[u8]` where the
132/// last `std::mem::size_of::<T>()` bytes are the metadata and the prefix is the
133/// [`BitSlice`] codes.
134///
135/// The canonical layout needs the following properties:
136///
137/// * `T: bytemuck::Pod`: For safely storing and retrieving.
138/// * The length for a vector with `N` dimensions must be equal to the value returne from
139///   [`Vector::canonical_bytes`].
140///
141/// The following functions can be used to construct [`VectorBase`]s from raw slices:
142///
143/// * [`VectorRef::from_canonical_front`]
144/// * [`VectorRef::from_canonical_back`]
145/// * [`VectorMut::from_canonical_front_mut`]
146/// * [`VectorMut::from_canonical_back_mut`]
147///
148/// An example is shown below.
149/// ```rust
150/// use diskann_quantization::{bits, meta::{Vector, VectorRef, VectorMut}};
151///
152/// type CVRef<'a, const NBITS: usize> = VectorRef<'a, NBITS, bits::Unsigned, f32>;
153/// type MutCV<'a, const NBITS: usize> = VectorMut<'a, NBITS, bits::Unsigned, f32>;
154///
155/// let dim = 3;
156///
157/// // Since we don't control the alignment of the returned pointer, we need to oversize it.
158/// let bytes = CVRef::<4>::canonical_bytes(dim);
159/// let mut data: Box<[u8]> = (0..bytes).map(|_| u8::default()).collect();
160///
161/// // Construct a mutable compensated vector over the slice.
162/// let mut mut_cv = MutCV::<4>::from_canonical_front_mut(&mut data, dim).unwrap();
163/// mut_cv.set_meta(1.0);
164/// let mut v = mut_cv.vector_mut();
165/// v.set(0, 1).unwrap();
166/// v.set(1, 2).unwrap();
167/// v.set(2, 3).unwrap();
168///
169/// // Reconstruct a constant CompensatedVector.
170/// let cv = CVRef::<4>::from_canonical_front(&data, dim).unwrap();
171/// assert_eq!(cv.meta(), 1.0);
172/// let v = cv.vector();
173/// assert_eq!(v.get(0).unwrap(), 1);
174/// assert_eq!(v.get(1).unwrap(), 2);
175/// assert_eq!(v.get(2).unwrap(), 3);
176/// ```
177#[derive(Debug, Clone, Copy)]
178pub struct VectorBase<const NBITS: usize, Repr, Ptr, T, Perm = Dense>
179where
180    Ptr: AsPtr<Type = u8>,
181    Repr: Representation<NBITS>,
182    Perm: PermutationStrategy<NBITS>,
183{
184    bits: BitSliceBase<NBITS, Repr, Ptr, Perm>,
185    meta: T,
186}
187
188impl<const NBITS: usize, Repr, Ptr, T, Perm> VectorBase<NBITS, Repr, Ptr, T, Perm>
189where
190    Ptr: AsPtr<Type = u8>,
191    Repr: Representation<NBITS>,
192    Perm: PermutationStrategy<NBITS>,
193{
194    /// Return the number of bytes required for the underlying `BitSlice`.
195    pub fn slice_bytes(count: usize) -> usize {
196        BitSliceBase::<NBITS, Repr, Ptr, Perm>::bytes_for(count)
197    }
198
199    /// Return the number of bytes required for the canonical representation of a
200    /// `Vector`.
201    ///
202    /// See: [`VectorRef::from_canonical_back`], [`VectorMut::from_canonical_back_mut`].
203    pub fn canonical_bytes(count: usize) -> usize
204    where
205        T: CopyRef,
206        T::Target: bytemuck::Pod,
207    {
208        Self::slice_bytes(count) + std::mem::size_of::<T::Target>()
209    }
210
211    /// Construct a new `VectorBase` over the bit-slice.
212    pub fn new<M>(bits: BitSliceBase<NBITS, Repr, Ptr, Perm>, meta: M) -> Self
213    where
214        M: Into<T>,
215    {
216        Self {
217            bits,
218            meta: meta.into(),
219        }
220    }
221
222    /// Return the number of dimensions of in the vector.
223    pub fn len(&self) -> usize {
224        self.bits.len()
225    }
226
227    /// Return whether or not the vector is empty.
228    pub fn is_empty(&self) -> bool {
229        self.bits.is_empty()
230    }
231
232    /// Return the metadata value for this vector.
233    pub fn meta(&self) -> T::Target
234    where
235        T: CopyRef,
236    {
237        self.meta.copy_ref()
238    }
239
240    /// Borrow the integer compressed vector.
241    pub fn vector(&self) -> BitSlice<'_, NBITS, Repr, Perm> {
242        self.bits.reborrow()
243    }
244
245    /// Mutably borrow the integer compressed vector.
246    pub fn vector_mut(&mut self) -> MutBitSlice<'_, NBITS, Repr, Perm>
247    where
248        Ptr: AsMutPtr,
249    {
250        self.bits.reborrow_mut()
251    }
252
253    /// Get a mutable reference to the metadata component.
254    ///
255    /// In addition to a mutable reference, this also requires `Ptr: AsMutPtr` to prevent
256    /// accidental misuse where the `VectorBase` is mutable but the underlying
257    /// `BitSlice` is not.
258    pub fn set_meta(&mut self, value: T::Target)
259    where
260        Ptr: AsMutPtr,
261        T: CopyMut,
262    {
263        self.meta.copy_mut(value)
264    }
265}
266
267impl<const NBITS: usize, Repr, Perm, T>
268    VectorBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Owned<T>, Perm>
269where
270    Repr: Representation<NBITS>,
271    Perm: PermutationStrategy<NBITS>,
272    T: Default,
273{
274    /// Create a new owned `VectorBase` with its metadata default initialized.
275    pub fn new_boxed(len: usize) -> Self {
276        Self {
277            bits: BitSliceBase::new_boxed(len),
278            meta: Owned::default(),
279        }
280    }
281}
282
283impl<const NBITS: usize, Repr, Perm, T, A> VectorBase<NBITS, Repr, Poly<[u8], A>, Owned<T>, Perm>
284where
285    Repr: Representation<NBITS>,
286    Perm: PermutationStrategy<NBITS>,
287    T: Default,
288    A: AllocatorCore,
289{
290    /// Create a new owned `VectorBase` with its metadata default initialized.
291    pub fn new_in(len: usize, allocator: A) -> Result<Self, AllocatorError> {
292        Ok(Self {
293            bits: BitSliceBase::new_in(len, allocator)?,
294            meta: Owned::default(),
295        })
296    }
297}
298
299/// A borrowed `Vector`.
300///
301/// See: [`VectorBase`].
302pub type VectorRef<'a, const NBITS: usize, Repr, T, Perm = Dense> =
303    VectorBase<NBITS, Repr, SlicePtr<'a, u8>, Ref<'a, T>, Perm>;
304
305/// A mutably borrowed `Vector`.
306///
307/// See: [`VectorBase`].
308pub type VectorMut<'a, const NBITS: usize, Repr, T, Perm = Dense> =
309    VectorBase<NBITS, Repr, MutSlicePtr<'a, u8>, Mut<'a, T>, Perm>;
310
311/// An owning `VectorBase`.
312///
313/// See: [`VectorBase`].
314pub type Vector<const NBITS: usize, Repr, T, Perm = Dense> =
315    VectorBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Owned<T>, Perm>;
316
317/// An owning `VectorBase`.
318///
319/// See: [`VectorBase`].
320pub type PolyVector<const NBITS: usize, Repr, T, Perm, A> =
321    VectorBase<NBITS, Repr, Poly<[u8], A>, Owned<T>, Perm>;
322
323// Reborrow
324impl<'this, const NBITS: usize, Repr, Ptr, T, Perm> Reborrow<'this>
325    for VectorBase<NBITS, Repr, Ptr, T, Perm>
326where
327    Ptr: AsPtr<Type = u8>,
328    Repr: Representation<NBITS>,
329    Perm: PermutationStrategy<NBITS>,
330    T: CopyRef + Reborrow<'this, Target = Ref<'this, <T as CopyRef>::Target>>,
331{
332    type Target = VectorRef<'this, NBITS, Repr, <T as CopyRef>::Target, Perm>;
333
334    fn reborrow(&'this self) -> Self::Target {
335        Self::Target {
336            bits: self.bits.reborrow(),
337            meta: self.meta.reborrow(),
338        }
339    }
340}
341
342// ReborrowMut
343impl<'this, const NBITS: usize, Repr, Ptr, T, Perm> ReborrowMut<'this>
344    for VectorBase<NBITS, Repr, Ptr, T, Perm>
345where
346    Ptr: AsMutPtr<Type = u8>,
347    Repr: Representation<NBITS>,
348    Perm: PermutationStrategy<NBITS>,
349    T: CopyMut + ReborrowMut<'this, Target = Mut<'this, <T as CopyRef>::Target>>,
350{
351    type Target = VectorMut<'this, NBITS, Repr, <T as CopyRef>::Target, Perm>;
352
353    fn reborrow_mut(&'this mut self) -> Self::Target {
354        Self::Target {
355            bits: self.bits.reborrow_mut(),
356            meta: self.meta.reborrow_mut(),
357        }
358    }
359}
360
361//////////////////////
362// Canonical Layout //
363//////////////////////
364
365#[derive(Debug, Error, PartialEq, Clone, Copy)]
366pub enum NotCanonical {
367    #[error("expected a slice length of {0} bytes but instead got {1} bytes")]
368    WrongLength(usize, usize),
369}
370
371impl<'a, const NBITS: usize, Repr, T, Perm> VectorRef<'a, NBITS, Repr, T, Perm>
372where
373    Repr: Representation<NBITS>,
374    Perm: PermutationStrategy<NBITS>,
375    T: bytemuck::Pod,
376{
377    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
378    /// The canonical layout is as follows:
379    ///
380    /// * `std::mem::size_of::<T>()` for the metadata coefficient.
381    /// * `Self::slice_bytes(dim)` for the underlying bit-slice.
382    ///
383    /// Returns an error if `data.len() != `Self::canonical_bytes`.
384    pub fn from_canonical_front(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
385        let expected = Self::canonical_bytes(dim);
386        if data.len() != expected {
387            Err(NotCanonical::WrongLength(expected, data.len()))
388        } else {
389            // SAFETY: We have checked both the length and alignment of `data`.
390            Ok(unsafe { Self::from_canonical_unchecked(data, dim) })
391        }
392    }
393
394    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
395    /// The back canonical layout is as follows:
396    ///
397    /// * `Self::slice_bytes(dim)` for the underlying bit-slice.
398    /// * `std::mem::size_of::<T>()` for the metadata coefficient.
399    ///
400    /// Returns an error if `data.len() != `Self::canonical_bytes`.
401    pub fn from_canonical_back(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
402        let expected = Self::canonical_bytes(dim);
403        if data.len() != expected {
404            Err(NotCanonical::WrongLength(expected, data.len()))
405        } else {
406            // SAFETY: We have checked both the length and alignment of `data`.
407            Ok(unsafe { Self::from_canonical_back_unchecked(data, dim) })
408        }
409    }
410
411    /// Construct a `VectorRef` from the raw data.
412    ///
413    /// # Safety
414    ///
415    /// * `data.len()` must be equal to `Self::canonical_bytes(dim)`.
416    ///
417    /// This invariant is checked in debug builds and will panic if not satisfied.
418    pub unsafe fn from_canonical_unchecked(data: &'a [u8], dim: usize) -> Self {
419        debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
420
421        // SAFETY: `BitSlice` has no alignment requirements, but the length precondition
422        // for this function (i.e., `data.len() == Self::canonical_bytes(dim)`) implies
423        // that `Self::slice_bytes(dim)` is valid beginning at an offset of
424        // `std::mem::size_of::<T>()`.
425        let bits =
426            unsafe { BitSlice::new_unchecked(data.get_unchecked(std::mem::size_of::<T>()..), dim) };
427
428        // SAFETY: The pointer is valid and non-null because `data` is a slice, its length
429        // must be at least `std::mem::size_of::<T>()` (from the length precondition for
430        // this function).
431        let meta =
432            unsafe { Ref::new(NonNull::new_unchecked(data.as_ptr().cast_mut()).cast::<T>()) };
433        Self { bits, meta }
434    }
435
436    /// Construct a `VectorRef` from the raw data.
437    ///
438    /// # Safety
439    ///
440    /// * `data.len()` must be equal to `Self::canonical_bytes(dim)`.
441    ///
442    /// This invariant is checked in debug builds and will panic if not satisfied.
443    pub unsafe fn from_canonical_back_unchecked(data: &'a [u8], dim: usize) -> Self {
444        debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
445        // SAFETY: The caller asserts that
446        // `data.len() == Self::canonical_bytes(dim) >= std::mem::size_of::<T>()`.
447        let (data, meta) =
448            unsafe { data.split_at_unchecked(data.len() - std::mem::size_of::<T>()) };
449
450        // SAFETY: `BitSlice` has no alignment requirements, but the length precondition
451        // for this function (i.e., `data.len() == Self::canonical_bytes(dim)`) implies
452        // that `Self::slice_bytes(dim)` is valid beginning at an offset of
453        // `std::mem::size_of::<T>()`.
454        let bits = unsafe { BitSlice::new_unchecked(data, dim) };
455
456        // SAFETY: The pointer is valid and non-null because `data` is a slice, its length
457        // must be at least `std::mem::size_of::<T>()` (from the length precondition for
458        // this function).
459        let meta =
460            unsafe { Ref::new(NonNull::new_unchecked(meta.as_ptr().cast_mut()).cast::<T>()) };
461        Self { bits, meta }
462    }
463}
464
465impl<'a, const NBITS: usize, Repr, T, Perm> VectorMut<'a, NBITS, Repr, T, Perm>
466where
467    Repr: Representation<NBITS>,
468    Perm: PermutationStrategy<NBITS>,
469    T: bytemuck::Pod,
470{
471    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
472    /// The canonical layout is as follows:
473    ///
474    /// * `std::mem::size_of::<T>()` for the metadata coefficient.
475    /// * `Self::slice_bytes(dim)` for the underlying bit-slice.
476    ///
477    /// Returns an error if `data.len() != `Self::canonical_bytes`.
478    pub fn from_canonical_front_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
479        let expected = Self::canonical_bytes(dim);
480        let bytes = data.len();
481        let (front, back) = match data.split_at_mut_checked(std::mem::size_of::<T>()) {
482            Some(v) => v,
483            None => {
484                return Err(NotCanonical::WrongLength(expected, bytes));
485            }
486        };
487
488        let bits =
489            MutBitSlice::new(back, dim).map_err(|_| NotCanonical::WrongLength(expected, bytes))?;
490
491        // SAFETY: `split_at_mut_checked` was successful, so `front` points to a valid
492        // slice of `std::mem::size_of::<T>()` bytes. Further, we have verified that the
493        // base pointer for `front` is properly aligned to `std::mem::align_of::<T>()`, so
494        // we can safely construct a reference to a `T` from the pointer returned by
495        // `front.as_ptr_mut()`.
496        let meta = unsafe { Mut::new(NonNull::new_unchecked(front.as_mut_ptr()).cast::<T>()) };
497        Ok(Self { bits, meta })
498    }
499
500    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
501    /// The back canonical layout is as follows:
502    ///
503    /// * `Self::slice_bytes(dim)` for the underlying bit-slice.
504    /// * `std::mem::size_of::<T>()` for the metadata coefficient.
505    ///
506    /// Returns an error if `data.len() != `Self::canonical_bytes`.
507    pub fn from_canonical_back_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
508        let len = data.len();
509        let expected = || Self::canonical_bytes(dim);
510        let (front, back) = match data.split_at_mut_checked(Self::slice_bytes(dim)) {
511            Some(v) => v,
512            None => {
513                return Err(NotCanonical::WrongLength(expected(), len));
514            }
515        };
516
517        if back.len() != std::mem::size_of::<T>() {
518            return Err(NotCanonical::WrongLength(expected(), len));
519        }
520
521        // SAFETY: Since `split_at_mut_checked` was successful, we know that the underlying
522        // slice is the correct size.
523        let bits = unsafe { MutBitSlice::new_unchecked(front, dim) };
524
525        // SAFETY: `split_at_mut_checked` was successful and `back` was checked for lenght,
526        // so `back` points to a valid slice of `std::mem::size_of::<T>()` bytes.
527        let meta = unsafe { Mut::new(NonNull::new_unchecked(back.as_mut_ptr()).cast::<T>()) };
528        Ok(Self { bits, meta })
529    }
530}
531
532///////////
533// Tests //
534///////////
535
536#[cfg(test)]
537mod tests {
538    use diskann_utils::{Reborrow, ReborrowMut};
539    use rand::{
540        distr::{Distribution, StandardUniform, Uniform},
541        rngs::StdRng,
542        Rng, SeedableRng,
543    };
544
545    use super::*;
546    use crate::bits::{BoxedBitSlice, Representation, Unsigned};
547
548    ////////////////////////
549    // Compensated Vector //
550    ////////////////////////
551
552    #[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
553    #[repr(C)]
554    struct Metadata {
555        a: u32,
556        b: u32,
557    }
558
559    impl Metadata {
560        fn new(a: u32, b: u32) -> Metadata {
561            Self { a, b }
562        }
563    }
564
565    #[test]
566    fn test_vector() {
567        let len = 20;
568        let mut base = Vector::<7, Unsigned, Metadata>::new_boxed(len);
569        assert_eq!(base.len(), len);
570        assert_eq!(base.meta(), Metadata::default());
571        assert!(!base.is_empty());
572        // Ensure that if we reborrow mutably that changes are visible.
573        {
574            let mut rb = base.reborrow_mut();
575            assert_eq!(rb.len(), len);
576            rb.set_meta(Metadata::new(1, 2));
577            let mut v = rb.vector_mut();
578
579            assert_eq!(v.len(), len);
580            for i in 0..v.len() {
581                v.set(i, i as i64).unwrap();
582            }
583        }
584
585        // Are the changes visible?
586        let expected_metadata = Metadata::new(1, 2);
587        assert_eq!(base.meta(), expected_metadata);
588        assert_eq!(base.len(), len);
589        let v = base.vector();
590        for i in 0..v.len() {
591            assert_eq!(v.get(i).unwrap(), i as i64);
592        }
593
594        // Are the changes still visible if we reborrow?
595        {
596            let rb = base.reborrow();
597            assert_eq!(rb.len(), len);
598            assert_eq!(rb.meta(), expected_metadata);
599            let v = rb.vector();
600            for i in 0..v.len() {
601                assert_eq!(v.get(i).unwrap(), i as i64);
602            }
603        }
604    }
605
606    #[test]
607    fn test_compensated_mut() {
608        let len = 30;
609        let mut v = BoxedBitSlice::<7, Unsigned>::new_boxed(len);
610        let mut m = Metadata::default();
611
612        // borrowed duration
613        let mut vector = VectorMut::new(v.reborrow_mut(), &mut m);
614        assert_eq!(vector.len(), len);
615        vector.set_meta(Metadata::new(200, 5));
616        for i in 0..vector.len() {
617            vector.vector_mut().set(i, i as i64).unwrap();
618        }
619
620        // ensure changes are visible
621        assert_eq!(m.a, 200);
622        assert_eq!(m.b, 5);
623        for i in 0..len {
624            assert_eq!(v.get(i).unwrap(), i as i64);
625        }
626    }
627
628    //////////////////////
629    // Canonicalization //
630    //////////////////////
631
632    type TestVectorRef<'a, const NBITS: usize> = VectorRef<'a, NBITS, Unsigned, Metadata>;
633    type TestVectorMut<'a, const NBITS: usize> = VectorMut<'a, NBITS, Unsigned, Metadata>;
634
635    fn check_canonicalization<const NBITS: usize, R>(dim: usize, ntrials: usize, rng: &mut R)
636    where
637        Unsigned: Representation<NBITS>,
638        R: Rng,
639    {
640        let bytes = TestVectorRef::<NBITS>::canonical_bytes(dim);
641        assert_eq!(
642            bytes,
643            std::mem::size_of::<Metadata>() + BitSlice::<NBITS, Unsigned>::bytes_for(dim)
644        );
645
646        let mut buffer_front = vec![u8::default(); bytes + std::mem::size_of::<Metadata>() + 1];
647        let mut buffer_back = vec![u8::default(); bytes + std::mem::size_of::<Metadata>() + 1];
648
649        // Expected metadata and vector encoding.
650        let mut expected = vec![i64::default(); dim];
651
652        let uniform = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
653
654        for _ in 0..ntrials {
655            let offset = Uniform::new(0, std::mem::size_of::<Metadata>())
656                .unwrap()
657                .sample(rng);
658            let a: u32 = StandardUniform.sample(rng);
659            let b: u32 = StandardUniform.sample(rng);
660
661            expected.iter_mut().for_each(|i| *i = uniform.sample(rng));
662            {
663                let set = |mut cv: TestVectorMut<NBITS>| {
664                    cv.set_meta(Metadata::new(a, b));
665                    let mut vector = cv.vector_mut();
666                    for (i, e) in expected.iter().enumerate() {
667                        vector.set(i, *e).unwrap();
668                    }
669                };
670
671                // Front
672                let cv = TestVectorMut::<NBITS>::from_canonical_front_mut(
673                    &mut buffer_front[offset..offset + bytes],
674                    dim,
675                )
676                .unwrap();
677                set(cv);
678
679                // Back
680                let cv = TestVectorMut::<NBITS>::from_canonical_back_mut(
681                    &mut buffer_back[offset..offset + bytes],
682                    dim,
683                )
684                .unwrap();
685                set(cv);
686            }
687
688            // Make sure the reconstruction is valid.
689            {
690                let check = |cv: TestVectorRef<NBITS>| {
691                    assert_eq!(cv.meta(), Metadata::new(a, b));
692                    let vector = cv.vector();
693                    for (i, e) in expected.iter().enumerate() {
694                        assert_eq!(vector.get(i).unwrap(), *e);
695                    }
696                };
697
698                let cv = TestVectorRef::<NBITS>::from_canonical_front(
699                    &buffer_front[offset..offset + bytes],
700                    dim,
701                )
702                .unwrap();
703                check(cv);
704
705                let cv = TestVectorRef::<NBITS>::from_canonical_back(
706                    &buffer_back[offset..offset + bytes],
707                    dim,
708                )
709                .unwrap();
710                check(cv);
711            }
712        }
713
714        // Check Errors - Mut
715        {
716            // Too short
717            let err = TestVectorMut::<NBITS>::from_canonical_front_mut(
718                &mut buffer_front[..bytes - 1],
719                dim,
720            )
721            .unwrap_err();
722
723            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
724
725            let err =
726                TestVectorMut::<NBITS>::from_canonical_back_mut(&mut buffer_back[..bytes - 1], dim)
727                    .unwrap_err();
728
729            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
730
731            // Empty
732            let err = TestVectorMut::<NBITS>::from_canonical_front_mut(&mut [], dim).unwrap_err();
733
734            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
735
736            let err = TestVectorMut::<NBITS>::from_canonical_back_mut(&mut [], dim).unwrap_err();
737
738            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
739
740            // Too long
741            let err = TestVectorMut::<NBITS>::from_canonical_front_mut(
742                &mut buffer_front[..bytes + 1],
743                dim,
744            )
745            .unwrap_err();
746
747            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
748
749            let err =
750                TestVectorMut::<NBITS>::from_canonical_back_mut(&mut buffer_back[..bytes + 1], dim)
751                    .unwrap_err();
752
753            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
754        }
755
756        // Check Errors - Const
757        {
758            // Too short
759            let err = TestVectorRef::<NBITS>::from_canonical_front(&buffer_front[..bytes - 1], dim)
760                .unwrap_err();
761            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
762
763            let err = TestVectorRef::<NBITS>::from_canonical_back(&buffer_back[..bytes - 1], dim)
764                .unwrap_err();
765            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
766
767            // Empty
768            let err = TestVectorRef::<NBITS>::from_canonical_front(&[], dim).unwrap_err();
769            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
770
771            let err = TestVectorRef::<NBITS>::from_canonical_back(&[], dim).unwrap_err();
772            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
773
774            // Too long
775            let err = TestVectorRef::<NBITS>::from_canonical_front(&buffer_front[..bytes + 1], dim)
776                .unwrap_err();
777            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
778
779            let err = TestVectorRef::<NBITS>::from_canonical_back(&buffer_back[..bytes + 1], dim)
780                .unwrap_err();
781            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
782        }
783    }
784
785    fn check_canonicalization_zst<const NBITS: usize, R>(dim: usize, ntrials: usize, rng: &mut R)
786    where
787        Unsigned: Representation<NBITS>,
788        R: Rng,
789    {
790        let bytes = VectorRef::<NBITS, Unsigned, ()>::canonical_bytes(dim);
791        assert_eq!(bytes, BitSlice::<NBITS, Unsigned>::bytes_for(dim));
792
793        let max_offset = 10;
794        let mut buffer_front = vec![u8::default(); bytes + max_offset];
795        let mut buffer_back = vec![u8::default(); bytes + max_offset];
796
797        // Expected metadata and vector encoding.
798        let mut expected = vec![i64::default(); dim];
799
800        let uniform = Uniform::try_from(Unsigned::domain_const::<NBITS>()).unwrap();
801
802        for _ in 0..ntrials {
803            let offset = Uniform::new(0, max_offset).unwrap().sample(rng);
804            expected.iter_mut().for_each(|i| *i = uniform.sample(rng));
805            {
806                let set = |mut cv: VectorMut<NBITS, Unsigned, ()>| {
807                    cv.set_meta(());
808                    let mut vector = cv.vector_mut();
809                    for (i, e) in expected.iter().enumerate() {
810                        vector.set(i, *e).unwrap();
811                    }
812                };
813
814                let cv = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
815                    &mut buffer_front[offset..offset + bytes],
816                    dim,
817                )
818                .unwrap();
819                set(cv);
820
821                let cv = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
822                    &mut buffer_back[offset..offset + bytes],
823                    dim,
824                )
825                .unwrap();
826                set(cv);
827            }
828
829            // Make sure the reconstruction is valid.
830            {
831                let check = |cv: VectorRef<NBITS, Unsigned, ()>| {
832                    let vector = cv.vector();
833                    for (i, e) in expected.iter().enumerate() {
834                        assert_eq!(vector.get(i).unwrap(), *e);
835                    }
836                };
837
838                let cv = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
839                    &buffer_front[offset..offset + bytes],
840                    dim,
841                )
842                .unwrap();
843                check(cv);
844
845                let cv = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
846                    &buffer_back[offset..offset + bytes],
847                    dim,
848                )
849                .unwrap();
850                check(cv);
851            }
852        }
853
854        // Check Errors - Mut
855        {
856            // Too short
857            if dim >= 1 {
858                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
859                    &mut buffer_front[..bytes - 1],
860                    dim,
861                )
862                .unwrap_err();
863                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
864
865                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
866                    &mut buffer_back[..bytes - 1],
867                    dim,
868                )
869                .unwrap_err();
870                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
871            }
872
873            // Empty
874            if dim >= 1 {
875                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(&mut [], dim)
876                    .unwrap_err();
877                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
878
879                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(&mut [], dim)
880                    .unwrap_err();
881                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
882            }
883
884            // Too long
885            {
886                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
887                    &mut buffer_front[..bytes + 1],
888                    dim,
889                )
890                .unwrap_err();
891
892                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
893
894                let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
895                    &mut buffer_back[..bytes + 1],
896                    dim,
897                )
898                .unwrap_err();
899
900                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
901            }
902        }
903
904        // Check Errors - Const
905        {
906            // Too short
907            if dim >= 1 {
908                let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
909                    &buffer_front[..bytes - 1],
910                    dim,
911                )
912                .unwrap_err();
913
914                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
915
916                let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
917                    &buffer_back[..bytes - 1],
918                    dim,
919                )
920                .unwrap_err();
921
922                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
923            }
924
925            // Too long
926            let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_front_mut(
927                &mut buffer_front[..bytes + 1],
928                dim,
929            )
930            .unwrap_err();
931
932            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
933
934            let err = VectorMut::<NBITS, Unsigned, ()>::from_canonical_back_mut(
935                &mut buffer_back[..bytes + 1],
936                dim,
937            )
938            .unwrap_err();
939
940            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
941        }
942
943        // Check Errors - Const
944        {
945            // Too short
946            if dim >= 1 {
947                let err =
948                    VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(&[], dim).unwrap_err();
949
950                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
951
952                let err =
953                    VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(&[], dim).unwrap_err();
954
955                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
956            }
957
958            // Too long
959            {
960                let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_front(
961                    &buffer_front[..bytes + 1],
962                    dim,
963                )
964                .unwrap_err();
965
966                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
967
968                let err = VectorRef::<NBITS, Unsigned, ()>::from_canonical_back(
969                    &buffer_back[..bytes + 1],
970                    dim,
971                )
972                .unwrap_err();
973
974                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
975            }
976        }
977    }
978
979    cfg_if::cfg_if! {
980        if #[cfg(miri)] {
981            // The max dim does not need to be as high for `CompensatedVectors` because they
982            // defer their distance function implementation to `BitSlice`, which is more
983            // heavily tested.
984            const MAX_DIM: usize = 37;
985            const TRIALS_PER_DIM: usize = 1;
986        } else {
987            const MAX_DIM: usize = 256;
988            const TRIALS_PER_DIM: usize = 20;
989        }
990    }
991
992    macro_rules! test_canonical {
993        ($name:ident, $nbits:literal, $seed:literal) => {
994            #[test]
995            fn $name() {
996                let mut rng = StdRng::seed_from_u64($seed);
997                for dim in 0..MAX_DIM {
998                    check_canonicalization::<$nbits, _>(dim, TRIALS_PER_DIM, &mut rng);
999                    check_canonicalization_zst::<$nbits, _>(dim, TRIALS_PER_DIM, &mut rng);
1000                }
1001            }
1002        };
1003    }
1004
1005    test_canonical!(canonical_8bit, 8, 0xe64518a00ee99e2f);
1006    test_canonical!(canonical_7bit, 7, 0x3907123f8c38def2);
1007    test_canonical!(canonical_6bit, 6, 0xeccaeb83965ff6a1);
1008    test_canonical!(canonical_5bit, 5, 0x9691fe59e49bfb96);
1009    test_canonical!(canonical_4bit, 4, 0xc4d3e9bc699a7e6f);
1010    test_canonical!(canonical_3bit, 3, 0x8a01b2ccdca8fb2b);
1011    test_canonical!(canonical_2bit, 2, 0x3a07429e8184b67f);
1012    test_canonical!(canonical_1bit, 1, 0x93fddb26059c115c);
1013}