use std::{marker::PhantomData, ops::RangeInclusive, ptr::NonNull};
use diskann_utils::{Reborrow, ReborrowMut};
use thiserror::Error;
use super::{
length::{Dynamic, Length},
packing,
ptr::{AsMutPtr, AsPtr, MutSlicePtr, Precursor, SlicePtr},
};
use crate::{
alloc::{AllocatorCore, AllocatorError, GlobalAllocator, Poly},
utils,
};
pub trait Representation<const NBITS: usize> {
type Domain: Iterator<Item = i64>;
fn encode(value: i64) -> Result<u8, EncodingError>;
fn encode_unchecked(value: i64) -> u8;
fn decode(raw: u8) -> i64;
fn check(value: i64) -> bool;
fn domain() -> Self::Domain;
}
#[derive(Debug, Error, Clone, Copy)]
#[error("value {} is not in the encodable range of {}", got, domain)]
pub struct EncodingError {
got: i64,
domain: &'static &'static str,
}
impl EncodingError {
fn new(got: i64, domain: &'static &'static str) -> Self {
Self { got, domain }
}
}
#[derive(Debug, Clone, Copy)]
pub struct Unsigned;
impl Unsigned {
pub const fn domain_const<const NBITS: usize>() -> std::ops::RangeInclusive<i64> {
0..=2i64.pow(NBITS as u32) - 1
}
#[allow(clippy::panic)]
const fn domain_str(nbits: usize) -> &'static &'static str {
match nbits {
8 => &"[0, 255]",
7 => &"[0, 127]",
6 => &"[0, 63]",
5 => &"[0, 31]",
4 => &"[0, 15]",
3 => &"[0, 7]",
2 => &"[0, 3]",
1 => &"[0, 1]",
_ => panic!("unimplemented"),
}
}
}
macro_rules! repr_unsigned {
($N:literal) => {
impl Representation<$N> for Unsigned {
type Domain = RangeInclusive<i64>;
fn encode(value: i64) -> Result<u8, EncodingError> {
if !<Self as Representation<$N>>::check(value) {
let domain = Self::domain_str($N);
Err(EncodingError::new(value, domain))
} else {
Ok(<Self as Representation<$N>>::encode_unchecked(value))
}
}
fn encode_unchecked(value: i64) -> u8 {
debug_assert!(<Self as Representation<$N>>::check(value));
value as u8
}
fn decode(raw: u8) -> i64 {
let raw: i64 = raw.into();
debug_assert!(<Self as Representation<$N>>::check(raw));
raw
}
fn check(value: i64) -> bool {
<Self as Representation<$N>>::domain().contains(&value)
}
fn domain() -> Self::Domain {
Self::domain_const::<$N>()
}
}
};
($N:literal, $($Ns:literal),+) => {
repr_unsigned!($N);
$(repr_unsigned!($Ns);)+
};
}
repr_unsigned!(1, 2, 3, 4, 5, 6, 7, 8);
#[derive(Debug, Clone, Copy)]
pub struct Binary;
impl Representation<1> for Binary {
type Domain = std::array::IntoIter<i64, 2>;
fn encode(value: i64) -> Result<u8, EncodingError> {
if !Self::check(value) {
const DOMAIN: &str = "{-1, 1}";
Err(EncodingError::new(value, &DOMAIN))
} else {
Ok(Self::encode_unchecked(value))
}
}
fn encode_unchecked(value: i64) -> u8 {
debug_assert!(Self::check(value));
value.clamp(0, 1) as u8
}
fn decode(raw: u8) -> i64 {
let raw: i64 = raw.into();
(raw << 1) - 1
}
fn check(value: i64) -> bool {
value == -1 || value == 1
}
fn domain() -> Self::Domain {
[-1, 1].into_iter()
}
}
pub unsafe trait PermutationStrategy<const NBITS: usize> {
fn bytes(count: usize) -> usize;
unsafe fn pack(s: &mut [u8], i: usize, value: u8);
unsafe fn unpack(s: &[u8], i: usize) -> u8;
}
#[derive(Debug, Clone, Copy)]
pub struct Dense;
impl Dense {
fn bytes<const NBITS: usize>(count: usize) -> usize {
utils::div_round_up(NBITS * count, 8)
}
}
unsafe impl<const NBITS: usize> PermutationStrategy<NBITS> for Dense {
fn bytes(count: usize) -> usize {
Self::bytes::<NBITS>(count)
}
unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
let bitaddress = NBITS * i;
let bytestart = bitaddress / 8;
let bytestop = (bitaddress + NBITS - 1) / 8;
let bitstart = bitaddress - 8 * bytestart;
debug_assert!(bytestop < data.len());
if bytestart == bytestop {
let raw = unsafe { data.as_ptr().add(bytestart).read() };
let packed = packing::pack_u8::<NBITS>(raw, encoded, bitstart);
unsafe { data.as_mut_ptr().add(bytestart).write(packed) };
} else {
let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
let packed = packing::pack_u16::<NBITS>(raw, encoded, bitstart);
unsafe {
data.as_mut_ptr()
.add(bytestart)
.cast::<u16>()
.write_unaligned(packed)
};
}
}
unsafe fn unpack(data: &[u8], i: usize) -> u8 {
let bitaddress = NBITS * i;
let bytestart = bitaddress / 8;
let bytestop = (bitaddress + NBITS - 1) / 8;
debug_assert!(bytestop < data.len());
if bytestart == bytestop {
let raw = unsafe { data.as_ptr().add(bytestart).read() };
packing::unpack_u8::<NBITS>(raw, bitaddress - 8 * bytestart)
} else {
let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
packing::unpack_u16::<NBITS>(raw, bitaddress - 8 * bytestart)
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct BitTranspose;
unsafe impl PermutationStrategy<4> for BitTranspose {
fn bytes(count: usize) -> usize {
32 * utils::div_round_up(count, 64)
}
unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
let block_start = 32 * (i / 64);
let byte_start = block_start + (i % 64) / 8;
let bit = i % 8;
let mask: u8 = 0x1 << bit;
for p in 0..4 {
let mut v = data[byte_start + 8 * p];
v = (v & !mask) | (((encoded >> p) & 0x1) << bit);
data[byte_start + 8 * p] = v;
}
}
unsafe fn unpack(data: &[u8], i: usize) -> u8 {
let block_start = 32 * (i / 64);
let byte_start = block_start + (i % 64) / 8;
let bit = i % 8;
let mut output: u8 = 0;
for p in 0..4 {
let v = data[byte_start + 8 * p];
output |= ((v >> bit) & 0x1) << p
}
output
}
}
#[derive(Debug, Error, Clone, Copy)]
#[error("input span has length {got} bytes but expected {expected}")]
pub struct ConstructionError {
got: usize,
expected: usize,
}
#[derive(Debug, Error, Clone, Copy)]
#[error("index {index} exceeds the maximum length of {len}")]
pub struct IndexOutOfBounds {
index: usize,
len: usize,
}
impl IndexOutOfBounds {
fn new(index: usize, len: usize) -> Self {
Self { index, len }
}
}
#[derive(Debug, Error, Clone, Copy)]
#[error("error setting index in bitslice")]
#[non_exhaustive]
pub enum SetError {
IndexError(#[from] IndexOutOfBounds),
EncodingError(#[from] EncodingError),
}
#[derive(Debug, Error, Clone, Copy)]
#[error("error getting index in bitslice")]
pub enum GetError {
IndexError(#[from] IndexOutOfBounds),
}
#[derive(Debug, Clone, Copy)]
pub struct BitSliceBase<const NBITS: usize, Repr, Ptr, Perm = Dense, Len = Dynamic>
where
Repr: Representation<NBITS>,
Ptr: AsPtr<Type = u8>,
Perm: PermutationStrategy<NBITS>,
Len: Length,
{
ptr: Ptr,
len: Len,
repr: PhantomData<Repr>,
packing: PhantomData<Perm>,
}
impl<const NBITS: usize, Repr, Ptr, Perm, Len> BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
where
Repr: Representation<NBITS>,
Ptr: AsPtr<Type = u8>,
Perm: PermutationStrategy<NBITS>,
Len: Length,
{
const _CHECK: () = assert!(NBITS > 0 && NBITS <= 8);
pub fn bytes_for(count: usize) -> usize {
Perm::bytes(count)
}
unsafe fn new_unchecked_internal(ptr: Ptr, len: Len) -> Self {
Self {
ptr,
len,
repr: PhantomData,
packing: PhantomData,
}
}
pub unsafe fn new_unchecked<Pre, Count>(precursor: Pre, count: Count) -> Self
where
Count: Into<Len>,
Pre: Precursor<Ptr>,
{
let count: Len = count.into();
debug_assert_eq!(precursor.precursor_len(), Self::bytes_for(count.value()));
unsafe { Self::new_unchecked_internal(precursor.precursor_into(), count) }
}
pub fn new<Pre, Count>(precursor: Pre, count: Count) -> Result<Self, ConstructionError>
where
Count: Into<Len>,
Pre: Precursor<Ptr>,
{
let count: Len = count.into();
if precursor.precursor_len() != Self::bytes_for(count.value()) {
Err(ConstructionError {
got: precursor.precursor_len(),
expected: Self::bytes_for(count.value()),
})
} else {
Ok(unsafe { Self::new_unchecked(precursor, count) })
}
}
pub fn len(&self) -> usize {
self.len.value()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn bytes(&self) -> usize {
Self::bytes_for(self.len())
}
pub fn get(&self, i: usize) -> Result<i64, GetError> {
if i >= self.len() {
Err(IndexOutOfBounds::new(i, self.len()).into())
} else {
Ok(unsafe { self.get_unchecked(i) })
}
}
pub unsafe fn get_unchecked(&self, i: usize) -> i64 {
debug_assert!(i < self.len());
debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
Repr::decode(unsafe { Perm::unpack(self.as_slice(), i) })
}
pub fn set(&mut self, i: usize, value: i64) -> Result<(), SetError>
where
Ptr: AsMutPtr<Type = u8>,
{
if i >= self.len() {
return Err(IndexOutOfBounds::new(i, self.len()).into());
}
let encoded = Repr::encode(value)?;
unsafe { self.set_unchecked(i, encoded) }
Ok(())
}
pub unsafe fn set_unchecked(&mut self, i: usize, encoded: u8)
where
Ptr: AsMutPtr<Type = u8>,
{
debug_assert!(i < self.len());
debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
unsafe { Perm::pack(self.as_mut_slice(), i, encoded) }
}
pub fn domain(&self) -> Repr::Domain {
Repr::domain()
}
pub(crate) fn as_slice(&self) -> &'_ [u8] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.bytes()) }
}
pub fn as_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
pub(super) fn as_mut_slice(&mut self) -> &'_ mut [u8]
where
Ptr: AsMutPtr,
{
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut_ptr(), self.bytes()) }
}
fn as_mut_ptr(&mut self) -> *mut u8
where
Ptr: AsMutPtr,
{
self.ptr.as_mut_ptr()
}
}
impl<const NBITS: usize, Repr, Perm, Len>
BitSliceBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Perm, Len>
where
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
Len: Length,
{
pub fn new_boxed<Count>(count: Count) -> Self
where
Count: Into<Len>,
{
let count: Len = count.into();
let bytes = Self::bytes_for(count.value());
let storage: Box<[u8]> = (0..bytes).map(|_| 0).collect();
unsafe { Self::new_unchecked(Poly::from(storage), count) }
}
}
impl<const NBITS: usize, Repr, Perm, Len, A> BitSliceBase<NBITS, Repr, Poly<[u8], A>, Perm, Len>
where
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
Len: Length,
A: AllocatorCore,
{
pub fn new_in<Count>(count: Count, allocator: A) -> Result<Self, AllocatorError>
where
Count: Into<Len>,
{
let count: Len = count.into();
let bytes = Self::bytes_for(count.value());
let storage = Poly::broadcast(0, bytes, allocator)?;
Ok(unsafe { Self::new_unchecked(storage, count) })
}
pub fn into_inner(self) -> Poly<[u8], A> {
self.ptr
}
}
pub type BitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
BitSliceBase<N, Repr, SlicePtr<'a, u8>, Perm, Len>;
pub type MutBitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
BitSliceBase<N, Repr, MutSlicePtr<'a, u8>, Perm, Len>;
pub type PolyBitSlice<const N: usize, Repr, A, Perm = Dense, Len = Dynamic> =
BitSliceBase<N, Repr, Poly<[u8], A>, Perm, Len>;
pub type BoxedBitSlice<const N: usize, Repr, Perm = Dense, Len = Dynamic> =
PolyBitSlice<N, Repr, GlobalAllocator, Perm, Len>;
impl<'a, Ptr> From<&'a BitSliceBase<8, Unsigned, Ptr>> for &'a [u8]
where
Ptr: AsPtr<Type = u8>,
{
fn from(slice: &'a BitSliceBase<8, Unsigned, Ptr>) -> Self {
unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) }
}
}
impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> Reborrow<'this>
for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
where
Repr: Representation<NBITS>,
Ptr: AsPtr<Type = u8>,
Perm: PermutationStrategy<NBITS>,
Len: Length,
{
type Target = BitSlice<'this, NBITS, Repr, Perm, Len>;
fn reborrow(&'this self) -> Self::Target {
let ptr: *const u8 = self.as_ptr();
debug_assert!(!ptr.is_null());
let nonnull = unsafe { NonNull::new_unchecked(ptr.cast_mut()) };
let ptr = unsafe { SlicePtr::new_unchecked(nonnull) };
Self::Target {
ptr,
len: self.len,
repr: PhantomData,
packing: PhantomData,
}
}
}
impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> ReborrowMut<'this>
for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
where
Repr: Representation<NBITS>,
Ptr: AsMutPtr<Type = u8>,
Perm: PermutationStrategy<NBITS>,
Len: Length,
{
type Target = MutBitSlice<'this, NBITS, Repr, Perm, Len>;
fn reborrow_mut(&'this mut self) -> Self::Target {
let ptr: *mut u8 = self.as_mut_ptr();
debug_assert!(!ptr.is_null());
let nonnull = unsafe { NonNull::new_unchecked(ptr) };
let ptr = unsafe { MutSlicePtr::new_unchecked(nonnull) };
Self::Target {
ptr,
len: self.len,
repr: PhantomData,
packing: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
use rand::{
Rng, SeedableRng,
distr::{Distribution, Uniform},
rngs::StdRng,
seq::{IndexedRandom, SliceRandom},
};
use super::*;
use crate::{bits::Static, test_util::AlwaysFails};
const BOUNDS: &str = "special bounds";
#[test]
fn test_encoding_error() {
assert_eq!(std::mem::size_of::<EncodingError>(), 16);
assert_eq!(
std::mem::size_of::<Option<EncodingError>>(),
16,
"expected EncodingError to have the niche optimization"
);
let err = EncodingError::new(7, &BOUNDS);
assert_eq!(
err.to_string(),
"value 7 is not in the encodable range of special bounds"
);
}
fn assert_send_and_sync<T: Send + Sync>(_x: &T) {}
#[test]
fn test_binary_repr() {
assert_eq!(Binary::encode(-1).unwrap(), 0);
assert_eq!(Binary::encode(1).unwrap(), 1);
assert_eq!(Binary::decode(0), -1);
assert_eq!(Binary::decode(1), 1);
assert!(Binary::check(-1));
assert!(Binary::check(1));
assert!(!Binary::check(0));
assert!(!Binary::check(-2));
assert!(!Binary::check(2));
let domain: Vec<_> = Binary::domain().collect();
assert_eq!(domain, &[-1, 1]);
}
#[test]
fn test_sizes() {
assert_eq!(std::mem::size_of::<BitSlice<'static, 8, Unsigned>>(), 16);
assert_eq!(std::mem::size_of::<MutBitSlice<'static, 8, Unsigned>>(), 16);
assert_eq!(
std::mem::size_of::<Option<BitSlice<'static, 8, Unsigned>>>(),
16
);
assert_eq!(
std::mem::size_of::<Option<MutBitSlice<'static, 8, Unsigned>>>(),
16
);
assert_eq!(
std::mem::size_of::<BitSlice<'static, 8, Unsigned, Dense, Static<128>>>(),
8
);
}
cfg_if::cfg_if! {
if #[cfg(miri)] {
const MAX_DIM: usize = 160;
const FUZZ_ITERATIONS: usize = 1;
} else if #[cfg(debug_assertions)] {
const MAX_DIM: usize = 128;
const FUZZ_ITERATIONS: usize = 10;
} else {
const MAX_DIM: usize = 256;
const FUZZ_ITERATIONS: usize = 100;
}
}
fn test_send_and_sync<const NBITS: usize, Repr, Perm>()
where
Repr: Representation<NBITS> + Send + Sync,
Perm: PermutationStrategy<NBITS> + Send + Sync,
{
let mut x = BoxedBitSlice::<NBITS, Repr, Perm>::new_boxed(1);
assert_send_and_sync(&x);
assert_send_and_sync(&x.reborrow());
assert_send_and_sync(&x.reborrow_mut());
}
fn test_empty<const NBITS: usize, Repr, Perm>()
where
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
{
let base: &mut [u8] = &mut [];
let mut slice = MutBitSlice::<NBITS, Repr, Perm>::new(base, 0).unwrap();
assert_eq!(slice.len(), 0);
assert!(slice.is_empty());
{
let reborrow = slice.reborrow();
assert_eq!(reborrow.len(), 0);
assert!(reborrow.is_empty());
}
{
let reborrow = slice.reborrow_mut();
assert_eq!(reborrow.len(), 0);
assert!(reborrow.is_empty());
}
}
fn test_construction_errors<const NBITS: usize, Repr, Perm>()
where
Repr: Representation<NBITS>,
Perm: PermutationStrategy<NBITS>,
{
let len: usize = 10;
let bytes = Perm::bytes(len);
let box_big = Poly::broadcast(0u8, bytes + 1, GlobalAllocator).unwrap();
let box_small = Poly::broadcast(0u8, bytes - 1, GlobalAllocator).unwrap();
let box_right = Poly::broadcast(0u8, bytes, GlobalAllocator).unwrap();
let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_big, len);
match result {
Err(ConstructionError { got, expected }) => {
assert_eq!(got, bytes + 1);
assert_eq!(expected, bytes);
}
_ => panic!("shouldn't have reached here!"),
};
let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_small, len);
match result {
Err(ConstructionError { got, expected }) => {
assert_eq!(got, bytes - 1);
assert_eq!(expected, bytes);
}
_ => panic!("shouldn't have reached here!"),
};
let mut base = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_right, len).unwrap();
let ptr = base.as_ptr();
assert_eq!(base.len(), len);
{
let borrowed = base.reborrow_mut();
assert_eq!(borrowed.as_ptr(), ptr);
assert_eq!(borrowed.len(), len);
let borrowed = MutBitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
assert_eq!(borrowed.as_ptr(), ptr);
assert_eq!(borrowed.len(), len);
}
{
let mut oversized = vec![0; bytes + 1];
let result = MutBitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
match result {
Err(ConstructionError { got, expected }) => {
assert_eq!(got, bytes + 1);
assert_eq!(expected, bytes);
}
_ => panic!("shouldn't have reached here!"),
};
let mut undersized = vec![0; bytes - 1];
let result = MutBitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
match result {
Err(ConstructionError { got, expected }) => {
assert_eq!(got, bytes - 1);
assert_eq!(expected, bytes);
}
_ => panic!("shouldn't have reached here!"),
};
}
{
let borrowed = base.reborrow();
assert_eq!(borrowed.as_ptr(), ptr);
assert_eq!(borrowed.len(), len);
let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_slice(), len).unwrap();
assert_eq!(borrowed.as_ptr(), ptr);
assert_eq!(borrowed.len(), len);
let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
assert_eq!(borrowed.as_ptr(), ptr);
assert_eq!(borrowed.len(), len);
}
{
let mut oversized = vec![0; bytes + 1];
let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
match result {
Err(ConstructionError { got, expected }) => {
assert_eq!(got, bytes + 1);
assert_eq!(expected, bytes);
}
_ => panic!("shouldn't have reached here!"),
};
let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_slice(), len);
match result {
Err(ConstructionError { got, expected }) => {
assert_eq!(got, bytes + 1);
assert_eq!(expected, bytes);
}
_ => panic!("shouldn't have reached here!"),
};
let mut undersized = vec![0; bytes - 1];
let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
match result {
Err(ConstructionError { got, expected }) => {
assert_eq!(got, bytes - 1);
assert_eq!(expected, bytes);
}
_ => panic!("shouldn't have reached here!"),
};
let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_slice(), len);
match result {
Err(ConstructionError { got, expected }) => {
assert_eq!(got, bytes - 1);
assert_eq!(expected, bytes);
}
_ => panic!("shouldn't have reached here!"),
};
}
}
fn run_overwrite_test<const NBITS: usize, Perm, Len, R>(
base: &mut BoxedBitSlice<NBITS, Unsigned, Perm, Len>,
num_iterations: usize,
rng: &mut R,
) where
Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
Len: Length,
Perm: PermutationStrategy<NBITS>,
R: Rng,
{
let mut expected: Vec<i64> = vec![0; base.len()];
let mut indices: Vec<usize> = (0..base.len()).collect();
for i in 0..base.len() {
base.set(i, 0).unwrap();
}
for i in 0..base.len() {
assert_eq!(base.get(i).unwrap(), 0, "failed to initialize bit vector");
}
let domain = base.domain();
assert_eq!(domain, 0..=2i64.pow(NBITS as u32) - 1);
let distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
for iter in 0..num_iterations {
indices.shuffle(rng);
for &i in indices.iter() {
let value = distribution.sample(rng);
expected[i] = value;
base.set(i, value).unwrap();
}
for (i, &expect) in expected.iter().enumerate() {
let value = base.get(i).unwrap();
assert_eq!(
value, expect,
"retrieval failed on iteration {iter} at index {i}"
);
}
let borrowed = base.reborrow();
for (i, &expect) in expected.iter().enumerate() {
let value = borrowed.get(i).unwrap();
assert_eq!(
value, expect,
"reborrow retrieval failed on iteration {iter} at index {i}"
);
}
}
}
fn run_overwrite_binary_test<Perm, Len, R>(
base: &mut BoxedBitSlice<1, Binary, Perm, Len>,
num_iterations: usize,
rng: &mut R,
) where
Len: Length,
Perm: PermutationStrategy<1>,
R: Rng,
{
let mut expected: Vec<i64> = vec![0; base.len()];
let mut indices: Vec<usize> = (0..base.len()).collect();
for i in 0..base.len() {
base.set(i, -1).unwrap();
}
for i in 0..base.len() {
assert_eq!(base.get(i).unwrap(), -1, "failed to initialize bit vector");
}
let distribution: [i64; 2] = [-1, 1];
for iter in 0..num_iterations {
indices.shuffle(rng);
for &i in indices.iter() {
let value = distribution.choose(rng).unwrap();
expected[i] = *value;
base.set(i, *value).unwrap();
}
for (i, &expect) in expected.iter().enumerate() {
let value = base.get(i).unwrap();
assert_eq!(
value, expect,
"retrieval failed on iteration {iter} at index {i}"
);
}
let borrowed = base.reborrow();
for (i, &expect) in expected.iter().enumerate() {
let value = borrowed.get(i).unwrap();
assert_eq!(
value, expect,
"reborrow retrieval failed on iteration {iter} at index {i}"
);
}
}
}
fn test_unsigned_dense<const NBITS: usize, Len, R>(
len: Len,
minimum: i64,
maximum: i64,
rng: &mut R,
) where
Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
Dense: PermutationStrategy<NBITS>,
Len: Length,
R: Rng,
{
test_send_and_sync::<NBITS, Unsigned, Dense>();
test_empty::<NBITS, Unsigned, Dense>();
test_construction_errors::<NBITS, Unsigned, Dense>();
assert_eq!(Unsigned::domain_const::<NBITS>(), Unsigned::domain(),);
match PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, AlwaysFails) {
Ok(_) => {
if len.value() != 0 {
panic!("zero sized allocations don't require an allocator");
}
}
Err(AllocatorError) => {
if len.value() == 0 {
panic!("allocation should have failed");
}
}
}
let mut base =
PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, GlobalAllocator).unwrap();
assert_eq!(
base.len(),
len.value(),
"BoxedBitSlice returned the incorrect length"
);
let expected_bytes = BitSlice::<'static, NBITS, Unsigned>::bytes_for(len.value());
assert_eq!(
base.bytes(),
expected_bytes,
"BoxedBitSlice has the incorrect number of bytes"
);
assert_eq!(base.domain(), minimum..=maximum);
if len.value() == 0 {
return;
}
let ptr = base.as_ptr();
{
let mut borrowed = base.reborrow_mut();
assert_eq!(
borrowed.as_ptr(),
ptr,
"pointer was not preserved during borrowing!"
);
assert_eq!(
borrowed.len(),
len.value(),
"borrowing did not preserve length!"
);
borrowed.set(0, 0).unwrap();
assert_eq!(borrowed.get(0).unwrap(), 0);
borrowed.set(0, 1).unwrap();
assert_eq!(borrowed.get(0).unwrap(), 1);
borrowed.set(0, 0).unwrap();
assert_eq!(borrowed.get(0).unwrap(), 0);
let result = borrowed.set(0, minimum - 1);
assert!(matches!(result, Err(SetError::EncodingError { .. })));
let result = borrowed.set(0, maximum + 1);
assert!(matches!(result, Err(SetError::EncodingError { .. })));
let result = borrowed.set(borrowed.len(), 0);
assert!(matches!(result, Err(SetError::IndexError { .. })));
let result = borrowed.get(borrowed.len());
assert!(matches!(result, Err(GetError::IndexError { .. })));
}
{
let borrowed =
MutBitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
assert_eq!(
borrowed.as_ptr(),
ptr,
"pointer was not preserved during borrowing!"
);
assert_eq!(
borrowed.len(),
len.value(),
"borrowing did not preserve length!"
);
}
{
let borrowed = base.reborrow();
assert_eq!(
borrowed.as_ptr(),
ptr,
"pointer was not preserved during borrowing!"
);
assert_eq!(
borrowed.len(),
len.value(),
"borrowing did not preserve length!"
);
let result = borrowed.get(borrowed.len());
assert!(matches!(result, Err(GetError::IndexError { .. })));
}
{
let borrowed =
BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_slice(), len).unwrap();
assert_eq!(
borrowed.as_ptr(),
ptr,
"pointer was not preserved during borrowing!"
);
assert_eq!(
borrowed.len(),
len.value(),
"borrowing did not preserve length!"
);
}
{
let borrowed =
BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
assert_eq!(
borrowed.as_ptr(),
ptr,
"pointer was not preserved during borrowing!"
);
assert_eq!(
borrowed.len(),
len.value(),
"borrowing did not preserve length!"
);
}
run_overwrite_test(&mut base, FUZZ_ITERATIONS, rng);
}
macro_rules! generate_unsigned_test {
($name:ident, $NBITS:literal, $MIN:literal, $MAX:literal, $SEED:literal) => {
#[test]
fn $name() {
let mut rng = StdRng::seed_from_u64($SEED);
for dim in 0..MAX_DIM {
test_unsigned_dense::<$NBITS, Dynamic, _>(dim.into(), $MIN, $MAX, &mut rng);
}
}
};
}
generate_unsigned_test!(test_unsigned_8bit, 8, 0, 0xff, 0xc652f2a1018f442b);
generate_unsigned_test!(test_unsigned_7bit, 7, 0, 0x7f, 0xb732e59fec6d6c9c);
generate_unsigned_test!(test_unsigned_6bit, 6, 0, 0x3f, 0x35d9380d0a318f21);
generate_unsigned_test!(test_unsigned_5bit, 5, 0, 0x1f, 0xfb09895183334304);
generate_unsigned_test!(test_unsigned_4bit, 4, 0, 0x0f, 0x38dfcf9e82c33f48);
generate_unsigned_test!(test_unsigned_3bit, 3, 0, 0x07, 0xf9a94c8c749ee26c);
generate_unsigned_test!(test_unsigned_2bit, 2, 0, 0x03, 0xbba03db62cecf4cf);
generate_unsigned_test!(test_unsigned_1bit, 1, 0, 0x01, 0x54ea2a07d7c67f37);
#[test]
fn test_binary_dense() {
let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
for len in 0..MAX_DIM {
test_send_and_sync::<1, Binary, Dense>();
test_empty::<1, Binary, Dense>();
test_construction_errors::<1, Binary, Dense>();
let mut base = BoxedBitSlice::<1, Binary>::new_boxed(len);
assert_eq!(
base.len(),
len,
"BoxedBitSlice returned the incorrect length"
);
assert_eq!(base.bytes(), len.div_ceil(8));
let bytes = BitSlice::<'static, 1, Binary>::bytes_for(len);
assert_eq!(
bytes,
len.div_ceil(8),
"BoxedBitSlice has the incorrect number of bytes"
);
if len == 0 {
continue;
}
let result = base.set(0, 0);
assert!(matches!(result, Err(SetError::EncodingError { .. })));
let result = base.set(base.len(), -1);
assert!(matches!(result, Err(SetError::IndexError { .. })));
let result = base.get(base.len());
assert!(matches!(result, Err(GetError::IndexError { .. })));
run_overwrite_binary_test(&mut base, FUZZ_ITERATIONS, &mut rng);
}
}
#[test]
fn test_4bit_bit_transpose() {
let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
for len in 0..MAX_DIM {
test_send_and_sync::<4, Unsigned, BitTranspose>();
test_empty::<4, Unsigned, BitTranspose>();
test_construction_errors::<4, Unsigned, BitTranspose>();
let mut base = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(len);
assert_eq!(
base.len(),
len,
"BoxedBitSlice returned the incorrect length"
);
assert_eq!(base.bytes(), 32 * len.div_ceil(64));
let bytes = BitSlice::<'static, 4, Unsigned, BitTranspose>::bytes_for(len);
assert_eq!(
bytes,
32 * len.div_ceil(64),
"BoxedBitSlice has the incorrect number of bytes"
);
if len == 0 {
continue;
}
let result = base.set(0, -1);
assert!(matches!(result, Err(SetError::EncodingError { .. })));
let result = base.set(base.len(), -1);
assert!(matches!(result, Err(SetError::IndexError { .. })));
let result = base.get(base.len());
assert!(matches!(result, Err(GetError::IndexError { .. })));
run_overwrite_test(&mut base, FUZZ_ITERATIONS, &mut rng);
}
}
}