Skip to main content

diskann_quantization/bits/
slice.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{marker::PhantomData, ops::RangeInclusive, ptr::NonNull};
7
8use diskann_utils::{Reborrow, ReborrowMut};
9use thiserror::Error;
10
11use super::{
12    length::{Dynamic, Length},
13    packing,
14    ptr::{AsMutPtr, AsPtr, MutSlicePtr, Precursor, SlicePtr},
15};
16use crate::{
17    alloc::{AllocatorCore, AllocatorError, GlobalAllocator, Poly},
18    utils,
19};
20
21//////////////////////
22// Retrieval Traits //
23//////////////////////
24
25/// Representation of `NBITS` bit numbers in the associated domain.
26pub trait Representation<const NBITS: usize> {
27    /// The type of the domain accepted by this representation.
28    type Domain: Iterator<Item = i64>;
29
30    /// Encode `value` into the lower order bits of a byte. Returns the encoded value on
31    /// success, or an `EncodingError` if the value is unencodable.
32    fn encode(value: i64) -> Result<u8, EncodingError>;
33
34    /// Encode `value` into the lower order bits of a byte without checking if `value`
35    /// is encodable. This function is not marked as unsafe because in-and-of itself, it
36    /// won't cause memory safety issues.
37    ///
38    /// This may panic in debug mode when `value` is outside of this representation's
39    /// domain.
40    fn encode_unchecked(value: i64) -> u8;
41
42    /// Decode a previously encoded value. The result will be in the range
43    /// `[Self::MIN, Self::MAX]`.
44    ///
45    /// # Panics
46    ///
47    /// May panic in debug builds if `raw` is not a valid pattern emitted by `encode`.
48    fn decode(raw: u8) -> i64;
49
50    /// Check whether or not the argument is in the domain.
51    fn check(value: i64) -> bool;
52
53    /// Return an iterator over the domain of representable values.
54    fn domain() -> Self::Domain;
55}
56
57#[derive(Debug, Error, Clone, Copy)]
58#[error("value {} is not in the encodable range of {}", got, domain)]
59pub struct EncodingError {
60    got: i64,
61    // Question: Why is this a ref-ref??
62    //
63    // Answer: I have a personal vendetta to keep this struct within 16-bytes with a
64    // niche-optimization. A `&'static src` is 16 bytes in and of itself. But a
65    // `&'static &'static str`, now *that's* just 8 bytes.
66    domain: &'static &'static str,
67}
68
69impl EncodingError {
70    fn new(got: i64, domain: &'static &'static str) -> Self {
71        Self { got, domain }
72    }
73}
74
75//////////////
76// Unsigned //
77//////////////
78
79/// Storage unsigned integers in slices.
80///
81/// For a bit count of `NBITS`, the `Unsigned` type can store unsigned integers in
82/// the range `[0, 2^NBITS - 1]`.
83#[derive(Debug, Clone, Copy)]
84pub struct Unsigned;
85
86impl Unsigned {
87    /// Return the dynamic range of an `Unsigned` encoding for `NBITS`.
88    pub const fn domain_const<const NBITS: usize>() -> std::ops::RangeInclusive<i64> {
89        0..=2i64.pow(NBITS as u32) - 1
90    }
91
92    #[allow(clippy::panic)]
93    const fn domain_str(nbits: usize) -> &'static &'static str {
94        match nbits {
95            8 => &"[0, 255]",
96            7 => &"[0, 127]",
97            6 => &"[0, 63]",
98            5 => &"[0, 31]",
99            4 => &"[0, 15]",
100            3 => &"[0, 7]",
101            2 => &"[0, 3]",
102            1 => &"[0, 1]",
103            _ => panic!("unimplemented"),
104        }
105    }
106}
107
108macro_rules! repr_unsigned {
109    ($N:literal) => {
110        impl Representation<$N> for Unsigned {
111            type Domain = RangeInclusive<i64>;
112
113            fn encode(value: i64) -> Result<u8, EncodingError> {
114                if !<Self as Representation<$N>>::check(value) {
115                    // Even with the macro gymnastics - we still have to manually inline
116                    // this computation :(
117                    let domain = Self::domain_str($N);
118                    Err(EncodingError::new(value, domain))
119                } else {
120                    Ok(<Self as Representation<$N>>::encode_unchecked(value))
121                }
122            }
123
124            fn encode_unchecked(value: i64) -> u8 {
125                debug_assert!(<Self as Representation<$N>>::check(value));
126                value as u8
127            }
128
129            fn decode(raw: u8) -> i64 {
130                // Feed through the value un-modified.
131                let raw: i64 = raw.into();
132                debug_assert!(<Self as Representation<$N>>::check(raw));
133                raw
134            }
135
136            fn check(value: i64) -> bool {
137                <Self as Representation<$N>>::domain().contains(&value)
138            }
139
140            fn domain() -> Self::Domain {
141                Self::domain_const::<$N>()
142            }
143        }
144    };
145    ($N:literal, $($Ns:literal),+) => {
146        repr_unsigned!($N);
147        $(repr_unsigned!($Ns);)+
148    };
149}
150
151repr_unsigned!(1, 2, 3, 4, 5, 6, 7, 8);
152
153////////////
154// Binary //
155////////////
156
157/// A 1-bit binary quantization mapping `-1` to `0` and `1` to `1`.
158#[derive(Debug, Clone, Copy)]
159pub struct Binary;
160
161impl Representation<1> for Binary {
162    type Domain = std::array::IntoIter<i64, 2>;
163
164    fn encode(value: i64) -> Result<u8, EncodingError> {
165        if !Self::check(value) {
166            const DOMAIN: &str = "{-1, 1}";
167            Err(EncodingError::new(value, &DOMAIN))
168        } else {
169            Ok(Self::encode_unchecked(value))
170        }
171    }
172
173    fn encode_unchecked(value: i64) -> u8 {
174        debug_assert!(Self::check(value));
175        // The use of `clamp` here is a quick way of sending `-1` to `0` and `1` to `1`.
176        value.clamp(0, 1) as u8
177    }
178
179    fn decode(raw: u8) -> i64 {
180        // Raw is either 0 or 1. We want to map it to -1 or 1.
181        // We can do this by multiplying by 2 and subtracting 1.
182        let raw: i64 = raw.into();
183        (raw << 1) - 1
184    }
185
186    fn check(value: i64) -> bool {
187        value == -1 || value == 1
188    }
189
190    /// Return the domain of the encoding.
191    ///
192    /// The domain is the set `{-1, 1}`.
193    fn domain() -> Self::Domain {
194        [-1, 1].into_iter()
195    }
196}
197
198////////////////////////
199// Permutation Traits //
200////////////////////////
201
202/// A enable the dimensions within a BitSlice to be permuted in an arbitrary way.
203///
204/// # Safety
205///
206/// This provides the computation for the number of bytes required to store a given number
207/// of `NBITS` bit-packed values. Improper implementation will result in out-of-bounds
208/// accesses being made.
209///
210/// The following must hold:
211///
212/// For all counts values `c`, let `b = Self::bytes(c)` be requested number of bytes for `c`
213/// for this permutation strategy and `s` be a slice of bytes with length `c`. Then, for all
214/// `i < c`, `Self::pack(s, i, _)` and `Self::unpack(s, i)` must only access `s` in-bounds.
215///
216/// This implementation must be such that unsafe code can rely on this property holding.
217pub unsafe trait PermutationStrategy<const NBITS: usize> {
218    /// Return the number of bytes required to store `count` values of with `NBITS`.
219    fn bytes(count: usize) -> usize;
220
221    /// Pack the lower `NBITS` bits of `value` into `s` at logical index `i`.
222    ///
223    /// # Safety
224    ///
225    /// This is a tricky function to call with several subtle requirements.
226    ///
227    /// * Let `s` be a slice of length `c` where `c = Self::bytes(b)` for some `b`. Then
228    ///   this function is safe to call if `i` is in `[0, b)`.
229    unsafe fn pack(s: &mut [u8], i: usize, value: u8);
230
231    /// Unpack the value stored at logical index `i` and return it as the lower `NBITS` bits
232    /// in the return value.
233    ///
234    /// # Safety
235    ///
236    /// This is a tricky function to call with several subtle requirements.
237    ///
238    /// * Let `s` be a slice of length `c` where `c = Self::bytes(b)` for some `b`. Then
239    ///   this function is safe to call if `i` is in `[0, b)`.
240    unsafe fn unpack(s: &[u8], i: usize) -> u8;
241}
242
243/// The identity permutation strategy.
244///
245/// All values are densly packed.
246#[derive(Debug, Clone, Copy)]
247pub struct Dense;
248
249impl Dense {
250    fn bytes<const NBITS: usize>(count: usize) -> usize {
251        utils::div_round_up(NBITS * count, 8)
252    }
253}
254
255/// Safety: For all `0 <= i < count`, `NBITS * i <= 8 * ceil((NBITS * count) / 8)`.
256unsafe impl<const NBITS: usize> PermutationStrategy<NBITS> for Dense {
257    fn bytes(count: usize) -> usize {
258        Self::bytes::<NBITS>(count)
259    }
260
261    unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
262        let bitaddress = NBITS * i;
263
264        let bytestart = bitaddress / 8;
265        let bytestop = (bitaddress + NBITS - 1) / 8;
266        let bitstart = bitaddress - 8 * bytestart;
267        debug_assert!(bytestop < data.len());
268
269        if bytestart == bytestop {
270            // SAFETY: This is safe for the following:
271            // ```
272            // data.len() >= ceil(NBITS * i / 8)        from `pack`'s safety requirements.
273            //            >= floor(NBITS * i / 8)
274            //            = bytestart
275            //
276            // Since we are only reading one byte - this is in-bounds.
277            // ```
278            let raw = unsafe { data.as_ptr().add(bytestart).read() };
279            let packed = packing::pack_u8::<NBITS>(raw, encoded, bitstart);
280
281            // SAFETY: See previous argument for in-bounds access.
282            // For writing, we are the only writers in this function and we have a mutable
283            // reference to `data`.
284            unsafe { data.as_mut_ptr().add(bytestart).write(packed) };
285        } else {
286            // SAFETY: This is safe for the following reason:
287            // ```
288            // data.len() >= ceil(NBITS * i / 8)        from `pack`'s safety requirements.
289            //            = bytestop
290            //            = bytestart + 1
291            // ```
292            // Therefore, it is safe to read 2-bytes starting at `bytestart`.
293            let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
294            let packed = packing::pack_u16::<NBITS>(raw, encoded, bitstart);
295
296            // SAFETY: See previous argument for in-bounds access.
297            // For writing, we are the only writers in this function and we have a mutable
298            // reference to `data`.
299            unsafe {
300                data.as_mut_ptr()
301                    .add(bytestart)
302                    .cast::<u16>()
303                    .write_unaligned(packed)
304            };
305        }
306    }
307
308    unsafe fn unpack(data: &[u8], i: usize) -> u8 {
309        let bitaddress = NBITS * i;
310
311        let bytestart = bitaddress / 8;
312        let bytestop = (bitaddress + NBITS - 1) / 8;
313        debug_assert!(bytestop < data.len());
314        if bytestart == bytestop {
315            // SAFETY: See the safety argument in `pack` for in-bounds.
316            let raw = unsafe { data.as_ptr().add(bytestart).read() };
317            packing::unpack_u8::<NBITS>(raw, bitaddress - 8 * bytestart)
318        } else {
319            // SAFETY: See the safety argument in `pack` for in-bounds.
320            let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
321            packing::unpack_u16::<NBITS>(raw, bitaddress - 8 * bytestart)
322        }
323    }
324}
325
326/// A layout specialized for performing multi-bit operations with 1-bit scalar quantization.
327///
328/// The layout provided by this struct is as follows. Assume we are compressing `N` bit data.
329/// Then, the store the data in blocks of `64 * N` bits (where 64 comes from the native CPU
330/// word size).
331///
332/// Each block can contain 64 values, stored in `N` 64-bit words. The 0th bit of each value
333/// is stored in word 0, the 1st bit is stored in word 1, etc.
334///
335/// # Partially Filled Blocks
336///
337/// This strategy always requests data in blocks. For partially filled blocks, the lower
338/// bits in the last block will be used.
339#[derive(Debug, Clone, Copy)]
340pub struct BitTranspose;
341
342/// Safety: We ask for bytes in multiples of 32. Furthermore, the accesses to the packed
343/// data in `pack` and `unpack` use checked accesses, so out-of-bounds reads will panic.
344unsafe impl PermutationStrategy<4> for BitTranspose {
345    fn bytes(count: usize) -> usize {
346        32 * utils::div_round_up(count, 64)
347    }
348
349    unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
350        // Compute the byte-address of the block containing `i`.
351        let block_start = 32 * (i / 64);
352        // Compute the offset within the block to find the first byte containing `i`.
353        let byte_start = block_start + (i % 64) / 8;
354        // Finally, compute the bit within the byte that we are interested in.
355        let bit = i % 8;
356
357        let mask: u8 = 0x1 << bit;
358        for p in 0..4 {
359            let mut v = data[byte_start + 8 * p];
360            v = (v & !mask) | (((encoded >> p) & 0x1) << bit);
361            data[byte_start + 8 * p] = v;
362        }
363    }
364
365    unsafe fn unpack(data: &[u8], i: usize) -> u8 {
366        // Compute the byte-address of the block containing `i`.
367        let block_start = 32 * (i / 64);
368        // Compute the offset within the block to find the first byte containing `i`.
369        let byte_start = block_start + (i % 64) / 8;
370        // Finally, compute the bit within the byte that we are interested in.
371        let bit = i % 8;
372
373        let mut output: u8 = 0;
374        for p in 0..4 {
375            let v = data[byte_start + 8 * p];
376            output |= ((v >> bit) & 0x1) << p
377        }
378        output
379    }
380}
381
382////////////
383// Errors //
384////////////
385
386#[derive(Debug, Error, Clone, Copy)]
387#[error("input span has length {got} bytes but expected {expected}")]
388pub struct ConstructionError {
389    got: usize,
390    expected: usize,
391}
392
393#[derive(Debug, Error, Clone, Copy)]
394#[error("index {index} exceeds the maximum length of {len}")]
395pub struct IndexOutOfBounds {
396    index: usize,
397    len: usize,
398}
399
400impl IndexOutOfBounds {
401    fn new(index: usize, len: usize) -> Self {
402        Self { index, len }
403    }
404}
405
406#[derive(Debug, Error, Clone, Copy)]
407#[error("error setting index in bitslice")]
408#[non_exhaustive]
409pub enum SetError {
410    IndexError(#[from] IndexOutOfBounds),
411    EncodingError(#[from] EncodingError),
412}
413
414#[derive(Debug, Error, Clone, Copy)]
415#[error("error getting index in bitslice")]
416pub enum GetError {
417    IndexError(#[from] IndexOutOfBounds),
418}
419
420//////////////
421// BitSlice //
422//////////////
423
424/// A generalized representation for packed small bit integer encodings over a contiguous
425/// span of memory.
426///
427/// Think of this as a Rust slice, but supporting integer elements with fewer than 8-bits.
428/// The borrowed representations [`BitSlice`] and [`MutBitSlice`] consist of just a pointer
429/// and a length and are therefore just 16-bytes in size and amenable to the niche
430/// optimization.
431///
432/// # Parameters
433///
434/// * `NBITS`: The number of bits occupied by each entry in the vector.
435///
436/// * `Repr`: The storage representation for each collection of 8-bits. This representation
437///   defines the domain of the encoding (i.e., range of realized values) as well as how
438///   this domain is mapped into `NBITS` bits.
439///
440/// * `Ptr`: The storage type for the contiguous memory. Possible representations are:
441///   - `diskann_quantization::bits::SlicePtr<'_, u8>`: For immutable views.
442///   - `diskann_quantization::bits::MutSlicePtr<'_, u8>`: For mutable views.
443///   - `Box<[u8]>`: For standalone vectors.
444///
445/// * `Perm`: By default, this type uses a dense storage strategy where the least significant
446///   bit of the value at index `i` occurs directly after the most significant bit of
447///   the value at index `i-1`.
448///
449///   Different permutations can be used to enable faster distance computations between
450///   compressed vectors and full-precision vectors by enabling faster SIMD unpacking.
451///
452/// * `Len`: The representation for the length of the vector. This may only be one of the
453///   two families of types:
454///   - `diskann_quantization::bits::Dynamic`: For instances with a run-time length.
455///   - `diskann_quantization::bits::Static<N>`: For instances with a compile-time known length
456///     of `N`.
457///
458/// # Examples
459///
460/// ## Canonical Bit Slice
461///
462/// The canonical `BitSlice` stores unsigned integers of `NBITS` densely in memory.
463/// That is, for a type `BitSliceBase<3, Unsigned, _>`, the layout is as follows:
464/// ```text
465/// |<--LSB-- byte 0 --MSB--->|<--LSB-- byte 1 --MSB--->|
466/// | a0 a1 a2 b0 b1 b2 c0 c1 | c2 d0 d1 d2 e0 e1 e2 f0 |
467/// |<-- A -->|<-- B ->|<--- C -->|<-- D ->|<-- E ->|<- F
468/// ```
469/// An example is shown below:
470///
471/// ```rust
472/// use diskann_quantization::bits::{BoxedBitSlice, Unsigned};
473/// // Create a new boxed bit-slice with capacity for 10 dimensions.
474/// let mut x = BoxedBitSlice::<3, Unsigned>::new_boxed(10);
475/// assert_eq!(x.len(), 10);
476/// // The number of bytes in the canonical representation is computed by
477/// // ceil((len * NBITS) / 8);
478/// assert_eq!(x.bytes(), 4);
479///
480/// // Assign values.
481/// x.set(0, 1).unwrap(); // assign the value 1 to index 0
482/// x.set(1, 5).unwrap(); // assign the value 5 to index 1
483/// assert_eq!(x.get(0).unwrap(), 1); // retrieve the value at index 0
484/// assert_eq!(x.get(1).unwrap(), 5); // retrieve the value at index 1
485///
486/// // Assigning out-of-bounds will result in an error.
487/// let err = x.set(1, 10).unwrap_err();
488/// assert!(matches!(diskann_quantization::bits::SetError::EncodingError, err));
489/// // The old value is left untouched.
490/// assert_eq!(x.get(1).unwrap(), 5);
491///
492/// // `BoxedBitSlice` allows itself to be consumed, returning the underlying storage.
493/// let y = x.into_inner();
494/// assert_eq!(y.len(), BoxedBitSlice::<3, Unsigned>::bytes_for(10));
495/// ```
496///
497/// The above example demonstrates a boxed bit slice - a type that owns its underlying
498/// memory. However, this is not always ergonomic when interfacing with data stores.
499/// For this, the viewing interface can be used.
500/// ```rust
501/// use diskann_quantization::bits::{MutBitSlice, Unsigned};
502///
503/// let mut x: Vec<u8> = vec![0; 4];
504/// let mut slice = MutBitSlice::<3, Unsigned>::new(x.as_mut_slice(), 10).unwrap();
505/// assert_eq!(slice.len(), 10);
506/// assert_eq!(slice.bytes(), 4);
507///
508/// // The slice reference behaves just like boxed slice.
509/// slice.set(0, 5).unwrap();
510/// assert_eq!(slice.get(0).unwrap(), 5);
511///
512/// // Note - if the number of bytes required for the provided dimensions does not match
513/// // the length of the provided span, than slice construction will return an error.
514/// let err = MutBitSlice::<3, Unsigned>::new(x.as_mut_slice(), 11).unwrap_err();
515/// assert_eq!(err.to_string(), "input span has length 4 bytes but expected 5");
516/// ```
517#[derive(Debug, Clone, Copy)]
518pub struct BitSliceBase<const NBITS: usize, Repr, Ptr, Perm = Dense, Len = Dynamic>
519where
520    Repr: Representation<NBITS>,
521    Ptr: AsPtr<Type = u8>,
522    Perm: PermutationStrategy<NBITS>,
523    Len: Length,
524{
525    ptr: Ptr,
526    len: Len,
527    repr: PhantomData<Repr>,
528    packing: PhantomData<Perm>,
529}
530
531impl<const NBITS: usize, Repr, Ptr, Perm, Len> BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
532where
533    Repr: Representation<NBITS>,
534    Ptr: AsPtr<Type = u8>,
535    Perm: PermutationStrategy<NBITS>,
536    Len: Length,
537{
538    /// Check that NBITS is in the interval [1, 8].
539    const _CHECK: () = assert!(NBITS > 0 && NBITS <= 8);
540
541    /// Return the exact number of bytes required to store `count` values.
542    pub fn bytes_for(count: usize) -> usize {
543        Perm::bytes(count)
544    }
545
546    /// Return a new `BitSlice` over the data behind `ptr`.
547    ///
548    /// # Safety
549    ///
550    /// It's the callers responsibility to ensure that all the invariants required for
551    /// `std::slice::from_raw_parts(ptr.as_ptr(), len.value)` hold.
552    unsafe fn new_unchecked_internal(ptr: Ptr, len: Len) -> Self {
553        Self {
554            ptr,
555            len,
556            repr: PhantomData,
557            packing: PhantomData,
558        }
559    }
560
561    /// Construct a new `BitSlice` without checking preconditions.
562    ///
563    /// # Safety
564    ///
565    /// Requires the following to avoid undefined behavior:
566    ///
567    /// * `precursor.precursor_len() == Self::bytes_for(<Count as Into<Len>>::into(count).value())`.
568    ///
569    /// This is checked in debug builds.
570    pub unsafe fn new_unchecked<Pre, Count>(precursor: Pre, count: Count) -> Self
571    where
572        Count: Into<Len>,
573        Pre: Precursor<Ptr>,
574    {
575        let count: Len = count.into();
576        debug_assert_eq!(precursor.precursor_len(), Self::bytes_for(count.value()));
577        Self::new_unchecked_internal(precursor.precursor_into(), count)
578    }
579
580    /// Construct a new `BitSlice` from the `precursor` capable of holding `count` encoded
581    /// elements of size `NBITS.
582    ///
583    /// # Requirements
584    ///
585    /// The number of bytes pointed to by the precursor must be equal to the number of bytes
586    /// required by the layout. That is:
587    ///
588    /// * `precursor.precursor_len() == Self::bytes_for(<Count as Into<Len>>::into(count).value())`.
589    pub fn new<Pre, Count>(precursor: Pre, count: Count) -> Result<Self, ConstructionError>
590    where
591        Count: Into<Len>,
592        Pre: Precursor<Ptr>,
593    {
594        // Allow callers to pass in `usize` as the count when using dynamic
595        let count: Len = count.into();
596
597        // Make sure that the slice has the correct length.
598        if precursor.precursor_len() != Self::bytes_for(count.value()) {
599            Err(ConstructionError {
600                got: precursor.precursor_len(),
601                expected: Self::bytes_for(count.value()),
602            })
603        } else {
604            // SAFETY: We have checked that `precursor` has the correct number of bytes.
605            // The only implementations of `Precursor` are those we defined for slices, so we
606            // don't have to worry about downstream users inserting their own, incorrectly
607            // implemented implementation.
608            Ok(unsafe { Self::new_unchecked(precursor, count) })
609        }
610    }
611
612    /// Return the number of elements contained in the slice.
613    pub fn len(&self) -> usize {
614        self.len.value()
615    }
616
617    /// Return whether or not the slice is empty.
618    pub fn is_empty(&self) -> bool {
619        self.len() == 0
620    }
621
622    /// Return the number of bytes occupied by this slice.
623    pub fn bytes(&self) -> usize {
624        Self::bytes_for(self.len())
625    }
626
627    /// Return the value at logical index `i`.
628    pub fn get(&self, i: usize) -> Result<i64, GetError> {
629        if i >= self.len() {
630            Err(IndexOutOfBounds::new(i, self.len()).into())
631        } else {
632            // SAFETY: We've performed the bounds check.
633            Ok(unsafe { self.get_unchecked(i) })
634        }
635    }
636
637    /// Return the value at logical index `i`.
638    ///
639    /// # Safety
640    ///
641    /// Argument `i` must be in bounds: `0 <= i < self.len()`.
642    pub unsafe fn get_unchecked(&self, i: usize) -> i64 {
643        debug_assert!(i < self.len());
644        debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
645
646        // SAFETY: We maintain the invariant that
647        // `self.as_slice().len() == Perm::bytes(self.len())`.
648        //
649        // So, `i < self.len()` implies we uphold the safety requirements of `unpack`.
650        Repr::decode(unsafe { Perm::unpack(self.as_slice(), i) })
651    }
652
653    /// Encode and assign `value` to logical index `i`.
654    pub fn set(&mut self, i: usize, value: i64) -> Result<(), SetError>
655    where
656        Ptr: AsMutPtr<Type = u8>,
657    {
658        if i >= self.len() {
659            return Err(IndexOutOfBounds::new(i, self.len()).into());
660        }
661
662        let encoded = Repr::encode(value)?;
663
664        // SAFETY: We've performed the bounds check.
665        unsafe { self.set_unchecked(i, encoded) }
666        Ok(())
667    }
668
669    /// Assign `value` to logical index `i`.
670    ///
671    /// # Safety
672    ///
673    /// Argument `i` must be in bounds: `0 <= i < self.len()`.
674    pub unsafe fn set_unchecked(&mut self, i: usize, encoded: u8)
675    where
676        Ptr: AsMutPtr<Type = u8>,
677    {
678        debug_assert!(i < self.len());
679        debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
680
681        // SAFETY: We maintain the invariant that
682        // `self.as_slice().len() == Perm::bytes(self.len())`.
683        //
684        // So, `i < self.len()` implies we uphold the safety requirements of `unpack`.
685        unsafe { Perm::pack(self.as_mut_slice(), i, encoded) }
686    }
687
688    /// Return the domain of acceptable values.
689    pub fn domain(&self) -> Repr::Domain {
690        Repr::domain()
691    }
692
693    pub(crate) fn as_slice(&self) -> &'_ [u8] {
694        // SAFETY: This class has the invariant that the backing storage must be initialized
695        // and exist in a single allocation containing at least
696        // `[self.ptr.as_ptr(), self.ptr_ptr() + self.bytes())`.
697        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.bytes()) }
698    }
699
700    /// Return a pointer to the beginning of the memory associated with this slice.
701    ///
702    /// # NOTE
703    ///
704    /// The memory span underlying this instances is valid for `self.bytes()`, not
705    /// necessarily `self.len()`.
706    pub fn as_ptr(&self) -> *const u8 {
707        self.ptr.as_ptr()
708    }
709
710    /// This function is very easy to use incorrectly and hence is crate-local.
711    pub(super) fn as_mut_slice(&mut self) -> &'_ mut [u8]
712    where
713        Ptr: AsMutPtr,
714    {
715        // SAFETY: This class has the invariant that the backing storage must be initialized
716        // and exist in a single allocation containing at least
717        // `[self.ptr.as_ptr(), self.ptr_ptr() + self.bytes())`.
718        //
719        // A mutable reference to self with `Ptr: AsMutPtr` attests to the fact that we
720        // have an exclusive borrow over the underlying memory.
721        unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut_ptr(), self.bytes()) }
722    }
723
724    /// This function is very easy to use incorrectly and hence is private.
725    fn as_mut_ptr(&mut self) -> *mut u8
726    where
727        Ptr: AsMutPtr,
728    {
729        self.ptr.as_mut_ptr()
730    }
731}
732
733impl<const NBITS: usize, Repr, Perm, Len>
734    BitSliceBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Perm, Len>
735where
736    Repr: Representation<NBITS>,
737    Perm: PermutationStrategy<NBITS>,
738    Len: Length,
739{
740    /// Construct a new owning `BitSlice` capable of holding `Count` logical values.
741    /// The slice is initialized in a valid but undefined state.
742    ///
743    /// # Example
744    ///
745    /// ```
746    /// use diskann_quantization::bits::{BoxedBitSlice, Unsigned};
747    /// let mut x = BoxedBitSlice::<3, Unsigned>::new_boxed(4);
748    /// x.set(0, 0).unwrap();
749    /// x.set(1, 2).unwrap();
750    /// x.set(2, 4).unwrap();
751    /// x.set(3, 6).unwrap();
752    ///
753    /// assert_eq!(x.get(0).unwrap(), 0);
754    /// assert_eq!(x.get(1).unwrap(), 2);
755    /// assert_eq!(x.get(2).unwrap(), 4);
756    /// assert_eq!(x.get(3).unwrap(), 6);
757    /// ```
758    pub fn new_boxed<Count>(count: Count) -> Self
759    where
760        Count: Into<Len>,
761    {
762        let count: Len = count.into();
763        let bytes = Self::bytes_for(count.value());
764        let storage: Box<[u8]> = (0..bytes).map(|_| 0).collect();
765
766        // SAFETY: We've ensured that the backing storage has the correct number of bytes
767        // as required by the count and PermutationStrategy.
768        //
769        // Since this is owned storage, we do not need to worry about capturing lifetimes.
770        unsafe { Self::new_unchecked(Poly::from(storage), count) }
771    }
772}
773
774impl<const NBITS: usize, Repr, Perm, Len, A> BitSliceBase<NBITS, Repr, Poly<[u8], A>, Perm, Len>
775where
776    Repr: Representation<NBITS>,
777    Perm: PermutationStrategy<NBITS>,
778    Len: Length,
779    A: AllocatorCore,
780{
781    /// Construct a new owning `BitSlice` capable of holding `Count` logical values using
782    /// the provided allocator.
783    ///
784    /// The slice is initialized in a valid but undefined state.
785    ///
786    /// # Example
787    ///
788    /// ```
789    /// use diskann_quantization::{
790    ///     alloc::GlobalAllocator,
791    ///     bits::{BoxedBitSlice, Unsigned}
792    /// };
793    /// let mut x = BoxedBitSlice::<3, Unsigned>::new_in(4, GlobalAllocator).unwrap();
794    /// x.set(0, 0).unwrap();
795    /// x.set(1, 2).unwrap();
796    /// x.set(2, 4).unwrap();
797    /// x.set(3, 6).unwrap();
798    ///
799    /// assert_eq!(x.get(0).unwrap(), 0);
800    /// assert_eq!(x.get(1).unwrap(), 2);
801    /// assert_eq!(x.get(2).unwrap(), 4);
802    /// assert_eq!(x.get(3).unwrap(), 6);
803    /// ```
804    pub fn new_in<Count>(count: Count, allocator: A) -> Result<Self, AllocatorError>
805    where
806        Count: Into<Len>,
807    {
808        let count: Len = count.into();
809        let bytes = Self::bytes_for(count.value());
810        let storage = Poly::broadcast(0, bytes, allocator)?;
811
812        // SAFETY: We've ensured that the backing storage has the correct number of bytes
813        // as required by the count and PermutationStrategy.
814        //
815        // Since this is owned storage, we do not need to worry about capturing lifetimes.
816        Ok(unsafe { Self::new_unchecked(storage, count) })
817    }
818
819    /// Consume `self` and return the boxed allocation.
820    pub fn into_inner(self) -> Poly<[u8], A> {
821        self.ptr
822    }
823}
824
825/// The layout for `N`-bit integers that references a raw underlying slice.
826pub type BitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
827    BitSliceBase<N, Repr, SlicePtr<'a, u8>, Perm, Len>;
828
829/// The layout for `N`-bit integers that mutable references a raw underlying slice.
830pub type MutBitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
831    BitSliceBase<N, Repr, MutSlicePtr<'a, u8>, Perm, Len>;
832
833/// The layout for `N`-bit integers that own the underlying slice.
834pub type PolyBitSlice<const N: usize, Repr, A, Perm = Dense, Len = Dynamic> =
835    BitSliceBase<N, Repr, Poly<[u8], A>, Perm, Len>;
836
837/// The layout for `N`-bit integers that own the underlying slice.
838pub type BoxedBitSlice<const N: usize, Repr, Perm = Dense, Len = Dynamic> =
839    PolyBitSlice<N, Repr, GlobalAllocator, Perm, Len>;
840
841///////////////////////////////
842// Special Cased Conversions //
843///////////////////////////////
844
845impl<'a, Ptr> From<&'a BitSliceBase<8, Unsigned, Ptr>> for &'a [u8]
846where
847    Ptr: AsPtr<Type = u8>,
848{
849    fn from(slice: &'a BitSliceBase<8, Unsigned, Ptr>) -> Self {
850        // SAFETY: The original pointer must have been obtained from a slice of the
851        // appropriate length.
852        //
853        // Furthermore, the layout of this type of slice is guaranteed to be identical
854        // to the layout of a `[u8]`.
855        unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) }
856    }
857}
858
859impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> Reborrow<'this>
860    for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
861where
862    Repr: Representation<NBITS>,
863    Ptr: AsPtr<Type = u8>,
864    Perm: PermutationStrategy<NBITS>,
865    Len: Length,
866{
867    type Target = BitSlice<'this, NBITS, Repr, Perm, Len>;
868
869    fn reborrow(&'this self) -> Self::Target {
870        let ptr: *const u8 = self.as_ptr();
871        debug_assert!(!ptr.is_null());
872
873        // Safety: `AsPtr` may never return null pointers.
874        // The `cast_mut()` is safe because `SlicePtr` does not provide a way of retrieving
875        // a mutable pointer.
876        let nonnull = unsafe { NonNull::new_unchecked(ptr.cast_mut()) };
877
878        // Safety: By struct invariant,
879        // `[self.ptr(), self.ptr() + Self::bytes_for(self.len()))` is a valid slice, so
880        // the returned object will also uphold these invariants.
881        //
882        // The returned struct will not outlive `&'this self`, so we've attached the
883        // proper lifetime.
884        let ptr = unsafe { SlicePtr::new_unchecked(nonnull) };
885
886        Self::Target {
887            ptr,
888            len: self.len,
889            repr: PhantomData,
890            packing: PhantomData,
891        }
892    }
893}
894
895impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> ReborrowMut<'this>
896    for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
897where
898    Repr: Representation<NBITS>,
899    Ptr: AsMutPtr<Type = u8>,
900    Perm: PermutationStrategy<NBITS>,
901    Len: Length,
902{
903    type Target = MutBitSlice<'this, NBITS, Repr, Perm, Len>;
904
905    fn reborrow_mut(&'this mut self) -> Self::Target {
906        let ptr: *mut u8 = self.as_mut_ptr();
907        debug_assert!(!ptr.is_null());
908
909        // Safety: `AsMutPtr` may never return null pointers.
910        let nonnull = unsafe { NonNull::new_unchecked(ptr) };
911
912        // Safety: By struct invariant,
913        // `[self.ptr(), self.ptr() + Self::bytes_for(self.len()))` is a valid slice, so
914        // the returned object will also uphold these invariants.
915        //
916        // The returned struct will not outlive `&'this mut self`, so we've attached the
917        // proper lifetime.
918        //
919        // Exclusive ownership is attested by both `AsMutPtr` and the mutable refernce
920        // to self.
921        let ptr = unsafe { MutSlicePtr::new_unchecked(nonnull) };
922
923        Self::Target {
924            ptr,
925            len: self.len,
926            repr: PhantomData,
927            packing: PhantomData,
928        }
929    }
930}
931
932///////////
933// Tests //
934///////////
935
936#[cfg(test)]
937mod tests {
938    use rand::{
939        distr::{Distribution, Uniform},
940        rngs::StdRng,
941        seq::{IndexedRandom, SliceRandom},
942        Rng, SeedableRng,
943    };
944
945    use super::*;
946    use crate::{bits::Static, test_util::AlwaysFails};
947
948    ////////////
949    // Errors //
950    ////////////
951
952    const BOUNDS: &str = "special bounds";
953
954    #[test]
955    fn test_encoding_error() {
956        assert_eq!(std::mem::size_of::<EncodingError>(), 16);
957        assert_eq!(
958            std::mem::size_of::<Option<EncodingError>>(),
959            16,
960            "expected EncodingError to have the niche optimization"
961        );
962        let err = EncodingError::new(7, &BOUNDS);
963        assert_eq!(
964            err.to_string(),
965            "value 7 is not in the encodable range of special bounds"
966        );
967    }
968
969    // Check that a type is `Send` and `Sync`.
970    fn assert_send_and_sync<T: Send + Sync>(_x: &T) {}
971
972    ////////////
973    // Binary //
974    ////////////
975
976    #[test]
977    fn test_binary_repr() {
978        assert_eq!(Binary::encode(-1).unwrap(), 0);
979        assert_eq!(Binary::encode(1).unwrap(), 1);
980        assert_eq!(Binary::decode(0), -1);
981        assert_eq!(Binary::decode(1), 1);
982
983        assert!(Binary::check(-1));
984        assert!(Binary::check(1));
985        assert!(!Binary::check(0));
986        assert!(!Binary::check(-2));
987        assert!(!Binary::check(2));
988
989        let domain: Vec<_> = Binary::domain().collect();
990        assert_eq!(domain, &[-1, 1]);
991    }
992
993    ///////////
994    // Sizes //
995    ///////////
996
997    #[test]
998    fn test_sizes() {
999        assert_eq!(std::mem::size_of::<BitSlice<'static, 8, Unsigned>>(), 16);
1000        assert_eq!(std::mem::size_of::<MutBitSlice<'static, 8, Unsigned>>(), 16);
1001
1002        // Ensure the borrowed slices are eligible for niche optimization.
1003        assert_eq!(
1004            std::mem::size_of::<Option<BitSlice<'static, 8, Unsigned>>>(),
1005            16
1006        );
1007        assert_eq!(
1008            std::mem::size_of::<Option<MutBitSlice<'static, 8, Unsigned>>>(),
1009            16
1010        );
1011
1012        assert_eq!(
1013            std::mem::size_of::<BitSlice<'static, 8, Unsigned, Dense, Static<128>>>(),
1014            8
1015        );
1016    }
1017
1018    ///////////////////
1019    // General Tests //
1020    ///////////////////
1021
1022    cfg_if::cfg_if! {
1023        if #[cfg(miri)] {
1024            const MAX_DIM: usize = 160;
1025            const FUZZ_ITERATIONS: usize = 1;
1026        } else if #[cfg(debug_assertions)] {
1027            const MAX_DIM: usize = 128;
1028            const FUZZ_ITERATIONS: usize = 10;
1029        } else {
1030            const MAX_DIM: usize = 256;
1031            const FUZZ_ITERATIONS: usize = 100;
1032        }
1033    }
1034
1035    fn test_send_and_sync<const NBITS: usize, Repr, Perm>()
1036    where
1037        Repr: Representation<NBITS> + Send + Sync,
1038        Perm: PermutationStrategy<NBITS> + Send + Sync,
1039    {
1040        let mut x = BoxedBitSlice::<NBITS, Repr, Perm>::new_boxed(1);
1041        assert_send_and_sync(&x);
1042        assert_send_and_sync(&x.reborrow());
1043        assert_send_and_sync(&x.reborrow_mut());
1044    }
1045
1046    fn test_empty<const NBITS: usize, Repr, Perm>()
1047    where
1048        Repr: Representation<NBITS>,
1049        Perm: PermutationStrategy<NBITS>,
1050    {
1051        let base: &mut [u8] = &mut [];
1052        let mut slice = MutBitSlice::<NBITS, Repr, Perm>::new(base, 0).unwrap();
1053        assert_eq!(slice.len(), 0);
1054        assert!(slice.is_empty());
1055
1056        {
1057            let reborrow = slice.reborrow();
1058            assert_eq!(reborrow.len(), 0);
1059            assert!(reborrow.is_empty());
1060        }
1061
1062        {
1063            let reborrow = slice.reborrow_mut();
1064            assert_eq!(reborrow.len(), 0);
1065            assert!(reborrow.is_empty());
1066        }
1067    }
1068
1069    // times, ensuring that values are preserved.
1070    fn test_construction_errors<const NBITS: usize, Repr, Perm>()
1071    where
1072        Repr: Representation<NBITS>,
1073        Perm: PermutationStrategy<NBITS>,
1074    {
1075        let len: usize = 10;
1076        let bytes = Perm::bytes(len);
1077
1078        // Construction errors for Boxes
1079        let box_big = Poly::broadcast(0u8, bytes + 1, GlobalAllocator).unwrap();
1080        let box_small = Poly::broadcast(0u8, bytes - 1, GlobalAllocator).unwrap();
1081        let box_right = Poly::broadcast(0u8, bytes, GlobalAllocator).unwrap();
1082
1083        let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_big, len);
1084        match result {
1085            Err(ConstructionError { got, expected }) => {
1086                assert_eq!(got, bytes + 1);
1087                assert_eq!(expected, bytes);
1088            }
1089            _ => panic!("shouldn't have reached here!"),
1090        };
1091
1092        let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_small, len);
1093        match result {
1094            Err(ConstructionError { got, expected }) => {
1095                assert_eq!(got, bytes - 1);
1096                assert_eq!(expected, bytes);
1097            }
1098            _ => panic!("shouldn't have reached here!"),
1099        };
1100
1101        let mut base = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_right, len).unwrap();
1102        let ptr = base.as_ptr();
1103        assert_eq!(base.len(), len);
1104
1105        // Successful mutable reborrow and borrow.
1106        {
1107            // Use reborrow
1108            let borrowed = base.reborrow_mut();
1109            assert_eq!(borrowed.as_ptr(), ptr);
1110            assert_eq!(borrowed.len(), len);
1111
1112            // Go through a slice.
1113            let borrowed = MutBitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
1114            assert_eq!(borrowed.as_ptr(), ptr);
1115            assert_eq!(borrowed.len(), len);
1116        }
1117
1118        // Successful mutable borrow.
1119        {
1120            // Try constructing from an oversized slice.
1121            let mut oversized = vec![0; bytes + 1];
1122            let result = MutBitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
1123            match result {
1124                Err(ConstructionError { got, expected }) => {
1125                    assert_eq!(got, bytes + 1);
1126                    assert_eq!(expected, bytes);
1127                }
1128                _ => panic!("shouldn't have reached here!"),
1129            };
1130
1131            let mut undersized = vec![0; bytes - 1];
1132            let result = MutBitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
1133            match result {
1134                Err(ConstructionError { got, expected }) => {
1135                    assert_eq!(got, bytes - 1);
1136                    assert_eq!(expected, bytes);
1137                }
1138                _ => panic!("shouldn't have reached here!"),
1139            };
1140        }
1141
1142        // Successful const borrow and reborrow.
1143        {
1144            // Use reborrow
1145            let borrowed = base.reborrow();
1146            assert_eq!(borrowed.as_ptr(), ptr);
1147            assert_eq!(borrowed.len(), len);
1148
1149            // Go through a slice.
1150            let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_slice(), len).unwrap();
1151            assert_eq!(borrowed.as_ptr(), ptr);
1152            assert_eq!(borrowed.len(), len);
1153
1154            // Go through a mutable slice.
1155            let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
1156            assert_eq!(borrowed.as_ptr(), ptr);
1157            assert_eq!(borrowed.len(), len);
1158        }
1159
1160        // Successful mutable borrow.
1161        {
1162            // Try constructing from an oversized slice.
1163            let mut oversized = vec![0; bytes + 1];
1164            let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
1165            match result {
1166                Err(ConstructionError { got, expected }) => {
1167                    assert_eq!(got, bytes + 1);
1168                    assert_eq!(expected, bytes);
1169                }
1170                _ => panic!("shouldn't have reached here!"),
1171            };
1172
1173            let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_slice(), len);
1174            match result {
1175                Err(ConstructionError { got, expected }) => {
1176                    assert_eq!(got, bytes + 1);
1177                    assert_eq!(expected, bytes);
1178                }
1179                _ => panic!("shouldn't have reached here!"),
1180            };
1181
1182            // Try constructing from an undersized slice.
1183            let mut undersized = vec![0; bytes - 1];
1184            let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
1185            match result {
1186                Err(ConstructionError { got, expected }) => {
1187                    assert_eq!(got, bytes - 1);
1188                    assert_eq!(expected, bytes);
1189                }
1190                _ => panic!("shouldn't have reached here!"),
1191            };
1192
1193            let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_slice(), len);
1194            match result {
1195                Err(ConstructionError { got, expected }) => {
1196                    assert_eq!(got, bytes - 1);
1197                    assert_eq!(expected, bytes);
1198                }
1199                _ => panic!("shouldn't have reached here!"),
1200            };
1201        }
1202    }
1203
1204    // This series of tests writes to all indices in the vector in random orders multiple
1205    // times, ensuring that values are preserved.
1206    fn run_overwrite_test<const NBITS: usize, Perm, Len, R>(
1207        base: &mut BoxedBitSlice<NBITS, Unsigned, Perm, Len>,
1208        num_iterations: usize,
1209        rng: &mut R,
1210    ) where
1211        Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
1212        Len: Length,
1213        Perm: PermutationStrategy<NBITS>,
1214        R: Rng,
1215    {
1216        let mut expected: Vec<i64> = vec![0; base.len()];
1217        let mut indices: Vec<usize> = (0..base.len()).collect();
1218        for i in 0..base.len() {
1219            base.set(i, 0).unwrap();
1220        }
1221
1222        for i in 0..base.len() {
1223            assert_eq!(base.get(i).unwrap(), 0, "failed to initialize bit vector");
1224        }
1225
1226        let domain = base.domain();
1227        assert_eq!(domain, 0..=2i64.pow(NBITS as u32) - 1);
1228        let distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
1229
1230        for iter in 0..num_iterations {
1231            // Shuffle insertion order.
1232            indices.shuffle(rng);
1233
1234            // Insert random values.
1235            for &i in indices.iter() {
1236                let value = distribution.sample(rng);
1237                expected[i] = value;
1238                base.set(i, value).unwrap();
1239            }
1240
1241            // Make sure values are preserved.
1242            for (i, &expect) in expected.iter().enumerate() {
1243                let value = base.get(i).unwrap();
1244                assert_eq!(
1245                    value, expect,
1246                    "retrieval failed on iteration {iter} at index {i}"
1247                );
1248            }
1249
1250            // Make sure the reborrowed version matches.
1251            let borrowed = base.reborrow();
1252            for (i, &expect) in expected.iter().enumerate() {
1253                let value = borrowed.get(i).unwrap();
1254                assert_eq!(
1255                    value, expect,
1256                    "reborrow retrieval failed on iteration {iter} at index {i}"
1257                );
1258            }
1259        }
1260    }
1261
1262    fn run_overwrite_binary_test<Perm, Len, R>(
1263        base: &mut BoxedBitSlice<1, Binary, Perm, Len>,
1264        num_iterations: usize,
1265        rng: &mut R,
1266    ) where
1267        Len: Length,
1268        Perm: PermutationStrategy<1>,
1269        R: Rng,
1270    {
1271        let mut expected: Vec<i64> = vec![0; base.len()];
1272        let mut indices: Vec<usize> = (0..base.len()).collect();
1273        for i in 0..base.len() {
1274            base.set(i, -1).unwrap();
1275        }
1276
1277        for i in 0..base.len() {
1278            assert_eq!(base.get(i).unwrap(), -1, "failed to initialize bit vector");
1279        }
1280
1281        let distribution: [i64; 2] = [-1, 1];
1282
1283        for iter in 0..num_iterations {
1284            // Shuffle insertion order.
1285            indices.shuffle(rng);
1286
1287            // Insert random values.
1288            for &i in indices.iter() {
1289                let value = distribution.choose(rng).unwrap();
1290                expected[i] = *value;
1291                base.set(i, *value).unwrap();
1292            }
1293
1294            // Make sure values are preserved.
1295            for (i, &expect) in expected.iter().enumerate() {
1296                let value = base.get(i).unwrap();
1297                assert_eq!(
1298                    value, expect,
1299                    "retrieval failed on iteration {iter} at index {i}"
1300                );
1301            }
1302
1303            // Make sure the reborrowed version matches.
1304            let borrowed = base.reborrow();
1305            for (i, &expect) in expected.iter().enumerate() {
1306                let value = borrowed.get(i).unwrap();
1307                assert_eq!(
1308                    value, expect,
1309                    "reborrow retrieval failed on iteration {iter} at index {i}"
1310                );
1311            }
1312        }
1313    }
1314
1315    //////////////////////
1316    // Unsigned - Dense //
1317    //////////////////////
1318
1319    fn test_unsigned_dense<const NBITS: usize, Len, R>(
1320        len: Len,
1321        minimum: i64,
1322        maximum: i64,
1323        rng: &mut R,
1324    ) where
1325        Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
1326        Dense: PermutationStrategy<NBITS>,
1327        Len: Length,
1328        R: Rng,
1329    {
1330        test_send_and_sync::<NBITS, Unsigned, Dense>();
1331        test_empty::<NBITS, Unsigned, Dense>();
1332        test_construction_errors::<NBITS, Unsigned, Dense>();
1333        assert_eq!(Unsigned::domain_const::<NBITS>(), Unsigned::domain(),);
1334
1335        match PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, AlwaysFails) {
1336            Ok(_) => {
1337                if len.value() != 0 {
1338                    panic!("zero sized allocations don't require an allocator");
1339                }
1340            }
1341            Err(AllocatorError) => {
1342                if len.value() == 0 {
1343                    panic!("allocation should have failed");
1344                }
1345            }
1346        }
1347
1348        let mut base =
1349            PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, GlobalAllocator).unwrap();
1350        assert_eq!(
1351            base.len(),
1352            len.value(),
1353            "BoxedBitSlice returned the incorrect length"
1354        );
1355
1356        let expected_bytes = BitSlice::<'static, NBITS, Unsigned>::bytes_for(len.value());
1357        assert_eq!(
1358            base.bytes(),
1359            expected_bytes,
1360            "BoxedBitSlice has the incorrect number of bytes"
1361        );
1362
1363        // Check that the minimum and maximum values reported by the struct are correct.
1364        assert_eq!(base.domain(), minimum..=maximum);
1365
1366        if len.value() == 0 {
1367            return;
1368        }
1369
1370        let ptr = base.as_ptr();
1371
1372        // Now that we know the length is non-zero, we can try testing the interface.
1373        // Setting the lowest index should always work.
1374        {
1375            let mut borrowed = base.reborrow_mut();
1376
1377            // Make sure the pointer is preserved.
1378            assert_eq!(
1379                borrowed.as_ptr(),
1380                ptr,
1381                "pointer was not preserved during borrowing!"
1382            );
1383            assert_eq!(
1384                borrowed.len(),
1385                len.value(),
1386                "borrowing did not preserve length!"
1387            );
1388
1389            borrowed.set(0, 0).unwrap();
1390            assert_eq!(borrowed.get(0).unwrap(), 0);
1391
1392            borrowed.set(0, 1).unwrap();
1393            assert_eq!(borrowed.get(0).unwrap(), 1);
1394
1395            borrowed.set(0, 0).unwrap();
1396            assert_eq!(borrowed.get(0).unwrap(), 0);
1397
1398            // Setting to an invalid value should yield an error.
1399            let result = borrowed.set(0, minimum - 1);
1400            assert!(matches!(result, Err(SetError::EncodingError { .. })));
1401
1402            let result = borrowed.set(0, maximum + 1);
1403            assert!(matches!(result, Err(SetError::EncodingError { .. })));
1404
1405            // Make sure an out-of-bounds access is caught.
1406            let result = borrowed.set(borrowed.len(), 0);
1407            assert!(matches!(result, Err(SetError::IndexError { .. })));
1408
1409            // Ensure that getting out-of-bounds is an error.
1410            let result = borrowed.get(borrowed.len());
1411            assert!(matches!(result, Err(GetError::IndexError { .. })));
1412        }
1413
1414        {
1415            // Reconsturct the mutable borrow directly through a slice.
1416            let borrowed =
1417                MutBitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
1418
1419            assert_eq!(
1420                borrowed.as_ptr(),
1421                ptr,
1422                "pointer was not preserved during borrowing!"
1423            );
1424            assert_eq!(
1425                borrowed.len(),
1426                len.value(),
1427                "borrowing did not preserve length!"
1428            );
1429        }
1430
1431        {
1432            let borrowed = base.reborrow();
1433
1434            // Make sure the pointer is preserved.
1435            assert_eq!(
1436                borrowed.as_ptr(),
1437                ptr,
1438                "pointer was not preserved during borrowing!"
1439            );
1440
1441            assert_eq!(
1442                borrowed.len(),
1443                len.value(),
1444                "borrowing did not preserve length!"
1445            );
1446
1447            // Ensure that getting out-of-bounds is an error.
1448            let result = borrowed.get(borrowed.len());
1449            assert!(matches!(result, Err(GetError::IndexError { .. })));
1450        }
1451
1452        {
1453            // Reconsturct the mutable borrow directly through a slice.
1454            let borrowed =
1455                BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_slice(), len).unwrap();
1456
1457            assert_eq!(
1458                borrowed.as_ptr(),
1459                ptr,
1460                "pointer was not preserved during borrowing!"
1461            );
1462            assert_eq!(
1463                borrowed.len(),
1464                len.value(),
1465                "borrowing did not preserve length!"
1466            );
1467        }
1468
1469        {
1470            // Reconsturct the mutable borrow directly through a slice.
1471            let borrowed =
1472                BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
1473
1474            assert_eq!(
1475                borrowed.as_ptr(),
1476                ptr,
1477                "pointer was not preserved during borrowing!"
1478            );
1479            assert_eq!(
1480                borrowed.len(),
1481                len.value(),
1482                "borrowing did not preserve length!"
1483            );
1484        }
1485
1486        // Now we begin the testing loop.
1487        run_overwrite_test(&mut base, FUZZ_ITERATIONS, rng);
1488    }
1489
1490    macro_rules! generate_unsigned_test {
1491        ($name:ident, $NBITS:literal, $MIN:literal, $MAX:literal, $SEED:literal) => {
1492            #[test]
1493            fn $name() {
1494                let mut rng = StdRng::seed_from_u64($SEED);
1495                for dim in 0..MAX_DIM {
1496                    test_unsigned_dense::<$NBITS, Dynamic, _>(dim.into(), $MIN, $MAX, &mut rng);
1497                }
1498            }
1499        };
1500    }
1501
1502    generate_unsigned_test!(test_unsigned_8bit, 8, 0, 0xff, 0xc652f2a1018f442b);
1503    generate_unsigned_test!(test_unsigned_7bit, 7, 0, 0x7f, 0xb732e59fec6d6c9c);
1504    generate_unsigned_test!(test_unsigned_6bit, 6, 0, 0x3f, 0x35d9380d0a318f21);
1505    generate_unsigned_test!(test_unsigned_5bit, 5, 0, 0x1f, 0xfb09895183334304);
1506    generate_unsigned_test!(test_unsigned_4bit, 4, 0, 0x0f, 0x38dfcf9e82c33f48);
1507    generate_unsigned_test!(test_unsigned_3bit, 3, 0, 0x07, 0xf9a94c8c749ee26c);
1508    generate_unsigned_test!(test_unsigned_2bit, 2, 0, 0x03, 0xbba03db62cecf4cf);
1509    generate_unsigned_test!(test_unsigned_1bit, 1, 0, 0x01, 0x54ea2a07d7c67f37);
1510
1511    #[test]
1512    fn test_binary_dense() {
1513        let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
1514        for len in 0..MAX_DIM {
1515            test_send_and_sync::<1, Binary, Dense>();
1516            test_empty::<1, Binary, Dense>();
1517            test_construction_errors::<1, Binary, Dense>();
1518
1519            // Create a boxed base.
1520            let mut base = BoxedBitSlice::<1, Binary>::new_boxed(len);
1521            assert_eq!(
1522                base.len(),
1523                len,
1524                "BoxedBitSlice returned the incorrect length"
1525            );
1526
1527            assert_eq!(base.bytes(), len.div_ceil(8));
1528
1529            let bytes = BitSlice::<'static, 1, Binary>::bytes_for(len);
1530            assert_eq!(
1531                bytes,
1532                len.div_ceil(8),
1533                "BoxedBitSlice has the incorrect number of bytes"
1534            );
1535
1536            if len == 0 {
1537                continue;
1538            }
1539
1540            // Setting to an invalid value should yield an error.
1541            let result = base.set(0, 0);
1542            assert!(matches!(result, Err(SetError::EncodingError { .. })));
1543
1544            // Make sure an out-of-bounds access is caught.
1545            let result = base.set(base.len(), -1);
1546            assert!(matches!(result, Err(SetError::IndexError { .. })));
1547
1548            // Ensure that getting out-of-bounds is an error.
1549            let result = base.get(base.len());
1550            assert!(matches!(result, Err(GetError::IndexError { .. })));
1551
1552            // Now we begin the testing loop.
1553            run_overwrite_binary_test(&mut base, FUZZ_ITERATIONS, &mut rng);
1554        }
1555    }
1556
1557    #[test]
1558    fn test_4bit_bit_transpose() {
1559        let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
1560        for len in 0..MAX_DIM {
1561            test_send_and_sync::<4, Unsigned, BitTranspose>();
1562            test_empty::<4, Unsigned, BitTranspose>();
1563            test_construction_errors::<4, Unsigned, BitTranspose>();
1564
1565            // Create a boxed base.
1566            let mut base = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(len);
1567            assert_eq!(
1568                base.len(),
1569                len,
1570                "BoxedBitSlice returned the incorrect length"
1571            );
1572
1573            assert_eq!(base.bytes(), 32 * len.div_ceil(64));
1574
1575            let bytes = BitSlice::<'static, 4, Unsigned, BitTranspose>::bytes_for(len);
1576            assert_eq!(
1577                bytes,
1578                32 * len.div_ceil(64),
1579                "BoxedBitSlice has the incorrect number of bytes"
1580            );
1581
1582            if len == 0 {
1583                continue;
1584            }
1585
1586            // Setting to an invalid value should yield an error.
1587            let result = base.set(0, -1);
1588            assert!(matches!(result, Err(SetError::EncodingError { .. })));
1589
1590            // Make sure an out-of-bounds access is caught.
1591            let result = base.set(base.len(), -1);
1592            assert!(matches!(result, Err(SetError::IndexError { .. })));
1593
1594            // Ensure that getting out-of-bounds is an error.
1595            let result = base.get(base.len());
1596            assert!(matches!(result, Err(GetError::IndexError { .. })));
1597
1598            // Now we begin the testing loop.
1599            run_overwrite_test(&mut base, FUZZ_ITERATIONS, &mut rng);
1600        }
1601    }
1602}