1#[cfg(any(test, feature = "fuzzing"))]
7use proptest::{
8 arbitrary::{any, Arbitrary, StrategyFor},
9 collection::{vec, VecStrategy},
10 strategy::{Map, Strategy},
11};
12use serde::{de::Error, Deserialize, Deserializer, Serialize};
13use std::{
14 iter::FromIterator,
15 ops::{BitAnd, BitOr},
16};
17
18const BUCKET_SIZE: usize = 8;
20const MAX_BUCKETS: usize = 8192;
21
22#[derive(Clone, Default, Debug, Eq, PartialEq, Serialize)]
67pub struct BitVec {
68 #[serde(with = "serde_bytes")]
69 inner: Vec<u8>,
70}
71
72impl BitVec {
73 fn with_capacity(num_buckets: usize) -> Self {
74 Self {
75 inner: Vec::with_capacity(num_buckets),
76 }
77 }
78
79 pub fn with_num_bits(num_bits: u16) -> Self {
81 Self {
82 inner: vec![0; Self::required_buckets(num_bits)],
83 }
84 }
85
86 pub fn set(&mut self, pos: u16) {
88 let bucket: usize = pos as usize / BUCKET_SIZE;
90 if self.inner.len() <= bucket {
91 self.inner.resize(bucket + 1, 0);
92 }
93 let bucket_pos = pos as usize - (bucket * BUCKET_SIZE);
95 self.inner[bucket] |= 0b1000_0000 >> bucket_pos as u8;
96 }
97
98 #[inline]
100 pub fn is_set(&self, pos: u16) -> bool {
101 let bucket: usize = pos as usize / BUCKET_SIZE;
103 if self.inner.len() <= bucket {
104 return false;
105 }
106 let bucket_pos = pos as usize - (bucket * BUCKET_SIZE);
108 (self.inner[bucket] & (0b1000_0000 >> bucket_pos as u8)) != 0
109 }
110
111 pub fn all_zeros(&self) -> bool {
113 self.inner.iter().all(|byte| *byte == 0)
114 }
115
116 pub fn count_ones(&self) -> u32 {
118 self.inner.iter().map(|a| a.count_ones()).sum()
119 }
120
121 pub fn last_set_bit(&self) -> Option<u16> {
123 self.inner
124 .iter()
125 .rev()
126 .enumerate()
127 .find(|(_, byte)| byte != &&0u8)
128 .map(|(i, byte)| {
129 (8 * (self.inner.len() - i) - byte.trailing_zeros() as usize - 1) as u16
130 })
131 }
132
133 pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
135 (0..self.inner.len() * BUCKET_SIZE).filter(move |idx| self.is_set(*idx as u16))
136 }
137
138 pub fn num_buckets(&self) -> usize {
140 self.inner.len()
141 }
142
143 pub fn required_buckets(num_bits: u16) -> usize {
145 num_bits
146 .checked_sub(1)
147 .map_or(0, |pos| pos as usize / BUCKET_SIZE + 1)
148 }
149}
150
151impl BitAnd for &BitVec {
152 type Output = BitVec;
153
154 fn bitand(self, other: Self) -> Self::Output {
156 let len = std::cmp::min(self.inner.len(), other.inner.len());
157 let mut ret = BitVec::with_capacity(len);
158 for i in 0..len {
159 ret.inner.push(self.inner[i] & other.inner[i]);
160 }
161 ret
162 }
163}
164
165impl BitOr for &BitVec {
166 type Output = BitVec;
167
168 fn bitor(self, other: Self) -> Self::Output {
170 let len = std::cmp::max(self.inner.len(), other.inner.len());
171 let mut ret = BitVec::with_capacity(len);
172 for i in 0..len {
173 let a = self.inner.get(i).copied().unwrap_or(0);
174 let b = other.inner.get(i).copied().unwrap_or(0);
175 ret.inner.push(a | b);
176 }
177 ret
178 }
179}
180
181impl FromIterator<u8> for BitVec {
182 fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
183 let mut bitvec = Self::default();
184 for bit in iter {
185 bitvec.set(bit as u16);
186 }
187 bitvec
188 }
189}
190
191impl From<Vec<u8>> for BitVec {
192 fn from(raw_bytes: Vec<u8>) -> Self {
193 assert!(raw_bytes.len() <= MAX_BUCKETS);
194 Self { inner: raw_bytes }
195 }
196}
197
198impl From<BitVec> for Vec<u8> {
199 fn from(bitvec: BitVec) -> Self {
200 bitvec.inner
201 }
202}
203
204impl From<Vec<bool>> for BitVec {
205 fn from(bits: Vec<bool>) -> Self {
206 assert!(bits.len() <= MAX_BUCKETS * BUCKET_SIZE);
207 let mut bitvec = Self::with_num_bits(bits.len() as u16);
208 for (index, b) in bits.iter().enumerate() {
209 if *b {
210 bitvec.set(index as u16);
211 }
212 }
213 bitvec
214 }
215}
216
217impl<'de> Deserialize<'de> for BitVec {
218 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
219 where
220 D: Deserializer<'de>,
221 {
222 #[derive(Deserialize)]
223 #[serde(rename = "BitVec")]
224 struct RawData {
225 #[serde(with = "serde_bytes")]
226 inner: Vec<u8>,
227 }
228 let v = RawData::deserialize(deserializer)?.inner;
229 if v.len() > MAX_BUCKETS {
230 return Err(D::Error::custom(format!("BitVec too long: {}", v.len())));
231 }
232 Ok(BitVec { inner: v })
233 }
234}
235
236#[cfg(any(test, feature = "fuzzing"))]
237impl Arbitrary for BitVec {
238 type Parameters = ();
239 type Strategy = Map<VecStrategy<StrategyFor<u8>>, fn(Vec<u8>) -> BitVec>;
240
241 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
242 vec(any::<u8>(), 0..=MAX_BUCKETS).prop_map(|inner| BitVec { inner })
243 }
244}
245
246#[cfg(test)]
247mod test {
248 use super::*;
249 use proptest::proptest;
250
251 #[test]
252 fn test_count_ones() {
253 let p0 = BitVec::default();
254 assert_eq!(p0.count_ones(), 0);
255 let p1 = BitVec {
257 inner: vec![7u8, 15u8],
258 };
259 assert_eq!(p1.count_ones(), 7);
260
261 let p2 = BitVec {
262 inner: vec![7u8; MAX_BUCKETS],
263 };
264 assert_eq!(p2.count_ones(), 3 * MAX_BUCKETS as u32);
265
266 let p3 = BitVec {
268 inner: vec![255u8; MAX_BUCKETS],
269 };
270 assert_eq!(p3.count_ones(), 8 * MAX_BUCKETS as u32);
271
272 let p4 = BitVec {
274 inner: vec![0u8; MAX_BUCKETS],
275 };
276 assert_eq!(p4.count_ones(), 0);
277 }
278
279 #[test]
280 fn test_last_set_bit() {
281 let p0 = BitVec::default();
282 assert_eq!(p0.last_set_bit(), None);
283 let p1 = BitVec { inner: vec![224u8] };
285 assert_eq!(p1.inner.len(), 1);
286 assert_eq!(p1.last_set_bit(), Some(2));
287
288 let p2 = BitVec {
290 inner: vec![7u8, 128u8],
291 };
292 assert_eq!(p2.inner.len(), 2);
293 assert_eq!(p2.last_set_bit(), Some(8));
294
295 let p3 = BitVec {
296 inner: vec![255u8; MAX_BUCKETS],
297 };
298 assert_eq!(p3.inner.len(), MAX_BUCKETS);
299 assert_eq!(p3.last_set_bit(), Some(65535));
300
301 let p4 = BitVec {
302 inner: vec![0u8; MAX_BUCKETS],
303 };
304 assert_eq!(p4.last_set_bit(), None);
305
306 let mut p5 = BitVec {
308 inner: vec![0b0000_0001, 0b0100_0000],
309 };
310 assert_eq!(p5.last_set_bit(), Some(9));
311 assert!(p5.is_set(7));
312 assert!(p5.is_set(9));
313 assert!(!p5.is_set(0));
314
315 p5.set(10);
316 assert!(p5.is_set(10));
317 assert_eq!(p5.last_set_bit(), Some(10));
318 assert_eq!(p5.inner, vec![0b0000_0001, 0b0110_0000]);
319
320 let p6 = BitVec {
321 inner: vec![0b1000_0000],
322 };
323 assert_eq!(p6.inner.len(), 1);
324 assert_eq!(p6.last_set_bit(), Some(0));
325 }
326
327 #[test]
328 fn test_empty() {
329 let p = BitVec::default();
330 for i in 0..=u16::MAX {
331 assert!(!p.is_set(i));
332 }
333 }
334
335 #[test]
336 fn test_extremes() {
337 let mut p = BitVec::default();
338 p.set(u16::MAX);
339 p.set(0);
340 assert!(p.is_set(u16::MAX));
341 assert!(p.is_set(0));
342 for i in 1..u16::MAX {
343 assert!(!p.is_set(i));
344 }
345 assert_eq!(
346 vec![0, u16::MAX as usize],
347 p.iter_ones().collect::<Vec<_>>()
348 );
349 }
350
351 #[test]
352 fn test_conversion() {
353 let bitmaps = vec![
354 false, true, true, false, false, true, true, false, true, true, true,
355 ];
356 let bitvec = BitVec::from(bitmaps.clone());
357 for (index, is_set) in bitmaps.into_iter().enumerate() {
358 assert_eq!(bitvec.is_set(index as u16), is_set);
359 }
360 }
361
362 #[test]
363 fn test_deserialization() {
364 let raw = vec![0u8; 9000];
365 let bytes = bcs::to_bytes(&raw).unwrap();
366 assert!(bcs::from_bytes::<Vec<u8>>(&bytes).is_ok());
367 assert!(bcs::from_bytes::<BitVec>(&bytes).is_err());
369 let mut bytes = [0u8; 33];
370 bytes[0] = 32;
371 let bv = BitVec {
372 inner: Vec::from([0u8; 32].as_ref()),
373 };
374 assert_eq!(Ok(bv), bcs::from_bytes::<BitVec>(&bytes));
375 }
376
377 proptest! {
379 #[test]
380 fn test_and(bv1 in any::<BitVec>(), bv2 in any::<BitVec>()) {
381 let intersection = bv1.bitand(&bv2);
382
383 assert!(intersection.count_ones() <= bv1.count_ones());
384 assert!(intersection.count_ones() <= bv2.count_ones());
385
386 for i in 0..=u16::MAX {
387 if bv1.is_set(i) && bv2.is_set(i) {
388 assert!(intersection.is_set(i));
389 } else {
390 assert!(!intersection.is_set(i));
391 }
392 }
393 }
394
395 #[test]
396 fn test_or(bv1 in any::<BitVec>(), bv2 in any::<BitVec>()) {
397 let union = bv1.bitor(&bv2);
398
399 assert!(union.count_ones() >= bv1.count_ones());
400 assert!(union.count_ones() >= bv2.count_ones());
401
402 for i in 0..=u16::MAX {
403 if bv1.is_set(i) || bv2.is_set(i) {
404 assert!(union.is_set(i));
405 } else {
406 assert!(!union.is_set(i));
407 }
408 }
409 }
410
411 #[test]
412 fn test_iter_ones(bv1 in any::<BitVec>()) {
413 assert_eq!(bv1.iter_ones().count(), bv1.count_ones() as usize);
414 }
415
416 #[test]
417 fn test_serde_roundtrip(bits in vec(any::<bool>(), 0..u16::MAX as usize)) {
418 let bitvec = BitVec::from(bits);
419 let bytes = serde_json::to_vec(&bitvec).unwrap();
420 let back = serde_json::from_slice(&bytes).unwrap();
421 assert_eq!(bitvec, back);
422 }
423
424 }
425}