1use std::{marker::PhantomData, ops::RangeInclusive, ptr::NonNull};
7
8use diskann_utils::{Reborrow, ReborrowMut};
9use thiserror::Error;
10
11use super::{
12 length::{Dynamic, Length},
13 packing,
14 ptr::{AsMutPtr, AsPtr, MutSlicePtr, Precursor, SlicePtr},
15};
16use crate::{
17 alloc::{AllocatorCore, AllocatorError, GlobalAllocator, Poly},
18 utils,
19};
20
21pub trait Representation<const NBITS: usize> {
27 type Domain: Iterator<Item = i64>;
29
30 fn encode(value: i64) -> Result<u8, EncodingError>;
33
34 fn encode_unchecked(value: i64) -> u8;
41
42 fn decode(raw: u8) -> i64;
49
50 fn check(value: i64) -> bool;
52
53 fn domain() -> Self::Domain;
55}
56
57#[derive(Debug, Error, Clone, Copy)]
58#[error("value {} is not in the encodable range of {}", got, domain)]
59pub struct EncodingError {
60 got: i64,
61 domain: &'static &'static str,
67}
68
69impl EncodingError {
70 fn new(got: i64, domain: &'static &'static str) -> Self {
71 Self { got, domain }
72 }
73}
74
75#[derive(Debug, Clone, Copy)]
84pub struct Unsigned;
85
86impl Unsigned {
87 pub const fn domain_const<const NBITS: usize>() -> std::ops::RangeInclusive<i64> {
89 0..=2i64.pow(NBITS as u32) - 1
90 }
91
92 #[allow(clippy::panic)]
93 const fn domain_str(nbits: usize) -> &'static &'static str {
94 match nbits {
95 8 => &"[0, 255]",
96 7 => &"[0, 127]",
97 6 => &"[0, 63]",
98 5 => &"[0, 31]",
99 4 => &"[0, 15]",
100 3 => &"[0, 7]",
101 2 => &"[0, 3]",
102 1 => &"[0, 1]",
103 _ => panic!("unimplemented"),
104 }
105 }
106}
107
108macro_rules! repr_unsigned {
109 ($N:literal) => {
110 impl Representation<$N> for Unsigned {
111 type Domain = RangeInclusive<i64>;
112
113 fn encode(value: i64) -> Result<u8, EncodingError> {
114 if !<Self as Representation<$N>>::check(value) {
115 let domain = Self::domain_str($N);
118 Err(EncodingError::new(value, domain))
119 } else {
120 Ok(<Self as Representation<$N>>::encode_unchecked(value))
121 }
122 }
123
124 fn encode_unchecked(value: i64) -> u8 {
125 debug_assert!(<Self as Representation<$N>>::check(value));
126 value as u8
127 }
128
129 fn decode(raw: u8) -> i64 {
130 let raw: i64 = raw.into();
132 debug_assert!(<Self as Representation<$N>>::check(raw));
133 raw
134 }
135
136 fn check(value: i64) -> bool {
137 <Self as Representation<$N>>::domain().contains(&value)
138 }
139
140 fn domain() -> Self::Domain {
141 Self::domain_const::<$N>()
142 }
143 }
144 };
145 ($N:literal, $($Ns:literal),+) => {
146 repr_unsigned!($N);
147 $(repr_unsigned!($Ns);)+
148 };
149}
150
151repr_unsigned!(1, 2, 3, 4, 5, 6, 7, 8);
152
153#[derive(Debug, Clone, Copy)]
159pub struct Binary;
160
161impl Representation<1> for Binary {
162 type Domain = std::array::IntoIter<i64, 2>;
163
164 fn encode(value: i64) -> Result<u8, EncodingError> {
165 if !Self::check(value) {
166 const DOMAIN: &str = "{-1, 1}";
167 Err(EncodingError::new(value, &DOMAIN))
168 } else {
169 Ok(Self::encode_unchecked(value))
170 }
171 }
172
173 fn encode_unchecked(value: i64) -> u8 {
174 debug_assert!(Self::check(value));
175 value.clamp(0, 1) as u8
177 }
178
179 fn decode(raw: u8) -> i64 {
180 let raw: i64 = raw.into();
183 (raw << 1) - 1
184 }
185
186 fn check(value: i64) -> bool {
187 value == -1 || value == 1
188 }
189
190 fn domain() -> Self::Domain {
194 [-1, 1].into_iter()
195 }
196}
197
198pub unsafe trait PermutationStrategy<const NBITS: usize> {
218 fn bytes(count: usize) -> usize;
220
221 unsafe fn pack(s: &mut [u8], i: usize, value: u8);
230
231 unsafe fn unpack(s: &[u8], i: usize) -> u8;
241}
242
243#[derive(Debug, Clone, Copy)]
247pub struct Dense;
248
249impl Dense {
250 fn bytes<const NBITS: usize>(count: usize) -> usize {
251 utils::div_round_up(NBITS * count, 8)
252 }
253}
254
255unsafe impl<const NBITS: usize> PermutationStrategy<NBITS> for Dense {
257 fn bytes(count: usize) -> usize {
258 Self::bytes::<NBITS>(count)
259 }
260
261 unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
262 let bitaddress = NBITS * i;
263
264 let bytestart = bitaddress / 8;
265 let bytestop = (bitaddress + NBITS - 1) / 8;
266 let bitstart = bitaddress - 8 * bytestart;
267 debug_assert!(bytestop < data.len());
268
269 if bytestart == bytestop {
270 let raw = unsafe { data.as_ptr().add(bytestart).read() };
279 let packed = packing::pack_u8::<NBITS>(raw, encoded, bitstart);
280
281 unsafe { data.as_mut_ptr().add(bytestart).write(packed) };
285 } else {
286 let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
294 let packed = packing::pack_u16::<NBITS>(raw, encoded, bitstart);
295
296 unsafe {
300 data.as_mut_ptr()
301 .add(bytestart)
302 .cast::<u16>()
303 .write_unaligned(packed)
304 };
305 }
306 }
307
308 unsafe fn unpack(data: &[u8], i: usize) -> u8 {
309 let bitaddress = NBITS * i;
310
311 let bytestart = bitaddress / 8;
312 let bytestop = (bitaddress + NBITS - 1) / 8;
313 debug_assert!(bytestop < data.len());
314 if bytestart == bytestop {
315 let raw = unsafe { data.as_ptr().add(bytestart).read() };
317 packing::unpack_u8::<NBITS>(raw, bitaddress - 8 * bytestart)
318 } else {
319 let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
321 packing::unpack_u16::<NBITS>(raw, bitaddress - 8 * bytestart)
322 }
323 }
324}
325
326#[derive(Debug, Clone, Copy)]
340pub struct BitTranspose;
341
342unsafe impl PermutationStrategy<4> for BitTranspose {
345 fn bytes(count: usize) -> usize {
346 32 * utils::div_round_up(count, 64)
347 }
348
349 unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
350 let block_start = 32 * (i / 64);
352 let byte_start = block_start + (i % 64) / 8;
354 let bit = i % 8;
356
357 let mask: u8 = 0x1 << bit;
358 for p in 0..4 {
359 let mut v = data[byte_start + 8 * p];
360 v = (v & !mask) | (((encoded >> p) & 0x1) << bit);
361 data[byte_start + 8 * p] = v;
362 }
363 }
364
365 unsafe fn unpack(data: &[u8], i: usize) -> u8 {
366 let block_start = 32 * (i / 64);
368 let byte_start = block_start + (i % 64) / 8;
370 let bit = i % 8;
372
373 let mut output: u8 = 0;
374 for p in 0..4 {
375 let v = data[byte_start + 8 * p];
376 output |= ((v >> bit) & 0x1) << p
377 }
378 output
379 }
380}
381
382#[derive(Debug, Error, Clone, Copy)]
387#[error("input span has length {got} bytes but expected {expected}")]
388pub struct ConstructionError {
389 got: usize,
390 expected: usize,
391}
392
393#[derive(Debug, Error, Clone, Copy)]
394#[error("index {index} exceeds the maximum length of {len}")]
395pub struct IndexOutOfBounds {
396 index: usize,
397 len: usize,
398}
399
400impl IndexOutOfBounds {
401 fn new(index: usize, len: usize) -> Self {
402 Self { index, len }
403 }
404}
405
406#[derive(Debug, Error, Clone, Copy)]
407#[error("error setting index in bitslice")]
408#[non_exhaustive]
409pub enum SetError {
410 IndexError(#[from] IndexOutOfBounds),
411 EncodingError(#[from] EncodingError),
412}
413
414#[derive(Debug, Error, Clone, Copy)]
415#[error("error getting index in bitslice")]
416pub enum GetError {
417 IndexError(#[from] IndexOutOfBounds),
418}
419
420#[derive(Debug, Clone, Copy)]
518pub struct BitSliceBase<const NBITS: usize, Repr, Ptr, Perm = Dense, Len = Dynamic>
519where
520 Repr: Representation<NBITS>,
521 Ptr: AsPtr<Type = u8>,
522 Perm: PermutationStrategy<NBITS>,
523 Len: Length,
524{
525 ptr: Ptr,
526 len: Len,
527 repr: PhantomData<Repr>,
528 packing: PhantomData<Perm>,
529}
530
531impl<const NBITS: usize, Repr, Ptr, Perm, Len> BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
532where
533 Repr: Representation<NBITS>,
534 Ptr: AsPtr<Type = u8>,
535 Perm: PermutationStrategy<NBITS>,
536 Len: Length,
537{
538 const _CHECK: () = assert!(NBITS > 0 && NBITS <= 8);
540
541 pub fn bytes_for(count: usize) -> usize {
543 Perm::bytes(count)
544 }
545
546 unsafe fn new_unchecked_internal(ptr: Ptr, len: Len) -> Self {
553 Self {
554 ptr,
555 len,
556 repr: PhantomData,
557 packing: PhantomData,
558 }
559 }
560
561 pub unsafe fn new_unchecked<Pre, Count>(precursor: Pre, count: Count) -> Self
571 where
572 Count: Into<Len>,
573 Pre: Precursor<Ptr>,
574 {
575 let count: Len = count.into();
576 debug_assert_eq!(precursor.precursor_len(), Self::bytes_for(count.value()));
577
578 unsafe { Self::new_unchecked_internal(precursor.precursor_into(), count) }
580 }
581
582 pub fn new<Pre, Count>(precursor: Pre, count: Count) -> Result<Self, ConstructionError>
592 where
593 Count: Into<Len>,
594 Pre: Precursor<Ptr>,
595 {
596 let count: Len = count.into();
598
599 if precursor.precursor_len() != Self::bytes_for(count.value()) {
601 Err(ConstructionError {
602 got: precursor.precursor_len(),
603 expected: Self::bytes_for(count.value()),
604 })
605 } else {
606 Ok(unsafe { Self::new_unchecked(precursor, count) })
611 }
612 }
613
614 pub fn len(&self) -> usize {
616 self.len.value()
617 }
618
619 pub fn is_empty(&self) -> bool {
621 self.len() == 0
622 }
623
624 pub fn bytes(&self) -> usize {
626 Self::bytes_for(self.len())
627 }
628
629 pub fn get(&self, i: usize) -> Result<i64, GetError> {
631 if i >= self.len() {
632 Err(IndexOutOfBounds::new(i, self.len()).into())
633 } else {
634 Ok(unsafe { self.get_unchecked(i) })
636 }
637 }
638
639 pub unsafe fn get_unchecked(&self, i: usize) -> i64 {
645 debug_assert!(i < self.len());
646 debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
647
648 Repr::decode(unsafe { Perm::unpack(self.as_slice(), i) })
653 }
654
655 pub fn set(&mut self, i: usize, value: i64) -> Result<(), SetError>
657 where
658 Ptr: AsMutPtr<Type = u8>,
659 {
660 if i >= self.len() {
661 return Err(IndexOutOfBounds::new(i, self.len()).into());
662 }
663
664 let encoded = Repr::encode(value)?;
665
666 unsafe { self.set_unchecked(i, encoded) }
668 Ok(())
669 }
670
671 pub unsafe fn set_unchecked(&mut self, i: usize, encoded: u8)
677 where
678 Ptr: AsMutPtr<Type = u8>,
679 {
680 debug_assert!(i < self.len());
681 debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
682
683 unsafe { Perm::pack(self.as_mut_slice(), i, encoded) }
688 }
689
690 pub fn domain(&self) -> Repr::Domain {
692 Repr::domain()
693 }
694
695 pub(crate) fn as_slice(&self) -> &'_ [u8] {
696 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.bytes()) }
700 }
701
702 pub fn as_ptr(&self) -> *const u8 {
709 self.ptr.as_ptr()
710 }
711
712 pub(super) fn as_mut_slice(&mut self) -> &'_ mut [u8]
714 where
715 Ptr: AsMutPtr,
716 {
717 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut_ptr(), self.bytes()) }
724 }
725
726 fn as_mut_ptr(&mut self) -> *mut u8
728 where
729 Ptr: AsMutPtr,
730 {
731 self.ptr.as_mut_ptr()
732 }
733}
734
735impl<const NBITS: usize, Repr, Perm, Len>
736 BitSliceBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Perm, Len>
737where
738 Repr: Representation<NBITS>,
739 Perm: PermutationStrategy<NBITS>,
740 Len: Length,
741{
742 pub fn new_boxed<Count>(count: Count) -> Self
761 where
762 Count: Into<Len>,
763 {
764 let count: Len = count.into();
765 let bytes = Self::bytes_for(count.value());
766 let storage: Box<[u8]> = (0..bytes).map(|_| 0).collect();
767
768 unsafe { Self::new_unchecked(Poly::from(storage), count) }
773 }
774}
775
776impl<const NBITS: usize, Repr, Perm, Len, A> BitSliceBase<NBITS, Repr, Poly<[u8], A>, Perm, Len>
777where
778 Repr: Representation<NBITS>,
779 Perm: PermutationStrategy<NBITS>,
780 Len: Length,
781 A: AllocatorCore,
782{
783 pub fn new_in<Count>(count: Count, allocator: A) -> Result<Self, AllocatorError>
807 where
808 Count: Into<Len>,
809 {
810 let count: Len = count.into();
811 let bytes = Self::bytes_for(count.value());
812 let storage = Poly::broadcast(0, bytes, allocator)?;
813
814 Ok(unsafe { Self::new_unchecked(storage, count) })
819 }
820
821 pub fn into_inner(self) -> Poly<[u8], A> {
823 self.ptr
824 }
825}
826
827pub type BitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
829 BitSliceBase<N, Repr, SlicePtr<'a, u8>, Perm, Len>;
830
831pub type MutBitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
833 BitSliceBase<N, Repr, MutSlicePtr<'a, u8>, Perm, Len>;
834
835pub type PolyBitSlice<const N: usize, Repr, A, Perm = Dense, Len = Dynamic> =
837 BitSliceBase<N, Repr, Poly<[u8], A>, Perm, Len>;
838
839pub type BoxedBitSlice<const N: usize, Repr, Perm = Dense, Len = Dynamic> =
841 PolyBitSlice<N, Repr, GlobalAllocator, Perm, Len>;
842
843impl<'a, Ptr> From<&'a BitSliceBase<8, Unsigned, Ptr>> for &'a [u8]
848where
849 Ptr: AsPtr<Type = u8>,
850{
851 fn from(slice: &'a BitSliceBase<8, Unsigned, Ptr>) -> Self {
852 unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) }
858 }
859}
860
861impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> Reborrow<'this>
862 for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
863where
864 Repr: Representation<NBITS>,
865 Ptr: AsPtr<Type = u8>,
866 Perm: PermutationStrategy<NBITS>,
867 Len: Length,
868{
869 type Target = BitSlice<'this, NBITS, Repr, Perm, Len>;
870
871 fn reborrow(&'this self) -> Self::Target {
872 let ptr: *const u8 = self.as_ptr();
873 debug_assert!(!ptr.is_null());
874
875 let nonnull = unsafe { NonNull::new_unchecked(ptr.cast_mut()) };
879
880 let ptr = unsafe { SlicePtr::new_unchecked(nonnull) };
887
888 Self::Target {
889 ptr,
890 len: self.len,
891 repr: PhantomData,
892 packing: PhantomData,
893 }
894 }
895}
896
897impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> ReborrowMut<'this>
898 for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
899where
900 Repr: Representation<NBITS>,
901 Ptr: AsMutPtr<Type = u8>,
902 Perm: PermutationStrategy<NBITS>,
903 Len: Length,
904{
905 type Target = MutBitSlice<'this, NBITS, Repr, Perm, Len>;
906
907 fn reborrow_mut(&'this mut self) -> Self::Target {
908 let ptr: *mut u8 = self.as_mut_ptr();
909 debug_assert!(!ptr.is_null());
910
911 let nonnull = unsafe { NonNull::new_unchecked(ptr) };
913
914 let ptr = unsafe { MutSlicePtr::new_unchecked(nonnull) };
924
925 Self::Target {
926 ptr,
927 len: self.len,
928 repr: PhantomData,
929 packing: PhantomData,
930 }
931 }
932}
933
934#[cfg(test)]
939mod tests {
940 use rand::{
941 Rng, SeedableRng,
942 distr::{Distribution, Uniform},
943 rngs::StdRng,
944 seq::{IndexedRandom, SliceRandom},
945 };
946
947 use super::*;
948 use crate::{bits::Static, test_util::AlwaysFails};
949
950 const BOUNDS: &str = "special bounds";
955
956 #[test]
957 fn test_encoding_error() {
958 assert_eq!(std::mem::size_of::<EncodingError>(), 16);
959 assert_eq!(
960 std::mem::size_of::<Option<EncodingError>>(),
961 16,
962 "expected EncodingError to have the niche optimization"
963 );
964 let err = EncodingError::new(7, &BOUNDS);
965 assert_eq!(
966 err.to_string(),
967 "value 7 is not in the encodable range of special bounds"
968 );
969 }
970
971 fn assert_send_and_sync<T: Send + Sync>(_x: &T) {}
973
974 #[test]
979 fn test_binary_repr() {
980 assert_eq!(Binary::encode(-1).unwrap(), 0);
981 assert_eq!(Binary::encode(1).unwrap(), 1);
982 assert_eq!(Binary::decode(0), -1);
983 assert_eq!(Binary::decode(1), 1);
984
985 assert!(Binary::check(-1));
986 assert!(Binary::check(1));
987 assert!(!Binary::check(0));
988 assert!(!Binary::check(-2));
989 assert!(!Binary::check(2));
990
991 let domain: Vec<_> = Binary::domain().collect();
992 assert_eq!(domain, &[-1, 1]);
993 }
994
995 #[test]
1000 fn test_sizes() {
1001 assert_eq!(std::mem::size_of::<BitSlice<'static, 8, Unsigned>>(), 16);
1002 assert_eq!(std::mem::size_of::<MutBitSlice<'static, 8, Unsigned>>(), 16);
1003
1004 assert_eq!(
1006 std::mem::size_of::<Option<BitSlice<'static, 8, Unsigned>>>(),
1007 16
1008 );
1009 assert_eq!(
1010 std::mem::size_of::<Option<MutBitSlice<'static, 8, Unsigned>>>(),
1011 16
1012 );
1013
1014 assert_eq!(
1015 std::mem::size_of::<BitSlice<'static, 8, Unsigned, Dense, Static<128>>>(),
1016 8
1017 );
1018 }
1019
1020 cfg_if::cfg_if! {
1025 if #[cfg(miri)] {
1026 const MAX_DIM: usize = 160;
1027 const FUZZ_ITERATIONS: usize = 1;
1028 } else if #[cfg(debug_assertions)] {
1029 const MAX_DIM: usize = 128;
1030 const FUZZ_ITERATIONS: usize = 10;
1031 } else {
1032 const MAX_DIM: usize = 256;
1033 const FUZZ_ITERATIONS: usize = 100;
1034 }
1035 }
1036
1037 fn test_send_and_sync<const NBITS: usize, Repr, Perm>()
1038 where
1039 Repr: Representation<NBITS> + Send + Sync,
1040 Perm: PermutationStrategy<NBITS> + Send + Sync,
1041 {
1042 let mut x = BoxedBitSlice::<NBITS, Repr, Perm>::new_boxed(1);
1043 assert_send_and_sync(&x);
1044 assert_send_and_sync(&x.reborrow());
1045 assert_send_and_sync(&x.reborrow_mut());
1046 }
1047
1048 fn test_empty<const NBITS: usize, Repr, Perm>()
1049 where
1050 Repr: Representation<NBITS>,
1051 Perm: PermutationStrategy<NBITS>,
1052 {
1053 let base: &mut [u8] = &mut [];
1054 let mut slice = MutBitSlice::<NBITS, Repr, Perm>::new(base, 0).unwrap();
1055 assert_eq!(slice.len(), 0);
1056 assert!(slice.is_empty());
1057
1058 {
1059 let reborrow = slice.reborrow();
1060 assert_eq!(reborrow.len(), 0);
1061 assert!(reborrow.is_empty());
1062 }
1063
1064 {
1065 let reborrow = slice.reborrow_mut();
1066 assert_eq!(reborrow.len(), 0);
1067 assert!(reborrow.is_empty());
1068 }
1069 }
1070
1071 fn test_construction_errors<const NBITS: usize, Repr, Perm>()
1073 where
1074 Repr: Representation<NBITS>,
1075 Perm: PermutationStrategy<NBITS>,
1076 {
1077 let len: usize = 10;
1078 let bytes = Perm::bytes(len);
1079
1080 let box_big = Poly::broadcast(0u8, bytes + 1, GlobalAllocator).unwrap();
1082 let box_small = Poly::broadcast(0u8, bytes - 1, GlobalAllocator).unwrap();
1083 let box_right = Poly::broadcast(0u8, bytes, GlobalAllocator).unwrap();
1084
1085 let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_big, len);
1086 match result {
1087 Err(ConstructionError { got, expected }) => {
1088 assert_eq!(got, bytes + 1);
1089 assert_eq!(expected, bytes);
1090 }
1091 _ => panic!("shouldn't have reached here!"),
1092 };
1093
1094 let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_small, len);
1095 match result {
1096 Err(ConstructionError { got, expected }) => {
1097 assert_eq!(got, bytes - 1);
1098 assert_eq!(expected, bytes);
1099 }
1100 _ => panic!("shouldn't have reached here!"),
1101 };
1102
1103 let mut base = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_right, len).unwrap();
1104 let ptr = base.as_ptr();
1105 assert_eq!(base.len(), len);
1106
1107 {
1109 let borrowed = base.reborrow_mut();
1111 assert_eq!(borrowed.as_ptr(), ptr);
1112 assert_eq!(borrowed.len(), len);
1113
1114 let borrowed = MutBitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
1116 assert_eq!(borrowed.as_ptr(), ptr);
1117 assert_eq!(borrowed.len(), len);
1118 }
1119
1120 {
1122 let mut oversized = vec![0; bytes + 1];
1124 let result = MutBitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
1125 match result {
1126 Err(ConstructionError { got, expected }) => {
1127 assert_eq!(got, bytes + 1);
1128 assert_eq!(expected, bytes);
1129 }
1130 _ => panic!("shouldn't have reached here!"),
1131 };
1132
1133 let mut undersized = vec![0; bytes - 1];
1134 let result = MutBitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
1135 match result {
1136 Err(ConstructionError { got, expected }) => {
1137 assert_eq!(got, bytes - 1);
1138 assert_eq!(expected, bytes);
1139 }
1140 _ => panic!("shouldn't have reached here!"),
1141 };
1142 }
1143
1144 {
1146 let borrowed = base.reborrow();
1148 assert_eq!(borrowed.as_ptr(), ptr);
1149 assert_eq!(borrowed.len(), len);
1150
1151 let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_slice(), len).unwrap();
1153 assert_eq!(borrowed.as_ptr(), ptr);
1154 assert_eq!(borrowed.len(), len);
1155
1156 let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
1158 assert_eq!(borrowed.as_ptr(), ptr);
1159 assert_eq!(borrowed.len(), len);
1160 }
1161
1162 {
1164 let mut oversized = vec![0; bytes + 1];
1166 let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
1167 match result {
1168 Err(ConstructionError { got, expected }) => {
1169 assert_eq!(got, bytes + 1);
1170 assert_eq!(expected, bytes);
1171 }
1172 _ => panic!("shouldn't have reached here!"),
1173 };
1174
1175 let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_slice(), len);
1176 match result {
1177 Err(ConstructionError { got, expected }) => {
1178 assert_eq!(got, bytes + 1);
1179 assert_eq!(expected, bytes);
1180 }
1181 _ => panic!("shouldn't have reached here!"),
1182 };
1183
1184 let mut undersized = vec![0; bytes - 1];
1186 let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
1187 match result {
1188 Err(ConstructionError { got, expected }) => {
1189 assert_eq!(got, bytes - 1);
1190 assert_eq!(expected, bytes);
1191 }
1192 _ => panic!("shouldn't have reached here!"),
1193 };
1194
1195 let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_slice(), len);
1196 match result {
1197 Err(ConstructionError { got, expected }) => {
1198 assert_eq!(got, bytes - 1);
1199 assert_eq!(expected, bytes);
1200 }
1201 _ => panic!("shouldn't have reached here!"),
1202 };
1203 }
1204 }
1205
1206 fn run_overwrite_test<const NBITS: usize, Perm, Len, R>(
1209 base: &mut BoxedBitSlice<NBITS, Unsigned, Perm, Len>,
1210 num_iterations: usize,
1211 rng: &mut R,
1212 ) where
1213 Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
1214 Len: Length,
1215 Perm: PermutationStrategy<NBITS>,
1216 R: Rng,
1217 {
1218 let mut expected: Vec<i64> = vec![0; base.len()];
1219 let mut indices: Vec<usize> = (0..base.len()).collect();
1220 for i in 0..base.len() {
1221 base.set(i, 0).unwrap();
1222 }
1223
1224 for i in 0..base.len() {
1225 assert_eq!(base.get(i).unwrap(), 0, "failed to initialize bit vector");
1226 }
1227
1228 let domain = base.domain();
1229 assert_eq!(domain, 0..=2i64.pow(NBITS as u32) - 1);
1230 let distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
1231
1232 for iter in 0..num_iterations {
1233 indices.shuffle(rng);
1235
1236 for &i in indices.iter() {
1238 let value = distribution.sample(rng);
1239 expected[i] = value;
1240 base.set(i, value).unwrap();
1241 }
1242
1243 for (i, &expect) in expected.iter().enumerate() {
1245 let value = base.get(i).unwrap();
1246 assert_eq!(
1247 value, expect,
1248 "retrieval failed on iteration {iter} at index {i}"
1249 );
1250 }
1251
1252 let borrowed = base.reborrow();
1254 for (i, &expect) in expected.iter().enumerate() {
1255 let value = borrowed.get(i).unwrap();
1256 assert_eq!(
1257 value, expect,
1258 "reborrow retrieval failed on iteration {iter} at index {i}"
1259 );
1260 }
1261 }
1262 }
1263
1264 fn run_overwrite_binary_test<Perm, Len, R>(
1265 base: &mut BoxedBitSlice<1, Binary, Perm, Len>,
1266 num_iterations: usize,
1267 rng: &mut R,
1268 ) where
1269 Len: Length,
1270 Perm: PermutationStrategy<1>,
1271 R: Rng,
1272 {
1273 let mut expected: Vec<i64> = vec![0; base.len()];
1274 let mut indices: Vec<usize> = (0..base.len()).collect();
1275 for i in 0..base.len() {
1276 base.set(i, -1).unwrap();
1277 }
1278
1279 for i in 0..base.len() {
1280 assert_eq!(base.get(i).unwrap(), -1, "failed to initialize bit vector");
1281 }
1282
1283 let distribution: [i64; 2] = [-1, 1];
1284
1285 for iter in 0..num_iterations {
1286 indices.shuffle(rng);
1288
1289 for &i in indices.iter() {
1291 let value = distribution.choose(rng).unwrap();
1292 expected[i] = *value;
1293 base.set(i, *value).unwrap();
1294 }
1295
1296 for (i, &expect) in expected.iter().enumerate() {
1298 let value = base.get(i).unwrap();
1299 assert_eq!(
1300 value, expect,
1301 "retrieval failed on iteration {iter} at index {i}"
1302 );
1303 }
1304
1305 let borrowed = base.reborrow();
1307 for (i, &expect) in expected.iter().enumerate() {
1308 let value = borrowed.get(i).unwrap();
1309 assert_eq!(
1310 value, expect,
1311 "reborrow retrieval failed on iteration {iter} at index {i}"
1312 );
1313 }
1314 }
1315 }
1316
1317 fn test_unsigned_dense<const NBITS: usize, Len, R>(
1322 len: Len,
1323 minimum: i64,
1324 maximum: i64,
1325 rng: &mut R,
1326 ) where
1327 Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
1328 Dense: PermutationStrategy<NBITS>,
1329 Len: Length,
1330 R: Rng,
1331 {
1332 test_send_and_sync::<NBITS, Unsigned, Dense>();
1333 test_empty::<NBITS, Unsigned, Dense>();
1334 test_construction_errors::<NBITS, Unsigned, Dense>();
1335 assert_eq!(Unsigned::domain_const::<NBITS>(), Unsigned::domain(),);
1336
1337 match PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, AlwaysFails) {
1338 Ok(_) => {
1339 if len.value() != 0 {
1340 panic!("zero sized allocations don't require an allocator");
1341 }
1342 }
1343 Err(AllocatorError) => {
1344 if len.value() == 0 {
1345 panic!("allocation should have failed");
1346 }
1347 }
1348 }
1349
1350 let mut base =
1351 PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, GlobalAllocator).unwrap();
1352 assert_eq!(
1353 base.len(),
1354 len.value(),
1355 "BoxedBitSlice returned the incorrect length"
1356 );
1357
1358 let expected_bytes = BitSlice::<'static, NBITS, Unsigned>::bytes_for(len.value());
1359 assert_eq!(
1360 base.bytes(),
1361 expected_bytes,
1362 "BoxedBitSlice has the incorrect number of bytes"
1363 );
1364
1365 assert_eq!(base.domain(), minimum..=maximum);
1367
1368 if len.value() == 0 {
1369 return;
1370 }
1371
1372 let ptr = base.as_ptr();
1373
1374 {
1377 let mut borrowed = base.reborrow_mut();
1378
1379 assert_eq!(
1381 borrowed.as_ptr(),
1382 ptr,
1383 "pointer was not preserved during borrowing!"
1384 );
1385 assert_eq!(
1386 borrowed.len(),
1387 len.value(),
1388 "borrowing did not preserve length!"
1389 );
1390
1391 borrowed.set(0, 0).unwrap();
1392 assert_eq!(borrowed.get(0).unwrap(), 0);
1393
1394 borrowed.set(0, 1).unwrap();
1395 assert_eq!(borrowed.get(0).unwrap(), 1);
1396
1397 borrowed.set(0, 0).unwrap();
1398 assert_eq!(borrowed.get(0).unwrap(), 0);
1399
1400 let result = borrowed.set(0, minimum - 1);
1402 assert!(matches!(result, Err(SetError::EncodingError { .. })));
1403
1404 let result = borrowed.set(0, maximum + 1);
1405 assert!(matches!(result, Err(SetError::EncodingError { .. })));
1406
1407 let result = borrowed.set(borrowed.len(), 0);
1409 assert!(matches!(result, Err(SetError::IndexError { .. })));
1410
1411 let result = borrowed.get(borrowed.len());
1413 assert!(matches!(result, Err(GetError::IndexError { .. })));
1414 }
1415
1416 {
1417 let borrowed =
1419 MutBitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
1420
1421 assert_eq!(
1422 borrowed.as_ptr(),
1423 ptr,
1424 "pointer was not preserved during borrowing!"
1425 );
1426 assert_eq!(
1427 borrowed.len(),
1428 len.value(),
1429 "borrowing did not preserve length!"
1430 );
1431 }
1432
1433 {
1434 let borrowed = base.reborrow();
1435
1436 assert_eq!(
1438 borrowed.as_ptr(),
1439 ptr,
1440 "pointer was not preserved during borrowing!"
1441 );
1442
1443 assert_eq!(
1444 borrowed.len(),
1445 len.value(),
1446 "borrowing did not preserve length!"
1447 );
1448
1449 let result = borrowed.get(borrowed.len());
1451 assert!(matches!(result, Err(GetError::IndexError { .. })));
1452 }
1453
1454 {
1455 let borrowed =
1457 BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_slice(), len).unwrap();
1458
1459 assert_eq!(
1460 borrowed.as_ptr(),
1461 ptr,
1462 "pointer was not preserved during borrowing!"
1463 );
1464 assert_eq!(
1465 borrowed.len(),
1466 len.value(),
1467 "borrowing did not preserve length!"
1468 );
1469 }
1470
1471 {
1472 let borrowed =
1474 BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
1475
1476 assert_eq!(
1477 borrowed.as_ptr(),
1478 ptr,
1479 "pointer was not preserved during borrowing!"
1480 );
1481 assert_eq!(
1482 borrowed.len(),
1483 len.value(),
1484 "borrowing did not preserve length!"
1485 );
1486 }
1487
1488 run_overwrite_test(&mut base, FUZZ_ITERATIONS, rng);
1490 }
1491
1492 macro_rules! generate_unsigned_test {
1493 ($name:ident, $NBITS:literal, $MIN:literal, $MAX:literal, $SEED:literal) => {
1494 #[test]
1495 fn $name() {
1496 let mut rng = StdRng::seed_from_u64($SEED);
1497 for dim in 0..MAX_DIM {
1498 test_unsigned_dense::<$NBITS, Dynamic, _>(dim.into(), $MIN, $MAX, &mut rng);
1499 }
1500 }
1501 };
1502 }
1503
1504 generate_unsigned_test!(test_unsigned_8bit, 8, 0, 0xff, 0xc652f2a1018f442b);
1505 generate_unsigned_test!(test_unsigned_7bit, 7, 0, 0x7f, 0xb732e59fec6d6c9c);
1506 generate_unsigned_test!(test_unsigned_6bit, 6, 0, 0x3f, 0x35d9380d0a318f21);
1507 generate_unsigned_test!(test_unsigned_5bit, 5, 0, 0x1f, 0xfb09895183334304);
1508 generate_unsigned_test!(test_unsigned_4bit, 4, 0, 0x0f, 0x38dfcf9e82c33f48);
1509 generate_unsigned_test!(test_unsigned_3bit, 3, 0, 0x07, 0xf9a94c8c749ee26c);
1510 generate_unsigned_test!(test_unsigned_2bit, 2, 0, 0x03, 0xbba03db62cecf4cf);
1511 generate_unsigned_test!(test_unsigned_1bit, 1, 0, 0x01, 0x54ea2a07d7c67f37);
1512
1513 #[test]
1514 fn test_binary_dense() {
1515 let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
1516 for len in 0..MAX_DIM {
1517 test_send_and_sync::<1, Binary, Dense>();
1518 test_empty::<1, Binary, Dense>();
1519 test_construction_errors::<1, Binary, Dense>();
1520
1521 let mut base = BoxedBitSlice::<1, Binary>::new_boxed(len);
1523 assert_eq!(
1524 base.len(),
1525 len,
1526 "BoxedBitSlice returned the incorrect length"
1527 );
1528
1529 assert_eq!(base.bytes(), len.div_ceil(8));
1530
1531 let bytes = BitSlice::<'static, 1, Binary>::bytes_for(len);
1532 assert_eq!(
1533 bytes,
1534 len.div_ceil(8),
1535 "BoxedBitSlice has the incorrect number of bytes"
1536 );
1537
1538 if len == 0 {
1539 continue;
1540 }
1541
1542 let result = base.set(0, 0);
1544 assert!(matches!(result, Err(SetError::EncodingError { .. })));
1545
1546 let result = base.set(base.len(), -1);
1548 assert!(matches!(result, Err(SetError::IndexError { .. })));
1549
1550 let result = base.get(base.len());
1552 assert!(matches!(result, Err(GetError::IndexError { .. })));
1553
1554 run_overwrite_binary_test(&mut base, FUZZ_ITERATIONS, &mut rng);
1556 }
1557 }
1558
1559 #[test]
1560 fn test_4bit_bit_transpose() {
1561 let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
1562 for len in 0..MAX_DIM {
1563 test_send_and_sync::<4, Unsigned, BitTranspose>();
1564 test_empty::<4, Unsigned, BitTranspose>();
1565 test_construction_errors::<4, Unsigned, BitTranspose>();
1566
1567 let mut base = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(len);
1569 assert_eq!(
1570 base.len(),
1571 len,
1572 "BoxedBitSlice returned the incorrect length"
1573 );
1574
1575 assert_eq!(base.bytes(), 32 * len.div_ceil(64));
1576
1577 let bytes = BitSlice::<'static, 4, Unsigned, BitTranspose>::bytes_for(len);
1578 assert_eq!(
1579 bytes,
1580 32 * len.div_ceil(64),
1581 "BoxedBitSlice has the incorrect number of bytes"
1582 );
1583
1584 if len == 0 {
1585 continue;
1586 }
1587
1588 let result = base.set(0, -1);
1590 assert!(matches!(result, Err(SetError::EncodingError { .. })));
1591
1592 let result = base.set(base.len(), -1);
1594 assert!(matches!(result, Err(SetError::IndexError { .. })));
1595
1596 let result = base.get(base.len());
1598 assert!(matches!(result, Err(GetError::IndexError { .. })));
1599
1600 run_overwrite_test(&mut base, FUZZ_ITERATIONS, &mut rng);
1602 }
1603 }
1604}