Skip to main content

diskann_quantization/bits/
slice.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{marker::PhantomData, ops::RangeInclusive, ptr::NonNull};
7
8use diskann_utils::{Reborrow, ReborrowMut};
9use thiserror::Error;
10
11use super::{
12    length::{Dynamic, Length},
13    packing,
14    ptr::{AsMutPtr, AsPtr, MutSlicePtr, Precursor, SlicePtr},
15};
16use crate::{
17    alloc::{AllocatorCore, AllocatorError, GlobalAllocator, Poly},
18    utils,
19};
20
21//////////////////////
22// Retrieval Traits //
23//////////////////////
24
25/// Representation of `NBITS` bit numbers in the associated domain.
26pub trait Representation<const NBITS: usize> {
27    /// The type of the domain accepted by this representation.
28    type Domain: Iterator<Item = i64>;
29
30    /// Encode `value` into the lower order bits of a byte. Returns the encoded value on
31    /// success, or an `EncodingError` if the value is unencodable.
32    fn encode(value: i64) -> Result<u8, EncodingError>;
33
34    /// Encode `value` into the lower order bits of a byte without checking if `value`
35    /// is encodable. This function is not marked as unsafe because in-and-of itself, it
36    /// won't cause memory safety issues.
37    ///
38    /// This may panic in debug mode when `value` is outside of this representation's
39    /// domain.
40    fn encode_unchecked(value: i64) -> u8;
41
42    /// Decode a previously encoded value. The result will be in the range
43    /// `[Self::MIN, Self::MAX]`.
44    ///
45    /// # Panics
46    ///
47    /// May panic in debug builds if `raw` is not a valid pattern emitted by `encode`.
48    fn decode(raw: u8) -> i64;
49
50    /// Check whether or not the argument is in the domain.
51    fn check(value: i64) -> bool;
52
53    /// Return an iterator over the domain of representable values.
54    fn domain() -> Self::Domain;
55}
56
57#[derive(Debug, Error, Clone, Copy)]
58#[error("value {} is not in the encodable range of {}", got, domain)]
59pub struct EncodingError {
60    got: i64,
61    // Question: Why is this a ref-ref??
62    //
63    // Answer: I have a personal vendetta to keep this struct within 16-bytes with a
64    // niche-optimization. A `&'static src` is 16 bytes in and of itself. But a
65    // `&'static &'static str`, now *that's* just 8 bytes.
66    domain: &'static &'static str,
67}
68
69impl EncodingError {
70    fn new(got: i64, domain: &'static &'static str) -> Self {
71        Self { got, domain }
72    }
73}
74
75//////////////
76// Unsigned //
77//////////////
78
79/// Storage unsigned integers in slices.
80///
81/// For a bit count of `NBITS`, the `Unsigned` type can store unsigned integers in
82/// the range `[0, 2^NBITS - 1]`.
83#[derive(Debug, Clone, Copy)]
84pub struct Unsigned;
85
86impl Unsigned {
87    /// Return the dynamic range of an `Unsigned` encoding for `NBITS`.
88    pub const fn domain_const<const NBITS: usize>() -> std::ops::RangeInclusive<i64> {
89        0..=2i64.pow(NBITS as u32) - 1
90    }
91
92    #[allow(clippy::panic)]
93    const fn domain_str(nbits: usize) -> &'static &'static str {
94        match nbits {
95            8 => &"[0, 255]",
96            7 => &"[0, 127]",
97            6 => &"[0, 63]",
98            5 => &"[0, 31]",
99            4 => &"[0, 15]",
100            3 => &"[0, 7]",
101            2 => &"[0, 3]",
102            1 => &"[0, 1]",
103            _ => panic!("unimplemented"),
104        }
105    }
106}
107
108macro_rules! repr_unsigned {
109    ($N:literal) => {
110        impl Representation<$N> for Unsigned {
111            type Domain = RangeInclusive<i64>;
112
113            fn encode(value: i64) -> Result<u8, EncodingError> {
114                if !<Self as Representation<$N>>::check(value) {
115                    // Even with the macro gymnastics - we still have to manually inline
116                    // this computation :(
117                    let domain = Self::domain_str($N);
118                    Err(EncodingError::new(value, domain))
119                } else {
120                    Ok(<Self as Representation<$N>>::encode_unchecked(value))
121                }
122            }
123
124            fn encode_unchecked(value: i64) -> u8 {
125                debug_assert!(<Self as Representation<$N>>::check(value));
126                value as u8
127            }
128
129            fn decode(raw: u8) -> i64 {
130                // Feed through the value un-modified.
131                let raw: i64 = raw.into();
132                debug_assert!(<Self as Representation<$N>>::check(raw));
133                raw
134            }
135
136            fn check(value: i64) -> bool {
137                <Self as Representation<$N>>::domain().contains(&value)
138            }
139
140            fn domain() -> Self::Domain {
141                Self::domain_const::<$N>()
142            }
143        }
144    };
145    ($N:literal, $($Ns:literal),+) => {
146        repr_unsigned!($N);
147        $(repr_unsigned!($Ns);)+
148    };
149}
150
151repr_unsigned!(1, 2, 3, 4, 5, 6, 7, 8);
152
153////////////
154// Binary //
155////////////
156
157/// A 1-bit binary quantization mapping `-1` to `0` and `1` to `1`.
158#[derive(Debug, Clone, Copy)]
159pub struct Binary;
160
161impl Representation<1> for Binary {
162    type Domain = std::array::IntoIter<i64, 2>;
163
164    fn encode(value: i64) -> Result<u8, EncodingError> {
165        if !Self::check(value) {
166            const DOMAIN: &str = "{-1, 1}";
167            Err(EncodingError::new(value, &DOMAIN))
168        } else {
169            Ok(Self::encode_unchecked(value))
170        }
171    }
172
173    fn encode_unchecked(value: i64) -> u8 {
174        debug_assert!(Self::check(value));
175        // The use of `clamp` here is a quick way of sending `-1` to `0` and `1` to `1`.
176        value.clamp(0, 1) as u8
177    }
178
179    fn decode(raw: u8) -> i64 {
180        // Raw is either 0 or 1. We want to map it to -1 or 1.
181        // We can do this by multiplying by 2 and subtracting 1.
182        let raw: i64 = raw.into();
183        (raw << 1) - 1
184    }
185
186    fn check(value: i64) -> bool {
187        value == -1 || value == 1
188    }
189
190    /// Return the domain of the encoding.
191    ///
192    /// The domain is the set `{-1, 1}`.
193    fn domain() -> Self::Domain {
194        [-1, 1].into_iter()
195    }
196}
197
198////////////////////////
199// Permutation Traits //
200////////////////////////
201
202/// A enable the dimensions within a BitSlice to be permuted in an arbitrary way.
203///
204/// # Safety
205///
206/// This provides the computation for the number of bytes required to store a given number
207/// of `NBITS` bit-packed values. Improper implementation will result in out-of-bounds
208/// accesses being made.
209///
210/// The following must hold:
211///
212/// For all counts values `c`, let `b = Self::bytes(c)` be requested number of bytes for `c`
213/// for this permutation strategy and `s` be a slice of bytes with length `c`. Then, for all
214/// `i < c`, `Self::pack(s, i, _)` and `Self::unpack(s, i)` must only access `s` in-bounds.
215///
216/// This implementation must be such that unsafe code can rely on this property holding.
217pub unsafe trait PermutationStrategy<const NBITS: usize> {
218    /// Return the number of bytes required to store `count` values of with `NBITS`.
219    fn bytes(count: usize) -> usize;
220
221    /// Pack the lower `NBITS` bits of `value` into `s` at logical index `i`.
222    ///
223    /// # Safety
224    ///
225    /// This is a tricky function to call with several subtle requirements.
226    ///
227    /// * Let `s` be a slice of length `c` where `c = Self::bytes(b)` for some `b`. Then
228    ///   this function is safe to call if `i` is in `[0, b)`.
229    unsafe fn pack(s: &mut [u8], i: usize, value: u8);
230
231    /// Unpack the value stored at logical index `i` and return it as the lower `NBITS` bits
232    /// in the return value.
233    ///
234    /// # Safety
235    ///
236    /// This is a tricky function to call with several subtle requirements.
237    ///
238    /// * Let `s` be a slice of length `c` where `c = Self::bytes(b)` for some `b`. Then
239    ///   this function is safe to call if `i` is in `[0, b)`.
240    unsafe fn unpack(s: &[u8], i: usize) -> u8;
241}
242
243/// The identity permutation strategy.
244///
245/// All values are densly packed.
246#[derive(Debug, Clone, Copy)]
247pub struct Dense;
248
249impl Dense {
250    fn bytes<const NBITS: usize>(count: usize) -> usize {
251        utils::div_round_up(NBITS * count, 8)
252    }
253}
254
255/// Safety: For all `0 <= i < count`, `NBITS * i <= 8 * ceil((NBITS * count) / 8)`.
256unsafe impl<const NBITS: usize> PermutationStrategy<NBITS> for Dense {
257    fn bytes(count: usize) -> usize {
258        Self::bytes::<NBITS>(count)
259    }
260
261    unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
262        let bitaddress = NBITS * i;
263
264        let bytestart = bitaddress / 8;
265        let bytestop = (bitaddress + NBITS - 1) / 8;
266        let bitstart = bitaddress - 8 * bytestart;
267        debug_assert!(bytestop < data.len());
268
269        if bytestart == bytestop {
270            // SAFETY: This is safe for the following:
271            // ```
272            // data.len() >= ceil(NBITS * i / 8)        from `pack`'s safety requirements.
273            //            >= floor(NBITS * i / 8)
274            //            = bytestart
275            //
276            // Since we are only reading one byte - this is in-bounds.
277            // ```
278            let raw = unsafe { data.as_ptr().add(bytestart).read() };
279            let packed = packing::pack_u8::<NBITS>(raw, encoded, bitstart);
280
281            // SAFETY: See previous argument for in-bounds access.
282            // For writing, we are the only writers in this function and we have a mutable
283            // reference to `data`.
284            unsafe { data.as_mut_ptr().add(bytestart).write(packed) };
285        } else {
286            // SAFETY: This is safe for the following reason:
287            // ```
288            // data.len() >= ceil(NBITS * i / 8)        from `pack`'s safety requirements.
289            //            = bytestop
290            //            = bytestart + 1
291            // ```
292            // Therefore, it is safe to read 2-bytes starting at `bytestart`.
293            let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
294            let packed = packing::pack_u16::<NBITS>(raw, encoded, bitstart);
295
296            // SAFETY: See previous argument for in-bounds access.
297            // For writing, we are the only writers in this function and we have a mutable
298            // reference to `data`.
299            unsafe {
300                data.as_mut_ptr()
301                    .add(bytestart)
302                    .cast::<u16>()
303                    .write_unaligned(packed)
304            };
305        }
306    }
307
308    unsafe fn unpack(data: &[u8], i: usize) -> u8 {
309        let bitaddress = NBITS * i;
310
311        let bytestart = bitaddress / 8;
312        let bytestop = (bitaddress + NBITS - 1) / 8;
313        debug_assert!(bytestop < data.len());
314        if bytestart == bytestop {
315            // SAFETY: See the safety argument in `pack` for in-bounds.
316            let raw = unsafe { data.as_ptr().add(bytestart).read() };
317            packing::unpack_u8::<NBITS>(raw, bitaddress - 8 * bytestart)
318        } else {
319            // SAFETY: See the safety argument in `pack` for in-bounds.
320            let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
321            packing::unpack_u16::<NBITS>(raw, bitaddress - 8 * bytestart)
322        }
323    }
324}
325
326/// A layout specialized for performing multi-bit operations with 1-bit scalar quantization.
327///
328/// The layout provided by this struct is as follows. Assume we are compressing `N` bit data.
329/// Then, the store the data in blocks of `64 * N` bits (where 64 comes from the native CPU
330/// word size).
331///
332/// Each block can contain 64 values, stored in `N` 64-bit words. The 0th bit of each value
333/// is stored in word 0, the 1st bit is stored in word 1, etc.
334///
335/// # Partially Filled Blocks
336///
337/// This strategy always requests data in blocks. For partially filled blocks, the lower
338/// bits in the last block will be used.
339#[derive(Debug, Clone, Copy)]
340pub struct BitTranspose;
341
342/// Safety: We ask for bytes in multiples of 32. Furthermore, the accesses to the packed
343/// data in `pack` and `unpack` use checked accesses, so out-of-bounds reads will panic.
344unsafe impl PermutationStrategy<4> for BitTranspose {
345    fn bytes(count: usize) -> usize {
346        32 * utils::div_round_up(count, 64)
347    }
348
349    unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
350        // Compute the byte-address of the block containing `i`.
351        let block_start = 32 * (i / 64);
352        // Compute the offset within the block to find the first byte containing `i`.
353        let byte_start = block_start + (i % 64) / 8;
354        // Finally, compute the bit within the byte that we are interested in.
355        let bit = i % 8;
356
357        let mask: u8 = 0x1 << bit;
358        for p in 0..4 {
359            let mut v = data[byte_start + 8 * p];
360            v = (v & !mask) | (((encoded >> p) & 0x1) << bit);
361            data[byte_start + 8 * p] = v;
362        }
363    }
364
365    unsafe fn unpack(data: &[u8], i: usize) -> u8 {
366        // Compute the byte-address of the block containing `i`.
367        let block_start = 32 * (i / 64);
368        // Compute the offset within the block to find the first byte containing `i`.
369        let byte_start = block_start + (i % 64) / 8;
370        // Finally, compute the bit within the byte that we are interested in.
371        let bit = i % 8;
372
373        let mut output: u8 = 0;
374        for p in 0..4 {
375            let v = data[byte_start + 8 * p];
376            output |= ((v >> bit) & 0x1) << p
377        }
378        output
379    }
380}
381
382////////////
383// Errors //
384////////////
385
386#[derive(Debug, Error, Clone, Copy)]
387#[error("input span has length {got} bytes but expected {expected}")]
388pub struct ConstructionError {
389    got: usize,
390    expected: usize,
391}
392
393#[derive(Debug, Error, Clone, Copy)]
394#[error("index {index} exceeds the maximum length of {len}")]
395pub struct IndexOutOfBounds {
396    index: usize,
397    len: usize,
398}
399
400impl IndexOutOfBounds {
401    fn new(index: usize, len: usize) -> Self {
402        Self { index, len }
403    }
404}
405
406#[derive(Debug, Error, Clone, Copy)]
407#[error("error setting index in bitslice")]
408#[non_exhaustive]
409pub enum SetError {
410    IndexError(#[from] IndexOutOfBounds),
411    EncodingError(#[from] EncodingError),
412}
413
414#[derive(Debug, Error, Clone, Copy)]
415#[error("error getting index in bitslice")]
416pub enum GetError {
417    IndexError(#[from] IndexOutOfBounds),
418}
419
420//////////////
421// BitSlice //
422//////////////
423
424/// A generalized representation for packed small bit integer encodings over a contiguous
425/// span of memory.
426///
427/// Think of this as a Rust slice, but supporting integer elements with fewer than 8-bits.
428/// The borrowed representations [`BitSlice`] and [`MutBitSlice`] consist of just a pointer
429/// and a length and are therefore just 16-bytes in size and amenable to the niche
430/// optimization.
431///
432/// # Parameters
433///
434/// * `NBITS`: The number of bits occupied by each entry in the vector.
435///
436/// * `Repr`: The storage representation for each collection of 8-bits. This representation
437///   defines the domain of the encoding (i.e., range of realized values) as well as how
438///   this domain is mapped into `NBITS` bits.
439///
440/// * `Ptr`: The storage type for the contiguous memory. Possible representations are:
441///   - `diskann_quantization::bits::SlicePtr<'_, u8>`: For immutable views.
442///   - `diskann_quantization::bits::MutSlicePtr<'_, u8>`: For mutable views.
443///   - `Box<[u8]>`: For standalone vectors.
444///
445/// * `Perm`: By default, this type uses a dense storage strategy where the least significant
446///   bit of the value at index `i` occurs directly after the most significant bit of
447///   the value at index `i-1`.
448///
449///   Different permutations can be used to enable faster distance computations between
450///   compressed vectors and full-precision vectors by enabling faster SIMD unpacking.
451///
452/// * `Len`: The representation for the length of the vector. This may only be one of the
453///   two families of types:
454///   - `diskann_quantization::bits::Dynamic`: For instances with a run-time length.
455///   - `diskann_quantization::bits::Static<N>`: For instances with a compile-time known length
456///     of `N`.
457///
458/// # Examples
459///
460/// ## Canonical Bit Slice
461///
462/// The canonical `BitSlice` stores unsigned integers of `NBITS` densely in memory.
463/// That is, for a type `BitSliceBase<3, Unsigned, _>`, the layout is as follows:
464/// ```text
465/// |<--LSB-- byte 0 --MSB--->|<--LSB-- byte 1 --MSB--->|
466/// | a0 a1 a2 b0 b1 b2 c0 c1 | c2 d0 d1 d2 e0 e1 e2 f0 |
467/// |<-- A -->|<-- B ->|<--- C -->|<-- D ->|<-- E ->|<- F
468/// ```
469/// An example is shown below:
470///
471/// ```rust
472/// use diskann_quantization::bits::{BoxedBitSlice, Unsigned};
473/// // Create a new boxed bit-slice with capacity for 10 dimensions.
474/// let mut x = BoxedBitSlice::<3, Unsigned>::new_boxed(10);
475/// assert_eq!(x.len(), 10);
476/// // The number of bytes in the canonical representation is computed by
477/// // ceil((len * NBITS) / 8);
478/// assert_eq!(x.bytes(), 4);
479///
480/// // Assign values.
481/// x.set(0, 1).unwrap(); // assign the value 1 to index 0
482/// x.set(1, 5).unwrap(); // assign the value 5 to index 1
483/// assert_eq!(x.get(0).unwrap(), 1); // retrieve the value at index 0
484/// assert_eq!(x.get(1).unwrap(), 5); // retrieve the value at index 1
485///
486/// // Assigning out-of-bounds will result in an error.
487/// let err = x.set(1, 10).unwrap_err();
488/// assert!(matches!(diskann_quantization::bits::SetError::EncodingError, err));
489/// // The old value is left untouched.
490/// assert_eq!(x.get(1).unwrap(), 5);
491///
492/// // `BoxedBitSlice` allows itself to be consumed, returning the underlying storage.
493/// let y = x.into_inner();
494/// assert_eq!(y.len(), BoxedBitSlice::<3, Unsigned>::bytes_for(10));
495/// ```
496///
497/// The above example demonstrates a boxed bit slice - a type that owns its underlying
498/// memory. However, this is not always ergonomic when interfacing with data stores.
499/// For this, the viewing interface can be used.
500/// ```rust
501/// use diskann_quantization::bits::{MutBitSlice, Unsigned};
502///
503/// let mut x: Vec<u8> = vec![0; 4];
504/// let mut slice = MutBitSlice::<3, Unsigned>::new(x.as_mut_slice(), 10).unwrap();
505/// assert_eq!(slice.len(), 10);
506/// assert_eq!(slice.bytes(), 4);
507///
508/// // The slice reference behaves just like boxed slice.
509/// slice.set(0, 5).unwrap();
510/// assert_eq!(slice.get(0).unwrap(), 5);
511///
512/// // Note - if the number of bytes required for the provided dimensions does not match
513/// // the length of the provided span, than slice construction will return an error.
514/// let err = MutBitSlice::<3, Unsigned>::new(x.as_mut_slice(), 11).unwrap_err();
515/// assert_eq!(err.to_string(), "input span has length 4 bytes but expected 5");
516/// ```
517#[derive(Debug, Clone, Copy)]
518pub struct BitSliceBase<const NBITS: usize, Repr, Ptr, Perm = Dense, Len = Dynamic>
519where
520    Repr: Representation<NBITS>,
521    Ptr: AsPtr<Type = u8>,
522    Perm: PermutationStrategy<NBITS>,
523    Len: Length,
524{
525    ptr: Ptr,
526    len: Len,
527    repr: PhantomData<Repr>,
528    packing: PhantomData<Perm>,
529}
530
531impl<const NBITS: usize, Repr, Ptr, Perm, Len> BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
532where
533    Repr: Representation<NBITS>,
534    Ptr: AsPtr<Type = u8>,
535    Perm: PermutationStrategy<NBITS>,
536    Len: Length,
537{
538    /// Check that NBITS is in the interval [1, 8].
539    const _CHECK: () = assert!(NBITS > 0 && NBITS <= 8);
540
541    /// Return the exact number of bytes required to store `count` values.
542    pub fn bytes_for(count: usize) -> usize {
543        Perm::bytes(count)
544    }
545
546    /// Return a new `BitSlice` over the data behind `ptr`.
547    ///
548    /// # Safety
549    ///
550    /// It's the callers responsibility to ensure that all the invariants required for
551    /// `std::slice::from_raw_parts(ptr.as_ptr(), len.value)` hold.
552    unsafe fn new_unchecked_internal(ptr: Ptr, len: Len) -> Self {
553        Self {
554            ptr,
555            len,
556            repr: PhantomData,
557            packing: PhantomData,
558        }
559    }
560
561    /// Construct a new `BitSlice` without checking preconditions.
562    ///
563    /// # Safety
564    ///
565    /// Requires the following to avoid undefined behavior:
566    ///
567    /// * `precursor.precursor_len() == Self::bytes_for(<Count as Into<Len>>::into(count).value())`.
568    ///
569    /// This is checked in debug builds.
570    pub unsafe fn new_unchecked<Pre, Count>(precursor: Pre, count: Count) -> Self
571    where
572        Count: Into<Len>,
573        Pre: Precursor<Ptr>,
574    {
575        let count: Len = count.into();
576        debug_assert_eq!(precursor.precursor_len(), Self::bytes_for(count.value()));
577
578        // SAFETY: Inherited from the caller.
579        unsafe { Self::new_unchecked_internal(precursor.precursor_into(), count) }
580    }
581
582    /// Construct a new `BitSlice` from the `precursor` capable of holding `count` encoded
583    /// elements of size `NBITS.
584    ///
585    /// # Requirements
586    ///
587    /// The number of bytes pointed to by the precursor must be equal to the number of bytes
588    /// required by the layout. That is:
589    ///
590    /// * `precursor.precursor_len() == Self::bytes_for(<Count as Into<Len>>::into(count).value())`.
591    pub fn new<Pre, Count>(precursor: Pre, count: Count) -> Result<Self, ConstructionError>
592    where
593        Count: Into<Len>,
594        Pre: Precursor<Ptr>,
595    {
596        // Allow callers to pass in `usize` as the count when using dynamic
597        let count: Len = count.into();
598
599        // Make sure that the slice has the correct length.
600        if precursor.precursor_len() != Self::bytes_for(count.value()) {
601            Err(ConstructionError {
602                got: precursor.precursor_len(),
603                expected: Self::bytes_for(count.value()),
604            })
605        } else {
606            // SAFETY: We have checked that `precursor` has the correct number of bytes.
607            // The only implementations of `Precursor` are those we defined for slices, so we
608            // don't have to worry about downstream users inserting their own, incorrectly
609            // implemented implementation.
610            Ok(unsafe { Self::new_unchecked(precursor, count) })
611        }
612    }
613
614    /// Return the number of elements contained in the slice.
615    pub fn len(&self) -> usize {
616        self.len.value()
617    }
618
619    /// Return whether or not the slice is empty.
620    pub fn is_empty(&self) -> bool {
621        self.len() == 0
622    }
623
624    /// Return the number of bytes occupied by this slice.
625    pub fn bytes(&self) -> usize {
626        Self::bytes_for(self.len())
627    }
628
629    /// Return the value at logical index `i`.
630    pub fn get(&self, i: usize) -> Result<i64, GetError> {
631        if i >= self.len() {
632            Err(IndexOutOfBounds::new(i, self.len()).into())
633        } else {
634            // SAFETY: We've performed the bounds check.
635            Ok(unsafe { self.get_unchecked(i) })
636        }
637    }
638
639    /// Return the value at logical index `i`.
640    ///
641    /// # Safety
642    ///
643    /// Argument `i` must be in bounds: `0 <= i < self.len()`.
644    pub unsafe fn get_unchecked(&self, i: usize) -> i64 {
645        debug_assert!(i < self.len());
646        debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
647
648        // SAFETY: We maintain the invariant that
649        // `self.as_slice().len() == Perm::bytes(self.len())`.
650        //
651        // So, `i < self.len()` implies we uphold the safety requirements of `unpack`.
652        Repr::decode(unsafe { Perm::unpack(self.as_slice(), i) })
653    }
654
655    /// Encode and assign `value` to logical index `i`.
656    pub fn set(&mut self, i: usize, value: i64) -> Result<(), SetError>
657    where
658        Ptr: AsMutPtr<Type = u8>,
659    {
660        if i >= self.len() {
661            return Err(IndexOutOfBounds::new(i, self.len()).into());
662        }
663
664        let encoded = Repr::encode(value)?;
665
666        // SAFETY: We've performed the bounds check.
667        unsafe { self.set_unchecked(i, encoded) }
668        Ok(())
669    }
670
671    /// Assign `value` to logical index `i`.
672    ///
673    /// # Safety
674    ///
675    /// Argument `i` must be in bounds: `0 <= i < self.len()`.
676    pub unsafe fn set_unchecked(&mut self, i: usize, encoded: u8)
677    where
678        Ptr: AsMutPtr<Type = u8>,
679    {
680        debug_assert!(i < self.len());
681        debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
682
683        // SAFETY: We maintain the invariant that
684        // `self.as_slice().len() == Perm::bytes(self.len())`.
685        //
686        // So, `i < self.len()` implies we uphold the safety requirements of `unpack`.
687        unsafe { Perm::pack(self.as_mut_slice(), i, encoded) }
688    }
689
690    /// Return the domain of acceptable values.
691    pub fn domain(&self) -> Repr::Domain {
692        Repr::domain()
693    }
694
695    pub(crate) fn as_slice(&self) -> &'_ [u8] {
696        // SAFETY: This class has the invariant that the backing storage must be initialized
697        // and exist in a single allocation containing at least
698        // `[self.ptr.as_ptr(), self.ptr_ptr() + self.bytes())`.
699        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.bytes()) }
700    }
701
702    /// Return a pointer to the beginning of the memory associated with this slice.
703    ///
704    /// # NOTE
705    ///
706    /// The memory span underlying this instances is valid for `self.bytes()`, not
707    /// necessarily `self.len()`.
708    pub fn as_ptr(&self) -> *const u8 {
709        self.ptr.as_ptr()
710    }
711
712    /// This function is very easy to use incorrectly and hence is crate-local.
713    pub(super) fn as_mut_slice(&mut self) -> &'_ mut [u8]
714    where
715        Ptr: AsMutPtr,
716    {
717        // SAFETY: This class has the invariant that the backing storage must be initialized
718        // and exist in a single allocation containing at least
719        // `[self.ptr.as_ptr(), self.ptr_ptr() + self.bytes())`.
720        //
721        // A mutable reference to self with `Ptr: AsMutPtr` attests to the fact that we
722        // have an exclusive borrow over the underlying memory.
723        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut_ptr(), self.bytes()) }
724    }
725
726    /// This function is very easy to use incorrectly and hence is private.
727    fn as_mut_ptr(&mut self) -> *mut u8
728    where
729        Ptr: AsMutPtr,
730    {
731        self.ptr.as_mut_ptr()
732    }
733}
734
735impl<const NBITS: usize, Repr, Perm, Len>
736    BitSliceBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Perm, Len>
737where
738    Repr: Representation<NBITS>,
739    Perm: PermutationStrategy<NBITS>,
740    Len: Length,
741{
742    /// Construct a new owning `BitSlice` capable of holding `Count` logical values.
743    /// The slice is initialized in a valid but undefined state.
744    ///
745    /// # Example
746    ///
747    /// ```
748    /// use diskann_quantization::bits::{BoxedBitSlice, Unsigned};
749    /// let mut x = BoxedBitSlice::<3, Unsigned>::new_boxed(4);
750    /// x.set(0, 0).unwrap();
751    /// x.set(1, 2).unwrap();
752    /// x.set(2, 4).unwrap();
753    /// x.set(3, 6).unwrap();
754    ///
755    /// assert_eq!(x.get(0).unwrap(), 0);
756    /// assert_eq!(x.get(1).unwrap(), 2);
757    /// assert_eq!(x.get(2).unwrap(), 4);
758    /// assert_eq!(x.get(3).unwrap(), 6);
759    /// ```
760    pub fn new_boxed<Count>(count: Count) -> Self
761    where
762        Count: Into<Len>,
763    {
764        let count: Len = count.into();
765        let bytes = Self::bytes_for(count.value());
766        let storage: Box<[u8]> = (0..bytes).map(|_| 0).collect();
767
768        // SAFETY: We've ensured that the backing storage has the correct number of bytes
769        // as required by the count and PermutationStrategy.
770        //
771        // Since this is owned storage, we do not need to worry about capturing lifetimes.
772        unsafe { Self::new_unchecked(Poly::from(storage), count) }
773    }
774}
775
776impl<const NBITS: usize, Repr, Perm, Len, A> BitSliceBase<NBITS, Repr, Poly<[u8], A>, Perm, Len>
777where
778    Repr: Representation<NBITS>,
779    Perm: PermutationStrategy<NBITS>,
780    Len: Length,
781    A: AllocatorCore,
782{
783    /// Construct a new owning `BitSlice` capable of holding `Count` logical values using
784    /// the provided allocator.
785    ///
786    /// The slice is initialized in a valid but undefined state.
787    ///
788    /// # Example
789    ///
790    /// ```
791    /// use diskann_quantization::{
792    ///     alloc::GlobalAllocator,
793    ///     bits::{BoxedBitSlice, Unsigned}
794    /// };
795    /// let mut x = BoxedBitSlice::<3, Unsigned>::new_in(4, GlobalAllocator).unwrap();
796    /// x.set(0, 0).unwrap();
797    /// x.set(1, 2).unwrap();
798    /// x.set(2, 4).unwrap();
799    /// x.set(3, 6).unwrap();
800    ///
801    /// assert_eq!(x.get(0).unwrap(), 0);
802    /// assert_eq!(x.get(1).unwrap(), 2);
803    /// assert_eq!(x.get(2).unwrap(), 4);
804    /// assert_eq!(x.get(3).unwrap(), 6);
805    /// ```
806    pub fn new_in<Count>(count: Count, allocator: A) -> Result<Self, AllocatorError>
807    where
808        Count: Into<Len>,
809    {
810        let count: Len = count.into();
811        let bytes = Self::bytes_for(count.value());
812        let storage = Poly::broadcast(0, bytes, allocator)?;
813
814        // SAFETY: We've ensured that the backing storage has the correct number of bytes
815        // as required by the count and PermutationStrategy.
816        //
817        // Since this is owned storage, we do not need to worry about capturing lifetimes.
818        Ok(unsafe { Self::new_unchecked(storage, count) })
819    }
820
821    /// Consume `self` and return the boxed allocation.
822    pub fn into_inner(self) -> Poly<[u8], A> {
823        self.ptr
824    }
825}
826
827/// The layout for `N`-bit integers that references a raw underlying slice.
828pub type BitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
829    BitSliceBase<N, Repr, SlicePtr<'a, u8>, Perm, Len>;
830
831/// The layout for `N`-bit integers that mutable references a raw underlying slice.
832pub type MutBitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
833    BitSliceBase<N, Repr, MutSlicePtr<'a, u8>, Perm, Len>;
834
835/// The layout for `N`-bit integers that own the underlying slice.
836pub type PolyBitSlice<const N: usize, Repr, A, Perm = Dense, Len = Dynamic> =
837    BitSliceBase<N, Repr, Poly<[u8], A>, Perm, Len>;
838
839/// The layout for `N`-bit integers that own the underlying slice.
840pub type BoxedBitSlice<const N: usize, Repr, Perm = Dense, Len = Dynamic> =
841    PolyBitSlice<N, Repr, GlobalAllocator, Perm, Len>;
842
843///////////////////////////////
844// Special Cased Conversions //
845///////////////////////////////
846
847impl<'a, Ptr> From<&'a BitSliceBase<8, Unsigned, Ptr>> for &'a [u8]
848where
849    Ptr: AsPtr<Type = u8>,
850{
851    fn from(slice: &'a BitSliceBase<8, Unsigned, Ptr>) -> Self {
852        // SAFETY: The original pointer must have been obtained from a slice of the
853        // appropriate length.
854        //
855        // Furthermore, the layout of this type of slice is guaranteed to be identical
856        // to the layout of a `[u8]`.
857        unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) }
858    }
859}
860
861impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> Reborrow<'this>
862    for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
863where
864    Repr: Representation<NBITS>,
865    Ptr: AsPtr<Type = u8>,
866    Perm: PermutationStrategy<NBITS>,
867    Len: Length,
868{
869    type Target = BitSlice<'this, NBITS, Repr, Perm, Len>;
870
871    fn reborrow(&'this self) -> Self::Target {
872        let ptr: *const u8 = self.as_ptr();
873        debug_assert!(!ptr.is_null());
874
875        // Safety: `AsPtr` may never return null pointers.
876        // The `cast_mut()` is safe because `SlicePtr` does not provide a way of retrieving
877        // a mutable pointer.
878        let nonnull = unsafe { NonNull::new_unchecked(ptr.cast_mut()) };
879
880        // Safety: By struct invariant,
881        // `[self.ptr(), self.ptr() + Self::bytes_for(self.len()))` is a valid slice, so
882        // the returned object will also uphold these invariants.
883        //
884        // The returned struct will not outlive `&'this self`, so we've attached the
885        // proper lifetime.
886        let ptr = unsafe { SlicePtr::new_unchecked(nonnull) };
887
888        Self::Target {
889            ptr,
890            len: self.len,
891            repr: PhantomData,
892            packing: PhantomData,
893        }
894    }
895}
896
897impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> ReborrowMut<'this>
898    for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
899where
900    Repr: Representation<NBITS>,
901    Ptr: AsMutPtr<Type = u8>,
902    Perm: PermutationStrategy<NBITS>,
903    Len: Length,
904{
905    type Target = MutBitSlice<'this, NBITS, Repr, Perm, Len>;
906
907    fn reborrow_mut(&'this mut self) -> Self::Target {
908        let ptr: *mut u8 = self.as_mut_ptr();
909        debug_assert!(!ptr.is_null());
910
911        // Safety: `AsMutPtr` may never return null pointers.
912        let nonnull = unsafe { NonNull::new_unchecked(ptr) };
913
914        // Safety: By struct invariant,
915        // `[self.ptr(), self.ptr() + Self::bytes_for(self.len()))` is a valid slice, so
916        // the returned object will also uphold these invariants.
917        //
918        // The returned struct will not outlive `&'this mut self`, so we've attached the
919        // proper lifetime.
920        //
921        // Exclusive ownership is attested by both `AsMutPtr` and the mutable refernce
922        // to self.
923        let ptr = unsafe { MutSlicePtr::new_unchecked(nonnull) };
924
925        Self::Target {
926            ptr,
927            len: self.len,
928            repr: PhantomData,
929            packing: PhantomData,
930        }
931    }
932}
933
934///////////
935// Tests //
936///////////
937
938#[cfg(test)]
939mod tests {
940    use rand::{
941        Rng, SeedableRng,
942        distr::{Distribution, Uniform},
943        rngs::StdRng,
944        seq::{IndexedRandom, SliceRandom},
945    };
946
947    use super::*;
948    use crate::{bits::Static, test_util::AlwaysFails};
949
950    ////////////
951    // Errors //
952    ////////////
953
954    const BOUNDS: &str = "special bounds";
955
956    #[test]
957    fn test_encoding_error() {
958        assert_eq!(std::mem::size_of::<EncodingError>(), 16);
959        assert_eq!(
960            std::mem::size_of::<Option<EncodingError>>(),
961            16,
962            "expected EncodingError to have the niche optimization"
963        );
964        let err = EncodingError::new(7, &BOUNDS);
965        assert_eq!(
966            err.to_string(),
967            "value 7 is not in the encodable range of special bounds"
968        );
969    }
970
971    // Check that a type is `Send` and `Sync`.
972    fn assert_send_and_sync<T: Send + Sync>(_x: &T) {}
973
974    ////////////
975    // Binary //
976    ////////////
977
978    #[test]
979    fn test_binary_repr() {
980        assert_eq!(Binary::encode(-1).unwrap(), 0);
981        assert_eq!(Binary::encode(1).unwrap(), 1);
982        assert_eq!(Binary::decode(0), -1);
983        assert_eq!(Binary::decode(1), 1);
984
985        assert!(Binary::check(-1));
986        assert!(Binary::check(1));
987        assert!(!Binary::check(0));
988        assert!(!Binary::check(-2));
989        assert!(!Binary::check(2));
990
991        let domain: Vec<_> = Binary::domain().collect();
992        assert_eq!(domain, &[-1, 1]);
993    }
994
995    ///////////
996    // Sizes //
997    ///////////
998
999    #[test]
1000    fn test_sizes() {
1001        assert_eq!(std::mem::size_of::<BitSlice<'static, 8, Unsigned>>(), 16);
1002        assert_eq!(std::mem::size_of::<MutBitSlice<'static, 8, Unsigned>>(), 16);
1003
1004        // Ensure the borrowed slices are eligible for niche optimization.
1005        assert_eq!(
1006            std::mem::size_of::<Option<BitSlice<'static, 8, Unsigned>>>(),
1007            16
1008        );
1009        assert_eq!(
1010            std::mem::size_of::<Option<MutBitSlice<'static, 8, Unsigned>>>(),
1011            16
1012        );
1013
1014        assert_eq!(
1015            std::mem::size_of::<BitSlice<'static, 8, Unsigned, Dense, Static<128>>>(),
1016            8
1017        );
1018    }
1019
1020    ///////////////////
1021    // General Tests //
1022    ///////////////////
1023
1024    cfg_if::cfg_if! {
1025        if #[cfg(miri)] {
1026            const MAX_DIM: usize = 160;
1027            const FUZZ_ITERATIONS: usize = 1;
1028        } else if #[cfg(debug_assertions)] {
1029            const MAX_DIM: usize = 128;
1030            const FUZZ_ITERATIONS: usize = 10;
1031        } else {
1032            const MAX_DIM: usize = 256;
1033            const FUZZ_ITERATIONS: usize = 100;
1034        }
1035    }
1036
1037    fn test_send_and_sync<const NBITS: usize, Repr, Perm>()
1038    where
1039        Repr: Representation<NBITS> + Send + Sync,
1040        Perm: PermutationStrategy<NBITS> + Send + Sync,
1041    {
1042        let mut x = BoxedBitSlice::<NBITS, Repr, Perm>::new_boxed(1);
1043        assert_send_and_sync(&x);
1044        assert_send_and_sync(&x.reborrow());
1045        assert_send_and_sync(&x.reborrow_mut());
1046    }
1047
1048    fn test_empty<const NBITS: usize, Repr, Perm>()
1049    where
1050        Repr: Representation<NBITS>,
1051        Perm: PermutationStrategy<NBITS>,
1052    {
1053        let base: &mut [u8] = &mut [];
1054        let mut slice = MutBitSlice::<NBITS, Repr, Perm>::new(base, 0).unwrap();
1055        assert_eq!(slice.len(), 0);
1056        assert!(slice.is_empty());
1057
1058        {
1059            let reborrow = slice.reborrow();
1060            assert_eq!(reborrow.len(), 0);
1061            assert!(reborrow.is_empty());
1062        }
1063
1064        {
1065            let reborrow = slice.reborrow_mut();
1066            assert_eq!(reborrow.len(), 0);
1067            assert!(reborrow.is_empty());
1068        }
1069    }
1070
1071    // times, ensuring that values are preserved.
1072    fn test_construction_errors<const NBITS: usize, Repr, Perm>()
1073    where
1074        Repr: Representation<NBITS>,
1075        Perm: PermutationStrategy<NBITS>,
1076    {
1077        let len: usize = 10;
1078        let bytes = Perm::bytes(len);
1079
1080        // Construction errors for Boxes
1081        let box_big = Poly::broadcast(0u8, bytes + 1, GlobalAllocator).unwrap();
1082        let box_small = Poly::broadcast(0u8, bytes - 1, GlobalAllocator).unwrap();
1083        let box_right = Poly::broadcast(0u8, bytes, GlobalAllocator).unwrap();
1084
1085        let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_big, len);
1086        match result {
1087            Err(ConstructionError { got, expected }) => {
1088                assert_eq!(got, bytes + 1);
1089                assert_eq!(expected, bytes);
1090            }
1091            _ => panic!("shouldn't have reached here!"),
1092        };
1093
1094        let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_small, len);
1095        match result {
1096            Err(ConstructionError { got, expected }) => {
1097                assert_eq!(got, bytes - 1);
1098                assert_eq!(expected, bytes);
1099            }
1100            _ => panic!("shouldn't have reached here!"),
1101        };
1102
1103        let mut base = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_right, len).unwrap();
1104        let ptr = base.as_ptr();
1105        assert_eq!(base.len(), len);
1106
1107        // Successful mutable reborrow and borrow.
1108        {
1109            // Use reborrow
1110            let borrowed = base.reborrow_mut();
1111            assert_eq!(borrowed.as_ptr(), ptr);
1112            assert_eq!(borrowed.len(), len);
1113
1114            // Go through a slice.
1115            let borrowed = MutBitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
1116            assert_eq!(borrowed.as_ptr(), ptr);
1117            assert_eq!(borrowed.len(), len);
1118        }
1119
1120        // Successful mutable borrow.
1121        {
1122            // Try constructing from an oversized slice.
1123            let mut oversized = vec![0; bytes + 1];
1124            let result = MutBitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
1125            match result {
1126                Err(ConstructionError { got, expected }) => {
1127                    assert_eq!(got, bytes + 1);
1128                    assert_eq!(expected, bytes);
1129                }
1130                _ => panic!("shouldn't have reached here!"),
1131            };
1132
1133            let mut undersized = vec![0; bytes - 1];
1134            let result = MutBitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
1135            match result {
1136                Err(ConstructionError { got, expected }) => {
1137                    assert_eq!(got, bytes - 1);
1138                    assert_eq!(expected, bytes);
1139                }
1140                _ => panic!("shouldn't have reached here!"),
1141            };
1142        }
1143
1144        // Successful const borrow and reborrow.
1145        {
1146            // Use reborrow
1147            let borrowed = base.reborrow();
1148            assert_eq!(borrowed.as_ptr(), ptr);
1149            assert_eq!(borrowed.len(), len);
1150
1151            // Go through a slice.
1152            let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_slice(), len).unwrap();
1153            assert_eq!(borrowed.as_ptr(), ptr);
1154            assert_eq!(borrowed.len(), len);
1155
1156            // Go through a mutable slice.
1157            let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
1158            assert_eq!(borrowed.as_ptr(), ptr);
1159            assert_eq!(borrowed.len(), len);
1160        }
1161
1162        // Successful mutable borrow.
1163        {
1164            // Try constructing from an oversized slice.
1165            let mut oversized = vec![0; bytes + 1];
1166            let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
1167            match result {
1168                Err(ConstructionError { got, expected }) => {
1169                    assert_eq!(got, bytes + 1);
1170                    assert_eq!(expected, bytes);
1171                }
1172                _ => panic!("shouldn't have reached here!"),
1173            };
1174
1175            let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_slice(), len);
1176            match result {
1177                Err(ConstructionError { got, expected }) => {
1178                    assert_eq!(got, bytes + 1);
1179                    assert_eq!(expected, bytes);
1180                }
1181                _ => panic!("shouldn't have reached here!"),
1182            };
1183
1184            // Try constructing from an undersized slice.
1185            let mut undersized = vec![0; bytes - 1];
1186            let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
1187            match result {
1188                Err(ConstructionError { got, expected }) => {
1189                    assert_eq!(got, bytes - 1);
1190                    assert_eq!(expected, bytes);
1191                }
1192                _ => panic!("shouldn't have reached here!"),
1193            };
1194
1195            let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_slice(), len);
1196            match result {
1197                Err(ConstructionError { got, expected }) => {
1198                    assert_eq!(got, bytes - 1);
1199                    assert_eq!(expected, bytes);
1200                }
1201                _ => panic!("shouldn't have reached here!"),
1202            };
1203        }
1204    }
1205
1206    // This series of tests writes to all indices in the vector in random orders multiple
1207    // times, ensuring that values are preserved.
1208    fn run_overwrite_test<const NBITS: usize, Perm, Len, R>(
1209        base: &mut BoxedBitSlice<NBITS, Unsigned, Perm, Len>,
1210        num_iterations: usize,
1211        rng: &mut R,
1212    ) where
1213        Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
1214        Len: Length,
1215        Perm: PermutationStrategy<NBITS>,
1216        R: Rng,
1217    {
1218        let mut expected: Vec<i64> = vec![0; base.len()];
1219        let mut indices: Vec<usize> = (0..base.len()).collect();
1220        for i in 0..base.len() {
1221            base.set(i, 0).unwrap();
1222        }
1223
1224        for i in 0..base.len() {
1225            assert_eq!(base.get(i).unwrap(), 0, "failed to initialize bit vector");
1226        }
1227
1228        let domain = base.domain();
1229        assert_eq!(domain, 0..=2i64.pow(NBITS as u32) - 1);
1230        let distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
1231
1232        for iter in 0..num_iterations {
1233            // Shuffle insertion order.
1234            indices.shuffle(rng);
1235
1236            // Insert random values.
1237            for &i in indices.iter() {
1238                let value = distribution.sample(rng);
1239                expected[i] = value;
1240                base.set(i, value).unwrap();
1241            }
1242
1243            // Make sure values are preserved.
1244            for (i, &expect) in expected.iter().enumerate() {
1245                let value = base.get(i).unwrap();
1246                assert_eq!(
1247                    value, expect,
1248                    "retrieval failed on iteration {iter} at index {i}"
1249                );
1250            }
1251
1252            // Make sure the reborrowed version matches.
1253            let borrowed = base.reborrow();
1254            for (i, &expect) in expected.iter().enumerate() {
1255                let value = borrowed.get(i).unwrap();
1256                assert_eq!(
1257                    value, expect,
1258                    "reborrow retrieval failed on iteration {iter} at index {i}"
1259                );
1260            }
1261        }
1262    }
1263
1264    fn run_overwrite_binary_test<Perm, Len, R>(
1265        base: &mut BoxedBitSlice<1, Binary, Perm, Len>,
1266        num_iterations: usize,
1267        rng: &mut R,
1268    ) where
1269        Len: Length,
1270        Perm: PermutationStrategy<1>,
1271        R: Rng,
1272    {
1273        let mut expected: Vec<i64> = vec![0; base.len()];
1274        let mut indices: Vec<usize> = (0..base.len()).collect();
1275        for i in 0..base.len() {
1276            base.set(i, -1).unwrap();
1277        }
1278
1279        for i in 0..base.len() {
1280            assert_eq!(base.get(i).unwrap(), -1, "failed to initialize bit vector");
1281        }
1282
1283        let distribution: [i64; 2] = [-1, 1];
1284
1285        for iter in 0..num_iterations {
1286            // Shuffle insertion order.
1287            indices.shuffle(rng);
1288
1289            // Insert random values.
1290            for &i in indices.iter() {
1291                let value = distribution.choose(rng).unwrap();
1292                expected[i] = *value;
1293                base.set(i, *value).unwrap();
1294            }
1295
1296            // Make sure values are preserved.
1297            for (i, &expect) in expected.iter().enumerate() {
1298                let value = base.get(i).unwrap();
1299                assert_eq!(
1300                    value, expect,
1301                    "retrieval failed on iteration {iter} at index {i}"
1302                );
1303            }
1304
1305            // Make sure the reborrowed version matches.
1306            let borrowed = base.reborrow();
1307            for (i, &expect) in expected.iter().enumerate() {
1308                let value = borrowed.get(i).unwrap();
1309                assert_eq!(
1310                    value, expect,
1311                    "reborrow retrieval failed on iteration {iter} at index {i}"
1312                );
1313            }
1314        }
1315    }
1316
1317    //////////////////////
1318    // Unsigned - Dense //
1319    //////////////////////
1320
1321    fn test_unsigned_dense<const NBITS: usize, Len, R>(
1322        len: Len,
1323        minimum: i64,
1324        maximum: i64,
1325        rng: &mut R,
1326    ) where
1327        Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
1328        Dense: PermutationStrategy<NBITS>,
1329        Len: Length,
1330        R: Rng,
1331    {
1332        test_send_and_sync::<NBITS, Unsigned, Dense>();
1333        test_empty::<NBITS, Unsigned, Dense>();
1334        test_construction_errors::<NBITS, Unsigned, Dense>();
1335        assert_eq!(Unsigned::domain_const::<NBITS>(), Unsigned::domain(),);
1336
1337        match PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, AlwaysFails) {
1338            Ok(_) => {
1339                if len.value() != 0 {
1340                    panic!("zero sized allocations don't require an allocator");
1341                }
1342            }
1343            Err(AllocatorError) => {
1344                if len.value() == 0 {
1345                    panic!("allocation should have failed");
1346                }
1347            }
1348        }
1349
1350        let mut base =
1351            PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, GlobalAllocator).unwrap();
1352        assert_eq!(
1353            base.len(),
1354            len.value(),
1355            "BoxedBitSlice returned the incorrect length"
1356        );
1357
1358        let expected_bytes = BitSlice::<'static, NBITS, Unsigned>::bytes_for(len.value());
1359        assert_eq!(
1360            base.bytes(),
1361            expected_bytes,
1362            "BoxedBitSlice has the incorrect number of bytes"
1363        );
1364
1365        // Check that the minimum and maximum values reported by the struct are correct.
1366        assert_eq!(base.domain(), minimum..=maximum);
1367
1368        if len.value() == 0 {
1369            return;
1370        }
1371
1372        let ptr = base.as_ptr();
1373
1374        // Now that we know the length is non-zero, we can try testing the interface.
1375        // Setting the lowest index should always work.
1376        {
1377            let mut borrowed = base.reborrow_mut();
1378
1379            // Make sure the pointer is preserved.
1380            assert_eq!(
1381                borrowed.as_ptr(),
1382                ptr,
1383                "pointer was not preserved during borrowing!"
1384            );
1385            assert_eq!(
1386                borrowed.len(),
1387                len.value(),
1388                "borrowing did not preserve length!"
1389            );
1390
1391            borrowed.set(0, 0).unwrap();
1392            assert_eq!(borrowed.get(0).unwrap(), 0);
1393
1394            borrowed.set(0, 1).unwrap();
1395            assert_eq!(borrowed.get(0).unwrap(), 1);
1396
1397            borrowed.set(0, 0).unwrap();
1398            assert_eq!(borrowed.get(0).unwrap(), 0);
1399
1400            // Setting to an invalid value should yield an error.
1401            let result = borrowed.set(0, minimum - 1);
1402            assert!(matches!(result, Err(SetError::EncodingError { .. })));
1403
1404            let result = borrowed.set(0, maximum + 1);
1405            assert!(matches!(result, Err(SetError::EncodingError { .. })));
1406
1407            // Make sure an out-of-bounds access is caught.
1408            let result = borrowed.set(borrowed.len(), 0);
1409            assert!(matches!(result, Err(SetError::IndexError { .. })));
1410
1411            // Ensure that getting out-of-bounds is an error.
1412            let result = borrowed.get(borrowed.len());
1413            assert!(matches!(result, Err(GetError::IndexError { .. })));
1414        }
1415
1416        {
1417            // Reconsturct the mutable borrow directly through a slice.
1418            let borrowed =
1419                MutBitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
1420
1421            assert_eq!(
1422                borrowed.as_ptr(),
1423                ptr,
1424                "pointer was not preserved during borrowing!"
1425            );
1426            assert_eq!(
1427                borrowed.len(),
1428                len.value(),
1429                "borrowing did not preserve length!"
1430            );
1431        }
1432
1433        {
1434            let borrowed = base.reborrow();
1435
1436            // Make sure the pointer is preserved.
1437            assert_eq!(
1438                borrowed.as_ptr(),
1439                ptr,
1440                "pointer was not preserved during borrowing!"
1441            );
1442
1443            assert_eq!(
1444                borrowed.len(),
1445                len.value(),
1446                "borrowing did not preserve length!"
1447            );
1448
1449            // Ensure that getting out-of-bounds is an error.
1450            let result = borrowed.get(borrowed.len());
1451            assert!(matches!(result, Err(GetError::IndexError { .. })));
1452        }
1453
1454        {
1455            // Reconsturct the mutable borrow directly through a slice.
1456            let borrowed =
1457                BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_slice(), len).unwrap();
1458
1459            assert_eq!(
1460                borrowed.as_ptr(),
1461                ptr,
1462                "pointer was not preserved during borrowing!"
1463            );
1464            assert_eq!(
1465                borrowed.len(),
1466                len.value(),
1467                "borrowing did not preserve length!"
1468            );
1469        }
1470
1471        {
1472            // Reconsturct the mutable borrow directly through a slice.
1473            let borrowed =
1474                BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
1475
1476            assert_eq!(
1477                borrowed.as_ptr(),
1478                ptr,
1479                "pointer was not preserved during borrowing!"
1480            );
1481            assert_eq!(
1482                borrowed.len(),
1483                len.value(),
1484                "borrowing did not preserve length!"
1485            );
1486        }
1487
1488        // Now we begin the testing loop.
1489        run_overwrite_test(&mut base, FUZZ_ITERATIONS, rng);
1490    }
1491
1492    macro_rules! generate_unsigned_test {
1493        ($name:ident, $NBITS:literal, $MIN:literal, $MAX:literal, $SEED:literal) => {
1494            #[test]
1495            fn $name() {
1496                let mut rng = StdRng::seed_from_u64($SEED);
1497                for dim in 0..MAX_DIM {
1498                    test_unsigned_dense::<$NBITS, Dynamic, _>(dim.into(), $MIN, $MAX, &mut rng);
1499                }
1500            }
1501        };
1502    }
1503
1504    generate_unsigned_test!(test_unsigned_8bit, 8, 0, 0xff, 0xc652f2a1018f442b);
1505    generate_unsigned_test!(test_unsigned_7bit, 7, 0, 0x7f, 0xb732e59fec6d6c9c);
1506    generate_unsigned_test!(test_unsigned_6bit, 6, 0, 0x3f, 0x35d9380d0a318f21);
1507    generate_unsigned_test!(test_unsigned_5bit, 5, 0, 0x1f, 0xfb09895183334304);
1508    generate_unsigned_test!(test_unsigned_4bit, 4, 0, 0x0f, 0x38dfcf9e82c33f48);
1509    generate_unsigned_test!(test_unsigned_3bit, 3, 0, 0x07, 0xf9a94c8c749ee26c);
1510    generate_unsigned_test!(test_unsigned_2bit, 2, 0, 0x03, 0xbba03db62cecf4cf);
1511    generate_unsigned_test!(test_unsigned_1bit, 1, 0, 0x01, 0x54ea2a07d7c67f37);
1512
1513    #[test]
1514    fn test_binary_dense() {
1515        let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
1516        for len in 0..MAX_DIM {
1517            test_send_and_sync::<1, Binary, Dense>();
1518            test_empty::<1, Binary, Dense>();
1519            test_construction_errors::<1, Binary, Dense>();
1520
1521            // Create a boxed base.
1522            let mut base = BoxedBitSlice::<1, Binary>::new_boxed(len);
1523            assert_eq!(
1524                base.len(),
1525                len,
1526                "BoxedBitSlice returned the incorrect length"
1527            );
1528
1529            assert_eq!(base.bytes(), len.div_ceil(8));
1530
1531            let bytes = BitSlice::<'static, 1, Binary>::bytes_for(len);
1532            assert_eq!(
1533                bytes,
1534                len.div_ceil(8),
1535                "BoxedBitSlice has the incorrect number of bytes"
1536            );
1537
1538            if len == 0 {
1539                continue;
1540            }
1541
1542            // Setting to an invalid value should yield an error.
1543            let result = base.set(0, 0);
1544            assert!(matches!(result, Err(SetError::EncodingError { .. })));
1545
1546            // Make sure an out-of-bounds access is caught.
1547            let result = base.set(base.len(), -1);
1548            assert!(matches!(result, Err(SetError::IndexError { .. })));
1549
1550            // Ensure that getting out-of-bounds is an error.
1551            let result = base.get(base.len());
1552            assert!(matches!(result, Err(GetError::IndexError { .. })));
1553
1554            // Now we begin the testing loop.
1555            run_overwrite_binary_test(&mut base, FUZZ_ITERATIONS, &mut rng);
1556        }
1557    }
1558
1559    #[test]
1560    fn test_4bit_bit_transpose() {
1561        let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
1562        for len in 0..MAX_DIM {
1563            test_send_and_sync::<4, Unsigned, BitTranspose>();
1564            test_empty::<4, Unsigned, BitTranspose>();
1565            test_construction_errors::<4, Unsigned, BitTranspose>();
1566
1567            // Create a boxed base.
1568            let mut base = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(len);
1569            assert_eq!(
1570                base.len(),
1571                len,
1572                "BoxedBitSlice returned the incorrect length"
1573            );
1574
1575            assert_eq!(base.bytes(), 32 * len.div_ceil(64));
1576
1577            let bytes = BitSlice::<'static, 4, Unsigned, BitTranspose>::bytes_for(len);
1578            assert_eq!(
1579                bytes,
1580                32 * len.div_ceil(64),
1581                "BoxedBitSlice has the incorrect number of bytes"
1582            );
1583
1584            if len == 0 {
1585                continue;
1586            }
1587
1588            // Setting to an invalid value should yield an error.
1589            let result = base.set(0, -1);
1590            assert!(matches!(result, Err(SetError::EncodingError { .. })));
1591
1592            // Make sure an out-of-bounds access is caught.
1593            let result = base.set(base.len(), -1);
1594            assert!(matches!(result, Err(SetError::IndexError { .. })));
1595
1596            // Ensure that getting out-of-bounds is an error.
1597            let result = base.get(base.len());
1598            assert!(matches!(result, Err(GetError::IndexError { .. })));
1599
1600            // Now we begin the testing loop.
1601            run_overwrite_test(&mut base, FUZZ_ITERATIONS, &mut rng);
1602        }
1603    }
1604}