1#[cfg(any(test, feature = "fuzzing"))]
7use proptest::{
8 arbitrary::{any, Arbitrary, StrategyFor},
9 collection::{vec, VecStrategy},
10 strategy::{Map, Strategy},
11};
12use serde::{de::Error, Deserialize, Deserializer, Serialize};
13use std::{
14 iter::FromIterator,
15 ops::{BitAnd, BitOr},
16};
17
18const BUCKET_SIZE: usize = 8;
20const MAX_BUCKETS: usize = 32;
21
22#[derive(Clone, Default, Debug, Eq, PartialEq, Serialize)]
67pub struct BitVec {
68 #[serde(with = "serde_bytes")]
69 inner: Vec<u8>,
70}
71
72impl BitVec {
73 fn with_capacity(num_buckets: usize) -> Self {
74 Self {
75 inner: Vec::with_capacity(num_buckets),
76 }
77 }
78
79 pub fn set(&mut self, pos: u8) {
81 let bucket: usize = pos as usize / BUCKET_SIZE;
83 if self.inner.len() <= bucket {
84 self.inner.resize(bucket + 1, 0);
85 }
86 let bucket_pos = pos as usize - (bucket * BUCKET_SIZE);
88 self.inner[bucket] |= 0b1000_0000 >> bucket_pos as u8;
89 }
90
91 #[inline]
93 pub fn is_set(&self, pos: u8) -> bool {
94 let bucket: usize = pos as usize / BUCKET_SIZE;
96 if self.inner.len() <= bucket {
97 return false;
98 }
99 let bucket_pos = pos as usize - (bucket * BUCKET_SIZE);
101 (self.inner[bucket] & (0b1000_0000 >> bucket_pos as u8)) != 0
102 }
103
104 pub fn all_zeros(&self) -> bool {
106 self.inner.iter().all(|byte| *byte == 0)
107 }
108
109 pub fn count_ones(&self) -> u32 {
111 self.inner.iter().map(|a| a.count_ones()).sum()
112 }
113
114 pub fn last_set_bit(&self) -> Option<u8> {
116 self.inner
117 .iter()
118 .rev()
119 .enumerate()
120 .find(|(_, byte)| byte != &&0u8)
121 .map(|(i, byte)| {
122 (8 * (self.inner.len() - i) - byte.trailing_zeros() as usize - 1) as u8
123 })
124 }
125
126 pub fn iter_ones(&self) -> impl Iterator<Item = u8> + '_ {
128 (0..=u8::MAX).filter(move |idx| self.is_set(*idx))
129 }
130}
131
132impl BitAnd for &BitVec {
133 type Output = BitVec;
134
135 fn bitand(self, other: Self) -> Self::Output {
137 let len = std::cmp::min(self.inner.len(), other.inner.len());
138 let mut ret = BitVec::with_capacity(len);
139 for i in 0..len {
140 ret.inner.push(self.inner[i] & other.inner[i]);
141 }
142 ret
143 }
144}
145
146impl BitOr for &BitVec {
147 type Output = BitVec;
148
149 fn bitor(self, other: Self) -> Self::Output {
151 let len = std::cmp::max(self.inner.len(), other.inner.len());
152 let mut ret = BitVec::with_capacity(len);
153 for i in 0..len {
154 let a = self.inner.get(i).copied().unwrap_or(0);
155 let b = other.inner.get(i).copied().unwrap_or(0);
156 ret.inner.push(a | b);
157 }
158 ret
159 }
160}
161
162impl FromIterator<u8> for BitVec {
163 fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
164 let mut bitvec = Self::default();
165 for bit in iter {
166 bitvec.set(bit);
167 }
168 bitvec
169 }
170}
171
172impl<'de> Deserialize<'de> for BitVec {
175 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
176 where
177 D: Deserializer<'de>,
178 {
179 let v = serde_bytes::ByteBuf::deserialize(deserializer)?.into_vec();
180 if v.len() > MAX_BUCKETS {
181 return Err(D::Error::custom(format!("BitVec too long: {}", v.len())));
182 }
183 Ok(BitVec { inner: v })
184 }
185}
186
187#[cfg(any(test, feature = "fuzzing"))]
188impl Arbitrary for BitVec {
189 type Parameters = ();
190 type Strategy = Map<VecStrategy<StrategyFor<u8>>, fn(Vec<u8>) -> BitVec>;
191
192 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
193 vec(any::<u8>(), 0..=MAX_BUCKETS).prop_map(|inner| BitVec { inner })
194 }
195}
196
197#[cfg(test)]
198mod test {
199 use super::*;
200 use proptest::proptest;
201
202 #[test]
203 fn test_count_ones() {
204 let p0 = BitVec::default();
205 assert_eq!(p0.count_ones(), 0);
206 let p1 = BitVec {
208 inner: vec![7u8, 15u8],
209 };
210 assert_eq!(p1.count_ones(), 7);
211
212 let p2 = BitVec {
213 inner: vec![7u8; MAX_BUCKETS],
214 };
215 assert_eq!(p2.count_ones(), 3 * MAX_BUCKETS as u32);
216
217 let p3 = BitVec {
219 inner: vec![255u8; MAX_BUCKETS],
220 };
221 assert_eq!(p3.count_ones(), 8 * MAX_BUCKETS as u32);
222
223 let p4 = BitVec {
225 inner: vec![0u8; MAX_BUCKETS],
226 };
227 assert_eq!(p4.count_ones(), 0);
228 }
229
230 #[test]
231 fn test_last_set_bit() {
232 let p0 = BitVec::default();
233 assert_eq!(p0.last_set_bit(), None);
234 let p1 = BitVec { inner: vec![224u8] };
236 assert_eq!(p1.inner.len(), 1);
237 assert_eq!(p1.last_set_bit(), Some(2));
238
239 let p2 = BitVec {
241 inner: vec![7u8, 128u8],
242 };
243 assert_eq!(p2.inner.len(), 2);
244 assert_eq!(p2.last_set_bit(), Some(8));
245
246 let p3 = BitVec {
247 inner: vec![255u8; MAX_BUCKETS],
248 };
249 assert_eq!(p3.inner.len(), MAX_BUCKETS);
250 assert_eq!(p3.last_set_bit(), Some(255));
251
252 let p4 = BitVec {
253 inner: vec![0u8; MAX_BUCKETS],
254 };
255 assert_eq!(p4.last_set_bit(), None);
256
257 let mut p5 = BitVec {
259 inner: vec![0b0000_0001, 0b0100_0000],
260 };
261 assert_eq!(p5.last_set_bit(), Some(9));
262 assert!(p5.is_set(7));
263 assert!(p5.is_set(9));
264 assert!(!p5.is_set(0));
265
266 p5.set(10);
267 assert!(p5.is_set(10));
268 assert_eq!(p5.last_set_bit(), Some(10));
269 assert_eq!(p5.inner, vec![0b0000_0001, 0b0110_0000]);
270
271 let p6 = BitVec {
272 inner: vec![0b1000_0000],
273 };
274 assert_eq!(p6.inner.len(), 1);
275 assert_eq!(p6.last_set_bit(), Some(0));
276 }
277
278 #[test]
279 fn test_empty() {
280 let p = BitVec::default();
281 for i in 0..=std::u8::MAX {
282 assert!(!p.is_set(i));
283 }
284 }
285
286 #[test]
287 fn test_extremes() {
288 let mut p = BitVec::default();
289 p.set(std::u8::MAX);
290 p.set(0);
291 assert!(p.is_set(std::u8::MAX));
292 assert!(p.is_set(0));
293 for i in 1..std::u8::MAX {
294 assert!(!p.is_set(i));
295 }
296 assert_eq!(vec![0, u8::MAX], p.iter_ones().collect::<Vec<_>>());
297 }
298
299 #[test]
300 fn test_deserialization() {
301 let mut bytes = [0u8; 47];
304 bytes[0] = 46;
305 assert!(bcs::from_bytes::<Vec<u8>>(&bytes).is_ok());
306 assert!(bcs::from_bytes::<BitVec>(&bytes).is_err());
308 let mut bytes = [0u8; 33];
309 bytes[0] = 32;
310 let bv = BitVec {
311 inner: Vec::from([0u8; 32].as_ref()),
312 };
313 assert_eq!(Ok(bv), bcs::from_bytes::<BitVec>(&bytes));
314 }
315
316 proptest! {
318 #[test]
319 fn test_and(bv1 in any::<BitVec>(), bv2 in any::<BitVec>()) {
320 let intersection = bv1.bitand(&bv2);
321
322 assert!(intersection.count_ones() <= bv1.count_ones());
323 assert!(intersection.count_ones() <= bv2.count_ones());
324
325 for i in 0..=std::u8::MAX {
326 if bv1.is_set(i) && bv2.is_set(i) {
327 assert!(intersection.is_set(i));
328 } else {
329 assert!(!intersection.is_set(i));
330 }
331 }
332 }
333
334 #[test]
335 fn test_or(bv1 in any::<BitVec>(), bv2 in any::<BitVec>()) {
336 let union = bv1.bitor(&bv2);
337
338 assert!(union.count_ones() >= bv1.count_ones());
339 assert!(union.count_ones() >= bv2.count_ones());
340
341 for i in 0..=std::u8::MAX {
342 if bv1.is_set(i) || bv2.is_set(i) {
343 assert!(union.is_set(i));
344 } else {
345 assert!(!union.is_set(i));
346 }
347 }
348 }
349
350 #[test]
351 fn test_iter_ones(bv1 in any::<BitVec>()) {
352 assert_eq!(bv1.iter_ones().count(), bv1.count_ones() as usize);
353 }
354 }
355}