1#![allow(unused_comparisons)]
2
3use crate::{BitMask, BitSize};
4use std::ops::RangeBounds;
5
6pub trait BitIndex: BitSize + BitMask {
30 fn bit(&self, index: usize) -> bool;
32
33 fn bits<Idx: RangeBounds<usize>>(&self, index: Idx) -> Self;
35}
36
37macro_rules! bit_index_impl {
38 ($type:ty) => {
39 impl BitIndex for $type {
40 fn bit(&self, index: usize) -> bool {
41 self.checked_shr(index as _)
42 .unwrap_or_else(|| if *self < 0 { 1 } else { 0 })
43 & 1
44 == 1
45 }
46
47 fn bits<Idx: RangeBounds<usize>>(&self, index: Idx) -> Self {
48 let mask = match (index.start_bound(), index.end_bound()) {
49 (::std::ops::Bound::Excluded(se), ::std::ops::Bound::Excluded(ee)) => {
50 Some(*ee - *se - 1)
51 }
52 (::std::ops::Bound::Excluded(se), ::std::ops::Bound::Included(ee)) => {
53 Some(*ee - *se)
54 }
55 (::std::ops::Bound::Excluded(_), ::std::ops::Bound::Unbounded) => None,
56 (::std::ops::Bound::Included(si), ::std::ops::Bound::Excluded(ee)) => {
57 Some(*ee - *si)
58 }
59 (::std::ops::Bound::Included(si), ::std::ops::Bound::Included(ei)) => {
60 Some(*ei + 1 - *si)
61 }
62 (::std::ops::Bound::Included(_), ::std::ops::Bound::Unbounded) => None,
63 (::std::ops::Bound::Unbounded, ::std::ops::Bound::Excluded(ee)) => Some(*ee),
64 (::std::ops::Bound::Unbounded, ::std::ops::Bound::Included(ei)) => {
65 Some(*ei + 1)
66 }
67 (::std::ops::Bound::Unbounded, ::std::ops::Bound::Unbounded) => None,
68 };
69
70 let shift = match index.start_bound() {
71 ::std::ops::Bound::Excluded(e) => Some(*e + 1),
72 ::std::ops::Bound::Included(i) => Some(*i),
73 ::std::ops::Bound::Unbounded => Some(0),
74 };
75
76 match (shift, mask) {
77 (Some(s), Some(m)) => self
78 .checked_shr(s as _)
79 .unwrap_or_else(|| {
80 if *self < 0 {
81 (0 as Self).wrapping_sub(1)
82 } else {
83 0
84 }
85 })
86 .mask_to(m),
87 (Some(s), None) => self.checked_shr(s as _).unwrap_or_else(|| {
88 if *self < 0 {
89 (0 as Self).wrapping_sub(1)
90 } else {
91 0
92 }
93 }),
94 (None, Some(m)) => self.mask_to(m),
95 (None, None) => *self,
96 }
97 }
98 }
99 };
100}
101
102bit_index_impl!(u8);
103bit_index_impl!(u16);
104bit_index_impl!(u32);
105bit_index_impl!(u64);
106bit_index_impl!(u128);
107bit_index_impl!(usize);
108bit_index_impl!(i8);
109bit_index_impl!(i16);
110bit_index_impl!(i32);
111bit_index_impl!(i64);
112bit_index_impl!(i128);
113bit_index_impl!(isize);
114
115#[cfg(test)]
116mod test {
117 use super::*;
118 use spectral::prelude::*;
119
120 struct RangeEE(usize, usize);
121 impl RangeBounds<usize> for RangeEE {
122 fn start_bound(&self) -> std::ops::Bound<&usize> {
123 std::ops::Bound::Excluded(&self.0)
124 }
125 fn end_bound(&self) -> std::ops::Bound<&usize> {
126 std::ops::Bound::Excluded(&self.1)
127 }
128 }
129
130 struct RangeEI(usize, usize);
131 impl RangeBounds<usize> for RangeEI {
132 fn start_bound(&self) -> std::ops::Bound<&usize> {
133 std::ops::Bound::Excluded(&self.0)
134 }
135 fn end_bound(&self) -> std::ops::Bound<&usize> {
136 std::ops::Bound::Included(&self.1)
137 }
138 }
139
140 struct RangeEU(usize);
141 impl RangeBounds<usize> for RangeEU {
142 fn start_bound(&self) -> std::ops::Bound<&usize> {
143 std::ops::Bound::Excluded(&self.0)
144 }
145 fn end_bound(&self) -> std::ops::Bound<&usize> {
146 std::ops::Bound::Unbounded
147 }
148 }
149
150 #[test]
151 fn unsigned_bit_index() {
152 let byte: u8 = 90;
153
154 asserting!("bit() looks up the correct bit")
155 .that(&[byte.bit(2), byte.bit(3), byte.bit(4), byte.bit(5)])
156 .is_equal_to(&[false, true, true, false]);
157
158 asserting!("bits(RangeFull) returns the whole value")
159 .that(&byte.bits(..))
160 .is_equal_to(90);
161
162 asserting!("bits(RangeTo) excludes the end bit")
163 .that(&byte.bits(..4))
164 .is_equal_to(10);
165
166 asserting!("bits(RangeToInclusive) includes the end bit")
167 .that(&byte.bits(..=4))
168 .is_equal_to(26);
169
170 asserting!("bits(RangeFrom) includes the start bit")
171 .that(&byte.bits(4..))
172 .is_equal_to(5);
173
174 asserting!("bits(Range) includes the start bit")
175 .that(&byte.bits(4..8))
176 .is_equal_to(5);
177 asserting!("bits(Range) excludes the end bit")
178 .that(&byte.bits(0..4))
179 .is_equal_to(10);
180
181 asserting!("bits(RangeInclusive) includes the start bit")
182 .that(&byte.bits(4..=7))
183 .is_equal_to(5);
184 asserting!("bits(RangeInclusive) includes the end bit")
185 .that(&byte.bits(0..=4))
186 .is_equal_to(26);
187
188 asserting!("bits(RangeEU) excludes the start bit")
189 .that(&byte.bits(RangeEU(4)))
190 .is_equal_to(2);
191
192 asserting!("bits(RangeEE) excludes the start bit")
193 .that(&byte.bits(RangeEE(4, 8)))
194 .is_equal_to(2);
195 asserting!("bits(RangeEE) excludes the end bit")
196 .that(&byte.bits(RangeEE(0, 4)))
197 .is_equal_to(5);
198
199 asserting!("bits(RangeEI) excludes the start bit")
200 .that(&byte.bits(RangeEI(4, 7)))
201 .is_equal_to(2);
202 asserting!("bits(RangeEI) includes the end bit")
203 .that(&byte.bits(RangeEI(0, 4)))
204 .is_equal_to(13);
205 }
206
207 #[test]
208 fn unsigned_extra_high_bits() {
209 let byte: u8 = 90;
210
211 asserting!("bit() returns 0 when indexing past the last bit")
212 .that(&[byte.bit(8), byte.bit(9), byte.bit(10)])
213 .is_equal_to(&[false, false, false]);
214
215 asserting!("bits(RangeTo) can index past the last bit")
216 .that(&byte.bits(..16))
217 .is_equal_to(90);
218
219 asserting!("bits(RangeToInclusive) can index past the last bit")
220 .that(&byte.bits(..=15))
221 .is_equal_to(90);
222
223 asserting!("bits(RangeFrom) can index past the last bit")
224 .that(&byte.bits(4..))
225 .is_equal_to(5);
226 asserting!("bits(RangeFrom) is 0 when completely past the last bit")
227 .that(&byte.bits(8..))
228 .is_equal_to(0);
229
230 asserting!("bits(Range) can index past the last bit")
231 .that(&byte.bits(4..16))
232 .is_equal_to(5);
233 asserting!("bits(Range) is 0 when completely past the last bit")
234 .that(&byte.bits(8..16))
235 .is_equal_to(0);
236
237 asserting!("bits(RangeInclusive) can index past the last bit")
238 .that(&byte.bits(4..=15))
239 .is_equal_to(5);
240 asserting!("bits(RangeInclusive) is 0 when completely past the last bit")
241 .that(&byte.bits(8..=15))
242 .is_equal_to(0);
243
244 asserting!("bits(RangeEU) can index past the last bit")
245 .that(&byte.bits(RangeEU(4)))
246 .is_equal_to(2);
247 asserting!("bits(RangeEU) is 0 when completely past the last bit")
248 .that(&byte.bits(RangeEU(8)))
249 .is_equal_to(0);
250
251 asserting!("bits(RangeEE) can index past the last bit")
252 .that(&byte.bits(RangeEE(4, 16)))
253 .is_equal_to(2);
254 asserting!("bits(RangeEE) is 0 when completely past the last bit")
255 .that(&byte.bits(RangeEE(8, 16)))
256 .is_equal_to(0);
257
258 asserting!("bits(RangeEI) can index past the last bit")
259 .that(&byte.bits(RangeEI(4, 15)))
260 .is_equal_to(2);
261 asserting!("bits(RangeEI) is 0 when completely past the last bit")
262 .that(&byte.bits(RangeEI(8, 15)))
263 .is_equal_to(0);
264 }
265
266 #[test]
267 fn signed_bit_index() {
268 let byte: i8 = -90;
269
270 asserting!("bit() looks up the correct bit")
271 .that(&[byte.bit(2), byte.bit(3), byte.bit(4), byte.bit(5)])
272 .is_equal_to(&[true, false, false, true]);
273
274 asserting!("bits(Range) is equal to the equivalent shift and mask")
275 .that(&byte.bits(2..6))
276 .is_equal_to(byte >> 2 & 0xf);
277 asserting!("bits(RangeFrom) is equal to the equivalent shift and mask")
278 .that(&byte.bits(2..))
279 .is_equal_to(byte >> 2);
280
281 asserting!("bits(RangeFull) returns the whole value")
282 .that(&byte.bits(..))
283 .is_equal_to(-90);
284
285 asserting!("bits(RangeTo) excludes the end bit")
286 .that(&byte.bits(..5))
287 .is_equal_to(6);
288
289 asserting!("bits(RangeToInclusive) includes the end bit")
290 .that(&byte.bits(..=5))
291 .is_equal_to(38);
292
293 asserting!("bits(RangeFrom) includes the start bit")
294 .that(&byte.bits(4..))
295 .is_equal_to(-6);
296
297 asserting!("bits(Range) includes the start bit")
298 .that(&byte.bits(4..8))
299 .is_equal_to(10);
300 asserting!("bits(Range) excludes the end bit")
301 .that(&byte.bits(0..5))
302 .is_equal_to(6);
303
304 asserting!("bits(RangeInclusive) includes the start bit")
305 .that(&byte.bits(4..=7))
306 .is_equal_to(10);
307 asserting!("bits(RangeInclusive) includes the end bit")
308 .that(&byte.bits(0..=5))
309 .is_equal_to(38);
310
311 asserting!("bits(RangeEU) excludes the start bit")
312 .that(&byte.bits(RangeEU(4)))
313 .is_equal_to(-3);
314
315 asserting!("bits(RangeEE) excludes the start bit")
316 .that(&byte.bits(RangeEE(4, 8)))
317 .is_equal_to(5);
318 asserting!("bits(RangeEE) excludes the end bit")
319 .that(&byte.bits(RangeEE(0, 2)))
320 .is_equal_to(1);
321
322 asserting!("bits(RangeEI) excludes the start bit")
323 .that(&byte.bits(RangeEI(2, 4)))
324 .is_equal_to(0);
325 asserting!("bits(RangeEI) includes the end bit")
326 .that(&byte.bits(RangeEI(2, 5)))
327 .is_equal_to(4);
328 }
329
330 #[test]
331 fn signed_extra_high_bits() {
332 let byte: i8 = -90;
333
334 asserting!("bit() returns 1 when indexing past the last bit")
335 .that(&[byte.bit(8), byte.bit(9), byte.bit(10)])
336 .is_equal_to(&[true, true, true]);
337
338 asserting!("bits(RangeTo) can index past the last bit")
339 .that(&byte.bits(..16))
340 .is_equal_to(-90);
341
342 asserting!("bits(RangeToInclusive) can index past the last bit")
343 .that(&byte.bits(..=15))
344 .is_equal_to(-90);
345
346 asserting!("bits(RangeFrom) can index past the last bit")
347 .that(&byte.bits(4..))
348 .is_equal_to(-6);
349 asserting!("bits(RangeFrom) is -1 (0xff) when completely past the last bit")
350 .that(&byte.bits(8..))
351 .is_equal_to(-1);
352
353 asserting!("bits(Range) can index past the last bit")
354 .that(&byte.bits(4..16))
355 .is_equal_to(-6);
356 asserting!("bits(Range) is -1 (0xff) when completely past the last bit")
357 .that(&byte.bits(8..16))
358 .is_equal_to(-1);
359
360 asserting!("bits(RangeInclusive) can index past the last bit")
361 .that(&byte.bits(4..=15))
362 .is_equal_to(-6);
363 asserting!("bits(RangeInclusive) is -1 (0xff) when completely past the last bit")
364 .that(&byte.bits(8..=15))
365 .is_equal_to(-1);
366
367 asserting!("bits(RangeEU) can index past the last bit")
368 .that(&byte.bits(RangeEU(4)))
369 .is_equal_to(-3);
370 asserting!("bits(RangeEU) is -1 (0xff) when completely past the last bit")
371 .that(&byte.bits(RangeEU(8)))
372 .is_equal_to(-1);
373
374 asserting!("bits(RangeEE) can index past the last bit")
375 .that(&byte.bits(RangeEE(4, 16)))
376 .is_equal_to(-3);
377 asserting!("bits(RangeEE) is -1 (0xff) when completely past the last bit")
378 .that(&byte.bits(RangeEE(8, 17)))
379 .is_equal_to(-1);
380
381 asserting!("bits(RangeEI) can index past the last bit")
382 .that(&byte.bits(RangeEI(4, 15)))
383 .is_equal_to(-3);
384 asserting!("bits(RangeEI) is -1 (0xff) when completely past the last bit")
385 .that(&byte.bits(RangeEI(8, 16)))
386 .is_equal_to(-1);
387 }
388}