entropy_map/
rank.rs

1//! `RankedBits` efficiently handles rank queries on bit vectors.
2//! Optimized for minimal memory usage with ~3.125% overhead and fast lookups, it supports the
3//! crate's focus on low-latency hash maps. For detailed methodology, refer to the related paper:
4//! [Engineering Compact Data Structures for Rank and Select Queries on Bit Vectors](https://arxiv.org/pdf/2206.01149.pdf).
5
6use std::mem::size_of_val;
7
8/// Size of the L2 block in bits.
9const L2_BIT_SIZE: usize = 512;
10/// Size of the L1 block in bits, calculated as a multiple of the L2 block size.
11const L1_BIT_SIZE: usize = 8 * L2_BIT_SIZE;
12
13/// Trait for efficient bit-level operations on ranked bit sequences.
14///
15/// This trait is designed to provide consistent methods for accessing ranked bit sequences in both
16/// their standard and `Archived` formats (utilizing the `rkyv` library).
17pub trait RankedBitsAccess {
18    /// Returns the number of set bits up to `idx`, or `None` if the bit at `idx` is not set.
19    fn rank(&self, idx: usize) -> Option<usize>;
20
21    /// Inner implementation of `rank` with `bits` and `l12_ranks` passed from different implementations.
22    ///
23    /// # Safety
24    /// This method is unsafe because `idx` must be within the bounds of the bits stored in `RankedBitsAccess`.
25    /// An index out of bounds can lead to undefined behavior.
26    #[inline]
27    unsafe fn rank_impl<T: L12RankAccess>(bits: &[u64], l12_ranks: &T, idx: usize) -> Option<usize> {
28        let word_idx = idx / 64;
29        let bit_idx = idx % 64;
30        let word = *bits.get_unchecked(word_idx);
31
32        if (word & (1u64 << bit_idx)) == 0 {
33            return None;
34        }
35
36        let l1_pos = idx / L1_BIT_SIZE;
37        let l2_pos = (idx % L1_BIT_SIZE) / L2_BIT_SIZE;
38
39        let idx_within_l2 = idx % L2_BIT_SIZE;
40        let blocks_num = idx_within_l2 / 64;
41        let offset = (idx / L2_BIT_SIZE) * 8;
42        let block = bits.get_unchecked(offset..offset + blocks_num);
43
44        let block_rank = block.iter().map(|&x| x.count_ones() as usize).sum::<usize>();
45
46        let word = *bits.get_unchecked(offset + blocks_num);
47        let word_mask = ((1u64 << (idx_within_l2 % 64)) - 1) * (idx_within_l2 > 0) as u64;
48        let word_rank = (word & word_mask).count_ones() as usize;
49
50        let (l1_rank, l2_rank) = l12_ranks.l12_ranks(l1_pos, l2_pos);
51        let total_rank = l1_rank + l2_rank + block_rank + word_rank;
52
53        Some(total_rank)
54    }
55}
56
57#[derive(Debug, Default)]
58#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
59#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
60pub struct RankedBits {
61    /// The bit vector represented as an array of u64 integers.
62    bits: Box<[u64]>,
63    /// Precomputed rank information for L1 and L2 blocks.
64    l12_ranks: Box<[L12Rank]>,
65}
66
67/// L12Rank represents l1 and l2 bit ranks stored inside 16 bytes (little endian).
68/// NB: it's important to use `[u8; 16]` instead of `u128` for `rkyv` versions 0.7.X
69/// because of alignment differences between `x86_64` and `aarch64` architectures.
70/// See https://github.com/rkyv/rkyv/issues/409 for more details.
71#[derive(Debug)]
72#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
73#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
74pub struct L12Rank([u8; 16]);
75
76/// Trait used to access archived and non-archived L1 and L2 ranks
77pub trait L12RankAccess {
78    /// Return `L12Rank` as `u128`
79    fn l12_rank(&self, l1_pos: usize) -> u128;
80
81    /// Return `l1_rank` and `l2_rank`
82    #[inline]
83    fn l12_ranks(&self, l1_pos: usize, l2_pos: usize) -> (usize, usize) {
84        let l12_rank = self.l12_rank(l1_pos);
85        let l1_rank = (l12_rank & 0xFFFFFFFFFFF) as usize;
86        let l2_rank = ((l12_rank >> (32 + 12 * l2_pos)) & 0xFFF) as usize;
87        (l1_rank, l2_rank)
88    }
89}
90
91impl L12RankAccess for Box<[L12Rank]> {
92    #[inline]
93    fn l12_rank(&self, l1_pos: usize) -> u128 {
94        u128::from_le_bytes(unsafe { self.get_unchecked(l1_pos).0 })
95    }
96}
97
98#[cfg(feature = "rkyv_derive")]
99impl L12RankAccess for rkyv::boxed::ArchivedBox<[ArchivedL12Rank]> {
100    #[inline]
101    fn l12_rank(&self, l1_pos: usize) -> u128 {
102        u128::from_le_bytes(unsafe { self.get_unchecked(l1_pos).0 })
103    }
104}
105
106impl From<u128> for L12Rank {
107    #[inline]
108    fn from(v: u128) -> Self {
109        L12Rank(v.to_le_bytes())
110    }
111}
112
113impl RankedBits {
114    /// Initializes `RankedBits` with a provided bit vector.
115    pub fn new(bits: Box<[u64]>) -> Self {
116        let blocks = bits.chunks_exact(64);
117        let remainder = blocks.remainder();
118        let mut l12_ranks = Vec::with_capacity(bits.len().div_ceil(64));
119        let mut l1_rank: u128 = 0;
120
121        for block64 in blocks {
122            let mut l12_rank = 0u128;
123            let mut sum = 0u16;
124            for (i, block8) in block64.chunks_exact(8).enumerate() {
125                sum += block8.iter().map(|&x| x.count_ones() as u16).sum::<u16>();
126                l12_rank += (sum as u128) << (i * 12);
127            }
128            l12_rank = (l12_rank << 44) | l1_rank;
129            l12_ranks.push(l12_rank.into());
130            l1_rank += sum as u128;
131        }
132
133        if !remainder.is_empty() {
134            let mut l12_rank = 0u128;
135            let mut sum = 0u16;
136            for (i, block) in remainder.chunks(8).enumerate() {
137                sum += block.iter().map(|&x| x.count_ones() as u16).sum::<u16>();
138                l12_rank += (sum as u128) << (i * 12);
139            }
140            l12_rank = (l12_rank << 44) | l1_rank;
141            l12_ranks.push(l12_rank.into());
142        }
143
144        RankedBits { bits, l12_ranks: l12_ranks.into_boxed_slice() }
145    }
146
147    /// Returns the total number of bytes occupied by `RankedBits`
148    pub fn size(&self) -> usize {
149        size_of_val(self) + size_of_val(self.bits.as_ref()) + size_of_val(self.l12_ranks.as_ref())
150    }
151}
152
153/// Implement `rank` for `Archived` version of `RankedBits` if feature is enabled
154impl RankedBitsAccess for RankedBits {
155    #[inline]
156    fn rank(&self, idx: usize) -> Option<usize> {
157        unsafe { Self::rank_impl(&self.bits, &self.l12_ranks, idx) }
158    }
159}
160
161/// Implement `rank` for `Archived` version of `RankedBits` if feature is enabled
162#[cfg(feature = "rkyv_derive")]
163impl RankedBitsAccess for ArchivedRankedBits {
164    #[inline]
165    fn rank(&self, idx: usize) -> Option<usize> {
166        unsafe { Self::rank_impl(&self.bits, &self.l12_ranks, idx) }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use bitvec::order::Lsb0;
174    use bitvec::vec::BitVec;
175    use rand::distributions::Standard;
176    use rand::Rng;
177
178    #[test]
179    fn test_rank_and_get() {
180        let bits = vec![
181            0b11001010, // 4 set bits
182            0b00110111, // 5 set bits
183            0b11110000, // 4 set bits
184        ];
185
186        let ranked_bits = RankedBits::new(bits.into_boxed_slice());
187        assert_eq!(ranked_bits.rank(0), None); // No set bits before the first
188        assert_eq!(ranked_bits.rank(7), Some(3)); // 3 set bits set before 7-th bit
189    }
190
191    #[test]
192    fn test_random_bits() {
193        let rng = rand::thread_rng();
194        let bits: Vec<u64> = rng.sample_iter(Standard).take(1001).collect();
195        let ranked_bits = RankedBits::new(bits.clone().into_boxed_slice());
196        let bv = BitVec::<u64, Lsb0>::from_slice(&bits);
197
198        for idx in 0..bv.len() {
199            if bv[idx] {
200                assert_eq!(
201                    ranked_bits.rank(idx).unwrap(),
202                    bv[..idx].count_ones(),
203                    "Rank mismatch at index {}",
204                    idx
205                );
206            }
207        }
208    }
209}