1use std::mem::size_of_val;
7
8const L2_BIT_SIZE: usize = 512;
10const L1_BIT_SIZE: usize = 8 * L2_BIT_SIZE;
12
13pub trait RankedBitsAccess {
18 fn rank(&self, idx: usize) -> Option<usize>;
20
21 #[inline]
27 unsafe fn rank_impl<T: L12RankAccess>(bits: &[u64], l12_ranks: &T, idx: usize) -> Option<usize> {
28 let word_idx = idx / 64;
29 let bit_idx = idx % 64;
30 let word = *bits.get_unchecked(word_idx);
31
32 if (word & (1u64 << bit_idx)) == 0 {
33 return None;
34 }
35
36 let l1_pos = idx / L1_BIT_SIZE;
37 let l2_pos = (idx % L1_BIT_SIZE) / L2_BIT_SIZE;
38
39 let idx_within_l2 = idx % L2_BIT_SIZE;
40 let blocks_num = idx_within_l2 / 64;
41 let offset = (idx / L2_BIT_SIZE) * 8;
42 let block = bits.get_unchecked(offset..offset + blocks_num);
43
44 let block_rank = block.iter().map(|&x| x.count_ones() as usize).sum::<usize>();
45
46 let word = *bits.get_unchecked(offset + blocks_num);
47 let word_mask = ((1u64 << (idx_within_l2 % 64)) - 1) * (idx_within_l2 > 0) as u64;
48 let word_rank = (word & word_mask).count_ones() as usize;
49
50 let (l1_rank, l2_rank) = l12_ranks.l12_ranks(l1_pos, l2_pos);
51 let total_rank = l1_rank + l2_rank + block_rank + word_rank;
52
53 Some(total_rank)
54 }
55}
56
57#[derive(Debug, Default)]
58#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
59#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
60pub struct RankedBits {
61 bits: Box<[u64]>,
63 l12_ranks: Box<[L12Rank]>,
65}
66
67#[derive(Debug)]
72#[cfg_attr(feature = "rkyv_derive", derive(rkyv::Archive, rkyv::Deserialize, rkyv::Serialize))]
73#[cfg_attr(feature = "rkyv_derive", archive_attr(derive(rkyv::CheckBytes)))]
74pub struct L12Rank([u8; 16]);
75
76pub trait L12RankAccess {
78 fn l12_rank(&self, l1_pos: usize) -> u128;
80
81 #[inline]
83 fn l12_ranks(&self, l1_pos: usize, l2_pos: usize) -> (usize, usize) {
84 let l12_rank = self.l12_rank(l1_pos);
85 let l1_rank = (l12_rank & 0xFFFFFFFFFFF) as usize;
86 let l2_rank = ((l12_rank >> (32 + 12 * l2_pos)) & 0xFFF) as usize;
87 (l1_rank, l2_rank)
88 }
89}
90
91impl L12RankAccess for Box<[L12Rank]> {
92 #[inline]
93 fn l12_rank(&self, l1_pos: usize) -> u128 {
94 u128::from_le_bytes(unsafe { self.get_unchecked(l1_pos).0 })
95 }
96}
97
98#[cfg(feature = "rkyv_derive")]
99impl L12RankAccess for rkyv::boxed::ArchivedBox<[ArchivedL12Rank]> {
100 #[inline]
101 fn l12_rank(&self, l1_pos: usize) -> u128 {
102 u128::from_le_bytes(unsafe { self.get_unchecked(l1_pos).0 })
103 }
104}
105
106impl From<u128> for L12Rank {
107 #[inline]
108 fn from(v: u128) -> Self {
109 L12Rank(v.to_le_bytes())
110 }
111}
112
113impl RankedBits {
114 pub fn new(bits: Box<[u64]>) -> Self {
116 let blocks = bits.chunks_exact(64);
117 let remainder = blocks.remainder();
118 let mut l12_ranks = Vec::with_capacity(bits.len().div_ceil(64));
119 let mut l1_rank: u128 = 0;
120
121 for block64 in blocks {
122 let mut l12_rank = 0u128;
123 let mut sum = 0u16;
124 for (i, block8) in block64.chunks_exact(8).enumerate() {
125 sum += block8.iter().map(|&x| x.count_ones() as u16).sum::<u16>();
126 l12_rank += (sum as u128) << (i * 12);
127 }
128 l12_rank = (l12_rank << 44) | l1_rank;
129 l12_ranks.push(l12_rank.into());
130 l1_rank += sum as u128;
131 }
132
133 if !remainder.is_empty() {
134 let mut l12_rank = 0u128;
135 let mut sum = 0u16;
136 for (i, block) in remainder.chunks(8).enumerate() {
137 sum += block.iter().map(|&x| x.count_ones() as u16).sum::<u16>();
138 l12_rank += (sum as u128) << (i * 12);
139 }
140 l12_rank = (l12_rank << 44) | l1_rank;
141 l12_ranks.push(l12_rank.into());
142 }
143
144 RankedBits { bits, l12_ranks: l12_ranks.into_boxed_slice() }
145 }
146
147 pub fn size(&self) -> usize {
149 size_of_val(self) + size_of_val(self.bits.as_ref()) + size_of_val(self.l12_ranks.as_ref())
150 }
151}
152
153impl RankedBitsAccess for RankedBits {
155 #[inline]
156 fn rank(&self, idx: usize) -> Option<usize> {
157 unsafe { Self::rank_impl(&self.bits, &self.l12_ranks, idx) }
158 }
159}
160
161#[cfg(feature = "rkyv_derive")]
163impl RankedBitsAccess for ArchivedRankedBits {
164 #[inline]
165 fn rank(&self, idx: usize) -> Option<usize> {
166 unsafe { Self::rank_impl(&self.bits, &self.l12_ranks, idx) }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use bitvec::order::Lsb0;
174 use bitvec::vec::BitVec;
175 use rand::distributions::Standard;
176 use rand::Rng;
177
178 #[test]
179 fn test_rank_and_get() {
180 let bits = vec![
181 0b11001010, 0b00110111, 0b11110000, ];
185
186 let ranked_bits = RankedBits::new(bits.into_boxed_slice());
187 assert_eq!(ranked_bits.rank(0), None); assert_eq!(ranked_bits.rank(7), Some(3)); }
190
191 #[test]
192 fn test_random_bits() {
193 let rng = rand::thread_rng();
194 let bits: Vec<u64> = rng.sample_iter(Standard).take(1001).collect();
195 let ranked_bits = RankedBits::new(bits.clone().into_boxed_slice());
196 let bv = BitVec::<u64, Lsb0>::from_slice(&bits);
197
198 for idx in 0..bv.len() {
199 if bv[idx] {
200 assert_eq!(
201 ranked_bits.rank(idx).unwrap(),
202 bv[..idx].count_ones(),
203 "Rank mismatch at index {}",
204 idx
205 );
206 }
207 }
208 }
209}