1use std::{marker::PhantomData, ops::RangeInclusive, ptr::NonNull};
7
8use diskann_utils::{Reborrow, ReborrowMut};
9use thiserror::Error;
10
11use super::{
12 length::{Dynamic, Length},
13 packing,
14 ptr::{AsMutPtr, AsPtr, MutSlicePtr, Precursor, SlicePtr},
15};
16use crate::{
17 alloc::{AllocatorCore, AllocatorError, GlobalAllocator, Poly},
18 utils,
19};
20
21pub trait Representation<const NBITS: usize> {
27 type Domain: Iterator<Item = i64>;
29
30 fn encode(value: i64) -> Result<u8, EncodingError>;
33
34 fn encode_unchecked(value: i64) -> u8;
41
42 fn decode(raw: u8) -> i64;
49
50 fn check(value: i64) -> bool;
52
53 fn domain() -> Self::Domain;
55}
56
57#[derive(Debug, Error, Clone, Copy)]
58#[error("value {} is not in the encodable range of {}", got, domain)]
59pub struct EncodingError {
60 got: i64,
61 domain: &'static &'static str,
67}
68
69impl EncodingError {
70 fn new(got: i64, domain: &'static &'static str) -> Self {
71 Self { got, domain }
72 }
73}
74
75#[derive(Debug, Clone, Copy)]
84pub struct Unsigned;
85
86impl Unsigned {
87 pub const fn domain_const<const NBITS: usize>() -> std::ops::RangeInclusive<i64> {
89 0..=2i64.pow(NBITS as u32) - 1
90 }
91
92 #[allow(clippy::panic)]
93 const fn domain_str(nbits: usize) -> &'static &'static str {
94 match nbits {
95 8 => &"[0, 255]",
96 7 => &"[0, 127]",
97 6 => &"[0, 63]",
98 5 => &"[0, 31]",
99 4 => &"[0, 15]",
100 3 => &"[0, 7]",
101 2 => &"[0, 3]",
102 1 => &"[0, 1]",
103 _ => panic!("unimplemented"),
104 }
105 }
106}
107
108macro_rules! repr_unsigned {
109 ($N:literal) => {
110 impl Representation<$N> for Unsigned {
111 type Domain = RangeInclusive<i64>;
112
113 fn encode(value: i64) -> Result<u8, EncodingError> {
114 if !<Self as Representation<$N>>::check(value) {
115 let domain = Self::domain_str($N);
118 Err(EncodingError::new(value, domain))
119 } else {
120 Ok(<Self as Representation<$N>>::encode_unchecked(value))
121 }
122 }
123
124 fn encode_unchecked(value: i64) -> u8 {
125 debug_assert!(<Self as Representation<$N>>::check(value));
126 value as u8
127 }
128
129 fn decode(raw: u8) -> i64 {
130 let raw: i64 = raw.into();
132 debug_assert!(<Self as Representation<$N>>::check(raw));
133 raw
134 }
135
136 fn check(value: i64) -> bool {
137 <Self as Representation<$N>>::domain().contains(&value)
138 }
139
140 fn domain() -> Self::Domain {
141 Self::domain_const::<$N>()
142 }
143 }
144 };
145 ($N:literal, $($Ns:literal),+) => {
146 repr_unsigned!($N);
147 $(repr_unsigned!($Ns);)+
148 };
149}
150
151repr_unsigned!(1, 2, 3, 4, 5, 6, 7, 8);
152
153#[derive(Debug, Clone, Copy)]
159pub struct Binary;
160
161impl Representation<1> for Binary {
162 type Domain = std::array::IntoIter<i64, 2>;
163
164 fn encode(value: i64) -> Result<u8, EncodingError> {
165 if !Self::check(value) {
166 const DOMAIN: &str = "{-1, 1}";
167 Err(EncodingError::new(value, &DOMAIN))
168 } else {
169 Ok(Self::encode_unchecked(value))
170 }
171 }
172
173 fn encode_unchecked(value: i64) -> u8 {
174 debug_assert!(Self::check(value));
175 value.clamp(0, 1) as u8
177 }
178
179 fn decode(raw: u8) -> i64 {
180 let raw: i64 = raw.into();
183 (raw << 1) - 1
184 }
185
186 fn check(value: i64) -> bool {
187 value == -1 || value == 1
188 }
189
190 fn domain() -> Self::Domain {
194 [-1, 1].into_iter()
195 }
196}
197
198pub unsafe trait PermutationStrategy<const NBITS: usize> {
218 fn bytes(count: usize) -> usize;
220
221 unsafe fn pack(s: &mut [u8], i: usize, value: u8);
230
231 unsafe fn unpack(s: &[u8], i: usize) -> u8;
241}
242
243#[derive(Debug, Clone, Copy)]
247pub struct Dense;
248
249impl Dense {
250 fn bytes<const NBITS: usize>(count: usize) -> usize {
251 utils::div_round_up(NBITS * count, 8)
252 }
253}
254
255unsafe impl<const NBITS: usize> PermutationStrategy<NBITS> for Dense {
257 fn bytes(count: usize) -> usize {
258 Self::bytes::<NBITS>(count)
259 }
260
261 unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
262 let bitaddress = NBITS * i;
263
264 let bytestart = bitaddress / 8;
265 let bytestop = (bitaddress + NBITS - 1) / 8;
266 let bitstart = bitaddress - 8 * bytestart;
267 debug_assert!(bytestop < data.len());
268
269 if bytestart == bytestop {
270 let raw = unsafe { data.as_ptr().add(bytestart).read() };
279 let packed = packing::pack_u8::<NBITS>(raw, encoded, bitstart);
280
281 unsafe { data.as_mut_ptr().add(bytestart).write(packed) };
285 } else {
286 let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
294 let packed = packing::pack_u16::<NBITS>(raw, encoded, bitstart);
295
296 unsafe {
300 data.as_mut_ptr()
301 .add(bytestart)
302 .cast::<u16>()
303 .write_unaligned(packed)
304 };
305 }
306 }
307
308 unsafe fn unpack(data: &[u8], i: usize) -> u8 {
309 let bitaddress = NBITS * i;
310
311 let bytestart = bitaddress / 8;
312 let bytestop = (bitaddress + NBITS - 1) / 8;
313 debug_assert!(bytestop < data.len());
314 if bytestart == bytestop {
315 let raw = unsafe { data.as_ptr().add(bytestart).read() };
317 packing::unpack_u8::<NBITS>(raw, bitaddress - 8 * bytestart)
318 } else {
319 let raw = unsafe { data.as_ptr().add(bytestart).cast::<u16>().read_unaligned() };
321 packing::unpack_u16::<NBITS>(raw, bitaddress - 8 * bytestart)
322 }
323 }
324}
325
326#[derive(Debug, Clone, Copy)]
340pub struct BitTranspose;
341
342unsafe impl PermutationStrategy<4> for BitTranspose {
345 fn bytes(count: usize) -> usize {
346 32 * utils::div_round_up(count, 64)
347 }
348
349 unsafe fn pack(data: &mut [u8], i: usize, encoded: u8) {
350 let block_start = 32 * (i / 64);
352 let byte_start = block_start + (i % 64) / 8;
354 let bit = i % 8;
356
357 let mask: u8 = 0x1 << bit;
358 for p in 0..4 {
359 let mut v = data[byte_start + 8 * p];
360 v = (v & !mask) | (((encoded >> p) & 0x1) << bit);
361 data[byte_start + 8 * p] = v;
362 }
363 }
364
365 unsafe fn unpack(data: &[u8], i: usize) -> u8 {
366 let block_start = 32 * (i / 64);
368 let byte_start = block_start + (i % 64) / 8;
370 let bit = i % 8;
372
373 let mut output: u8 = 0;
374 for p in 0..4 {
375 let v = data[byte_start + 8 * p];
376 output |= ((v >> bit) & 0x1) << p
377 }
378 output
379 }
380}
381
382#[derive(Debug, Error, Clone, Copy)]
387#[error("input span has length {got} bytes but expected {expected}")]
388pub struct ConstructionError {
389 got: usize,
390 expected: usize,
391}
392
393#[derive(Debug, Error, Clone, Copy)]
394#[error("index {index} exceeds the maximum length of {len}")]
395pub struct IndexOutOfBounds {
396 index: usize,
397 len: usize,
398}
399
400impl IndexOutOfBounds {
401 fn new(index: usize, len: usize) -> Self {
402 Self { index, len }
403 }
404}
405
406#[derive(Debug, Error, Clone, Copy)]
407#[error("error setting index in bitslice")]
408#[non_exhaustive]
409pub enum SetError {
410 IndexError(#[from] IndexOutOfBounds),
411 EncodingError(#[from] EncodingError),
412}
413
414#[derive(Debug, Error, Clone, Copy)]
415#[error("error getting index in bitslice")]
416pub enum GetError {
417 IndexError(#[from] IndexOutOfBounds),
418}
419
420#[derive(Debug, Clone, Copy)]
518pub struct BitSliceBase<const NBITS: usize, Repr, Ptr, Perm = Dense, Len = Dynamic>
519where
520 Repr: Representation<NBITS>,
521 Ptr: AsPtr<Type = u8>,
522 Perm: PermutationStrategy<NBITS>,
523 Len: Length,
524{
525 ptr: Ptr,
526 len: Len,
527 repr: PhantomData<Repr>,
528 packing: PhantomData<Perm>,
529}
530
531impl<const NBITS: usize, Repr, Ptr, Perm, Len> BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
532where
533 Repr: Representation<NBITS>,
534 Ptr: AsPtr<Type = u8>,
535 Perm: PermutationStrategy<NBITS>,
536 Len: Length,
537{
538 const _CHECK: () = assert!(NBITS > 0 && NBITS <= 8);
540
541 pub fn bytes_for(count: usize) -> usize {
543 Perm::bytes(count)
544 }
545
546 unsafe fn new_unchecked_internal(ptr: Ptr, len: Len) -> Self {
553 Self {
554 ptr,
555 len,
556 repr: PhantomData,
557 packing: PhantomData,
558 }
559 }
560
561 pub unsafe fn new_unchecked<Pre, Count>(precursor: Pre, count: Count) -> Self
571 where
572 Count: Into<Len>,
573 Pre: Precursor<Ptr>,
574 {
575 let count: Len = count.into();
576 debug_assert_eq!(precursor.precursor_len(), Self::bytes_for(count.value()));
577 Self::new_unchecked_internal(precursor.precursor_into(), count)
578 }
579
580 pub fn new<Pre, Count>(precursor: Pre, count: Count) -> Result<Self, ConstructionError>
590 where
591 Count: Into<Len>,
592 Pre: Precursor<Ptr>,
593 {
594 let count: Len = count.into();
596
597 if precursor.precursor_len() != Self::bytes_for(count.value()) {
599 Err(ConstructionError {
600 got: precursor.precursor_len(),
601 expected: Self::bytes_for(count.value()),
602 })
603 } else {
604 Ok(unsafe { Self::new_unchecked(precursor, count) })
609 }
610 }
611
612 pub fn len(&self) -> usize {
614 self.len.value()
615 }
616
617 pub fn is_empty(&self) -> bool {
619 self.len() == 0
620 }
621
622 pub fn bytes(&self) -> usize {
624 Self::bytes_for(self.len())
625 }
626
627 pub fn get(&self, i: usize) -> Result<i64, GetError> {
629 if i >= self.len() {
630 Err(IndexOutOfBounds::new(i, self.len()).into())
631 } else {
632 Ok(unsafe { self.get_unchecked(i) })
634 }
635 }
636
637 pub unsafe fn get_unchecked(&self, i: usize) -> i64 {
643 debug_assert!(i < self.len());
644 debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
645
646 Repr::decode(unsafe { Perm::unpack(self.as_slice(), i) })
651 }
652
653 pub fn set(&mut self, i: usize, value: i64) -> Result<(), SetError>
655 where
656 Ptr: AsMutPtr<Type = u8>,
657 {
658 if i >= self.len() {
659 return Err(IndexOutOfBounds::new(i, self.len()).into());
660 }
661
662 let encoded = Repr::encode(value)?;
663
664 unsafe { self.set_unchecked(i, encoded) }
666 Ok(())
667 }
668
669 pub unsafe fn set_unchecked(&mut self, i: usize, encoded: u8)
675 where
676 Ptr: AsMutPtr<Type = u8>,
677 {
678 debug_assert!(i < self.len());
679 debug_assert_eq!(self.as_slice().len(), Perm::bytes(self.len()));
680
681 unsafe { Perm::pack(self.as_mut_slice(), i, encoded) }
686 }
687
688 pub fn domain(&self) -> Repr::Domain {
690 Repr::domain()
691 }
692
693 pub(crate) fn as_slice(&self) -> &'_ [u8] {
694 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.bytes()) }
698 }
699
700 pub fn as_ptr(&self) -> *const u8 {
707 self.ptr.as_ptr()
708 }
709
710 pub(super) fn as_mut_slice(&mut self) -> &'_ mut [u8]
712 where
713 Ptr: AsMutPtr,
714 {
715 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut_ptr(), self.bytes()) }
722 }
723
724 fn as_mut_ptr(&mut self) -> *mut u8
726 where
727 Ptr: AsMutPtr,
728 {
729 self.ptr.as_mut_ptr()
730 }
731}
732
733impl<const NBITS: usize, Repr, Perm, Len>
734 BitSliceBase<NBITS, Repr, Poly<[u8], GlobalAllocator>, Perm, Len>
735where
736 Repr: Representation<NBITS>,
737 Perm: PermutationStrategy<NBITS>,
738 Len: Length,
739{
740 pub fn new_boxed<Count>(count: Count) -> Self
759 where
760 Count: Into<Len>,
761 {
762 let count: Len = count.into();
763 let bytes = Self::bytes_for(count.value());
764 let storage: Box<[u8]> = (0..bytes).map(|_| 0).collect();
765
766 unsafe { Self::new_unchecked(Poly::from(storage), count) }
771 }
772}
773
774impl<const NBITS: usize, Repr, Perm, Len, A> BitSliceBase<NBITS, Repr, Poly<[u8], A>, Perm, Len>
775where
776 Repr: Representation<NBITS>,
777 Perm: PermutationStrategy<NBITS>,
778 Len: Length,
779 A: AllocatorCore,
780{
781 pub fn new_in<Count>(count: Count, allocator: A) -> Result<Self, AllocatorError>
805 where
806 Count: Into<Len>,
807 {
808 let count: Len = count.into();
809 let bytes = Self::bytes_for(count.value());
810 let storage = Poly::broadcast(0, bytes, allocator)?;
811
812 Ok(unsafe { Self::new_unchecked(storage, count) })
817 }
818
819 pub fn into_inner(self) -> Poly<[u8], A> {
821 self.ptr
822 }
823}
824
825pub type BitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
827 BitSliceBase<N, Repr, SlicePtr<'a, u8>, Perm, Len>;
828
829pub type MutBitSlice<'a, const N: usize, Repr, Perm = Dense, Len = Dynamic> =
831 BitSliceBase<N, Repr, MutSlicePtr<'a, u8>, Perm, Len>;
832
833pub type PolyBitSlice<const N: usize, Repr, A, Perm = Dense, Len = Dynamic> =
835 BitSliceBase<N, Repr, Poly<[u8], A>, Perm, Len>;
836
837pub type BoxedBitSlice<const N: usize, Repr, Perm = Dense, Len = Dynamic> =
839 PolyBitSlice<N, Repr, GlobalAllocator, Perm, Len>;
840
841impl<'a, Ptr> From<&'a BitSliceBase<8, Unsigned, Ptr>> for &'a [u8]
846where
847 Ptr: AsPtr<Type = u8>,
848{
849 fn from(slice: &'a BitSliceBase<8, Unsigned, Ptr>) -> Self {
850 unsafe { std::slice::from_raw_parts(slice.as_ptr(), slice.len()) }
856 }
857}
858
859impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> Reborrow<'this>
860 for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
861where
862 Repr: Representation<NBITS>,
863 Ptr: AsPtr<Type = u8>,
864 Perm: PermutationStrategy<NBITS>,
865 Len: Length,
866{
867 type Target = BitSlice<'this, NBITS, Repr, Perm, Len>;
868
869 fn reborrow(&'this self) -> Self::Target {
870 let ptr: *const u8 = self.as_ptr();
871 debug_assert!(!ptr.is_null());
872
873 let nonnull = unsafe { NonNull::new_unchecked(ptr.cast_mut()) };
877
878 let ptr = unsafe { SlicePtr::new_unchecked(nonnull) };
885
886 Self::Target {
887 ptr,
888 len: self.len,
889 repr: PhantomData,
890 packing: PhantomData,
891 }
892 }
893}
894
895impl<'this, const NBITS: usize, Repr, Ptr, Perm, Len> ReborrowMut<'this>
896 for BitSliceBase<NBITS, Repr, Ptr, Perm, Len>
897where
898 Repr: Representation<NBITS>,
899 Ptr: AsMutPtr<Type = u8>,
900 Perm: PermutationStrategy<NBITS>,
901 Len: Length,
902{
903 type Target = MutBitSlice<'this, NBITS, Repr, Perm, Len>;
904
905 fn reborrow_mut(&'this mut self) -> Self::Target {
906 let ptr: *mut u8 = self.as_mut_ptr();
907 debug_assert!(!ptr.is_null());
908
909 let nonnull = unsafe { NonNull::new_unchecked(ptr) };
911
912 let ptr = unsafe { MutSlicePtr::new_unchecked(nonnull) };
922
923 Self::Target {
924 ptr,
925 len: self.len,
926 repr: PhantomData,
927 packing: PhantomData,
928 }
929 }
930}
931
932#[cfg(test)]
937mod tests {
938 use rand::{
939 distr::{Distribution, Uniform},
940 rngs::StdRng,
941 seq::{IndexedRandom, SliceRandom},
942 Rng, SeedableRng,
943 };
944
945 use super::*;
946 use crate::{bits::Static, test_util::AlwaysFails};
947
948 const BOUNDS: &str = "special bounds";
953
954 #[test]
955 fn test_encoding_error() {
956 assert_eq!(std::mem::size_of::<EncodingError>(), 16);
957 assert_eq!(
958 std::mem::size_of::<Option<EncodingError>>(),
959 16,
960 "expected EncodingError to have the niche optimization"
961 );
962 let err = EncodingError::new(7, &BOUNDS);
963 assert_eq!(
964 err.to_string(),
965 "value 7 is not in the encodable range of special bounds"
966 );
967 }
968
969 fn assert_send_and_sync<T: Send + Sync>(_x: &T) {}
971
972 #[test]
977 fn test_binary_repr() {
978 assert_eq!(Binary::encode(-1).unwrap(), 0);
979 assert_eq!(Binary::encode(1).unwrap(), 1);
980 assert_eq!(Binary::decode(0), -1);
981 assert_eq!(Binary::decode(1), 1);
982
983 assert!(Binary::check(-1));
984 assert!(Binary::check(1));
985 assert!(!Binary::check(0));
986 assert!(!Binary::check(-2));
987 assert!(!Binary::check(2));
988
989 let domain: Vec<_> = Binary::domain().collect();
990 assert_eq!(domain, &[-1, 1]);
991 }
992
993 #[test]
998 fn test_sizes() {
999 assert_eq!(std::mem::size_of::<BitSlice<'static, 8, Unsigned>>(), 16);
1000 assert_eq!(std::mem::size_of::<MutBitSlice<'static, 8, Unsigned>>(), 16);
1001
1002 assert_eq!(
1004 std::mem::size_of::<Option<BitSlice<'static, 8, Unsigned>>>(),
1005 16
1006 );
1007 assert_eq!(
1008 std::mem::size_of::<Option<MutBitSlice<'static, 8, Unsigned>>>(),
1009 16
1010 );
1011
1012 assert_eq!(
1013 std::mem::size_of::<BitSlice<'static, 8, Unsigned, Dense, Static<128>>>(),
1014 8
1015 );
1016 }
1017
1018 cfg_if::cfg_if! {
1023 if #[cfg(miri)] {
1024 const MAX_DIM: usize = 160;
1025 const FUZZ_ITERATIONS: usize = 1;
1026 } else if #[cfg(debug_assertions)] {
1027 const MAX_DIM: usize = 128;
1028 const FUZZ_ITERATIONS: usize = 10;
1029 } else {
1030 const MAX_DIM: usize = 256;
1031 const FUZZ_ITERATIONS: usize = 100;
1032 }
1033 }
1034
1035 fn test_send_and_sync<const NBITS: usize, Repr, Perm>()
1036 where
1037 Repr: Representation<NBITS> + Send + Sync,
1038 Perm: PermutationStrategy<NBITS> + Send + Sync,
1039 {
1040 let mut x = BoxedBitSlice::<NBITS, Repr, Perm>::new_boxed(1);
1041 assert_send_and_sync(&x);
1042 assert_send_and_sync(&x.reborrow());
1043 assert_send_and_sync(&x.reborrow_mut());
1044 }
1045
1046 fn test_empty<const NBITS: usize, Repr, Perm>()
1047 where
1048 Repr: Representation<NBITS>,
1049 Perm: PermutationStrategy<NBITS>,
1050 {
1051 let base: &mut [u8] = &mut [];
1052 let mut slice = MutBitSlice::<NBITS, Repr, Perm>::new(base, 0).unwrap();
1053 assert_eq!(slice.len(), 0);
1054 assert!(slice.is_empty());
1055
1056 {
1057 let reborrow = slice.reborrow();
1058 assert_eq!(reborrow.len(), 0);
1059 assert!(reborrow.is_empty());
1060 }
1061
1062 {
1063 let reborrow = slice.reborrow_mut();
1064 assert_eq!(reborrow.len(), 0);
1065 assert!(reborrow.is_empty());
1066 }
1067 }
1068
1069 fn test_construction_errors<const NBITS: usize, Repr, Perm>()
1071 where
1072 Repr: Representation<NBITS>,
1073 Perm: PermutationStrategy<NBITS>,
1074 {
1075 let len: usize = 10;
1076 let bytes = Perm::bytes(len);
1077
1078 let box_big = Poly::broadcast(0u8, bytes + 1, GlobalAllocator).unwrap();
1080 let box_small = Poly::broadcast(0u8, bytes - 1, GlobalAllocator).unwrap();
1081 let box_right = Poly::broadcast(0u8, bytes, GlobalAllocator).unwrap();
1082
1083 let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_big, len);
1084 match result {
1085 Err(ConstructionError { got, expected }) => {
1086 assert_eq!(got, bytes + 1);
1087 assert_eq!(expected, bytes);
1088 }
1089 _ => panic!("shouldn't have reached here!"),
1090 };
1091
1092 let result = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_small, len);
1093 match result {
1094 Err(ConstructionError { got, expected }) => {
1095 assert_eq!(got, bytes - 1);
1096 assert_eq!(expected, bytes);
1097 }
1098 _ => panic!("shouldn't have reached here!"),
1099 };
1100
1101 let mut base = BoxedBitSlice::<NBITS, Repr, Perm>::new(box_right, len).unwrap();
1102 let ptr = base.as_ptr();
1103 assert_eq!(base.len(), len);
1104
1105 {
1107 let borrowed = base.reborrow_mut();
1109 assert_eq!(borrowed.as_ptr(), ptr);
1110 assert_eq!(borrowed.len(), len);
1111
1112 let borrowed = MutBitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
1114 assert_eq!(borrowed.as_ptr(), ptr);
1115 assert_eq!(borrowed.len(), len);
1116 }
1117
1118 {
1120 let mut oversized = vec![0; bytes + 1];
1122 let result = MutBitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
1123 match result {
1124 Err(ConstructionError { got, expected }) => {
1125 assert_eq!(got, bytes + 1);
1126 assert_eq!(expected, bytes);
1127 }
1128 _ => panic!("shouldn't have reached here!"),
1129 };
1130
1131 let mut undersized = vec![0; bytes - 1];
1132 let result = MutBitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
1133 match result {
1134 Err(ConstructionError { got, expected }) => {
1135 assert_eq!(got, bytes - 1);
1136 assert_eq!(expected, bytes);
1137 }
1138 _ => panic!("shouldn't have reached here!"),
1139 };
1140 }
1141
1142 {
1144 let borrowed = base.reborrow();
1146 assert_eq!(borrowed.as_ptr(), ptr);
1147 assert_eq!(borrowed.len(), len);
1148
1149 let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_slice(), len).unwrap();
1151 assert_eq!(borrowed.as_ptr(), ptr);
1152 assert_eq!(borrowed.len(), len);
1153
1154 let borrowed = BitSlice::<NBITS, Repr, Perm>::new(base.as_mut_slice(), len).unwrap();
1156 assert_eq!(borrowed.as_ptr(), ptr);
1157 assert_eq!(borrowed.len(), len);
1158 }
1159
1160 {
1162 let mut oversized = vec![0; bytes + 1];
1164 let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_mut_slice(), len);
1165 match result {
1166 Err(ConstructionError { got, expected }) => {
1167 assert_eq!(got, bytes + 1);
1168 assert_eq!(expected, bytes);
1169 }
1170 _ => panic!("shouldn't have reached here!"),
1171 };
1172
1173 let result = BitSlice::<NBITS, Repr, Perm>::new(oversized.as_slice(), len);
1174 match result {
1175 Err(ConstructionError { got, expected }) => {
1176 assert_eq!(got, bytes + 1);
1177 assert_eq!(expected, bytes);
1178 }
1179 _ => panic!("shouldn't have reached here!"),
1180 };
1181
1182 let mut undersized = vec![0; bytes - 1];
1184 let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_mut_slice(), len);
1185 match result {
1186 Err(ConstructionError { got, expected }) => {
1187 assert_eq!(got, bytes - 1);
1188 assert_eq!(expected, bytes);
1189 }
1190 _ => panic!("shouldn't have reached here!"),
1191 };
1192
1193 let result = BitSlice::<NBITS, Repr, Perm>::new(undersized.as_slice(), len);
1194 match result {
1195 Err(ConstructionError { got, expected }) => {
1196 assert_eq!(got, bytes - 1);
1197 assert_eq!(expected, bytes);
1198 }
1199 _ => panic!("shouldn't have reached here!"),
1200 };
1201 }
1202 }
1203
1204 fn run_overwrite_test<const NBITS: usize, Perm, Len, R>(
1207 base: &mut BoxedBitSlice<NBITS, Unsigned, Perm, Len>,
1208 num_iterations: usize,
1209 rng: &mut R,
1210 ) where
1211 Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
1212 Len: Length,
1213 Perm: PermutationStrategy<NBITS>,
1214 R: Rng,
1215 {
1216 let mut expected: Vec<i64> = vec![0; base.len()];
1217 let mut indices: Vec<usize> = (0..base.len()).collect();
1218 for i in 0..base.len() {
1219 base.set(i, 0).unwrap();
1220 }
1221
1222 for i in 0..base.len() {
1223 assert_eq!(base.get(i).unwrap(), 0, "failed to initialize bit vector");
1224 }
1225
1226 let domain = base.domain();
1227 assert_eq!(domain, 0..=2i64.pow(NBITS as u32) - 1);
1228 let distribution = Uniform::new_inclusive(*domain.start(), *domain.end()).unwrap();
1229
1230 for iter in 0..num_iterations {
1231 indices.shuffle(rng);
1233
1234 for &i in indices.iter() {
1236 let value = distribution.sample(rng);
1237 expected[i] = value;
1238 base.set(i, value).unwrap();
1239 }
1240
1241 for (i, &expect) in expected.iter().enumerate() {
1243 let value = base.get(i).unwrap();
1244 assert_eq!(
1245 value, expect,
1246 "retrieval failed on iteration {iter} at index {i}"
1247 );
1248 }
1249
1250 let borrowed = base.reborrow();
1252 for (i, &expect) in expected.iter().enumerate() {
1253 let value = borrowed.get(i).unwrap();
1254 assert_eq!(
1255 value, expect,
1256 "reborrow retrieval failed on iteration {iter} at index {i}"
1257 );
1258 }
1259 }
1260 }
1261
1262 fn run_overwrite_binary_test<Perm, Len, R>(
1263 base: &mut BoxedBitSlice<1, Binary, Perm, Len>,
1264 num_iterations: usize,
1265 rng: &mut R,
1266 ) where
1267 Len: Length,
1268 Perm: PermutationStrategy<1>,
1269 R: Rng,
1270 {
1271 let mut expected: Vec<i64> = vec![0; base.len()];
1272 let mut indices: Vec<usize> = (0..base.len()).collect();
1273 for i in 0..base.len() {
1274 base.set(i, -1).unwrap();
1275 }
1276
1277 for i in 0..base.len() {
1278 assert_eq!(base.get(i).unwrap(), -1, "failed to initialize bit vector");
1279 }
1280
1281 let distribution: [i64; 2] = [-1, 1];
1282
1283 for iter in 0..num_iterations {
1284 indices.shuffle(rng);
1286
1287 for &i in indices.iter() {
1289 let value = distribution.choose(rng).unwrap();
1290 expected[i] = *value;
1291 base.set(i, *value).unwrap();
1292 }
1293
1294 for (i, &expect) in expected.iter().enumerate() {
1296 let value = base.get(i).unwrap();
1297 assert_eq!(
1298 value, expect,
1299 "retrieval failed on iteration {iter} at index {i}"
1300 );
1301 }
1302
1303 let borrowed = base.reborrow();
1305 for (i, &expect) in expected.iter().enumerate() {
1306 let value = borrowed.get(i).unwrap();
1307 assert_eq!(
1308 value, expect,
1309 "reborrow retrieval failed on iteration {iter} at index {i}"
1310 );
1311 }
1312 }
1313 }
1314
1315 fn test_unsigned_dense<const NBITS: usize, Len, R>(
1320 len: Len,
1321 minimum: i64,
1322 maximum: i64,
1323 rng: &mut R,
1324 ) where
1325 Unsigned: Representation<NBITS, Domain = RangeInclusive<i64>>,
1326 Dense: PermutationStrategy<NBITS>,
1327 Len: Length,
1328 R: Rng,
1329 {
1330 test_send_and_sync::<NBITS, Unsigned, Dense>();
1331 test_empty::<NBITS, Unsigned, Dense>();
1332 test_construction_errors::<NBITS, Unsigned, Dense>();
1333 assert_eq!(Unsigned::domain_const::<NBITS>(), Unsigned::domain(),);
1334
1335 match PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, AlwaysFails) {
1336 Ok(_) => {
1337 if len.value() != 0 {
1338 panic!("zero sized allocations don't require an allocator");
1339 }
1340 }
1341 Err(AllocatorError) => {
1342 if len.value() == 0 {
1343 panic!("allocation should have failed");
1344 }
1345 }
1346 }
1347
1348 let mut base =
1349 PolyBitSlice::<NBITS, Unsigned, _, Dense, Len>::new_in(len, GlobalAllocator).unwrap();
1350 assert_eq!(
1351 base.len(),
1352 len.value(),
1353 "BoxedBitSlice returned the incorrect length"
1354 );
1355
1356 let expected_bytes = BitSlice::<'static, NBITS, Unsigned>::bytes_for(len.value());
1357 assert_eq!(
1358 base.bytes(),
1359 expected_bytes,
1360 "BoxedBitSlice has the incorrect number of bytes"
1361 );
1362
1363 assert_eq!(base.domain(), minimum..=maximum);
1365
1366 if len.value() == 0 {
1367 return;
1368 }
1369
1370 let ptr = base.as_ptr();
1371
1372 {
1375 let mut borrowed = base.reborrow_mut();
1376
1377 assert_eq!(
1379 borrowed.as_ptr(),
1380 ptr,
1381 "pointer was not preserved during borrowing!"
1382 );
1383 assert_eq!(
1384 borrowed.len(),
1385 len.value(),
1386 "borrowing did not preserve length!"
1387 );
1388
1389 borrowed.set(0, 0).unwrap();
1390 assert_eq!(borrowed.get(0).unwrap(), 0);
1391
1392 borrowed.set(0, 1).unwrap();
1393 assert_eq!(borrowed.get(0).unwrap(), 1);
1394
1395 borrowed.set(0, 0).unwrap();
1396 assert_eq!(borrowed.get(0).unwrap(), 0);
1397
1398 let result = borrowed.set(0, minimum - 1);
1400 assert!(matches!(result, Err(SetError::EncodingError { .. })));
1401
1402 let result = borrowed.set(0, maximum + 1);
1403 assert!(matches!(result, Err(SetError::EncodingError { .. })));
1404
1405 let result = borrowed.set(borrowed.len(), 0);
1407 assert!(matches!(result, Err(SetError::IndexError { .. })));
1408
1409 let result = borrowed.get(borrowed.len());
1411 assert!(matches!(result, Err(GetError::IndexError { .. })));
1412 }
1413
1414 {
1415 let borrowed =
1417 MutBitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
1418
1419 assert_eq!(
1420 borrowed.as_ptr(),
1421 ptr,
1422 "pointer was not preserved during borrowing!"
1423 );
1424 assert_eq!(
1425 borrowed.len(),
1426 len.value(),
1427 "borrowing did not preserve length!"
1428 );
1429 }
1430
1431 {
1432 let borrowed = base.reborrow();
1433
1434 assert_eq!(
1436 borrowed.as_ptr(),
1437 ptr,
1438 "pointer was not preserved during borrowing!"
1439 );
1440
1441 assert_eq!(
1442 borrowed.len(),
1443 len.value(),
1444 "borrowing did not preserve length!"
1445 );
1446
1447 let result = borrowed.get(borrowed.len());
1449 assert!(matches!(result, Err(GetError::IndexError { .. })));
1450 }
1451
1452 {
1453 let borrowed =
1455 BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_slice(), len).unwrap();
1456
1457 assert_eq!(
1458 borrowed.as_ptr(),
1459 ptr,
1460 "pointer was not preserved during borrowing!"
1461 );
1462 assert_eq!(
1463 borrowed.len(),
1464 len.value(),
1465 "borrowing did not preserve length!"
1466 );
1467 }
1468
1469 {
1470 let borrowed =
1472 BitSlice::<NBITS, Unsigned, Dense, Len>::new(base.as_mut_slice(), len).unwrap();
1473
1474 assert_eq!(
1475 borrowed.as_ptr(),
1476 ptr,
1477 "pointer was not preserved during borrowing!"
1478 );
1479 assert_eq!(
1480 borrowed.len(),
1481 len.value(),
1482 "borrowing did not preserve length!"
1483 );
1484 }
1485
1486 run_overwrite_test(&mut base, FUZZ_ITERATIONS, rng);
1488 }
1489
1490 macro_rules! generate_unsigned_test {
1491 ($name:ident, $NBITS:literal, $MIN:literal, $MAX:literal, $SEED:literal) => {
1492 #[test]
1493 fn $name() {
1494 let mut rng = StdRng::seed_from_u64($SEED);
1495 for dim in 0..MAX_DIM {
1496 test_unsigned_dense::<$NBITS, Dynamic, _>(dim.into(), $MIN, $MAX, &mut rng);
1497 }
1498 }
1499 };
1500 }
1501
1502 generate_unsigned_test!(test_unsigned_8bit, 8, 0, 0xff, 0xc652f2a1018f442b);
1503 generate_unsigned_test!(test_unsigned_7bit, 7, 0, 0x7f, 0xb732e59fec6d6c9c);
1504 generate_unsigned_test!(test_unsigned_6bit, 6, 0, 0x3f, 0x35d9380d0a318f21);
1505 generate_unsigned_test!(test_unsigned_5bit, 5, 0, 0x1f, 0xfb09895183334304);
1506 generate_unsigned_test!(test_unsigned_4bit, 4, 0, 0x0f, 0x38dfcf9e82c33f48);
1507 generate_unsigned_test!(test_unsigned_3bit, 3, 0, 0x07, 0xf9a94c8c749ee26c);
1508 generate_unsigned_test!(test_unsigned_2bit, 2, 0, 0x03, 0xbba03db62cecf4cf);
1509 generate_unsigned_test!(test_unsigned_1bit, 1, 0, 0x01, 0x54ea2a07d7c67f37);
1510
1511 #[test]
1512 fn test_binary_dense() {
1513 let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
1514 for len in 0..MAX_DIM {
1515 test_send_and_sync::<1, Binary, Dense>();
1516 test_empty::<1, Binary, Dense>();
1517 test_construction_errors::<1, Binary, Dense>();
1518
1519 let mut base = BoxedBitSlice::<1, Binary>::new_boxed(len);
1521 assert_eq!(
1522 base.len(),
1523 len,
1524 "BoxedBitSlice returned the incorrect length"
1525 );
1526
1527 assert_eq!(base.bytes(), len.div_ceil(8));
1528
1529 let bytes = BitSlice::<'static, 1, Binary>::bytes_for(len);
1530 assert_eq!(
1531 bytes,
1532 len.div_ceil(8),
1533 "BoxedBitSlice has the incorrect number of bytes"
1534 );
1535
1536 if len == 0 {
1537 continue;
1538 }
1539
1540 let result = base.set(0, 0);
1542 assert!(matches!(result, Err(SetError::EncodingError { .. })));
1543
1544 let result = base.set(base.len(), -1);
1546 assert!(matches!(result, Err(SetError::IndexError { .. })));
1547
1548 let result = base.get(base.len());
1550 assert!(matches!(result, Err(GetError::IndexError { .. })));
1551
1552 run_overwrite_binary_test(&mut base, FUZZ_ITERATIONS, &mut rng);
1554 }
1555 }
1556
1557 #[test]
1558 fn test_4bit_bit_transpose() {
1559 let mut rng = StdRng::seed_from_u64(0xb3c95e8e19d3842e);
1560 for len in 0..MAX_DIM {
1561 test_send_and_sync::<4, Unsigned, BitTranspose>();
1562 test_empty::<4, Unsigned, BitTranspose>();
1563 test_construction_errors::<4, Unsigned, BitTranspose>();
1564
1565 let mut base = BoxedBitSlice::<4, Unsigned, BitTranspose>::new_boxed(len);
1567 assert_eq!(
1568 base.len(),
1569 len,
1570 "BoxedBitSlice returned the incorrect length"
1571 );
1572
1573 assert_eq!(base.bytes(), 32 * len.div_ceil(64));
1574
1575 let bytes = BitSlice::<'static, 4, Unsigned, BitTranspose>::bytes_for(len);
1576 assert_eq!(
1577 bytes,
1578 32 * len.div_ceil(64),
1579 "BoxedBitSlice has the incorrect number of bytes"
1580 );
1581
1582 if len == 0 {
1583 continue;
1584 }
1585
1586 let result = base.set(0, -1);
1588 assert!(matches!(result, Err(SetError::EncodingError { .. })));
1589
1590 let result = base.set(base.len(), -1);
1592 assert!(matches!(result, Err(SetError::IndexError { .. })));
1593
1594 let result = base.get(base.len());
1596 assert!(matches!(result, Err(GetError::IndexError { .. })));
1597
1598 run_overwrite_test(&mut base, FUZZ_ITERATIONS, &mut rng);
1600 }
1601 }
1602}