diskann_quantization/meta/
slice.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    ops::{Deref, DerefMut},
8    ptr::NonNull,
9};
10
11use thiserror::Error;
12
13use crate::{
14    alloc::{AllocatorCore, AllocatorError, Poly},
15    num::PowerOfTwo,
16    ownership::{Mut, Owned, Ref},
17};
18
19/// A wrapper for a traditional Rust slice that provides the addition of arbitrary metadata.
20///
21/// # Examples
22///
23/// The `Slice` has several named variants that should be used instead of `Slice` directly:
24/// * [`PolySlice`]: An owning, independently allocated `Slice`.
25/// * [`SliceMut`]: A mutable, reference-like type.
26/// * [`SliceRef`]: A const, reference-like type.
27///
28/// ```
29/// use diskann_quantization::{
30///     alloc::GlobalAllocator,
31///     meta::slice,
32///     bits::Unsigned,
33/// };
34///
35/// use diskann_utils::{Reborrow, ReborrowMut};
36///
37/// #[derive(Debug, Default, Clone, Copy, PartialEq)]
38/// struct Metadata {
39///     value: f32,
40/// }
41///
42/// // Create a new heap-allocated Vector for 4-bit compressions capable of
43/// // holding 3 elements.
44/// //
45/// // In this case, the associated m
46/// let mut v = slice::PolySlice::new_in(3, GlobalAllocator).unwrap();
47///
48/// // We can inspect the underlying bitslice.
49/// let data = v.vector();
50/// assert_eq!(&data, &[0, 0, 0]);
51/// assert_eq!(*v.meta(), Metadata::default(), "expected default metadata value");
52///
53/// // If we want, we can mutably borrow the bitslice and mutate its components.
54/// let mut data = v.vector_mut();
55/// assert_eq!(data.len(), 3);
56/// data[0] = 1;
57/// data[1] = 2;
58/// data[2] = 3;
59///
60/// // Setting the underlying compensation will be visible in the original allocation.
61/// *v.meta_mut() = Metadata { value: 10.5 };
62///
63/// // Check that the changes are visible.
64/// assert_eq!(v.meta().value, 10.5);
65/// assert_eq!(&v.vector(), &[1, 2, 3]);
66/// ```
67///
68/// ## Constructing a `SliceMut` From Components
69///
70/// The following example shows how to assemble a `SliceMut` from raw parts.
71/// ```
72/// use diskann_quantization::meta::slice;
73///
74/// // For exposition purposes, we will use a slice of `u8` and `f32` as the metadata.
75/// let mut data = vec![0u8; 4];
76/// let mut metadata: f32 = 0.0;
77/// {
78///     let mut v = slice::SliceMut::new(data.as_mut_slice(), &mut metadata);
79///
80///     // Through `v`, we can set all the components in `slice` and the compensation.
81///     *v.meta_mut() = 123.4;
82///     let mut data = v.vector_mut();
83///     data[0] = 1;
84///     data[1] = 2;
85///     data[2] = 3;
86///     data[3] = 4;
87/// }
88///
89/// // Now we can check that the changes made internally are visible.
90/// assert_eq!(&data, &[1, 2, 3, 4]);
91/// assert_eq!(metadata, 123.4);
92/// ```
93///
94/// ## Canonical Layout
95///
96/// When the slice element type `T` and metadata type `M` are both
97/// [`bytemuck::Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html), [`SliceRef`]
98/// and [`SliceMut`] support layout canonicalization, where a raw slice can be used as the
99/// backing store for such vectors, enabling inline storage.
100///
101/// The layout is specified by:
102///
103/// * A base alignment of the maximum alignments of `T` and `M`.
104/// * The first `M` bytes contain the metadata.
105/// * Padding if necessary to reach the alignment of `T`.
106/// * The values of type `T` stored contiguously.
107///
108/// The canonical layout needs the following properties:
109///
110/// * `T: bytemuck::Pod` and `M: bytemuck::Pod: For safely storing and retrieving.
111/// * The length for a vector with `N` dimensions must be equal to the value returned
112///   from [`SliceRef::canonical_bytes`].
113/// * The **alignment** of the base pointer must be equal to [`SliceRef::canonical_align()`].
114///
115/// The following functions can be used to construct slices from raw slices:
116///
117/// * [`SliceRef::from_canonical`]
118/// * [`SliceMut::from_canonical_mut`]
119///
120/// An example is shown below.
121/// ```rust
122/// use diskann_quantization::{
123///     alloc::{AlignedAllocator, Poly},
124///     meta::slice,
125///     num::PowerOfTwo,
126/// };
127///
128/// let dim = 3;
129///
130/// // Since we don't control the alignment of the returned pointer, we need to oversize it.
131/// let bytes = slice::SliceRef::<u16, f32>::canonical_bytes(dim);
132/// let align = slice::SliceRef::<u16, f32>::canonical_align();
133/// let mut data = Poly::broadcast(
134///     0u8,
135///     bytes,
136///     AlignedAllocator::new(align)
137/// ).unwrap();
138///
139/// // Construct a mutable compensated vector over the slice.
140/// let mut v = slice::SliceMut::<u16, f32>::from_canonical_mut(&mut data, dim).unwrap();
141/// *v.meta_mut() = 1.0;
142/// v.vector_mut().copy_from_slice(&[1, 2, 3]);
143///
144/// // Reconstruct a constant CompensatedVector.
145/// let cv = slice::SliceRef::<u16, f32>::from_canonical(&data, dim).unwrap();
146/// assert_eq!(*cv.meta(), 1.0);
147/// assert_eq!(&cv.vector(), &[1, 2, 3]);
148/// ```
149#[derive(Debug, Clone, Copy)]
150pub struct Slice<T, M> {
151    slice: T,
152    meta: M,
153}
154
155// Use the maximum alignment of `T` and `M` to ensure that no runtime padding is needed.
156//
157// For example, if `T` had a stricter alignment than `M` and we required an alignment of
158// `M`, then the number of padding bytes necessary would depend on the runtime alignment
159// of `M`, which is pretty useless for a storage format.
160const fn canonical_align<T, M>() -> PowerOfTwo {
161    let m_align = PowerOfTwo::alignment_of::<M>();
162    let t_align = PowerOfTwo::alignment_of::<T>();
163
164    // Poor man's `const`-compatible `max`.
165    if m_align.raw() > t_align.raw() {
166        m_align
167    } else {
168        t_align
169    }
170}
171
172// The number of bytes required for the metadata prefix. This will consist of the bytes
173// required for `M` as well as any padding to obtain an alignment of `T`.
174//
175// If `M` is a zero-sized type, then the return value is zero. This works because the base
176// alignment is at least the alignment of `T`, so no padding is necessary.
177const fn canonical_metadata_bytes<T, M>() -> usize {
178    let m_size = std::mem::size_of::<M>();
179    if m_size == 0 {
180        0
181    } else {
182        m_size.next_multiple_of(std::mem::align_of::<T>())
183    }
184}
185
186// A simple computation consisting of the bytes for the metadata, followed by the bytes
187// needed for the slice itself.
188const fn canonical_bytes<T, M>(count: usize) -> usize {
189    canonical_metadata_bytes::<T, M>() + std::mem::size_of::<T>() * count
190}
191
192impl<T, M> Slice<T, M> {
193    /// Construct a new `Slice` over the components.
194    pub fn new<U>(slice: T, meta: U) -> Self
195    where
196        U: Into<M>,
197    {
198        Self {
199            slice,
200            meta: meta.into(),
201        }
202    }
203
204    /// Return the metadata value for this vector.
205    pub fn meta(&self) -> &M::Target
206    where
207        M: Deref,
208    {
209        &self.meta
210    }
211
212    /// Get a mutable reference to the metadata component.
213    pub fn meta_mut(&mut self) -> &mut M::Target
214    where
215        M: DerefMut,
216    {
217        &mut self.meta
218    }
219}
220
221impl<T, M, U, V> Slice<T, M>
222where
223    T: Deref<Target = [U]>,
224    M: Deref<Target = V>,
225{
226    /// Return the number of dimensions of in the slice
227    pub fn len(&self) -> usize {
228        self.slice.len()
229    }
230
231    /// Return whether or not the vector is empty.
232    pub fn is_empty(&self) -> bool {
233        self.slice.is_empty()
234    }
235
236    /// Borrow the data slice.
237    pub fn vector(&self) -> &[U] {
238        &self.slice
239    }
240
241    /// Borrow the integer compressed vector.
242    pub fn vector_mut(&mut self) -> &mut [U]
243    where
244        T: DerefMut,
245    {
246        &mut self.slice
247    }
248
249    /// Return the necessary alignment for the base pointer required for
250    /// [`SliceRef::from_canonical`] and [`SliceMut::from_canonical_mut`].
251    ///
252    /// The return value is guaranteed to be a power of two.
253    pub const fn canonical_align() -> PowerOfTwo {
254        canonical_align::<U, V>()
255    }
256
257    /// Return the number of bytes required to store `count` elements plus metadata in a
258    /// canonical layout.
259    ///
260    /// See: [`SliceRef::from_canonical`], [`SliceMut::from_canonical_mut`].
261    pub const fn canonical_bytes(count: usize) -> usize {
262        canonical_bytes::<U, V>(count)
263    }
264}
265
266impl<T, A, M> Slice<Poly<[T], A>, Owned<M>>
267where
268    A: AllocatorCore,
269    T: Default,
270    M: Default,
271{
272    /// Create a new owned `VectorBase` with its metadata default initialized.
273    pub fn new_in(len: usize, allocator: A) -> Result<Self, AllocatorError> {
274        Ok(Self {
275            slice: Poly::from_iter((0..len).map(|_| T::default()), allocator)?,
276            meta: Owned::default(),
277        })
278    }
279}
280
281/// A reference to a slice and associated metadata.
282pub type SliceRef<'a, T, M> = Slice<&'a [T], Ref<'a, M>>;
283
284/// A mutable reference to a slice and associated metadata.
285pub type SliceMut<'a, T, M> = Slice<&'a mut [T], Mut<'a, M>>;
286
287/// An owning slice and associated metadata.
288pub type PolySlice<T, M, A> = Slice<Poly<[T], A>, Owned<M>>;
289
290//////////////////////
291// Canonical Layout //
292//////////////////////
293
294#[derive(Debug, Error, PartialEq, Clone, Copy)]
295pub enum NotCanonical {
296    #[error("expected a slice length of {0} bytes but instead got {1} bytes")]
297    WrongLength(usize, usize),
298    #[error("expected a base pointer alignment of at least {0}")]
299    NotAligned(usize),
300}
301
302impl<'a, T, M> SliceRef<'a, T, M>
303where
304    T: bytemuck::Pod,
305    M: bytemuck::Pod,
306{
307    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
308    /// The canonical layout is as follows:
309    ///
310    /// * `std::mem::size_of::<T>().max(std::mem::size_of::<M>())` for the metadata.
311    /// * Necessary additional padding to achieve the alignment requirements for `T`.
312    /// * `std::mem::size_of::<T>() * dim` for the slice.
313    ///
314    /// Returns an error if:
315    ///
316    /// * `data` is not aligned to `Self::canonical_align()`.
317    /// * `data.len() != `Self::canonical_bytes(dim)`.
318    pub fn from_canonical(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
319        let expected_align = Self::canonical_align().raw();
320        let expected_len = Self::canonical_bytes(dim);
321
322        if !(data.as_ptr() as usize).is_multiple_of(expected_align) {
323            Err(NotCanonical::NotAligned(expected_align))
324        } else if data.len() != expected_len {
325            Err(NotCanonical::WrongLength(expected_len, data.len()))
326        } else {
327            // SAFETY: We have checked both the length and alignment of `data`.
328            Ok(unsafe { Self::from_canonical_unchecked(data, dim) })
329        }
330    }
331
332    /// Construct a `VectorRef` from the raw data.
333    ///
334    /// # Safety
335    ///
336    /// * `data.as_ptr()` must be aligned to `Self::canonical_align()`.
337    /// * `data.len()` must be equal to `Self::canonical_bytes(dim)`.
338    ///
339    /// This invariant is checked in debug builds and will panic if not satisfied.
340    pub unsafe fn from_canonical_unchecked(data: &'a [u8], dim: usize) -> Self {
341        debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
342        let offset = canonical_metadata_bytes::<T, M>();
343
344        // SAFETY: The length pre-condition of this function implies that the offset region
345        // `[offset, offset + size_of::<T>() * dim]` is valid for reading.
346        //
347        // Additionally, the alignment requirment of the base pointer ensures that after
348        // applying `offset`, we still have proper alignment for `T`.
349        //
350        // The `bytemuck::Pod` bound ensures we don't have malformed types after the type cast.
351        let slice =
352            unsafe { std::slice::from_raw_parts(data.as_ptr().add(offset).cast::<T>(), dim) };
353
354        // SAFETY: The pointer is valid and non-null because `data` is a slice, its length
355        // must be at least `std::mem::size_of::<M>()` (from the length precondition for
356        // this function).
357        //
358        // The alignemnt pre-condition ensures that the pointer is suitable aligned.
359        //
360        // THe `bytemuck::Pod` bound ensures that the resulting type is valid.
361        let meta =
362            unsafe { Ref::new(NonNull::new_unchecked(data.as_ptr().cast_mut()).cast::<M>()) };
363        Self { slice, meta }
364    }
365}
366
367impl<'a, T, M> SliceMut<'a, T, M>
368where
369    T: bytemuck::Pod,
370    M: bytemuck::Pod,
371{
372    /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
373    /// The canonical layout is as follows:
374    ///
375    /// * `std::mem::size_of::<T>().max(std::mem::size_of::<M>())` for the metadata.
376    /// * Necessary additional padding to achieve the alignment requirements for `T`.
377    /// * `std::mem::size_of::<T>() * dim` for the slice.
378    ///
379    /// Returns an error if:
380    ///
381    /// * `data` is not aligned to `Self::canonical_align()`.
382    /// * `data.len() != `Self::canonical_bytes(dim)`.
383    pub fn from_canonical_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
384        let expected_align = Self::canonical_align().raw();
385        let expected_len = Self::canonical_bytes(dim);
386
387        if !(data.as_ptr() as usize).is_multiple_of(expected_align) {
388            return Err(NotCanonical::NotAligned(expected_align));
389        } else if data.len() != expected_len {
390            return Err(NotCanonical::WrongLength(expected_len, data.len()));
391        }
392
393        let offset = canonical_metadata_bytes::<T, M>();
394
395        // SAFETY: `offset < expected_len` and `data.len() == expected_len`, so `offset`
396        // is a valid interior offset for `data`.
397        let (meta, slice) = unsafe { data.split_at_mut_unchecked(offset) };
398
399        // SAFETY: `data.as_ptr()` when offset by `offset` will have an alignment suitable
400        // for type `T`.
401        //
402        // We have checked that `data.len() == expected_len`, which implies that the region
403        // of memory between `offset` and `data.len()` covers exactly `size_of::<T>() * dim`
404        // bytes.
405        //
406        // The `bytemuck::Pod` requirement on `T` ensures the resulting values are valid.
407        let slice = unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast::<T>(), dim) };
408
409        // SAFETY: `data.as_ptr()` has an alignemnt of at least that required by `M`.
410        //
411        // Since `data` is a slice, its base pointer is `NonNull`.
412        //
413        // The `bytemuck::Pod` requirement ensures we have a valid instance.
414        let meta = unsafe { Mut::new(NonNull::new_unchecked(meta.as_mut_ptr()).cast::<M>()) };
415
416        Ok(Self { slice, meta })
417    }
418}
419
420///////////
421// Tests //
422///////////
423
424#[cfg(test)]
425mod tests {
426    use std::fmt::Debug;
427
428    use rand::{
429        distr::{Distribution, Uniform},
430        rngs::StdRng,
431        SeedableRng,
432    };
433
434    use super::*;
435    use crate::{
436        alloc::{AlignedAllocator, GlobalAllocator},
437        num::PowerOfTwo,
438    };
439
440    ////////////////////////
441    // Compensated Vector //
442    ////////////////////////
443
444    #[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
445    #[repr(C)]
446    struct Metadata {
447        a: u32,
448        b: u32,
449    }
450
451    impl Metadata {
452        fn new(a: u32, b: u32) -> Metadata {
453            Self { a, b }
454        }
455    }
456
457    #[test]
458    fn test_vector() {
459        let len = 20;
460        let mut base = PolySlice::<f32, Metadata, _>::new_in(len, GlobalAllocator).unwrap();
461
462        assert_eq!(base.len(), len);
463        assert_eq!(*base.meta(), Metadata::default());
464        assert!(!base.is_empty());
465
466        // Ensure that if we reborrow mutably that changes are visible.
467        {
468            *base.meta_mut() = Metadata::new(1, 2);
469            let v = base.vector_mut();
470
471            assert_eq!(v.len(), len);
472            v.iter_mut().enumerate().for_each(|(i, v)| *v = i as f32);
473        }
474
475        // Are the changes visible?
476        {
477            let expected_metadata = Metadata::new(1, 2);
478            assert_eq!(*base.meta(), expected_metadata);
479            assert_eq!(base.len(), len);
480            let v = base.vector();
481            v.iter().enumerate().for_each(|(i, v)| {
482                assert_eq!(*v, i as f32);
483            })
484        }
485    }
486
487    //////////////////////
488    // Canonicalization //
489    //////////////////////
490
491    // A test zero-sized type with non-strict alignment.
492    #[derive(Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
493    #[repr(C)]
494    struct Zst;
495
496    #[expect(clippy::infallible_try_from)]
497    impl TryFrom<usize> for Zst {
498        type Error = std::convert::Infallible;
499        fn try_from(_: usize) -> Result<Self, Self::Error> {
500            Ok(Self)
501        }
502    }
503
504    // A test zero-sized type with a strict alignment.
505    #[derive(Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
506    #[repr(C, align(16))]
507    struct ZstAligned;
508
509    #[expect(clippy::infallible_try_from)]
510    impl TryFrom<usize> for ZstAligned {
511        type Error = std::convert::Infallible;
512        fn try_from(_: usize) -> Result<Self, Self::Error> {
513            Ok(Self)
514        }
515    }
516
517    fn check_canonicalization<T, M>(
518        dim: usize,
519        align: usize,
520        slope: usize,
521        offset: usize,
522        ntrials: usize,
523        rng: &mut StdRng,
524    ) where
525        T: bytemuck::Pod + TryFrom<usize, Error: Debug> + Debug + PartialEq,
526        M: bytemuck::Pod + TryFrom<usize, Error: Debug> + Debug + PartialEq,
527    {
528        let bytes = SliceRef::<T, M>::canonical_bytes(dim);
529
530        assert_eq!(
531            bytes,
532            slope * dim + offset,
533            "computed bytes did not match the expected formula"
534        );
535
536        let expected_align = std::mem::align_of::<T>().max(std::mem::align_of::<M>());
537        assert_eq!(SliceRef::<T, M>::canonical_align().raw(), align);
538        assert_eq!(SliceRef::<T, M>::canonical_align().raw(), expected_align);
539
540        let mut buffer = Poly::broadcast(
541            0u8,
542            bytes + expected_align,
543            AlignedAllocator::new(PowerOfTwo::new(expected_align).unwrap()),
544        )
545        .unwrap();
546
547        // Expected metadata and vector encoding.
548        let mut expected = vec![usize::default(); dim];
549        let dist = Uniform::new(0, 255).unwrap();
550
551        for _ in 0..ntrials {
552            let m: usize = dist.sample(rng);
553            expected.iter_mut().for_each(|i| *i = dist.sample(rng));
554            {
555                let mut v =
556                    SliceMut::<T, M>::from_canonical_mut(&mut buffer[..bytes], dim).unwrap();
557                *v.meta_mut() = m.try_into().unwrap();
558
559                assert_eq!(v.vector().len(), dim);
560                assert_eq!(v.vector_mut().len(), dim);
561                std::iter::zip(v.vector_mut().iter_mut(), expected.iter_mut()).for_each(
562                    |(v, e)| {
563                        *v = (*e).try_into().unwrap();
564                    },
565                );
566            }
567
568            // Make sure the reconstruction is valid.
569            {
570                let v = SliceRef::<T, M>::from_canonical(&buffer[..bytes], dim).unwrap();
571                assert_eq!(*v.meta(), m.try_into().unwrap());
572
573                assert_eq!(v.vector().len(), dim);
574                std::iter::zip(v.vector().iter(), expected.iter()).for_each(|(v, e)| {
575                    assert_eq!(*v, (*e).try_into().unwrap());
576                });
577            }
578        }
579
580        // Length Errors
581        {
582            for len in 0..bytes {
583                // Too short
584                let err =
585                    SliceMut::<T, M>::from_canonical_mut(&mut buffer[..len], dim).unwrap_err();
586                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
587
588                // Too short
589                let err = SliceRef::<T, M>::from_canonical(&buffer[..len], dim).unwrap_err();
590                assert!(matches!(err, NotCanonical::WrongLength(_, _)));
591            }
592
593            // Too long
594            let err =
595                SliceMut::<T, M>::from_canonical_mut(&mut buffer[..bytes + 1], dim).unwrap_err();
596
597            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
598
599            let err = SliceRef::<T, M>::from_canonical(&buffer[..bytes + 1], dim).unwrap_err();
600
601            assert!(matches!(err, NotCanonical::WrongLength(_, _)));
602        }
603
604        // Alignment
605        {
606            for offset in 1..expected_align {
607                let err =
608                    SliceMut::<T, M>::from_canonical_mut(&mut buffer[offset..offset + bytes], dim)
609                        .unwrap_err();
610                assert!(matches!(err, NotCanonical::NotAligned(_)));
611
612                let err = SliceRef::<T, M>::from_canonical(&buffer[offset..offset + bytes], dim)
613                    .unwrap_err();
614                assert!(matches!(err, NotCanonical::NotAligned(_)));
615            }
616        }
617    }
618
619    cfg_if::cfg_if! {
620        if #[cfg(miri)] {
621            const MAX_DIM: usize = 10;
622            const TRIALS_PER_DIM: usize = 1;
623        } else {
624            const MAX_DIM: usize = 256;
625            const TRIALS_PER_DIM: usize = 20;
626        }
627    }
628
629    macro_rules! test_canonical {
630        ($name:ident, $M:ty, $T:ty, $align:literal, $slope:literal, $offset:literal, $seed:literal) => {
631            #[test]
632            fn $name() {
633                let mut rng = StdRng::seed_from_u64($seed);
634                for dim in 0..MAX_DIM {
635                    check_canonicalization::<$T, $M>(
636                        dim,
637                        $align,
638                        $slope,
639                        $offset,
640                        TRIALS_PER_DIM,
641                        &mut rng,
642                    );
643                }
644            }
645        };
646    }
647
648    test_canonical!(canonical_u8_u32, u8, u32, 4, 4, 4, 0x60884b7a4ca28f49);
649    test_canonical!(canonical_u32_u8, u32, u8, 4, 1, 4, 0x874aa5d8f40ec5ef);
650    test_canonical!(canonical_u32_u32, u32, u32, 4, 4, 4, 0x516c550e7be19acc);
651
652    test_canonical!(canonical_zst_u32, Zst, u32, 4, 4, 0, 0x908682ebda7c0fb9);
653    test_canonical!(canonical_u32_zst, u32, Zst, 4, 0, 4, 0xf223385881819c1c);
654
655    test_canonical!(
656        canonical_zstaligned_u32,
657        ZstAligned,
658        u32,
659        16,
660        4,
661        0,
662        0x1811ee0fd078a173
663    );
664    test_canonical!(
665        canonical_u32_zstaligned,
666        u32,
667        ZstAligned,
668        16,
669        0,
670        16,
671        0x6c9a67b09c0b6c0f
672    );
673}