boolvec/
bool_ref.rs

1
2use std::fmt;
3use std::ptr::NonNull;
4use std::marker::PhantomData;
5
6use std::cmp::{
7    Ordering,
8    Ord, PartialOrd,
9};
10
11use crate::mask::Mask;
12
13/// An unsafe interface that holds a reference to a boolean that may be
14/// in the middle of byte.
15#[derive(Clone, Copy)]
16pub(crate) struct UnsafeBoolRef {
17    pub byte: NonNull<u8>,
18    pub bit_mask: Mask,
19}
20
21impl fmt::Debug for UnsafeBoolRef {
22    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
23        unsafe {
24            fmt::Debug::fmt(&self.get(), f)
25        }
26    }
27}
28
29impl fmt::Display for UnsafeBoolRef {
30    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
31        unsafe {
32            fmt::Debug::fmt(&self.get(), f)
33        }
34    }
35}
36
37impl UnsafeBoolRef {
38    /// Creates a new `MutBool`.
39    #[inline]
40    pub fn new(byte: NonNull<u8>, bit_mask: Mask) -> Self {
41        Self {
42            byte,
43            bit_mask,
44        }
45    }
46
47    /// Gets the value of the referenced boolean.
48    ///
49    /// # Safety
50    /// The lifetime of the referenced byte must be valid.
51    #[inline(always)]
52    pub unsafe fn get(&self) -> bool {
53        self.bit_mask.check(
54            *self.byte.as_ref()
55        )
56    }
57
58    /// Sets the value of the referenced boolean.
59    ///
60    /// # Safety
61    /// The lifetime of the referenced byte must be valid.
62    pub unsafe fn set(&mut self, value: bool) {
63        self.bit_mask.set(self.byte.as_mut(), value);
64    }
65
66    /// Returns a reference that points to the next bit.
67    /// 
68    /// # Safety
69    /// The next bit must be a valid value.
70    pub unsafe fn next_bit(mut self) -> Self {
71        self.bit_mask >>= 1;
72
73        if self.bit_mask == Mask::VALUES[0] {
74            self.byte = NonNull::new(self.byte.as_ptr().add(1)).unwrap();
75        }
76
77        self
78    }
79
80    /// Returns a reference that points to the next bit.
81    /// 
82    /// # Safety
83    /// The previous bit must be a valid value.
84    pub unsafe fn prev_bit(mut self) -> Self {
85        self.bit_mask <<= 1;
86
87        if self.bit_mask == Mask::VALUES[7] {
88            self.byte = NonNull::new(self.byte.as_ptr().sub(1)).unwrap();
89        }
90
91        self
92    }
93}
94
95impl PartialOrd for UnsafeBoolRef {
96    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
97        Some(self.cmp(other))
98    }
99}
100
101impl Ord for UnsafeBoolRef {
102    fn cmp(&self, other: &Self) -> Ordering {
103        match self.byte.cmp(&other.byte) {
104            Ordering::Greater => Ordering::Greater,
105            Ordering::Less => Ordering::Less,
106            Ordering::Equal => self.bit_mask.cmp(&other.bit_mask),
107        }
108    }
109}
110
111impl PartialEq for UnsafeBoolRef {
112    fn eq(&self, other: &Self) -> bool {
113        unsafe { self.get() == other.get() }
114    }
115}
116
117impl Eq for UnsafeBoolRef { }
118
119/// A mutable reference to a `bool` value that may be in the middle of a byte.
120#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq)]
121pub struct RefBoolMut<'s> {
122    inner: UnsafeBoolRef,
123    _marker: PhantomData<&'s mut u8>,
124}
125
126impl<'s> fmt::Debug for RefBoolMut<'s> {
127    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
128        fmt::Debug::fmt(&self.get(), f)
129    }
130}
131
132impl<'s> fmt::Display for RefBoolMut<'s> {
133    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
134        fmt::Debug::fmt(&self.get(), f)
135    }
136}
137
138impl<'s> RefBoolMut<'s> {
139    pub(crate) fn from_inner(inner: UnsafeBoolRef) -> Self {
140        Self {
141            inner,
142            _marker: PhantomData,
143        }
144    }
145
146    /// Creates a new `MutBool`.
147    pub fn new(byte: &'s mut u8, bit_mask: Mask) -> Self {
148        Self::from_inner(UnsafeBoolRef::new(NonNull::from(byte), bit_mask))
149    }
150
151    /// Gets the value of the referenced boolean.
152    #[inline(always)]
153    pub fn get(&self) -> bool {
154        // Safety: The lifetime is checked by the compiler thanks to the
155        // phantom data.
156        unsafe { self.inner.get() }
157    }
158
159    /// Sets the value of the referenced boolean.
160    #[inline(always)]
161    pub fn set(&mut self, value: bool) {
162        // Safety: The lifetime is checked by the compiler thanks to the
163        // phantom data.
164        unsafe { self.inner.set(value); }
165    }
166
167    /// Returns a reference that points to the next bit without being constrained
168    /// by the referenced byte.
169    ///
170    /// # Safety
171    /// The previous bit must be a valid value.
172    #[inline]
173    pub unsafe fn unconstrained_next_bit(mut self) -> Self {
174        self.inner = self.inner.next_bit();
175        self
176    }
177
178    /// Returns a reference that points to the previous bit without being constrained
179    /// by the referenced byte.
180    ///
181    /// # Safety
182    /// The previous bit must be a valid value.
183    #[inline]
184    pub unsafe fn unconstrained_prev_bit(mut self) -> Self {
185        self.inner = self.inner.prev_bit();
186        self
187    }
188
189    /// Returns a reference that points to the next bit within the referenced byte.
190    ///
191    /// # Panics
192    /// This function panics if the previous bit is not part of the referenced byte.
193    #[inline]
194    pub fn next_bit(mut self) -> Self {
195        if self.inner.bit_mask == Mask::VALUES[7] {
196            panic!("The next bit was not part of the referenced byte.");
197        }
198
199        self.inner = unsafe { self.inner.next_bit() };
200        self
201    }
202
203    /// Returns a reference that points to the previous bit within the referenced byte.
204    /// 
205    /// # Panics
206    /// This function panics if the previous bit is not part of the referenced byte.
207    #[inline]
208    pub fn prev_bit(mut self) -> Self {
209        if self.inner.bit_mask == Mask::VALUES[0] {
210            panic!("The previous bit was not part of the referenced byte.");
211        }
212
213        self.inner = unsafe { self.inner.prev_bit() };
214        self
215    }
216}
217
218/// A reference to a `bool` value that may be in the middle of a byte.
219#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq)]
220pub struct RefBool<'s> {
221    inner: UnsafeBoolRef,
222    _marker: PhantomData<&'s u8>,
223}
224
225impl<'s> fmt::Debug for RefBool<'s> {
226    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
227        fmt::Debug::fmt(&self.get(), f)
228    }
229}
230
231impl<'s> fmt::Display for RefBool<'s> {
232    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
233        fmt::Debug::fmt(&self.get(), f)
234    }
235}
236
237impl<'s> RefBool<'s> {
238    /// Creates a new `MutBool`.
239    pub fn new(byte: &'s u8, bit_mask: Mask) -> Self {
240        Self {
241            inner: UnsafeBoolRef::new(NonNull::from(byte), bit_mask),
242            _marker: PhantomData,
243        }
244    }
245
246    /// Gets the value of the referenced boolean.
247    pub fn get(&self) -> bool {
248        // Safety: The lifetime is checked by the compiler thanks to the
249        // phantom data.
250        unsafe { self.inner.get() }
251    }
252
253    /// Returns a reference that points to the next bit without being constrained
254    /// by the referenced byte.
255    ///
256    /// # Safety
257    /// The previous bit must be a valid value.
258    #[inline]
259    pub unsafe fn unconstrained_next_bit(mut self) -> Self {
260        self.inner = self.inner.next_bit();
261        self
262    }
263
264    /// Returns a reference that points to the previous bit without being constrained
265    /// by the referenced byte.
266    ///
267    /// # Safety
268    /// The previous bit must be a valid value.
269    #[inline]
270    pub unsafe fn unconstrained_prev_bit(mut self) -> Self {
271        self.inner = self.inner.prev_bit();
272        self
273    }
274
275    /// Returns a reference that points to the next bit within the referenced byte.
276    ///
277    /// # Panics
278    /// This function panics if the previous bit is not part of the referenced byte.
279    #[inline]
280    pub fn next_bit(mut self) -> Self {
281        if self.inner.bit_mask == Mask::VALUES[7] {
282            panic!("The next bit was not part of the referenced byte.");
283        }
284
285        self.inner = unsafe { self.inner.next_bit() };
286        self
287    }
288
289    /// Returns a reference that points to the previous bit within the referenced byte.
290    /// 
291    /// # Panics
292    /// This function panics if the previous bit is not part of the referenced byte.
293    #[inline]
294    pub fn prev_bit(mut self) -> Self {
295        if self.inner.bit_mask == Mask::VALUES[0] {
296            panic!("The previous bit was not part of the referenced byte.");
297        }
298
299        self.inner = unsafe { self.inner.prev_bit() };
300        self
301    }
302}
303
304#[cfg(test)]
305mod ref_tests {
306    use super::*;
307
308    #[test]
309    fn gets_sets() {
310        let mut byte = 0b1100_1010;
311        
312        assert_eq!(RefBool::new(&byte, Mask::VALUES[0]).get(), true);
313        assert_eq!(RefBool::new(&byte, Mask::VALUES[1]).get(), true);
314        assert_eq!(RefBool::new(&byte, Mask::VALUES[2]).get(), false);
315        assert_eq!(RefBool::new(&byte, Mask::VALUES[3]).get(), false);
316        assert_eq!(RefBool::new(&byte, Mask::VALUES[4]).get(), true);
317        assert_eq!(RefBool::new(&byte, Mask::VALUES[5]).get(), false);
318        assert_eq!(RefBool::new(&byte, Mask::VALUES[6]).get(), true);
319        assert_eq!(RefBool::new(&byte, Mask::VALUES[7]).get(), false);
320
321        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[0]).get(), true);
322        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[1]).get(), true);
323        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[2]).get(), false);
324        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[3]).get(), false);
325        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[4]).get(), true);
326        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[5]).get(), false);
327        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[6]).get(), true);
328        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[7]).get(), false);
329
330        RefBoolMut::new(&mut byte, Mask::VALUES[0]).set(false);
331        RefBoolMut::new(&mut byte, Mask::VALUES[1]).set(false);
332        RefBoolMut::new(&mut byte, Mask::VALUES[2]).set(true);
333        RefBoolMut::new(&mut byte, Mask::VALUES[3]).set(true);
334        RefBoolMut::new(&mut byte, Mask::VALUES[4]).set(false);
335        RefBoolMut::new(&mut byte, Mask::VALUES[5]).set(true);
336        RefBoolMut::new(&mut byte, Mask::VALUES[6]).set(false);
337        RefBoolMut::new(&mut byte, Mask::VALUES[7]).set(true);
338
339        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[0]).get(), false);
340        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[1]).get(), false);
341        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[2]).get(), true);
342        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[3]).get(), true);
343        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[4]).get(), false);
344        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[5]).get(), true);
345        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[6]).get(), false);
346        assert_eq!(RefBoolMut::new(&mut byte, Mask::VALUES[7]).get(), true);
347    }
348
349    #[test]
350    fn offsets() {
351        let slice = &mut [0b1100_0011, 0b1101_1010];
352        let mut b = RefBoolMut::new(unsafe { &mut *slice.as_mut_ptr() }, Mask::VALUES[0]);
353
354        unsafe {
355            assert_eq!(b.get(), true);
356            b = b.unconstrained_next_bit();
357            assert_eq!(b.get(), true);
358            b = b.unconstrained_next_bit();
359            assert_eq!(b.get(), false);
360            b = b.unconstrained_next_bit();
361            assert_eq!(b.get(), false);
362            b = b.unconstrained_next_bit();
363            assert_eq!(b.get(), false);
364            b = b.unconstrained_next_bit();
365            assert_eq!(b.get(), false);
366            b = b.unconstrained_next_bit();
367            assert_eq!(b.get(), true);
368            b = b.unconstrained_next_bit();
369            assert_eq!(b.get(), true);
370            b = b.unconstrained_next_bit();
371            assert_eq!(b.get(), true);
372            b = b.unconstrained_next_bit();
373            assert_eq!(b.get(), true);
374            b = b.unconstrained_next_bit();
375            assert_eq!(b.get(), false);
376            b = b.unconstrained_next_bit();
377            assert_eq!(b.get(), true);
378            b = b.unconstrained_next_bit();
379            assert_eq!(b.get(), true);
380            b = b.unconstrained_next_bit();
381            assert_eq!(b.get(), false);
382            b = b.unconstrained_next_bit();
383            assert_eq!(b.get(), true);
384            b = b.unconstrained_next_bit();
385            assert_eq!(b.get(), false);
386
387            b = b.unconstrained_prev_bit();
388            assert_eq!(b.get(), true);
389            b = b.unconstrained_prev_bit();
390            assert_eq!(b.get(), false);
391            b = b.unconstrained_prev_bit();
392            assert_eq!(b.get(), true);
393            b = b.unconstrained_prev_bit();
394            assert_eq!(b.get(), true);
395            b = b.unconstrained_prev_bit();
396            assert_eq!(b.get(), false);
397            b = b.unconstrained_prev_bit();
398            assert_eq!(b.get(), true);
399            b = b.unconstrained_prev_bit();
400            assert_eq!(b.get(), true);
401            b = b.unconstrained_prev_bit();
402            assert_eq!(b.get(), true);
403            b = b.unconstrained_prev_bit();
404            assert_eq!(b.get(), true);
405            b = b.unconstrained_prev_bit();
406            assert_eq!(b.get(), false);
407            b = b.unconstrained_prev_bit();
408            assert_eq!(b.get(), false);
409            b = b.unconstrained_prev_bit();
410            assert_eq!(b.get(), false);
411            b = b.unconstrained_prev_bit();
412            assert_eq!(b.get(), false);
413            b = b.unconstrained_prev_bit();
414            assert_eq!(b.get(), true);
415            b = b.unconstrained_prev_bit();
416            assert_eq!(b.get(), true);
417        }
418    }
419}