diskann_quantization/meta/slice.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7 ops::{Deref, DerefMut},
8 ptr::NonNull,
9};
10
11use thiserror::Error;
12
13use crate::{
14 alloc::{AllocatorCore, AllocatorError, Poly},
15 num::PowerOfTwo,
16 ownership::{Mut, Owned, Ref},
17};
18
19/// A wrapper for a traditional Rust slice that provides the addition of arbitrary metadata.
20///
21/// # Examples
22///
23/// The `Slice` has several named variants that should be used instead of `Slice` directly:
24/// * [`PolySlice`]: An owning, independently allocated `Slice`.
25/// * [`SliceMut`]: A mutable, reference-like type.
26/// * [`SliceRef`]: A const, reference-like type.
27///
28/// ```
29/// use diskann_quantization::{
30/// alloc::GlobalAllocator,
31/// meta::slice,
32/// bits::Unsigned,
33/// };
34///
35/// use diskann_utils::{Reborrow, ReborrowMut};
36///
37/// #[derive(Debug, Default, Clone, Copy, PartialEq)]
38/// struct Metadata {
39/// value: f32,
40/// }
41///
42/// // Create a new heap-allocated Vector for 4-bit compressions capable of
43/// // holding 3 elements.
44/// //
45/// // In this case, the associated m
46/// let mut v = slice::PolySlice::new_in(3, GlobalAllocator).unwrap();
47///
48/// // We can inspect the underlying bitslice.
49/// let data = v.vector();
50/// assert_eq!(&data, &[0, 0, 0]);
51/// assert_eq!(*v.meta(), Metadata::default(), "expected default metadata value");
52///
53/// // If we want, we can mutably borrow the bitslice and mutate its components.
54/// let mut data = v.vector_mut();
55/// assert_eq!(data.len(), 3);
56/// data[0] = 1;
57/// data[1] = 2;
58/// data[2] = 3;
59///
60/// // Setting the underlying compensation will be visible in the original allocation.
61/// *v.meta_mut() = Metadata { value: 10.5 };
62///
63/// // Check that the changes are visible.
64/// assert_eq!(v.meta().value, 10.5);
65/// assert_eq!(&v.vector(), &[1, 2, 3]);
66/// ```
67///
68/// ## Constructing a `SliceMut` From Components
69///
70/// The following example shows how to assemble a `SliceMut` from raw parts.
71/// ```
72/// use diskann_quantization::meta::slice;
73///
74/// // For exposition purposes, we will use a slice of `u8` and `f32` as the metadata.
75/// let mut data = vec![0u8; 4];
76/// let mut metadata: f32 = 0.0;
77/// {
78/// let mut v = slice::SliceMut::new(data.as_mut_slice(), &mut metadata);
79///
80/// // Through `v`, we can set all the components in `slice` and the compensation.
81/// *v.meta_mut() = 123.4;
82/// let mut data = v.vector_mut();
83/// data[0] = 1;
84/// data[1] = 2;
85/// data[2] = 3;
86/// data[3] = 4;
87/// }
88///
89/// // Now we can check that the changes made internally are visible.
90/// assert_eq!(&data, &[1, 2, 3, 4]);
91/// assert_eq!(metadata, 123.4);
92/// ```
93///
94/// ## Canonical Layout
95///
96/// When the slice element type `T` and metadata type `M` are both
97/// [`bytemuck::Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html), [`SliceRef`]
98/// and [`SliceMut`] support layout canonicalization, where a raw slice can be used as the
99/// backing store for such vectors, enabling inline storage.
100///
101/// The layout is specified by:
102///
103/// * A base alignment of the maximum alignments of `T` and `M`.
104/// * The first `M` bytes contain the metadata.
105/// * Padding if necessary to reach the alignment of `T`.
106/// * The values of type `T` stored contiguously.
107///
108/// The canonical layout needs the following properties:
109///
110/// * `T: bytemuck::Pod` and `M: bytemuck::Pod: For safely storing and retrieving.
111/// * The length for a vector with `N` dimensions must be equal to the value returned
112/// from [`SliceRef::canonical_bytes`].
113/// * The **alignment** of the base pointer must be equal to [`SliceRef::canonical_align()`].
114///
115/// The following functions can be used to construct slices from raw slices:
116///
117/// * [`SliceRef::from_canonical`]
118/// * [`SliceMut::from_canonical_mut`]
119///
120/// An example is shown below.
121/// ```rust
122/// use diskann_quantization::{
123/// alloc::{AlignedAllocator, Poly},
124/// meta::slice,
125/// num::PowerOfTwo,
126/// };
127///
128/// let dim = 3;
129///
130/// // Since we don't control the alignment of the returned pointer, we need to oversize it.
131/// let bytes = slice::SliceRef::<u16, f32>::canonical_bytes(dim);
132/// let align = slice::SliceRef::<u16, f32>::canonical_align();
133/// let mut data = Poly::broadcast(
134/// 0u8,
135/// bytes,
136/// AlignedAllocator::new(align)
137/// ).unwrap();
138///
139/// // Construct a mutable compensated vector over the slice.
140/// let mut v = slice::SliceMut::<u16, f32>::from_canonical_mut(&mut data, dim).unwrap();
141/// *v.meta_mut() = 1.0;
142/// v.vector_mut().copy_from_slice(&[1, 2, 3]);
143///
144/// // Reconstruct a constant CompensatedVector.
145/// let cv = slice::SliceRef::<u16, f32>::from_canonical(&data, dim).unwrap();
146/// assert_eq!(*cv.meta(), 1.0);
147/// assert_eq!(&cv.vector(), &[1, 2, 3]);
148/// ```
149#[derive(Debug, Clone, Copy)]
150pub struct Slice<T, M> {
151 slice: T,
152 meta: M,
153}
154
155// Use the maximum alignment of `T` and `M` to ensure that no runtime padding is needed.
156//
157// For example, if `T` had a stricter alignment than `M` and we required an alignment of
158// `M`, then the number of padding bytes necessary would depend on the runtime alignment
159// of `M`, which is pretty useless for a storage format.
160const fn canonical_align<T, M>() -> PowerOfTwo {
161 let m_align = PowerOfTwo::alignment_of::<M>();
162 let t_align = PowerOfTwo::alignment_of::<T>();
163
164 // Poor man's `const`-compatible `max`.
165 if m_align.raw() > t_align.raw() {
166 m_align
167 } else {
168 t_align
169 }
170}
171
172// The number of bytes required for the metadata prefix. This will consist of the bytes
173// required for `M` as well as any padding to obtain an alignment of `T`.
174//
175// If `M` is a zero-sized type, then the return value is zero. This works because the base
176// alignment is at least the alignment of `T`, so no padding is necessary.
177const fn canonical_metadata_bytes<T, M>() -> usize {
178 let m_size = std::mem::size_of::<M>();
179 if m_size == 0 {
180 0
181 } else {
182 m_size.next_multiple_of(std::mem::align_of::<T>())
183 }
184}
185
186// A simple computation consisting of the bytes for the metadata, followed by the bytes
187// needed for the slice itself.
188const fn canonical_bytes<T, M>(count: usize) -> usize {
189 canonical_metadata_bytes::<T, M>() + std::mem::size_of::<T>() * count
190}
191
192impl<T, M> Slice<T, M> {
193 /// Construct a new `Slice` over the components.
194 pub fn new<U>(slice: T, meta: U) -> Self
195 where
196 U: Into<M>,
197 {
198 Self {
199 slice,
200 meta: meta.into(),
201 }
202 }
203
204 /// Return the metadata value for this vector.
205 pub fn meta(&self) -> &M::Target
206 where
207 M: Deref,
208 {
209 &self.meta
210 }
211
212 /// Get a mutable reference to the metadata component.
213 pub fn meta_mut(&mut self) -> &mut M::Target
214 where
215 M: DerefMut,
216 {
217 &mut self.meta
218 }
219}
220
221impl<T, M, U, V> Slice<T, M>
222where
223 T: Deref<Target = [U]>,
224 M: Deref<Target = V>,
225{
226 /// Return the number of dimensions of in the slice
227 pub fn len(&self) -> usize {
228 self.slice.len()
229 }
230
231 /// Return whether or not the vector is empty.
232 pub fn is_empty(&self) -> bool {
233 self.slice.is_empty()
234 }
235
236 /// Borrow the data slice.
237 pub fn vector(&self) -> &[U] {
238 &self.slice
239 }
240
241 /// Borrow the integer compressed vector.
242 pub fn vector_mut(&mut self) -> &mut [U]
243 where
244 T: DerefMut,
245 {
246 &mut self.slice
247 }
248
249 /// Return the necessary alignment for the base pointer required for
250 /// [`SliceRef::from_canonical`] and [`SliceMut::from_canonical_mut`].
251 ///
252 /// The return value is guaranteed to be a power of two.
253 pub const fn canonical_align() -> PowerOfTwo {
254 canonical_align::<U, V>()
255 }
256
257 /// Return the number of bytes required to store `count` elements plus metadata in a
258 /// canonical layout.
259 ///
260 /// See: [`SliceRef::from_canonical`], [`SliceMut::from_canonical_mut`].
261 pub const fn canonical_bytes(count: usize) -> usize {
262 canonical_bytes::<U, V>(count)
263 }
264}
265
266impl<T, A, M> Slice<Poly<[T], A>, Owned<M>>
267where
268 A: AllocatorCore,
269 T: Default,
270 M: Default,
271{
272 /// Create a new owned `VectorBase` with its metadata default initialized.
273 pub fn new_in(len: usize, allocator: A) -> Result<Self, AllocatorError> {
274 Ok(Self {
275 slice: Poly::from_iter((0..len).map(|_| T::default()), allocator)?,
276 meta: Owned::default(),
277 })
278 }
279}
280
281/// A reference to a slice and associated metadata.
282pub type SliceRef<'a, T, M> = Slice<&'a [T], Ref<'a, M>>;
283
284/// A mutable reference to a slice and associated metadata.
285pub type SliceMut<'a, T, M> = Slice<&'a mut [T], Mut<'a, M>>;
286
287/// An owning slice and associated metadata.
288pub type PolySlice<T, M, A> = Slice<Poly<[T], A>, Owned<M>>;
289
290//////////////////////
291// Canonical Layout //
292//////////////////////
293
294#[derive(Debug, Error, PartialEq, Clone, Copy)]
295pub enum NotCanonical {
296 #[error("expected a slice length of {0} bytes but instead got {1} bytes")]
297 WrongLength(usize, usize),
298 #[error("expected a base pointer alignment of at least {0}")]
299 NotAligned(usize),
300}
301
302impl<'a, T, M> SliceRef<'a, T, M>
303where
304 T: bytemuck::Pod,
305 M: bytemuck::Pod,
306{
307 /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
308 /// The canonical layout is as follows:
309 ///
310 /// * `std::mem::size_of::<T>().max(std::mem::size_of::<M>())` for the metadata.
311 /// * Necessary additional padding to achieve the alignment requirements for `T`.
312 /// * `std::mem::size_of::<T>() * dim` for the slice.
313 ///
314 /// Returns an error if:
315 ///
316 /// * `data` is not aligned to `Self::canonical_align()`.
317 /// * `data.len() != `Self::canonical_bytes(dim)`.
318 pub fn from_canonical(data: &'a [u8], dim: usize) -> Result<Self, NotCanonical> {
319 let expected_align = Self::canonical_align().raw();
320 let expected_len = Self::canonical_bytes(dim);
321
322 if !(data.as_ptr() as usize).is_multiple_of(expected_align) {
323 Err(NotCanonical::NotAligned(expected_align))
324 } else if data.len() != expected_len {
325 Err(NotCanonical::WrongLength(expected_len, data.len()))
326 } else {
327 // SAFETY: We have checked both the length and alignment of `data`.
328 Ok(unsafe { Self::from_canonical_unchecked(data, dim) })
329 }
330 }
331
332 /// Construct a `VectorRef` from the raw data.
333 ///
334 /// # Safety
335 ///
336 /// * `data.as_ptr()` must be aligned to `Self::canonical_align()`.
337 /// * `data.len()` must be equal to `Self::canonical_bytes(dim)`.
338 ///
339 /// This invariant is checked in debug builds and will panic if not satisfied.
340 pub unsafe fn from_canonical_unchecked(data: &'a [u8], dim: usize) -> Self {
341 debug_assert_eq!(data.len(), Self::canonical_bytes(dim));
342 let offset = canonical_metadata_bytes::<T, M>();
343
344 // SAFETY: The length pre-condition of this function implies that the offset region
345 // `[offset, offset + size_of::<T>() * dim]` is valid for reading.
346 //
347 // Additionally, the alignment requirment of the base pointer ensures that after
348 // applying `offset`, we still have proper alignment for `T`.
349 //
350 // The `bytemuck::Pod` bound ensures we don't have malformed types after the type cast.
351 let slice =
352 unsafe { std::slice::from_raw_parts(data.as_ptr().add(offset).cast::<T>(), dim) };
353
354 // SAFETY: The pointer is valid and non-null because `data` is a slice, its length
355 // must be at least `std::mem::size_of::<M>()` (from the length precondition for
356 // this function).
357 //
358 // The alignemnt pre-condition ensures that the pointer is suitable aligned.
359 //
360 // THe `bytemuck::Pod` bound ensures that the resulting type is valid.
361 let meta =
362 unsafe { Ref::new(NonNull::new_unchecked(data.as_ptr().cast_mut()).cast::<M>()) };
363 Self { slice, meta }
364 }
365}
366
367impl<'a, T, M> SliceMut<'a, T, M>
368where
369 T: bytemuck::Pod,
370 M: bytemuck::Pod,
371{
372 /// Construct an instance of `Self` viewing `data` as the canonical layout for a vector.
373 /// The canonical layout is as follows:
374 ///
375 /// * `std::mem::size_of::<T>().max(std::mem::size_of::<M>())` for the metadata.
376 /// * Necessary additional padding to achieve the alignment requirements for `T`.
377 /// * `std::mem::size_of::<T>() * dim` for the slice.
378 ///
379 /// Returns an error if:
380 ///
381 /// * `data` is not aligned to `Self::canonical_align()`.
382 /// * `data.len() != `Self::canonical_bytes(dim)`.
383 pub fn from_canonical_mut(data: &'a mut [u8], dim: usize) -> Result<Self, NotCanonical> {
384 let expected_align = Self::canonical_align().raw();
385 let expected_len = Self::canonical_bytes(dim);
386
387 if !(data.as_ptr() as usize).is_multiple_of(expected_align) {
388 return Err(NotCanonical::NotAligned(expected_align));
389 } else if data.len() != expected_len {
390 return Err(NotCanonical::WrongLength(expected_len, data.len()));
391 }
392
393 let offset = canonical_metadata_bytes::<T, M>();
394
395 // SAFETY: `offset < expected_len` and `data.len() == expected_len`, so `offset`
396 // is a valid interior offset for `data`.
397 let (meta, slice) = unsafe { data.split_at_mut_unchecked(offset) };
398
399 // SAFETY: `data.as_ptr()` when offset by `offset` will have an alignment suitable
400 // for type `T`.
401 //
402 // We have checked that `data.len() == expected_len`, which implies that the region
403 // of memory between `offset` and `data.len()` covers exactly `size_of::<T>() * dim`
404 // bytes.
405 //
406 // The `bytemuck::Pod` requirement on `T` ensures the resulting values are valid.
407 let slice = unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr().cast::<T>(), dim) };
408
409 // SAFETY: `data.as_ptr()` has an alignemnt of at least that required by `M`.
410 //
411 // Since `data` is a slice, its base pointer is `NonNull`.
412 //
413 // The `bytemuck::Pod` requirement ensures we have a valid instance.
414 let meta = unsafe { Mut::new(NonNull::new_unchecked(meta.as_mut_ptr()).cast::<M>()) };
415
416 Ok(Self { slice, meta })
417 }
418}
419
420///////////
421// Tests //
422///////////
423
424#[cfg(test)]
425mod tests {
426 use std::fmt::Debug;
427
428 use rand::{
429 distr::{Distribution, Uniform},
430 rngs::StdRng,
431 SeedableRng,
432 };
433
434 use super::*;
435 use crate::{
436 alloc::{AlignedAllocator, GlobalAllocator},
437 num::PowerOfTwo,
438 };
439
440 ////////////////////////
441 // Compensated Vector //
442 ////////////////////////
443
444 #[derive(Default, Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
445 #[repr(C)]
446 struct Metadata {
447 a: u32,
448 b: u32,
449 }
450
451 impl Metadata {
452 fn new(a: u32, b: u32) -> Metadata {
453 Self { a, b }
454 }
455 }
456
457 #[test]
458 fn test_vector() {
459 let len = 20;
460 let mut base = PolySlice::<f32, Metadata, _>::new_in(len, GlobalAllocator).unwrap();
461
462 assert_eq!(base.len(), len);
463 assert_eq!(*base.meta(), Metadata::default());
464 assert!(!base.is_empty());
465
466 // Ensure that if we reborrow mutably that changes are visible.
467 {
468 *base.meta_mut() = Metadata::new(1, 2);
469 let v = base.vector_mut();
470
471 assert_eq!(v.len(), len);
472 v.iter_mut().enumerate().for_each(|(i, v)| *v = i as f32);
473 }
474
475 // Are the changes visible?
476 {
477 let expected_metadata = Metadata::new(1, 2);
478 assert_eq!(*base.meta(), expected_metadata);
479 assert_eq!(base.len(), len);
480 let v = base.vector();
481 v.iter().enumerate().for_each(|(i, v)| {
482 assert_eq!(*v, i as f32);
483 })
484 }
485 }
486
487 //////////////////////
488 // Canonicalization //
489 //////////////////////
490
491 // A test zero-sized type with non-strict alignment.
492 #[derive(Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
493 #[repr(C)]
494 struct Zst;
495
496 #[expect(clippy::infallible_try_from)]
497 impl TryFrom<usize> for Zst {
498 type Error = std::convert::Infallible;
499 fn try_from(_: usize) -> Result<Self, Self::Error> {
500 Ok(Self)
501 }
502 }
503
504 // A test zero-sized type with a strict alignment.
505 #[derive(Debug, Clone, Copy, PartialEq, bytemuck::Zeroable, bytemuck::Pod)]
506 #[repr(C, align(16))]
507 struct ZstAligned;
508
509 #[expect(clippy::infallible_try_from)]
510 impl TryFrom<usize> for ZstAligned {
511 type Error = std::convert::Infallible;
512 fn try_from(_: usize) -> Result<Self, Self::Error> {
513 Ok(Self)
514 }
515 }
516
517 fn check_canonicalization<T, M>(
518 dim: usize,
519 align: usize,
520 slope: usize,
521 offset: usize,
522 ntrials: usize,
523 rng: &mut StdRng,
524 ) where
525 T: bytemuck::Pod + TryFrom<usize, Error: Debug> + Debug + PartialEq,
526 M: bytemuck::Pod + TryFrom<usize, Error: Debug> + Debug + PartialEq,
527 {
528 let bytes = SliceRef::<T, M>::canonical_bytes(dim);
529
530 assert_eq!(
531 bytes,
532 slope * dim + offset,
533 "computed bytes did not match the expected formula"
534 );
535
536 let expected_align = std::mem::align_of::<T>().max(std::mem::align_of::<M>());
537 assert_eq!(SliceRef::<T, M>::canonical_align().raw(), align);
538 assert_eq!(SliceRef::<T, M>::canonical_align().raw(), expected_align);
539
540 let mut buffer = Poly::broadcast(
541 0u8,
542 bytes + expected_align,
543 AlignedAllocator::new(PowerOfTwo::new(expected_align).unwrap()),
544 )
545 .unwrap();
546
547 // Expected metadata and vector encoding.
548 let mut expected = vec![usize::default(); dim];
549 let dist = Uniform::new(0, 255).unwrap();
550
551 for _ in 0..ntrials {
552 let m: usize = dist.sample(rng);
553 expected.iter_mut().for_each(|i| *i = dist.sample(rng));
554 {
555 let mut v =
556 SliceMut::<T, M>::from_canonical_mut(&mut buffer[..bytes], dim).unwrap();
557 *v.meta_mut() = m.try_into().unwrap();
558
559 assert_eq!(v.vector().len(), dim);
560 assert_eq!(v.vector_mut().len(), dim);
561 std::iter::zip(v.vector_mut().iter_mut(), expected.iter_mut()).for_each(
562 |(v, e)| {
563 *v = (*e).try_into().unwrap();
564 },
565 );
566 }
567
568 // Make sure the reconstruction is valid.
569 {
570 let v = SliceRef::<T, M>::from_canonical(&buffer[..bytes], dim).unwrap();
571 assert_eq!(*v.meta(), m.try_into().unwrap());
572
573 assert_eq!(v.vector().len(), dim);
574 std::iter::zip(v.vector().iter(), expected.iter()).for_each(|(v, e)| {
575 assert_eq!(*v, (*e).try_into().unwrap());
576 });
577 }
578 }
579
580 // Length Errors
581 {
582 for len in 0..bytes {
583 // Too short
584 let err =
585 SliceMut::<T, M>::from_canonical_mut(&mut buffer[..len], dim).unwrap_err();
586 assert!(matches!(err, NotCanonical::WrongLength(_, _)));
587
588 // Too short
589 let err = SliceRef::<T, M>::from_canonical(&buffer[..len], dim).unwrap_err();
590 assert!(matches!(err, NotCanonical::WrongLength(_, _)));
591 }
592
593 // Too long
594 let err =
595 SliceMut::<T, M>::from_canonical_mut(&mut buffer[..bytes + 1], dim).unwrap_err();
596
597 assert!(matches!(err, NotCanonical::WrongLength(_, _)));
598
599 let err = SliceRef::<T, M>::from_canonical(&buffer[..bytes + 1], dim).unwrap_err();
600
601 assert!(matches!(err, NotCanonical::WrongLength(_, _)));
602 }
603
604 // Alignment
605 {
606 for offset in 1..expected_align {
607 let err =
608 SliceMut::<T, M>::from_canonical_mut(&mut buffer[offset..offset + bytes], dim)
609 .unwrap_err();
610 assert!(matches!(err, NotCanonical::NotAligned(_)));
611
612 let err = SliceRef::<T, M>::from_canonical(&buffer[offset..offset + bytes], dim)
613 .unwrap_err();
614 assert!(matches!(err, NotCanonical::NotAligned(_)));
615 }
616 }
617 }
618
619 cfg_if::cfg_if! {
620 if #[cfg(miri)] {
621 const MAX_DIM: usize = 10;
622 const TRIALS_PER_DIM: usize = 1;
623 } else {
624 const MAX_DIM: usize = 256;
625 const TRIALS_PER_DIM: usize = 20;
626 }
627 }
628
629 macro_rules! test_canonical {
630 ($name:ident, $M:ty, $T:ty, $align:literal, $slope:literal, $offset:literal, $seed:literal) => {
631 #[test]
632 fn $name() {
633 let mut rng = StdRng::seed_from_u64($seed);
634 for dim in 0..MAX_DIM {
635 check_canonicalization::<$T, $M>(
636 dim,
637 $align,
638 $slope,
639 $offset,
640 TRIALS_PER_DIM,
641 &mut rng,
642 );
643 }
644 }
645 };
646 }
647
648 test_canonical!(canonical_u8_u32, u8, u32, 4, 4, 4, 0x60884b7a4ca28f49);
649 test_canonical!(canonical_u32_u8, u32, u8, 4, 1, 4, 0x874aa5d8f40ec5ef);
650 test_canonical!(canonical_u32_u32, u32, u32, 4, 4, 4, 0x516c550e7be19acc);
651
652 test_canonical!(canonical_zst_u32, Zst, u32, 4, 4, 0, 0x908682ebda7c0fb9);
653 test_canonical!(canonical_u32_zst, u32, Zst, 4, 0, 4, 0xf223385881819c1c);
654
655 test_canonical!(
656 canonical_zstaligned_u32,
657 ZstAligned,
658 u32,
659 16,
660 4,
661 0,
662 0x1811ee0fd078a173
663 );
664 test_canonical!(
665 canonical_u32_zstaligned,
666 u32,
667 ZstAligned,
668 16,
669 0,
670 16,
671 0x6c9a67b09c0b6c0f
672 );
673}