Skip to main content

lerc/
bitmask.rs

1use alloc::vec::Vec;
2
3use bitvec::prelude::*;
4
5/// Validity bitmask matching the C++ BitMask layout.
6///
7/// Two canonical representations:
8///
9/// - [`BitMask::AllValid`] — every pixel is valid. No heap allocation; all
10///   queries ([`is_valid`](Self::is_valid), [`count_valid`](Self::count_valid),
11///   [`is_all_valid`](Self::is_all_valid)) are O(1). This is the canonical form
12///   the decoder produces when the blob header reports full validity.
13/// - [`BitMask::Explicit`] — per-pixel bits stored MSB-first within each byte,
14///   matching the C++ convention: bit `k` is at byte `k >> 3`, position
15///   `(1 << 7) >> (k & 7)`.
16///
17/// Mutating an [`AllValid`](Self::AllValid) mask via [`set_invalid`](Self::set_invalid)
18/// transitions it to [`Explicit`](Self::Explicit) in place (one allocation).
19#[derive(Debug, Clone)]
20pub enum BitMask {
21    /// Every pixel is valid. The `usize` is the pixel count.
22    AllValid(usize),
23    /// Explicit per-pixel validity bits.
24    Explicit(BitVec<u8, Msb0>),
25}
26
27impl BitMask {
28    /// Create a new mask with every pixel marked invalid.
29    pub fn new(num_pixels: usize) -> Self {
30        Self::Explicit(bitvec![u8, Msb0; 0; num_pixels])
31    }
32
33    /// Create a new mask with every pixel marked valid — O(1), no allocation.
34    pub fn all_valid(num_pixels: usize) -> Self {
35        Self::AllValid(num_pixels)
36    }
37
38    /// Create a mask from raw MSB-first bytes, truncated to `num_pixels` bits.
39    pub fn from_bytes(data: Vec<u8>, num_pixels: usize) -> Self {
40        let mut bits = BitVec::<u8, Msb0>::from_vec(data);
41        bits.truncate(num_pixels);
42        Self::Explicit(bits)
43    }
44
45    /// Returns `true` if pixel `k` is valid. O(1).
46    #[inline]
47    pub fn is_valid(&self, k: usize) -> bool {
48        match self {
49            Self::AllValid(_) => true,
50            Self::Explicit(bits) => bits[k],
51        }
52    }
53
54    /// Mark pixel `k` as valid.
55    ///
56    /// No-op for [`AllValid`](Self::AllValid).
57    #[inline]
58    pub fn set_valid(&mut self, k: usize) {
59        if let Self::Explicit(bits) = self {
60            bits.set(k, true);
61        }
62    }
63
64    /// Mark pixel `k` as invalid.
65    ///
66    /// If the mask is [`AllValid`](Self::AllValid), this materializes it into
67    /// [`Explicit`](Self::Explicit) first (one allocation of `(num_pixels + 7) / 8`
68    /// bytes), then clears bit `k`.
69    #[inline]
70    pub fn set_invalid(&mut self, k: usize) {
71        if let Self::AllValid(n) = *self {
72            *self = Self::Explicit(bitvec![u8, Msb0; 1; n]);
73        }
74        if let Self::Explicit(bits) = self {
75            bits.set(k, false);
76        }
77    }
78
79    /// Returns the number of valid pixels. O(1) for `AllValid`; popcount for `Explicit`.
80    pub fn count_valid(&self) -> usize {
81        match self {
82            Self::AllValid(n) => *n,
83            Self::Explicit(bits) => bits.count_ones(),
84        }
85    }
86
87    /// Returns `true` if every pixel is valid.
88    ///
89    /// O(1) for [`AllValid`](Self::AllValid); O(n) popcount fallback for
90    /// [`Explicit`](Self::Explicit) (an explicit mask may still have every bit set).
91    #[inline]
92    pub fn is_all_valid(&self) -> bool {
93        match self {
94            Self::AllValid(_) => true,
95            Self::Explicit(bits) => bits.count_ones() == bits.len(),
96        }
97    }
98
99    /// Total pixel count. O(1).
100    pub fn num_pixels(&self) -> usize {
101        match self {
102            Self::AllValid(n) => *n,
103            Self::Explicit(bits) => bits.len(),
104        }
105    }
106
107    /// Number of bytes in the raw byte representation (`ceil(num_pixels / 8)`).
108    pub fn num_bytes(&self) -> usize {
109        match self {
110            Self::AllValid(n) => n.div_ceil(8),
111            Self::Explicit(bits) => bits.as_raw_slice().len(),
112        }
113    }
114
115    /// Borrow the underlying byte storage as a slice.
116    ///
117    /// Returns `None` for [`AllValid`](Self::AllValid) — no bytes are stored.
118    /// Callers that need bytes for serialization should first branch on
119    /// [`is_all_valid`](Self::is_all_valid), since the LERC blob format omits the
120    /// mask section entirely for all-valid bands.
121    pub fn as_bytes(&self) -> Option<&[u8]> {
122        match self {
123            Self::AllValid(_) => None,
124            Self::Explicit(bits) => Some(bits.as_raw_slice()),
125        }
126    }
127
128    /// Mutably borrow the underlying byte storage.
129    ///
130    /// Returns `None` for [`AllValid`](Self::AllValid); convert via a mutation
131    /// like [`set_invalid`](Self::set_invalid) first if you need mutable bytes.
132    pub fn as_bytes_mut(&mut self) -> Option<&mut [u8]> {
133        match self {
134            Self::AllValid(_) => None,
135            Self::Explicit(bits) => Some(bits.as_raw_mut_slice()),
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn new_all_invalid() {
146        let mask = BitMask::new(16);
147        assert_eq!(mask.count_valid(), 0);
148        for i in 0..16 {
149            assert!(!mask.is_valid(i), "pixel {i} should be invalid");
150        }
151    }
152
153    #[test]
154    fn all_valid_is_o1() {
155        let mask = BitMask::all_valid(1_000_000);
156        assert!(matches!(mask, BitMask::AllValid(1_000_000)));
157        assert!(mask.is_all_valid());
158        assert_eq!(mask.count_valid(), 1_000_000);
159        assert_eq!(mask.num_pixels(), 1_000_000);
160        for i in [0, 1, 999, 1_000, 999_999] {
161            assert!(mask.is_valid(i));
162        }
163    }
164
165    #[test]
166    fn all_valid_non_byte_aligned_count() {
167        let mask = BitMask::all_valid(13);
168        assert_eq!(mask.count_valid(), 13);
169        assert_eq!(mask.num_pixels(), 13);
170        assert_eq!(mask.num_bytes(), 2); // ceil(13 / 8)
171        for i in 0..13 {
172            assert!(mask.is_valid(i));
173        }
174    }
175
176    #[test]
177    fn set_valid_on_all_valid_is_noop() {
178        let mut mask = BitMask::all_valid(16);
179        mask.set_valid(5);
180        assert!(matches!(mask, BitMask::AllValid(16)));
181        assert!(mask.is_all_valid());
182    }
183
184    #[test]
185    fn set_invalid_materializes_all_valid() {
186        let mut mask = BitMask::all_valid(16);
187        mask.set_invalid(7);
188        assert!(matches!(mask, BitMask::Explicit(_)));
189        assert!(!mask.is_valid(7));
190        for i in 0..16 {
191            if i != 7 {
192                assert!(mask.is_valid(i), "pixel {i} should still be valid");
193            }
194        }
195        assert_eq!(mask.count_valid(), 15);
196        assert!(!mask.is_all_valid());
197    }
198
199    #[test]
200    fn set_valid_then_is_valid() {
201        let mut mask = BitMask::new(16);
202        mask.set_valid(5);
203        assert!(mask.is_valid(5));
204        assert_eq!(mask.count_valid(), 1);
205        assert!(!mask.is_valid(0));
206        assert!(!mask.is_valid(4));
207        assert!(!mask.is_valid(6));
208    }
209
210    #[test]
211    fn set_invalid_after_set_valid() {
212        let mut mask = BitMask::new(16);
213        mask.set_valid(7);
214        assert!(mask.is_valid(7));
215        mask.set_invalid(7);
216        assert!(!mask.is_valid(7));
217        assert_eq!(mask.count_valid(), 0);
218    }
219
220    #[test]
221    fn from_bytes_msb_first_bit_ordering() {
222        // 0x80 = 0b1000_0000 means pixel 0 valid, pixels 1-7 invalid
223        let mask = BitMask::from_bytes(vec![0x80], 8);
224        assert!(mask.is_valid(0));
225        for i in 1..8 {
226            assert!(!mask.is_valid(i));
227        }
228    }
229
230    #[test]
231    fn from_bytes_all_ones_not_autoconverted() {
232        // An Explicit mask of all-1s is a valid Explicit BitMask; we don't
233        // auto-canonicalize here (decoder canonicalizes via header info instead).
234        let mask = BitMask::from_bytes(vec![0xFF; 2], 16);
235        assert!(matches!(mask, BitMask::Explicit(_)));
236        // But is_all_valid still detects it via popcount fallback.
237        assert!(mask.is_all_valid());
238    }
239
240    #[test]
241    fn from_bytes_multiple_bits() {
242        let mask = BitMask::from_bytes(vec![0xC0, 0x01], 16);
243        assert!(mask.is_valid(0));
244        assert!(mask.is_valid(1));
245        for i in 2..15 {
246            assert!(!mask.is_valid(i));
247        }
248        assert!(mask.is_valid(15));
249        assert_eq!(mask.count_valid(), 3);
250    }
251
252    #[test]
253    fn as_bytes_returns_none_for_all_valid() {
254        let mask = BitMask::all_valid(16);
255        assert!(mask.as_bytes().is_none());
256    }
257
258    #[test]
259    fn as_bytes_round_trip() {
260        let original_data = vec![0xA5, 0x3C];
261        let mask = BitMask::from_bytes(original_data.clone(), 16);
262        let bytes = mask.as_bytes().unwrap();
263        assert_eq!(bytes, &original_data[..]);
264
265        let mask2 = BitMask::from_bytes(bytes.to_vec(), 16);
266        for i in 0..16 {
267            assert_eq!(mask.is_valid(i), mask2.is_valid(i));
268        }
269    }
270
271    #[test]
272    fn as_bytes_round_trip_non_aligned() {
273        let mut mask = BitMask::new(10);
274        mask.set_valid(0);
275        mask.set_valid(3);
276        mask.set_valid(9);
277
278        let bytes = mask.as_bytes().unwrap().to_vec();
279        let mask2 = BitMask::from_bytes(bytes, 10);
280        for i in 0..10 {
281            assert_eq!(mask.is_valid(i), mask2.is_valid(i));
282        }
283    }
284
285    #[test]
286    fn num_pixels_and_num_bytes_consistency() {
287        let mask = BitMask::new(16);
288        assert_eq!(mask.num_pixels(), 16);
289        assert_eq!(mask.num_bytes(), 2);
290
291        let mask = BitMask::new(13);
292        assert_eq!(mask.num_pixels(), 13);
293        assert_eq!(mask.num_bytes(), 2);
294
295        let mask = BitMask::new(1);
296        assert_eq!(mask.num_pixels(), 1);
297        assert_eq!(mask.num_bytes(), 1);
298
299        let mask = BitMask::new(8);
300        assert_eq!(mask.num_pixels(), 8);
301        assert_eq!(mask.num_bytes(), 1);
302
303        let mask = BitMask::new(9);
304        assert_eq!(mask.num_pixels(), 9);
305        assert_eq!(mask.num_bytes(), 2);
306
307        // AllValid reports the same byte count even though nothing is stored.
308        let mask = BitMask::all_valid(13);
309        assert_eq!(mask.num_bytes(), 2);
310    }
311
312    #[test]
313    fn is_all_valid_fast_path() {
314        assert!(BitMask::all_valid(100).is_all_valid());
315    }
316
317    #[test]
318    fn is_all_valid_false_after_materialization() {
319        let mut mask = BitMask::all_valid(16);
320        assert!(mask.is_all_valid());
321        mask.set_invalid(0);
322        assert!(!mask.is_all_valid());
323    }
324
325    #[cfg(not(target_arch = "wasm32"))]
326    mod proptest_tests {
327        use super::*;
328        use proptest::prelude::*;
329
330        proptest! {
331            #[test]
332            fn prop_set_valid_is_valid(n in 1..1000usize, k in 0..999usize) {
333                let n = n.max(1);
334                let k = k % n;
335                let mut mask = BitMask::new(n);
336                mask.set_valid(k);
337                prop_assert!(mask.is_valid(k));
338            }
339
340            #[test]
341            fn prop_from_bytes_round_trip(n in 1..200usize) {
342                let mask = BitMask::from_bytes(vec![0xFF; n.div_ceil(8)], n);
343                let bytes = mask.as_bytes().unwrap().to_vec();
344                let restored = BitMask::from_bytes(bytes, n);
345                for i in 0..n {
346                    prop_assert_eq!(mask.is_valid(i), restored.is_valid(i));
347                }
348            }
349
350            #[test]
351            fn prop_all_valid_is_valid_everywhere(n in 1..1000usize) {
352                let mask = BitMask::all_valid(n);
353                for i in 0..n {
354                    prop_assert!(mask.is_valid(i));
355                }
356            }
357        }
358    }
359}