1use alloc::vec::Vec;
2
3use bitvec::prelude::*;
4
5#[derive(Debug, Clone)]
20pub enum BitMask {
21 AllValid(usize),
23 Explicit(BitVec<u8, Msb0>),
25}
26
27impl BitMask {
28 pub fn new(num_pixels: usize) -> Self {
30 Self::Explicit(bitvec![u8, Msb0; 0; num_pixels])
31 }
32
33 pub fn all_valid(num_pixels: usize) -> Self {
35 Self::AllValid(num_pixels)
36 }
37
38 pub fn from_bytes(data: Vec<u8>, num_pixels: usize) -> Self {
40 let mut bits = BitVec::<u8, Msb0>::from_vec(data);
41 bits.truncate(num_pixels);
42 Self::Explicit(bits)
43 }
44
45 #[inline]
47 pub fn is_valid(&self, k: usize) -> bool {
48 match self {
49 Self::AllValid(_) => true,
50 Self::Explicit(bits) => bits[k],
51 }
52 }
53
54 #[inline]
58 pub fn set_valid(&mut self, k: usize) {
59 if let Self::Explicit(bits) = self {
60 bits.set(k, true);
61 }
62 }
63
64 #[inline]
70 pub fn set_invalid(&mut self, k: usize) {
71 if let Self::AllValid(n) = *self {
72 *self = Self::Explicit(bitvec![u8, Msb0; 1; n]);
73 }
74 if let Self::Explicit(bits) = self {
75 bits.set(k, false);
76 }
77 }
78
79 pub fn count_valid(&self) -> usize {
81 match self {
82 Self::AllValid(n) => *n,
83 Self::Explicit(bits) => bits.count_ones(),
84 }
85 }
86
87 #[inline]
92 pub fn is_all_valid(&self) -> bool {
93 match self {
94 Self::AllValid(_) => true,
95 Self::Explicit(bits) => bits.count_ones() == bits.len(),
96 }
97 }
98
99 pub fn num_pixels(&self) -> usize {
101 match self {
102 Self::AllValid(n) => *n,
103 Self::Explicit(bits) => bits.len(),
104 }
105 }
106
107 pub fn num_bytes(&self) -> usize {
109 match self {
110 Self::AllValid(n) => n.div_ceil(8),
111 Self::Explicit(bits) => bits.as_raw_slice().len(),
112 }
113 }
114
115 pub fn as_bytes(&self) -> Option<&[u8]> {
122 match self {
123 Self::AllValid(_) => None,
124 Self::Explicit(bits) => Some(bits.as_raw_slice()),
125 }
126 }
127
128 pub fn as_bytes_mut(&mut self) -> Option<&mut [u8]> {
133 match self {
134 Self::AllValid(_) => None,
135 Self::Explicit(bits) => Some(bits.as_raw_mut_slice()),
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn new_all_invalid() {
146 let mask = BitMask::new(16);
147 assert_eq!(mask.count_valid(), 0);
148 for i in 0..16 {
149 assert!(!mask.is_valid(i), "pixel {i} should be invalid");
150 }
151 }
152
153 #[test]
154 fn all_valid_is_o1() {
155 let mask = BitMask::all_valid(1_000_000);
156 assert!(matches!(mask, BitMask::AllValid(1_000_000)));
157 assert!(mask.is_all_valid());
158 assert_eq!(mask.count_valid(), 1_000_000);
159 assert_eq!(mask.num_pixels(), 1_000_000);
160 for i in [0, 1, 999, 1_000, 999_999] {
161 assert!(mask.is_valid(i));
162 }
163 }
164
165 #[test]
166 fn all_valid_non_byte_aligned_count() {
167 let mask = BitMask::all_valid(13);
168 assert_eq!(mask.count_valid(), 13);
169 assert_eq!(mask.num_pixels(), 13);
170 assert_eq!(mask.num_bytes(), 2); for i in 0..13 {
172 assert!(mask.is_valid(i));
173 }
174 }
175
176 #[test]
177 fn set_valid_on_all_valid_is_noop() {
178 let mut mask = BitMask::all_valid(16);
179 mask.set_valid(5);
180 assert!(matches!(mask, BitMask::AllValid(16)));
181 assert!(mask.is_all_valid());
182 }
183
184 #[test]
185 fn set_invalid_materializes_all_valid() {
186 let mut mask = BitMask::all_valid(16);
187 mask.set_invalid(7);
188 assert!(matches!(mask, BitMask::Explicit(_)));
189 assert!(!mask.is_valid(7));
190 for i in 0..16 {
191 if i != 7 {
192 assert!(mask.is_valid(i), "pixel {i} should still be valid");
193 }
194 }
195 assert_eq!(mask.count_valid(), 15);
196 assert!(!mask.is_all_valid());
197 }
198
199 #[test]
200 fn set_valid_then_is_valid() {
201 let mut mask = BitMask::new(16);
202 mask.set_valid(5);
203 assert!(mask.is_valid(5));
204 assert_eq!(mask.count_valid(), 1);
205 assert!(!mask.is_valid(0));
206 assert!(!mask.is_valid(4));
207 assert!(!mask.is_valid(6));
208 }
209
210 #[test]
211 fn set_invalid_after_set_valid() {
212 let mut mask = BitMask::new(16);
213 mask.set_valid(7);
214 assert!(mask.is_valid(7));
215 mask.set_invalid(7);
216 assert!(!mask.is_valid(7));
217 assert_eq!(mask.count_valid(), 0);
218 }
219
220 #[test]
221 fn from_bytes_msb_first_bit_ordering() {
222 let mask = BitMask::from_bytes(vec![0x80], 8);
224 assert!(mask.is_valid(0));
225 for i in 1..8 {
226 assert!(!mask.is_valid(i));
227 }
228 }
229
230 #[test]
231 fn from_bytes_all_ones_not_autoconverted() {
232 let mask = BitMask::from_bytes(vec![0xFF; 2], 16);
235 assert!(matches!(mask, BitMask::Explicit(_)));
236 assert!(mask.is_all_valid());
238 }
239
240 #[test]
241 fn from_bytes_multiple_bits() {
242 let mask = BitMask::from_bytes(vec![0xC0, 0x01], 16);
243 assert!(mask.is_valid(0));
244 assert!(mask.is_valid(1));
245 for i in 2..15 {
246 assert!(!mask.is_valid(i));
247 }
248 assert!(mask.is_valid(15));
249 assert_eq!(mask.count_valid(), 3);
250 }
251
252 #[test]
253 fn as_bytes_returns_none_for_all_valid() {
254 let mask = BitMask::all_valid(16);
255 assert!(mask.as_bytes().is_none());
256 }
257
258 #[test]
259 fn as_bytes_round_trip() {
260 let original_data = vec![0xA5, 0x3C];
261 let mask = BitMask::from_bytes(original_data.clone(), 16);
262 let bytes = mask.as_bytes().unwrap();
263 assert_eq!(bytes, &original_data[..]);
264
265 let mask2 = BitMask::from_bytes(bytes.to_vec(), 16);
266 for i in 0..16 {
267 assert_eq!(mask.is_valid(i), mask2.is_valid(i));
268 }
269 }
270
271 #[test]
272 fn as_bytes_round_trip_non_aligned() {
273 let mut mask = BitMask::new(10);
274 mask.set_valid(0);
275 mask.set_valid(3);
276 mask.set_valid(9);
277
278 let bytes = mask.as_bytes().unwrap().to_vec();
279 let mask2 = BitMask::from_bytes(bytes, 10);
280 for i in 0..10 {
281 assert_eq!(mask.is_valid(i), mask2.is_valid(i));
282 }
283 }
284
285 #[test]
286 fn num_pixels_and_num_bytes_consistency() {
287 let mask = BitMask::new(16);
288 assert_eq!(mask.num_pixels(), 16);
289 assert_eq!(mask.num_bytes(), 2);
290
291 let mask = BitMask::new(13);
292 assert_eq!(mask.num_pixels(), 13);
293 assert_eq!(mask.num_bytes(), 2);
294
295 let mask = BitMask::new(1);
296 assert_eq!(mask.num_pixels(), 1);
297 assert_eq!(mask.num_bytes(), 1);
298
299 let mask = BitMask::new(8);
300 assert_eq!(mask.num_pixels(), 8);
301 assert_eq!(mask.num_bytes(), 1);
302
303 let mask = BitMask::new(9);
304 assert_eq!(mask.num_pixels(), 9);
305 assert_eq!(mask.num_bytes(), 2);
306
307 let mask = BitMask::all_valid(13);
309 assert_eq!(mask.num_bytes(), 2);
310 }
311
312 #[test]
313 fn is_all_valid_fast_path() {
314 assert!(BitMask::all_valid(100).is_all_valid());
315 }
316
317 #[test]
318 fn is_all_valid_false_after_materialization() {
319 let mut mask = BitMask::all_valid(16);
320 assert!(mask.is_all_valid());
321 mask.set_invalid(0);
322 assert!(!mask.is_all_valid());
323 }
324
325 #[cfg(not(target_arch = "wasm32"))]
326 mod proptest_tests {
327 use super::*;
328 use proptest::prelude::*;
329
330 proptest! {
331 #[test]
332 fn prop_set_valid_is_valid(n in 1..1000usize, k in 0..999usize) {
333 let n = n.max(1);
334 let k = k % n;
335 let mut mask = BitMask::new(n);
336 mask.set_valid(k);
337 prop_assert!(mask.is_valid(k));
338 }
339
340 #[test]
341 fn prop_from_bytes_round_trip(n in 1..200usize) {
342 let mask = BitMask::from_bytes(vec![0xFF; n.div_ceil(8)], n);
343 let bytes = mask.as_bytes().unwrap().to_vec();
344 let restored = BitMask::from_bytes(bytes, n);
345 for i in 0..n {
346 prop_assert_eq!(mask.is_valid(i), restored.is_valid(i));
347 }
348 }
349
350 #[test]
351 fn prop_all_valid_is_valid_everywhere(n in 1..1000usize) {
352 let mask = BitMask::all_valid(n);
353 for i in 0..n {
354 prop_assert!(mask.is_valid(i));
355 }
356 }
357 }
358 }
359}