nostd_bv/
storage.rs

1use core::mem;
2use core::ops;
3
4/// Interface to primitive bit storage.
5///
6/// Types implementing this trait can be used as the blocks of a bit-vector.
7pub trait BlockType:
8    Copy
9    + PartialEq
10    + Ord
11    + ops::BitAnd<Output = Self>
12    + ops::BitOr<Output = Self>
13    + ops::BitXor<Output = Self>
14    + ops::Not<Output = Self>
15    + ops::Shl<usize, Output = Self>
16    + ops::Shr<usize, Output = Self>
17    + ops::Sub<Output = Self>
18{
19    /// The number of bits in a block.
20    #[inline]
21    fn nbits() -> usize {
22        8 * mem::size_of::<Self>()
23    }
24
25    /// Returns `index / Self::nbits()`, computed by shifting.
26    ///
27    /// This is intended for converting a bit address into a block
28    /// address, which is why it takes `u64` and returns `usize`.
29    /// There is no check that the result actually fits in a `usize`,
30    /// so this should only be used when `index` is already known to
31    /// be small enough.
32    #[inline]
33    fn div_nbits(index: u64) -> usize {
34        (index >> Self::lg_nbits()) as usize
35    }
36
37    /// Returns `index / Self::nbits()`, computed by shifting.
38    ///
39    /// This is intended for converting a bit address into a block
40    /// address, which is why it takes `u64` and returns `usize`. It can only fail (returning
41    /// `None`) if `usize` is 32 bits.
42    #[inline]
43    fn checked_div_nbits(index: u64) -> Option<usize> {
44        (index >> Self::lg_nbits()).to_usize()
45    }
46
47    /// Returns `index / Self::nbits()` rounded up, computed by shifting.
48    ///
49    /// This is intended for converting a bit size into a block
50    /// size, which is why it takes `u64` and returns `usize`.
51    #[inline]
52    fn ceil_div_nbits(index: u64) -> usize {
53        Self::div_nbits(index + (Self::nbits() as u64 - 1))
54    }
55
56    /// Returns `index / Self::nbits()` rounded up, computed by shifting.
57    ///
58    /// This is intended for converting a bit size into a block
59    /// size, which is why it takes `u64` and returns `usize`.
60    /// There is no check that the result actually fits in a `usize`,
61    /// so this should only be used when `index` is already known to
62    /// be small enough.
63    #[inline]
64    fn checked_ceil_div_nbits(index: u64) -> Option<usize> {
65        Self::checked_div_nbits(index + (Self::nbits() as u64 - 1))
66    }
67
68    /// Returns `index % Self::nbits()`, computed by masking.
69    ///
70    /// This is intended for converting a bit address into a bit offset
71    /// within a block, which is why it takes `u64` and returns `usize`.
72    #[inline]
73    fn mod_nbits(index: u64) -> usize {
74        let mask: u64 = Self::lg_nbits_mask();
75        (index & mask) as usize
76    }
77
78    /// Returns `index * Self::nbits()`, computed by shifting.
79    ///
80    /// This is intended for converting a block address into a bit address,
81    /// which is why it takes a `usize` and returns a `u64`.
82    #[inline]
83    fn mul_nbits(index: usize) -> u64 {
84        (index as u64) << Self::lg_nbits()
85    }
86
87    /// The number of bits in the block at `position`, given a total bit length
88    /// of `len`.
89    ///
90    /// This will be `Self::nbits()` for all but the last block, for which it may
91    /// be less.
92    ///
93    /// # Precondition
94    ///
95    /// `position * Self::nbits() <= len`, or the block doesn't exist and the result
96    /// is undefined.
97    #[inline]
98    fn block_bits(len: u64, position: usize) -> usize {
99        let block_start = Self::mul_nbits(position);
100        let block_limit = block_start + Self::nbits() as u64;
101
102        debug_assert!(block_start <= len, "BlockType::block_bits: precondition");
103
104        usize::if_then_else(
105            block_limit <= len,
106            Self::nbits(),
107            len.wrapping_sub(block_start) as usize,
108        )
109    }
110
111    /// Log-base-2 of the number of bits in a block.
112    #[inline]
113    fn lg_nbits() -> usize {
114        Self::nbits().floor_lg()
115    }
116
117    /// Mask with the lowest-order `lg_nbits()` set.
118    #[inline]
119    fn lg_nbits_mask<Result: BlockType>() -> Result {
120        Result::low_mask(Self::lg_nbits())
121    }
122
123    /// The bit mask consisting of `Self::nbits() - element_bits` zeroes
124    /// followed by `element_bits` ones.
125    ///
126    /// The default implementation has a branch, but should be overrided with
127    /// a branchless algorithm if possible.
128    ///
129    /// # Precondition
130    ///
131    /// `element_bits <= Self::nbits()`
132    #[inline]
133    fn low_mask(element_bits: usize) -> Self {
134        debug_assert!(element_bits <= Self::nbits());
135
136        if element_bits == Self::nbits() {
137            !Self::zero()
138        } else {
139            (Self::one() << element_bits) - Self::one()
140        }
141    }
142
143    /// The bit mask with the `bit_index`th bit set.
144    ///
145    /// Bits are indexed in little-endian style based at 0.
146    ///
147    /// # Precondition
148    ///
149    /// `bit_index < Self::nbits()`
150    #[inline]
151    fn nth_mask(bit_index: usize) -> Self {
152        Self::one() << bit_index
153    }
154
155    // Methods for getting and setting bits.
156
157    /// Extracts the value of the `bit_index`th bit.
158    ///
159    /// # Panics
160    ///
161    /// Panics if `bit_index` is out of bounds.
162    #[inline]
163    fn get_bit(self, bit_index: usize) -> bool {
164        assert!(bit_index < Self::nbits(), "Block::get_bit: out of bounds");
165        self & Self::nth_mask(bit_index) != Self::zero()
166    }
167
168    /// Functionally updates the value of the `bit_index`th bit to `bit_value`.
169    ///
170    /// # Panics
171    ///
172    /// Panics if `bit_index` is out of bounds.
173    #[inline]
174    fn with_bit(self, bit_index: usize, bit_value: bool) -> Self {
175        assert!(bit_index < Self::nbits(), "Block::with_bit: out of bounds");
176        if bit_value {
177            self | Self::nth_mask(bit_index)
178        } else {
179            self & !Self::nth_mask(bit_index)
180        }
181    }
182
183    /// Extracts `len` bits starting at bit offset `start`.
184    ///
185    /// # Panics
186    ///
187    /// Panics of the bit span is out of bounds.
188    #[inline]
189    fn get_bits(self, start: usize, len: usize) -> Self {
190        assert!(
191            start + len <= Self::nbits(),
192            "Block::get_bits: out of bounds"
193        );
194
195        (self >> start) & Self::low_mask(len)
196    }
197
198    /// Functionally updates `len` bits to `value` starting at offset `start`.
199    ///
200    /// # Panics
201    ///
202    /// Panics of the bit span is out of bounds.
203    #[inline]
204    fn with_bits(self, start: usize, len: usize, value: Self) -> Self {
205        assert!(
206            start + len <= Self::nbits(),
207            "Block::with_bits: out of bounds"
208        );
209
210        let mask = Self::low_mask(len) << start;
211        let shifted_value = value << start;
212
213        (self & !mask) | (shifted_value & mask)
214    }
215
216    /// Returns the smallest number `n` such that `2.pow(n) >= self`.
217    #[inline]
218    fn ceil_lg(self) -> usize {
219        usize::if_then(
220            self > Self::one(),
221            Self::nbits().wrapping_sub((self.wrapping_sub(Self::one())).leading_zeros()),
222        )
223    }
224
225    /// Returns the largest number `n` such that `2.pow(n) <= self`.
226    #[inline]
227    fn floor_lg(self) -> usize {
228        usize::if_then(
229            self > Self::one(),
230            Self::nbits()
231                .wrapping_sub(1)
232                .wrapping_sub(self.leading_zeros()),
233        )
234    }
235
236    /// A shift-left operation that does not overflow.
237    fn wrapping_shl(self, shift: u32) -> Self;
238
239    /// A subtraction operation that does not overflow.
240    fn wrapping_sub(self, other: Self) -> Self;
241
242    /// Returns the number of leading zero bits in the given number.
243    fn leading_zeros(self) -> usize;
244
245    /// Converts the number to a `usize`, if it fits.
246    fn to_usize(self) -> Option<usize>;
247
248    /// Returns 0.
249    fn zero() -> Self;
250
251    /// Returns 1.
252    fn one() -> Self;
253}
254
255trait IfThenElse {
256    fn if_then_else(cond: bool, then_val: Self, else_val: Self) -> Self;
257    fn if_then(cond: bool, then_val: Self) -> Self;
258}
259
260macro_rules! impl_block_type {
261    ( $ty:ident ) => {
262        impl IfThenElse for $ty {
263            #[inline]
264            fn if_then_else(cond: bool, then_val: Self, else_val: Self) -> Self {
265                let then_cond = cond as Self;
266                let else_cond = 1 - then_cond;
267                (then_cond * then_val) | (else_cond * else_val)
268            }
269
270            #[inline]
271            fn if_then(cond: bool, then_val: Self) -> Self {
272                (cond as Self) * then_val
273            }
274        }
275
276        impl BlockType for $ty {
277            // The default `low_mask` has a branch, but we can do better if we have
278            // `wrapping_shl`. That isn't a member of any trait, but all the primitive
279            // numeric types have it, so we can override low_mask in this macro.
280            #[inline]
281            fn low_mask(k: usize) -> Self {
282                debug_assert!(k <= Self::nbits());
283
284                // Compute the mask when element_bits is not the word size:
285                let a = Self::one().wrapping_shl(k as u32).wrapping_sub(1);
286
287                // Special case for the word size:
288                let b = (Self::div_nbits(k as u64) & 1) as Self * !0;
289
290                a | b
291            }
292
293            #[inline]
294            fn wrapping_shl(self, shift: u32) -> Self {
295                self.wrapping_shl(shift)
296            }
297
298            #[inline]
299            fn wrapping_sub(self, other: Self) -> Self {
300                self.wrapping_sub(other)
301            }
302
303            #[inline]
304            fn leading_zeros(self) -> usize {
305                self.leading_zeros() as usize
306            }
307
308            #[inline]
309            fn to_usize(self) -> Option<usize> {
310                if self as usize as Self == self {
311                    Some(self as usize)
312                } else {
313                    None
314                }
315            }
316
317            #[inline]
318            fn zero() -> Self {
319                0
320            }
321
322            #[inline]
323            fn one() -> Self {
324                1
325            }
326        }
327    };
328}
329
330impl_block_type!(u8);
331impl_block_type!(u16);
332impl_block_type!(u32);
333impl_block_type!(u64);
334impl_block_type!(u128);
335impl_block_type!(usize);
336
337/// Represents the address of a bit, broken into a block component
338/// and a bit offset component.
339#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
340pub struct Address {
341    /// The index of the block containing the bit in question.
342    pub block_index: usize,
343    /// The position of the bit in question within its block.
344    pub bit_offset: usize,
345}
346
347impl Address {
348    /// Creates an `Address` for the given bit index for storage in
349    /// block type `Block`.
350    ///
351    /// # Panics
352    ///
353    /// Panics if `bit_index` divided by the block size doesn’t fit in a
354    /// `usize`.
355    #[inline]
356    pub fn new<Block: BlockType>(bit_index: u64) -> Self {
357        Address {
358            block_index: Block::checked_div_nbits(bit_index).expect("Address::new: index overflow"),
359            bit_offset: Block::mod_nbits(bit_index),
360        }
361    }
362
363    //    /// Converts an `Address` back into a raw bit index.
364    //    ///
365    //    /// This method and `new` should be inverses.
366    //    #[inline]
367    //    pub fn bit_index<Block: BlockType>(&self) -> u64 {
368    //        Block::mul_nbits(self.block_index) + self.bit_offset as u64
369    //    }
370}
371
372#[cfg(test)]
373mod test {
374    use super::*;
375    use quickcheck::{quickcheck, TestResult};
376
377    #[test]
378    fn nbits() {
379        assert_eq!(8, u8::nbits());
380        assert_eq!(16, u16::nbits());
381        assert_eq!(32, u32::nbits());
382        assert_eq!(64, u64::nbits());
383    }
384
385    quickcheck! {
386        fn prop_div_nbits(n: u32) -> bool {
387            u32::div_nbits(n as u64) == (n / 32) as usize
388        }
389
390        fn prop_ceil_div_nbits1(n: u32) -> bool {
391            u32::ceil_div_nbits(n as u64) ==
392                (n as f32 / 32.0f32).ceil() as usize
393        }
394
395        fn prop_ceil_div_nbits2(n: u32) -> bool {
396            let result = u32::ceil_div_nbits(n as u64);
397            result * 32 >= n as usize &&
398                (result == 0 || (result - 1) * 32 < n as usize)
399        }
400
401        fn prop_mod_nbits(n: u32) -> bool {
402            u32::mod_nbits(n as u64) == n as usize % 32
403        }
404
405        fn prop_mul_nbits(n: u32) -> bool {
406            u32::mul_nbits(n as usize) == n as u64 * 32
407        }
408    }
409
410    #[test]
411    fn lg_nbits() {
412        assert_eq!(u8::lg_nbits(), 3);
413        assert_eq!(u16::lg_nbits(), 4);
414        assert_eq!(u32::lg_nbits(), 5);
415        assert_eq!(u64::lg_nbits(), 6);
416    }
417
418    #[test]
419    fn low_mask() {
420        assert_eq!(0b00011111, u8::low_mask(5));
421        assert_eq!(0b0011111111111111, u16::low_mask(14));
422        assert_eq!(0b1111111111111111, u16::low_mask(16));
423    }
424
425    #[test]
426    fn nth_mask() {
427        assert_eq!(0b10000000, u8::nth_mask(7));
428        assert_eq!(0b01000000, u8::nth_mask(6));
429        assert_eq!(0b00100000, u8::nth_mask(5));
430        assert_eq!(0b00000010, u8::nth_mask(1));
431        assert_eq!(0b00000001, u8::nth_mask(0));
432
433        assert_eq!(0b0000000000000001, u16::nth_mask(0));
434        assert_eq!(0b1000000000000000, u16::nth_mask(15));
435    }
436
437    #[test]
438    fn get_bits() {
439        assert_eq!(0b0, 0b0100110001110000u16.get_bits(0, 0));
440        assert_eq!(0b010, 0b0100110001110000u16.get_bits(13, 3));
441        assert_eq!(0b110001, 0b0100110001110000u16.get_bits(6, 6));
442        assert_eq!(0b10000, 0b0100110001110000u16.get_bits(0, 5));
443        assert_eq!(0b0100110001110000, 0b0100110001110000u16.get_bits(0, 16));
444    }
445
446    #[test]
447    fn with_bits() {
448        assert_eq!(
449            0b0111111111000001,
450            0b0110001111000001u16.with_bits(10, 3, 0b111)
451        );
452        assert_eq!(
453            0b0101110111000001,
454            0b0110001111000001u16.with_bits(9, 5, 0b01110)
455        );
456        assert_eq!(
457            0b0110001111000001,
458            0b0110001111000001u16.with_bits(14, 0, 0b01110)
459        );
460        assert_eq!(
461            0b0110001110101010,
462            0b0110001111000001u16.with_bits(0, 8, 0b10101010)
463        );
464        assert_eq!(
465            0b0000000000000010,
466            0b0110001111000001u16.with_bits(0, 16, 0b10)
467        );
468    }
469
470    #[test]
471    fn get_bit() {
472        assert!(!0b00000000u8.get_bit(0));
473        assert!(!0b00000000u8.get_bit(1));
474        assert!(!0b00000000u8.get_bit(2));
475        assert!(!0b00000000u8.get_bit(3));
476        assert!(!0b00000000u8.get_bit(7));
477        assert!(!0b10101010u8.get_bit(0));
478        assert!(0b10101010u8.get_bit(1));
479        assert!(!0b10101010u8.get_bit(2));
480        assert!(0b10101010u8.get_bit(3));
481        assert!(0b10101010u8.get_bit(7));
482    }
483
484    #[test]
485    fn with_bit() {
486        assert_eq!(0b00100000, 0b00000000u8.with_bit(5, true));
487        assert_eq!(0b00000000, 0b00000000u8.with_bit(5, false));
488        assert_eq!(0b10101010, 0b10101010u8.with_bit(7, true));
489        assert_eq!(0b00101010, 0b10101010u8.with_bit(7, false));
490        assert_eq!(0b10101011, 0b10101010u8.with_bit(0, true));
491        assert_eq!(0b10101010, 0b10101010u8.with_bit(0, false));
492    }
493
494    #[test]
495    fn floor_lg() {
496        assert_eq!(0, 1u32.floor_lg());
497        assert_eq!(1, 2u32.floor_lg());
498        assert_eq!(1, 3u32.floor_lg());
499        assert_eq!(2, 4u32.floor_lg());
500        assert_eq!(2, 5u32.floor_lg());
501        assert_eq!(2, 7u32.floor_lg());
502        assert_eq!(3, 8u32.floor_lg());
503
504        fn prop(n: u64) -> TestResult {
505            if n == 0 {
506                return TestResult::discard();
507            }
508
509            TestResult::from_bool(
510                2u64.pow(n.floor_lg() as u32) <= n && 2u64.pow(n.floor_lg() as u32 + 1) > n,
511            )
512        }
513
514        quickcheck(prop as fn(u64) -> TestResult);
515    }
516
517    #[test]
518    fn ceil_lg() {
519        assert_eq!(0, 1u32.ceil_lg());
520        assert_eq!(1, 2u32.ceil_lg());
521        assert_eq!(2, 3u32.ceil_lg());
522        assert_eq!(2, 4u32.ceil_lg());
523        assert_eq!(3, 5u32.ceil_lg());
524        assert_eq!(3, 7u32.ceil_lg());
525        assert_eq!(3, 8u32.ceil_lg());
526        assert_eq!(4, 9u32.ceil_lg());
527
528        fn prop(n: u64) -> TestResult {
529            if n <= 1 {
530                return TestResult::discard();
531            }
532
533            TestResult::from_bool(
534                2u64.pow(n.ceil_lg() as u32) >= n && 2u64.pow(n.ceil_lg() as u32 - 1) < n,
535            )
536        }
537        quickcheck(prop as fn(u64) -> TestResult);
538    }
539
540    #[test]
541    fn block_bits() {
542        assert_eq!(u16::block_bits(1, 0), 1);
543        assert_eq!(u16::block_bits(2, 0), 2);
544        assert_eq!(u16::block_bits(16, 0), 16);
545        assert_eq!(u16::block_bits(16, 1), 0); // boundary condition
546        assert_eq!(u16::block_bits(23, 0), 16);
547        assert_eq!(u16::block_bits(23, 1), 7);
548        assert_eq!(u16::block_bits(35, 0), 16);
549        assert_eq!(u16::block_bits(35, 1), 16);
550        assert_eq!(u16::block_bits(35, 2), 3);
551        assert_eq!(u16::block_bits(48, 0), 16);
552        assert_eq!(u16::block_bits(48, 1), 16);
553        assert_eq!(u16::block_bits(48, 2), 16);
554        assert_eq!(u16::block_bits(48, 3), 0); // boundary condition
555    }
556}