quark/
bit_index.rs

1#![allow(unused_comparisons)]
2
3use crate::{BitMask, BitSize};
4use std::ops::RangeBounds;
5
6/// Provides bit indexing operations.
7///
8/// This trait defines functions for accessing single bits to determine whether they are set and for
9/// accessing ranges of bits to extract the value they contain.
10///
11/// # Examples
12///
13/// ```
14/// use quark::BitIndex;
15///
16/// let value: u32 = 0xe01a3497;
17///
18/// let s = value.bit(20);
19/// assert!(s);
20/// let rd = value.bits(16..20);
21/// assert_eq!(rd, 10);
22/// let rn = value.bits(12..16);
23/// assert_eq!(rn, 3);
24/// let rs = value.bits(8..12);
25/// assert_eq!(rs, 4);
26/// let rm = value.bits(0..4);
27/// assert_eq!(rm, 7);
28/// ```
29pub trait BitIndex: BitSize + BitMask {
30    /// Returns whether the specified bit is set.
31    fn bit(&self, index: usize) -> bool;
32
33    /// Returns the bits contained in the specified bit range.
34    fn bits<Idx: RangeBounds<usize>>(&self, index: Idx) -> Self;
35}
36
37macro_rules! bit_index_impl {
38    ($type:ty) => {
39        impl BitIndex for $type {
40            fn bit(&self, index: usize) -> bool {
41                self.checked_shr(index as _)
42                    .unwrap_or_else(|| if *self < 0 { 1 } else { 0 })
43                    & 1
44                    == 1
45            }
46
47            fn bits<Idx: RangeBounds<usize>>(&self, index: Idx) -> Self {
48                let mask = match (index.start_bound(), index.end_bound()) {
49                    (::std::ops::Bound::Excluded(se), ::std::ops::Bound::Excluded(ee)) => {
50                        Some(*ee - *se - 1)
51                    }
52                    (::std::ops::Bound::Excluded(se), ::std::ops::Bound::Included(ee)) => {
53                        Some(*ee - *se)
54                    }
55                    (::std::ops::Bound::Excluded(_), ::std::ops::Bound::Unbounded) => None,
56                    (::std::ops::Bound::Included(si), ::std::ops::Bound::Excluded(ee)) => {
57                        Some(*ee - *si)
58                    }
59                    (::std::ops::Bound::Included(si), ::std::ops::Bound::Included(ei)) => {
60                        Some(*ei + 1 - *si)
61                    }
62                    (::std::ops::Bound::Included(_), ::std::ops::Bound::Unbounded) => None,
63                    (::std::ops::Bound::Unbounded, ::std::ops::Bound::Excluded(ee)) => Some(*ee),
64                    (::std::ops::Bound::Unbounded, ::std::ops::Bound::Included(ei)) => {
65                        Some(*ei + 1)
66                    }
67                    (::std::ops::Bound::Unbounded, ::std::ops::Bound::Unbounded) => None,
68                };
69
70                let shift = match index.start_bound() {
71                    ::std::ops::Bound::Excluded(e) => Some(*e + 1),
72                    ::std::ops::Bound::Included(i) => Some(*i),
73                    ::std::ops::Bound::Unbounded => Some(0),
74                };
75
76                match (shift, mask) {
77                    (Some(s), Some(m)) => self
78                        .checked_shr(s as _)
79                        .unwrap_or_else(|| {
80                            if *self < 0 {
81                                (0 as Self).wrapping_sub(1)
82                            } else {
83                                0
84                            }
85                        })
86                        .mask_to(m),
87                    (Some(s), None) => self.checked_shr(s as _).unwrap_or_else(|| {
88                        if *self < 0 {
89                            (0 as Self).wrapping_sub(1)
90                        } else {
91                            0
92                        }
93                    }),
94                    (None, Some(m)) => self.mask_to(m),
95                    (None, None) => *self,
96                }
97            }
98        }
99    };
100}
101
102bit_index_impl!(u8);
103bit_index_impl!(u16);
104bit_index_impl!(u32);
105bit_index_impl!(u64);
106bit_index_impl!(u128);
107bit_index_impl!(usize);
108bit_index_impl!(i8);
109bit_index_impl!(i16);
110bit_index_impl!(i32);
111bit_index_impl!(i64);
112bit_index_impl!(i128);
113bit_index_impl!(isize);
114
115#[cfg(test)]
116mod test {
117    use super::*;
118    use spectral::prelude::*;
119
120    struct RangeEE(usize, usize);
121    impl RangeBounds<usize> for RangeEE {
122        fn start_bound(&self) -> std::ops::Bound<&usize> {
123            std::ops::Bound::Excluded(&self.0)
124        }
125        fn end_bound(&self) -> std::ops::Bound<&usize> {
126            std::ops::Bound::Excluded(&self.1)
127        }
128    }
129
130    struct RangeEI(usize, usize);
131    impl RangeBounds<usize> for RangeEI {
132        fn start_bound(&self) -> std::ops::Bound<&usize> {
133            std::ops::Bound::Excluded(&self.0)
134        }
135        fn end_bound(&self) -> std::ops::Bound<&usize> {
136            std::ops::Bound::Included(&self.1)
137        }
138    }
139
140    struct RangeEU(usize);
141    impl RangeBounds<usize> for RangeEU {
142        fn start_bound(&self) -> std::ops::Bound<&usize> {
143            std::ops::Bound::Excluded(&self.0)
144        }
145        fn end_bound(&self) -> std::ops::Bound<&usize> {
146            std::ops::Bound::Unbounded
147        }
148    }
149
150    #[test]
151    fn unsigned_bit_index() {
152        let byte: u8 = 90;
153
154        asserting!("bit() looks up the correct bit")
155            .that(&[byte.bit(2), byte.bit(3), byte.bit(4), byte.bit(5)])
156            .is_equal_to(&[false, true, true, false]);
157
158        asserting!("bits(RangeFull) returns the whole value")
159            .that(&byte.bits(..))
160            .is_equal_to(90);
161
162        asserting!("bits(RangeTo) excludes the end bit")
163            .that(&byte.bits(..4))
164            .is_equal_to(10);
165
166        asserting!("bits(RangeToInclusive) includes the end bit")
167            .that(&byte.bits(..=4))
168            .is_equal_to(26);
169
170        asserting!("bits(RangeFrom) includes the start bit")
171            .that(&byte.bits(4..))
172            .is_equal_to(5);
173
174        asserting!("bits(Range) includes the start bit")
175            .that(&byte.bits(4..8))
176            .is_equal_to(5);
177        asserting!("bits(Range) excludes the end bit")
178            .that(&byte.bits(0..4))
179            .is_equal_to(10);
180
181        asserting!("bits(RangeInclusive) includes the start bit")
182            .that(&byte.bits(4..=7))
183            .is_equal_to(5);
184        asserting!("bits(RangeInclusive) includes the end bit")
185            .that(&byte.bits(0..=4))
186            .is_equal_to(26);
187
188        asserting!("bits(RangeEU) excludes the start bit")
189            .that(&byte.bits(RangeEU(4)))
190            .is_equal_to(2);
191
192        asserting!("bits(RangeEE) excludes the start bit")
193            .that(&byte.bits(RangeEE(4, 8)))
194            .is_equal_to(2);
195        asserting!("bits(RangeEE) excludes the end bit")
196            .that(&byte.bits(RangeEE(0, 4)))
197            .is_equal_to(5);
198
199        asserting!("bits(RangeEI) excludes the start bit")
200            .that(&byte.bits(RangeEI(4, 7)))
201            .is_equal_to(2);
202        asserting!("bits(RangeEI) includes the end bit")
203            .that(&byte.bits(RangeEI(0, 4)))
204            .is_equal_to(13);
205    }
206
207    #[test]
208    fn unsigned_extra_high_bits() {
209        let byte: u8 = 90;
210
211        asserting!("bit() returns 0 when indexing past the last bit")
212            .that(&[byte.bit(8), byte.bit(9), byte.bit(10)])
213            .is_equal_to(&[false, false, false]);
214
215        asserting!("bits(RangeTo) can index past the last bit")
216            .that(&byte.bits(..16))
217            .is_equal_to(90);
218
219        asserting!("bits(RangeToInclusive) can index past the last bit")
220            .that(&byte.bits(..=15))
221            .is_equal_to(90);
222
223        asserting!("bits(RangeFrom) can index past the last bit")
224            .that(&byte.bits(4..))
225            .is_equal_to(5);
226        asserting!("bits(RangeFrom) is 0 when completely past the last bit")
227            .that(&byte.bits(8..))
228            .is_equal_to(0);
229
230        asserting!("bits(Range) can index past the last bit")
231            .that(&byte.bits(4..16))
232            .is_equal_to(5);
233        asserting!("bits(Range) is 0 when completely past the last bit")
234            .that(&byte.bits(8..16))
235            .is_equal_to(0);
236
237        asserting!("bits(RangeInclusive) can index past the last bit")
238            .that(&byte.bits(4..=15))
239            .is_equal_to(5);
240        asserting!("bits(RangeInclusive) is 0 when completely past the last bit")
241            .that(&byte.bits(8..=15))
242            .is_equal_to(0);
243
244        asserting!("bits(RangeEU) can index past the last bit")
245            .that(&byte.bits(RangeEU(4)))
246            .is_equal_to(2);
247        asserting!("bits(RangeEU) is 0 when completely past the last bit")
248            .that(&byte.bits(RangeEU(8)))
249            .is_equal_to(0);
250
251        asserting!("bits(RangeEE) can index past the last bit")
252            .that(&byte.bits(RangeEE(4, 16)))
253            .is_equal_to(2);
254        asserting!("bits(RangeEE) is 0 when completely past the last bit")
255            .that(&byte.bits(RangeEE(8, 16)))
256            .is_equal_to(0);
257
258        asserting!("bits(RangeEI) can index past the last bit")
259            .that(&byte.bits(RangeEI(4, 15)))
260            .is_equal_to(2);
261        asserting!("bits(RangeEI) is 0 when completely past the last bit")
262            .that(&byte.bits(RangeEI(8, 15)))
263            .is_equal_to(0);
264    }
265
266    #[test]
267    fn signed_bit_index() {
268        let byte: i8 = -90;
269
270        asserting!("bit() looks up the correct bit")
271            .that(&[byte.bit(2), byte.bit(3), byte.bit(4), byte.bit(5)])
272            .is_equal_to(&[true, false, false, true]);
273
274        asserting!("bits(Range) is equal to the equivalent shift and mask")
275            .that(&byte.bits(2..6))
276            .is_equal_to(byte >> 2 & 0xf);
277        asserting!("bits(RangeFrom) is equal to the equivalent shift and mask")
278            .that(&byte.bits(2..))
279            .is_equal_to(byte >> 2);
280
281        asserting!("bits(RangeFull) returns the whole value")
282            .that(&byte.bits(..))
283            .is_equal_to(-90);
284
285        asserting!("bits(RangeTo) excludes the end bit")
286            .that(&byte.bits(..5))
287            .is_equal_to(6);
288
289        asserting!("bits(RangeToInclusive) includes the end bit")
290            .that(&byte.bits(..=5))
291            .is_equal_to(38);
292
293        asserting!("bits(RangeFrom) includes the start bit")
294            .that(&byte.bits(4..))
295            .is_equal_to(-6);
296
297        asserting!("bits(Range) includes the start bit")
298            .that(&byte.bits(4..8))
299            .is_equal_to(10);
300        asserting!("bits(Range) excludes the end bit")
301            .that(&byte.bits(0..5))
302            .is_equal_to(6);
303
304        asserting!("bits(RangeInclusive) includes the start bit")
305            .that(&byte.bits(4..=7))
306            .is_equal_to(10);
307        asserting!("bits(RangeInclusive) includes the end bit")
308            .that(&byte.bits(0..=5))
309            .is_equal_to(38);
310
311        asserting!("bits(RangeEU) excludes the start bit")
312            .that(&byte.bits(RangeEU(4)))
313            .is_equal_to(-3);
314
315        asserting!("bits(RangeEE) excludes the start bit")
316            .that(&byte.bits(RangeEE(4, 8)))
317            .is_equal_to(5);
318        asserting!("bits(RangeEE) excludes the end bit")
319            .that(&byte.bits(RangeEE(0, 2)))
320            .is_equal_to(1);
321
322        asserting!("bits(RangeEI) excludes the start bit")
323            .that(&byte.bits(RangeEI(2, 4)))
324            .is_equal_to(0);
325        asserting!("bits(RangeEI) includes the end bit")
326            .that(&byte.bits(RangeEI(2, 5)))
327            .is_equal_to(4);
328    }
329
330    #[test]
331    fn signed_extra_high_bits() {
332        let byte: i8 = -90;
333
334        asserting!("bit() returns 1 when indexing past the last bit")
335            .that(&[byte.bit(8), byte.bit(9), byte.bit(10)])
336            .is_equal_to(&[true, true, true]);
337
338        asserting!("bits(RangeTo) can index past the last bit")
339            .that(&byte.bits(..16))
340            .is_equal_to(-90);
341
342        asserting!("bits(RangeToInclusive) can index past the last bit")
343            .that(&byte.bits(..=15))
344            .is_equal_to(-90);
345
346        asserting!("bits(RangeFrom) can index past the last bit")
347            .that(&byte.bits(4..))
348            .is_equal_to(-6);
349        asserting!("bits(RangeFrom) is -1 (0xff) when completely past the last bit")
350            .that(&byte.bits(8..))
351            .is_equal_to(-1);
352
353        asserting!("bits(Range) can index past the last bit")
354            .that(&byte.bits(4..16))
355            .is_equal_to(-6);
356        asserting!("bits(Range) is -1 (0xff) when completely past the last bit")
357            .that(&byte.bits(8..16))
358            .is_equal_to(-1);
359
360        asserting!("bits(RangeInclusive) can index past the last bit")
361            .that(&byte.bits(4..=15))
362            .is_equal_to(-6);
363        asserting!("bits(RangeInclusive) is -1 (0xff) when completely past the last bit")
364            .that(&byte.bits(8..=15))
365            .is_equal_to(-1);
366
367        asserting!("bits(RangeEU) can index past the last bit")
368            .that(&byte.bits(RangeEU(4)))
369            .is_equal_to(-3);
370        asserting!("bits(RangeEU) is -1 (0xff) when completely past the last bit")
371            .that(&byte.bits(RangeEU(8)))
372            .is_equal_to(-1);
373
374        asserting!("bits(RangeEE) can index past the last bit")
375            .that(&byte.bits(RangeEE(4, 16)))
376            .is_equal_to(-3);
377        asserting!("bits(RangeEE) is -1 (0xff) when completely past the last bit")
378            .that(&byte.bits(RangeEE(8, 17)))
379            .is_equal_to(-1);
380
381        asserting!("bits(RangeEI) can index past the last bit")
382            .that(&byte.bits(RangeEI(4, 15)))
383            .is_equal_to(-3);
384        asserting!("bits(RangeEI) is -1 (0xff) when completely past the last bit")
385            .that(&byte.bits(RangeEI(8, 16)))
386            .is_equal_to(-1);
387    }
388}