rsdict/
lib.rs

1//! 'RsDict' data structure that supports both rank and select over a bitmap.
2//!
3//! This crate is an implementation of [Navarro and Providel, "Fast, Small,
4//! Simple Rank/Select On
5//! Bitmaps"](https://users.dcc.uchile.cl/~gnavarro/ps/sea12.1.pdf), with heavy
6//! inspiration from a [Go implementation](https://github.com/hillbig/rsdic).
7//!
8//! ```
9//! use rsdict::RsDict;
10//!
11//! let mut r = RsDict::new();
12//! r.push(false);
13//! r.push(true);
14//! r.push(true);
15//! r.push(false);
16//!
17//! // There's one bit set to the left of index 2.
18//! assert_eq!(r.rank(2, true), 1);
19//!
20//! // The index of the second (zero-indexed as 1) bit is 3.
21//! assert_eq!(r.select(1, false), Some(3));
22//! ```
23//!
24//! # Implementation notes
25//! First, we store the bitmap in compressed form.  Each block of 64 bits is
26//! stored with a variable length code, where the length is determined by the
27//! number of bits set in the block (its "class").  Then, we store the classes
28//! (i.e. the number of bits set per block) in a separate array, allowing us to
29//! iterate forward from a pointer into the variable length buffer.
30//!
31//! To allow efficient indexing, we then break up the input into
32//! `LARGE_BLOCK_SIZE` blocks and store a pointer into the variable length
33//! buffer per block.  As with other rank structures, we also store a
34//! precomputed rank from the beginning of the large block.
35//!
36//! Finally, we store precomputed indices for selection in separate arrays.  For
37//! every `SELECT_BLOCK_SIZE`th bit, we maintain a pointer to the large block
38//! this bit falls in.  We also do the same for zeros.
39//!
40//! Then, we can compute ranks by consulting the large block rank and then
41//! iterating over the small block classes before our desired position.  Once
42//! we've found the boundary small block, we can then decode it and compute the
43//! rank within the block.  The choice of variable length code allows computing
44//! its internal rank without decoding the entire block.
45//!
46//! Select works similarly where we start with the large block indices, skip
47//! over as many small blocks as possible, and then select within a small
48//! block. As with rank, we're able to select within a small block directly.
49
50#![cfg_attr(feature = "simd", feature(portable_simd))]
51
52#[cfg(test)]
53extern crate quickcheck;
54#[cfg(test)]
55#[macro_use(quickcheck)]
56extern crate quickcheck_macros;
57
58use std::cmp::Ordering;
59use std::mem;
60
61mod constants;
62mod enum_code;
63
64mod rank_acceleration;
65
66#[cfg(test)]
67mod test_helpers;
68
69use self::constants::{
70    LARGE_BLOCK_SIZE, SELECT_BLOCK_SIZE, SMALL_BLOCK_PER_LARGE_BLOCK, SMALL_BLOCK_SIZE,
71};
72use self::enum_code::ENUM_CODE_LENGTH;
73
74/// Data structure for efficiently computing both rank and select queries
75#[derive(Debug, Clone)]
76pub struct RsDict {
77    len: u64,
78    num_ones: u64,
79    num_zeros: u64,
80
81    // Small block metadata (stored every SMALL_BLOCK_SIZE bits):
82    // * number of set bits (the "class") for the small block
83    // * index within a class for each small block; note that the indexes are
84    //   variable length (see `ENUM_CODE_LENGTH`), so there isn't direct access
85    //   for a particular small block.
86    sb_classes: Vec<u8>,
87    sb_indices: VarintBuffer,
88
89    // Large block metadata (stored every LARGE_BLOCK_SIZE bits):
90    // * pointer into variable-length `bits` for the block start
91    // * cached rank at the block start
92    large_blocks: Vec<LargeBlock>,
93
94    // Select acceleration:
95    // `select_{one,zero}_inds` store the (offset / LARGE_BLOCK_SIZE) of each
96    // SELECT_BLOCK_SIZE'th bit.
97    select_one_inds: Vec<u64>,
98    select_zero_inds: Vec<u64>,
99
100    // Current in-progress small block we're appending to
101    last_block: LastBlock,
102}
103
104impl RsDict {
105    /// Create a dictionary from a bitset, specified as an iterator of 64-bit blocks.  This function
106    /// is equivalent to pushing each bit one at a time but is much faster.
107    pub fn from_blocks(blocks: impl Iterator<Item = u64>) -> Self {
108        #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
109        {
110            if is_x86_feature_detected!("popcnt") {
111                return unsafe { Self::from_blocks_popcount(blocks) };
112            }
113        }
114        Self::from_blocks_impl(blocks)
115    }
116
117    /// Return the size of the heap allocations associated with the `RsDict`.
118    pub fn heap_size(&self) -> usize {
119        self.sb_classes.capacity() * mem::size_of::<u8>()
120            + self.sb_indices.heap_size()
121            + self.large_blocks.capacity() * mem::size_of::<LargeBlock>()
122            + self.select_one_inds.capacity() * mem::size_of::<u64>()
123            + self.select_zero_inds.capacity() * mem::size_of::<u64>()
124    }
125
126    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
127    #[target_feature(enable = "popcnt")]
128    unsafe fn from_blocks_popcount(blocks: impl Iterator<Item = u64>) -> Self {
129        Self::from_blocks_impl(blocks)
130    }
131
132    #[inline(always)]
133    fn from_blocks_impl(blocks: impl Iterator<Item = u64>) -> Self {
134        let (_, hint) = blocks.size_hint();
135        let hint = hint.unwrap_or(0);
136
137        let mut large_blocks = Vec::with_capacity(hint / LARGE_BLOCK_SIZE as usize);
138        let mut select_one_inds = Vec::with_capacity(hint / SELECT_BLOCK_SIZE as usize);
139        let mut select_zero_inds = Vec::with_capacity(hint / SELECT_BLOCK_SIZE as usize);
140        let mut sb_classes = Vec::with_capacity(hint / SMALL_BLOCK_SIZE as usize);
141        let mut sb_indices = VarintBuffer::with_capacity(hint);
142        let mut last_block = LastBlock::new();
143
144        let mut num_ones = 0;
145        let mut num_zeros = 0;
146
147        let mut iter = blocks.enumerate().peekable();
148
149        while let Some((i, block)) = iter.next() {
150            let sb_class = block.count_ones() as u8;
151
152            if i as u64 % SMALL_BLOCK_PER_LARGE_BLOCK == 0 {
153                let lblock = LargeBlock {
154                    rank: num_ones,
155                    pointer: sb_indices.len() as u64,
156                };
157                large_blocks.push(lblock);
158            }
159
160            // If we're on the last block, write to `last_block` rather than
161            // pushing onto the `VarintBuffer`.
162            if iter.peek().is_none() {
163                last_block.bits = block;
164                last_block.num_ones = sb_class as u64;
165                last_block.num_zeros = 64 - sb_class as u64;
166            } else {
167                sb_classes.push(sb_class);
168                let (code_len, code) = enum_code::encode(block, sb_class);
169                sb_indices.push(code_len as usize, code);
170            }
171
172            let lb_start = i as u64 * SMALL_BLOCK_SIZE / LARGE_BLOCK_SIZE;
173
174            // We want to see if there's any j in [num_ones, num_ones + sb_class) such
175            // that j % SELECT_BLOCK_SIZE = 0.  We can do this arithmetically by
176            // comparing two divisors:
177            //
178            // 1. (num_ones - 1) / SELECT_BLOCK_SIZE and
179            // 2. (num_ones + sb_class - 1) / SELECT_BLOCK_SIZE.
180            //
181            // If they're not equal, there must be a multiple of SELECT_BLOCK_SIZE in
182            // the interval [num_ones, num_ones + sb_class).  To see why, consider
183            // the case where sb_class > 0 and SELECT_BLOCK_SIZE divides num_ones.
184            // Then, the first divisor's numerator is one less than a multiple, and
185            // the second one must be greater than or equal to it.  Similarly, if the
186            // last value num_ones + sb_class - 1 is a multiple, then the first divsior
187            // must be less than the second.  Then, since sb_class < SELECT_BLOCK_SIZE,
188            // the same argument holds for any divisor in the middle.
189            //
190            // Finally, since we're working with unsigned integers, add SELECT_BLOCK_SIZE
191            // to both numerators so we don't ever underflow when subtracting one.
192            let start = num_ones + SELECT_BLOCK_SIZE - 1;
193            let end = num_ones + SELECT_BLOCK_SIZE + sb_class as u64 - 1;
194            if start / SELECT_BLOCK_SIZE != end / SELECT_BLOCK_SIZE {
195                select_one_inds.push(lb_start);
196            }
197
198            // Now do the same for the zero indices.
199            let start = num_zeros + SELECT_BLOCK_SIZE - 1;
200            let end = num_zeros + SELECT_BLOCK_SIZE + (64 - sb_class as u64) - 1;
201            if start / SELECT_BLOCK_SIZE != end / SELECT_BLOCK_SIZE {
202                select_zero_inds.push(lb_start);
203            }
204
205            num_ones += sb_class as u64;
206            num_zeros += 64 - sb_class as u64;
207        }
208
209        let num_sb = sb_classes.len();
210        let align = SMALL_BLOCK_PER_LARGE_BLOCK as usize;
211        sb_classes.reserve((num_sb + align - 1) / align * align);
212
213        Self {
214            large_blocks,
215            select_one_inds,
216            select_zero_inds,
217            sb_classes,
218            sb_indices,
219
220            len: num_ones + num_zeros,
221            num_ones,
222            num_zeros,
223
224            last_block,
225        }
226    }
227
228    /// Create a new `RsDict` with zero capacity.
229    pub fn new() -> Self {
230        Self::with_capacity(0)
231    }
232
233    /// Create a new `RsDict` with the given capacity preallocated.
234    pub fn with_capacity(n: usize) -> Self {
235        Self {
236            large_blocks: Vec::with_capacity(n / LARGE_BLOCK_SIZE as usize),
237            select_one_inds: Vec::with_capacity(n / SELECT_BLOCK_SIZE as usize),
238            select_zero_inds: Vec::with_capacity(n / SELECT_BLOCK_SIZE as usize),
239            sb_classes: Vec::with_capacity(n / SMALL_BLOCK_SIZE as usize),
240            sb_indices: VarintBuffer::with_capacity(n),
241
242            len: 0,
243            num_ones: 0,
244            num_zeros: 0,
245
246            last_block: LastBlock::new(),
247        }
248    }
249
250    /// Non-inclusive rank: Count the number of `bit` values left of `pos`. Panics if `pos` is
251    /// out-of-bounds.
252    pub fn rank(&self, pos: u64, bit: bool) -> u64 {
253        if pos >= self.len {
254            panic!("Out of bounds position: {} >= {}", pos, self.len);
255        }
256        // If we're in the last block, count the number of ones set after our
257        // bit in the last block and remove that from the global count.
258        if self.is_last_block(pos) {
259            let trailing_ones = self.last_block.count_suffix(pos % SMALL_BLOCK_SIZE);
260            return rank_by_bit(self.num_ones - trailing_ones, pos, bit);
261        }
262
263        // Start with the rank from our position's large block.
264        let lblock = pos / LARGE_BLOCK_SIZE;
265        let LargeBlock {
266            mut pointer,
267            mut rank,
268        } = self.large_blocks[lblock as usize];
269
270        // Add in the ranks (i.e. the classes) per small block up to our
271        // position's small block.
272        let sblock_start = (lblock * SMALL_BLOCK_PER_LARGE_BLOCK) as usize;
273        let sblock = (pos / SMALL_BLOCK_SIZE) as usize;
274        let (class_sum, length_sum) =
275            rank_acceleration::scan_block(&self.sb_classes, sblock_start, sblock);
276        rank += class_sum;
277        pointer += length_sum;
278
279        // If we aren't on a small block boundary, add in the rank within the small block.
280        if pos % SMALL_BLOCK_SIZE != 0 {
281            let sb_class = self.sb_classes[sblock];
282            let code = self.read_sb_index(pointer, ENUM_CODE_LENGTH[sb_class as usize]);
283            rank += enum_code::rank(code, sb_class, pos % SMALL_BLOCK_SIZE);
284        }
285
286        rank_by_bit(rank, pos, bit)
287    }
288
289    /// Query the `pos`th bit (zero-indexed) of the underlying bit and the number of set bits to the
290    /// left of `pos` in a single operation.  This method is faster than calling `get_bit(pos)` and
291    /// `rank(pos, true)` separately.
292    pub fn bit_and_one_rank(&self, pos: u64) -> (bool, u64) {
293        if pos >= self.len {
294            panic!("Out of bounds position: {} >= {}", pos, self.len);
295        }
296        if self.is_last_block(pos) {
297            let sb_pos = pos % SMALL_BLOCK_SIZE;
298            let bit = self.last_block.get_bit(sb_pos);
299            let after_rank = self.last_block.count_suffix(sb_pos);
300            return (bit, self.num_ones - after_rank);
301        }
302        let lblock = pos / LARGE_BLOCK_SIZE;
303        let sblock = (pos / SMALL_BLOCK_SIZE) as usize;
304        let sblock_start = (lblock * SMALL_BLOCK_PER_LARGE_BLOCK) as usize;
305        let LargeBlock {
306            mut pointer,
307            mut rank,
308        } = self.large_blocks[lblock as usize];
309        for &sb_class in &self.sb_classes[sblock_start..sblock] {
310            pointer += ENUM_CODE_LENGTH[sb_class as usize] as u64;
311            rank += sb_class as u64;
312        }
313        let sb_class = self.sb_classes[sblock];
314        let code_length = ENUM_CODE_LENGTH[sb_class as usize];
315        let code = self.read_sb_index(pointer, code_length);
316
317        rank += enum_code::rank(code, sb_class, pos % SMALL_BLOCK_SIZE);
318        let bit = enum_code::decode_bit(code, sb_class, pos % SMALL_BLOCK_SIZE);
319        (bit, rank)
320    }
321
322    /// Inclusive rank: Count the number of `bit` values at indices less than or equal to
323    /// `pos`. Panics if `pos` is out-of-bounds.
324    pub fn inclusive_rank(&self, pos: u64, bit: bool) -> u64 {
325        let (pos_bit, one_rank) = self.bit_and_one_rank(pos);
326        rank_by_bit(one_rank, pos, bit) + if pos_bit == bit { 1 } else { 0 }
327    }
328
329    /// Compute the position of the `rank`th instance of `bit` (zero-indexed), returning `None` if
330    /// there are not `rank + 1` instances of `bit` in the array.
331    pub fn select(&self, rank: u64, bit: bool) -> Option<u64> {
332        if bit {
333            self.select1(rank)
334        } else {
335            self.select0(rank)
336        }
337    }
338
339    /// Specialized version of [`RsDict::select`] for finding positions of zeros.
340    pub fn select0(&self, rank: u64) -> Option<u64> {
341        if rank >= self.num_zeros {
342            return None;
343        }
344        // How many zeros are there *excluding* the last block?
345        let prefix_num_zeros = self.num_zeros - self.last_block.num_zeros;
346
347        // Our rank must be in the last block.
348        if rank >= prefix_num_zeros {
349            let lb_rank = (rank - prefix_num_zeros) as u8;
350            return Some(self.last_block_ind() + self.last_block.select0(lb_rank));
351        }
352
353        // First, use the select pointer to jump forward to a large block and
354        // then walk forward over the large blocks until we pass our rank.
355        let select_ind = (rank / SELECT_BLOCK_SIZE) as usize;
356        let lb_start = self.select_zero_inds[select_ind] as usize;
357        let mut lblock = None;
358        for (i, large_block) in self.large_blocks[lb_start..].iter().enumerate() {
359            let lb_ix = (lb_start + i) as u64;
360            let lb_rank = lb_ix * LARGE_BLOCK_SIZE - large_block.rank;
361            if rank < lb_rank {
362                lblock = Some(lb_ix - 1);
363                break;
364            }
365        }
366        let lblock = lblock.unwrap_or(self.large_blocks.len() as u64 - 1);
367        let large_block = &self.large_blocks[lblock as usize];
368
369        // Next, iterate over the small blocks, using their cached class to
370        // subtract out our rank.
371        let sb_start = (lblock * SMALL_BLOCK_PER_LARGE_BLOCK) as usize;
372        let mut pointer = large_block.pointer;
373        let mut remaining = rank - (lblock * LARGE_BLOCK_SIZE - large_block.rank);
374        for (i, &sb_class) in self.sb_classes[sb_start..].iter().enumerate() {
375            let sb_zeros = (SMALL_BLOCK_SIZE as u8 - sb_class) as u64;
376            let code_length = ENUM_CODE_LENGTH[sb_class as usize];
377
378            // Our desired rank is within this block.
379            if remaining < sb_zeros {
380                let code = self.read_sb_index(pointer, code_length);
381                let sb_rank = (sb_start + i) as u64 * SMALL_BLOCK_SIZE;
382                let block_rank = enum_code::select0(code, sb_class, remaining);
383                return Some(sb_rank + block_rank);
384            }
385
386            // Otherwise, subtract out this block and continue.
387            remaining -= sb_zeros;
388            pointer += code_length as u64;
389        }
390        panic!("Ran out of small blocks when iterating over rank");
391    }
392
393    /// Specialized version of [`RsDict::select`] for finding positions of ones.
394    pub fn select1(&self, rank: u64) -> Option<u64> {
395        if rank >= self.num_ones {
396            return None;
397        }
398
399        let prefix_num_ones = self.num_ones - self.last_block.num_ones;
400        if rank >= prefix_num_ones {
401            let lb_rank = (rank - prefix_num_ones) as u8;
402            return Some(self.last_block_ind() + self.last_block.select1(lb_rank));
403        }
404
405        let select_ind = (rank / SELECT_BLOCK_SIZE) as usize;
406        let lb_start = self.select_one_inds[select_ind] as usize;
407        let mut lblock = None;
408        for (i, large_block) in self.large_blocks[lb_start..].iter().enumerate() {
409            if rank < large_block.rank {
410                lblock = Some((lb_start + i - 1) as u64);
411                break;
412            }
413        }
414        let lblock = lblock.unwrap_or(self.large_blocks.len() as u64 - 1);
415        let large_block = &self.large_blocks[lblock as usize];
416
417        let sb_start = (lblock * SMALL_BLOCK_PER_LARGE_BLOCK) as usize;
418        let mut pointer = large_block.pointer;
419        let mut remaining = rank - large_block.rank;
420        for (i, &sb_class) in self.sb_classes[sb_start..].iter().enumerate() {
421            let sb_ones = sb_class as u64;
422            let code_length = ENUM_CODE_LENGTH[sb_class as usize];
423
424            if remaining < sb_ones {
425                let code = self.read_sb_index(pointer, code_length);
426                let sb_rank = (sb_start + i) as u64 * SMALL_BLOCK_SIZE;
427                let block_rank = enum_code::select1(code, sb_class, remaining);
428                return Some(sb_rank + block_rank);
429            }
430
431            remaining -= sb_ones;
432            pointer += code_length as u64;
433        }
434        panic!("Ran out of small blocks when iterating over rank");
435    }
436
437    /// Return the length of the underlying bitmap.
438    pub fn len(&self) -> usize {
439        self.len as usize
440    }
441
442    /// Return whether the underlying bitmap is empty.
443    pub fn is_empty(&self) -> bool {
444        self.len == 0
445    }
446
447    /// Count the number of set bits in the underlying bitmap.
448    pub fn count_ones(&self) -> usize {
449        self.num_ones as usize
450    }
451
452    /// Count the number of unset bits in the underlying bitmap.
453    pub fn count_zeros(&self) -> usize {
454        self.num_zeros as usize
455    }
456
457    /// Push a bit at the end of the underlying bitmap.
458    pub fn push(&mut self, bit: bool) {
459        if self.len % SMALL_BLOCK_SIZE == 0 {
460            self.write_block();
461        }
462        if bit {
463            self.last_block.set_one(self.len % SMALL_BLOCK_SIZE);
464            if self.num_ones % SELECT_BLOCK_SIZE == 0 {
465                self.select_one_inds.push(self.len / LARGE_BLOCK_SIZE);
466            }
467            self.num_ones += 1;
468        } else {
469            self.last_block.set_zero(self.len % SMALL_BLOCK_SIZE);
470            if self.num_zeros % SELECT_BLOCK_SIZE == 0 {
471                self.select_zero_inds.push(self.len / LARGE_BLOCK_SIZE);
472            }
473            self.num_zeros += 1;
474        }
475        self.len += 1;
476    }
477
478    /// Query the `pos`th bit (zero-indexed) of the underlying bitmap.
479    pub fn get_bit(&self, pos: u64) -> bool {
480        if self.is_last_block(pos) {
481            return self.last_block.get_bit(pos % SMALL_BLOCK_SIZE);
482        }
483        let lblock = pos / LARGE_BLOCK_SIZE;
484        let sblock = (pos / SMALL_BLOCK_SIZE) as usize;
485        let sblock_start = (lblock * SMALL_BLOCK_PER_LARGE_BLOCK) as usize;
486        let mut pointer = self.large_blocks[lblock as usize].pointer;
487        for &sb_class in &self.sb_classes[sblock_start..sblock] {
488            pointer += ENUM_CODE_LENGTH[sb_class as usize] as u64;
489        }
490        let sb_class = self.sb_classes[sblock];
491        let code_length = ENUM_CODE_LENGTH[sb_class as usize];
492        let code = self.read_sb_index(pointer, code_length);
493        enum_code::decode_bit(code, sb_class, pos % SMALL_BLOCK_SIZE)
494    }
495
496    /// Iterate over the bits in the bitset.
497    pub fn iter(&self) -> impl Iterator<Item = bool> + '_ {
498        struct RsDictIter<'a> {
499            rsdict: &'a RsDict,
500            pos: u64,
501        }
502        impl<'a> Iterator for RsDictIter<'a> {
503            type Item = bool;
504
505            fn next(&mut self) -> Option<bool> {
506                if self.pos >= self.rsdict.len {
507                    return None;
508                }
509                // TODO: We could optimize this to read in a block once rather than decoding a bit
510                // at a time.
511                let out = self.rsdict.get_bit(self.pos);
512                self.pos += 1;
513                Some(out)
514            }
515        }
516        RsDictIter {
517            rsdict: self,
518            pos: 0,
519        }
520    }
521
522    fn write_block(&mut self) {
523        if self.len > 0 {
524            let block = mem::replace(&mut self.last_block, LastBlock::new());
525
526            let sb_class = block.num_ones as u8;
527            self.sb_classes.push(sb_class);
528
529            // To avoid indexing past the end of our allocation when
530            // scanning through a large block, reserve some extra space to
531            // ensure that we always have a full large block in
532            // `sb_classes`.
533            let num_sb = self.sb_classes.len();
534            let align = SMALL_BLOCK_PER_LARGE_BLOCK as usize;
535            self.sb_classes
536                .reserve((num_sb + align - 1) / align * align);
537
538            let (code_len, code) = enum_code::encode(block.bits, sb_class);
539            self.sb_indices.push(code_len as usize, code);
540        }
541        if self.len % LARGE_BLOCK_SIZE == 0 {
542            let lblock = LargeBlock {
543                rank: self.num_ones,
544                pointer: self.sb_indices.len() as u64,
545            };
546            self.large_blocks.push(lblock);
547        }
548    }
549
550    fn last_block_ind(&self) -> u64 {
551        if self.len == 0 {
552            return 0;
553        }
554        ((self.len - 1) / SMALL_BLOCK_SIZE) * SMALL_BLOCK_SIZE
555    }
556
557    fn is_last_block(&self, pos: u64) -> bool {
558        pos >= self.last_block_ind()
559    }
560
561    fn read_sb_index(&self, ptr: u64, code_len: u8) -> u64 {
562        self.sb_indices.get(ptr as usize, code_len as usize)
563    }
564}
565
566impl PartialEq for RsDict {
567    fn eq(&self, rhs: &Self) -> bool {
568        self.iter().eq(rhs.iter())
569    }
570}
571
572impl Eq for RsDict {}
573
574impl PartialOrd for RsDict {
575    fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
576        self.iter().partial_cmp(rhs.iter())
577    }
578}
579
580impl Ord for RsDict {
581    fn cmp(&self, rhs: &Self) -> Ordering {
582        self.iter().cmp(rhs.iter())
583    }
584}
585
586#[derive(Clone, Debug, Eq, PartialEq)]
587struct LargeBlock {
588    pointer: u64,
589    rank: u64,
590}
591
592#[derive(Clone, Debug, Eq, PartialEq)]
593struct VarintBuffer {
594    buf: Vec<u64>,
595    len: usize,
596}
597
598impl VarintBuffer {
599    fn with_capacity(bits: usize) -> Self {
600        Self {
601            buf: Vec::with_capacity(bits / 64),
602            len: 0,
603        }
604    }
605
606    fn push(&mut self, num_bits: usize, value: u64) {
607        debug_assert!(num_bits <= 64);
608        if num_bits == 0 {
609            return;
610        }
611        let (block, offset) = (self.len / 64, self.len % 64);
612        if self.buf.len() == block || offset + num_bits > 64 {
613            self.buf.push(0);
614        }
615        self.buf[block] |= value << offset;
616        if offset + num_bits > 64 {
617            self.buf[block + 1] |= value >> (64 - offset);
618        }
619        self.len += num_bits;
620    }
621
622    fn get(&self, index: usize, num_bits: usize) -> u64 {
623        debug_assert!(num_bits <= 64);
624        if num_bits == 0 {
625            return 0;
626        }
627        let (block, offset) = (index / 64, index % 64);
628        let mask = 1u64
629            .checked_shl(num_bits as u32)
630            .unwrap_or(0)
631            .wrapping_sub(1);
632        let mut ret = (self.buf[block] >> offset) & mask;
633        if offset + num_bits > 64 {
634            ret |= self.buf[block + 1] << (64 - offset);
635        }
636        ret & mask
637    }
638
639    fn heap_size(&self) -> usize {
640        self.buf.capacity() * mem::size_of::<u64>()
641    }
642
643    fn len(&self) -> usize {
644        self.len
645    }
646}
647
648#[derive(Clone, Debug, Eq, PartialEq)]
649struct LastBlock {
650    bits: u64,
651    num_ones: u64,
652    num_zeros: u64,
653}
654
655impl LastBlock {
656    fn new() -> Self {
657        LastBlock {
658            bits: 0,
659            num_ones: 0,
660            num_zeros: 0,
661        }
662    }
663
664    fn select0(&self, rank: u8) -> u64 {
665        debug_assert!(rank < self.num_zeros as u8);
666        enum_code::select1_raw(!self.bits, rank as u64)
667    }
668
669    fn select1(&self, rank: u8) -> u64 {
670        debug_assert!(rank < self.num_ones as u8);
671        enum_code::select1_raw(self.bits, rank as u64)
672    }
673
674    // Count the number of bits set at indices i >= pos
675    fn count_suffix(&self, pos: u64) -> u64 {
676        (self.bits >> pos).count_ones() as u64
677    }
678
679    fn get_bit(&self, pos: u64) -> bool {
680        (self.bits >> pos) & 1 == 1
681    }
682
683    // Only call one of `set_one` or `set_zeros` for any `pos`.
684    fn set_one(&mut self, pos: u64) {
685        self.bits |= 1 << pos;
686        self.num_ones += 1;
687    }
688    fn set_zero(&mut self, _pos: u64) {
689        self.num_zeros += 1;
690    }
691}
692
693fn rank_by_bit(x: u64, n: u64, b: bool) -> u64 {
694    if b {
695        x
696    } else {
697        n - x
698    }
699}
700
701#[cfg(test)]
702mod tests {
703    use super::RsDict;
704    use crate::test_helpers::hash_u64;
705
706    fn hash_u64_blocks(blocks: &[u64]) -> Vec<bool> {
707        let mut bits = Vec::with_capacity(blocks.len() * 64);
708        let to_pop = blocks.get(0).unwrap_or(&0) % 64;
709        for block in blocks {
710            for i in 0..4 {
711                let block = hash_u64(block.wrapping_add(i));
712                if block % 2 != 0 {
713                    for j in 0..64 {
714                        let bit = (block >> j) & 1 != 0;
715                        bits.push(bit);
716                    }
717                }
718            }
719        }
720        for _ in 0..to_pop {
721            bits.pop();
722        }
723        bits
724    }
725
726    fn check_rsdict(bits: &[bool]) {
727        let mut rs_dict = RsDict::with_capacity(bits.len());
728        for &bit in bits {
729            rs_dict.push(bit);
730        }
731
732        // Check that rank(i) matches our naively computed ranks for all indices.
733        let mut one_rank = 0;
734        let mut zero_rank = 0;
735        for (i, &inp_bit) in bits.iter().enumerate() {
736            assert_eq!(rs_dict.rank(i as u64, false), zero_rank);
737            assert_eq!(rs_dict.rank(i as u64, true), one_rank);
738            if inp_bit {
739                one_rank += 1;
740            } else {
741                zero_rank += 1;
742            }
743        }
744
745        // Check `select(r)` for ranks "in bounds" within the bitvector against
746        // our naively computed ranks.
747        let mut one_rank = 0;
748        let mut zero_rank = 0;
749        for (i, &inp_bit) in bits.iter().enumerate() {
750            if inp_bit {
751                assert_eq!(rs_dict.select(one_rank as u64, true), Some(i as u64));
752                one_rank += 1;
753            } else {
754                assert_eq!(rs_dict.select(zero_rank as u64, false), Some(i as u64));
755                zero_rank += 1;
756            }
757        }
758        // Check all of the "out of bounds" ranks up until `bits.len()`.
759        for r in (one_rank + 1)..bits.len() {
760            assert_eq!(rs_dict.select(r as u64, true), None);
761        }
762        for r in (zero_rank + 1)..bits.len() {
763            assert_eq!(rs_dict.select(r as u64, false), None);
764        }
765
766        // Check that we can query all of the bits back out.
767        for (i, &bit) in bits.iter().enumerate() {
768            assert_eq!(rs_dict.get_bit(i as u64), bit);
769        }
770
771        // Check our combined bit and rank method.
772        let mut one_rank = 0;
773        for (i, &bit) in bits.iter().enumerate() {
774            let (rs_bit, rs_rank) = rs_dict.bit_and_one_rank(i as u64);
775            assert_eq!((rs_bit, rs_rank), (bit, one_rank));
776            if bit {
777                one_rank += 1;
778            }
779        }
780
781        // Check that iteration matches.
782        assert!(bits.iter().cloned().eq(rs_dict.iter()));
783
784        // Check that equality is reflexive.
785        assert_eq!(bits, bits)
786    }
787
788    #[quickcheck]
789    fn qc_from_blocks(blocks: Vec<u64>) {
790        let bits = hash_u64_blocks(&blocks);
791        let mut rs_dict = RsDict::with_capacity(bits.len());
792        for &bit in &bits {
793            rs_dict.push(bit);
794        }
795        let blocks = (0..(bits.len() / 64)).map(|i| {
796            let mut block = 0u64;
797            for j in 0..64 {
798                if bits[i * 64 + j] {
799                    block |= 1 << j;
800                }
801            }
802            block
803        });
804        let mut block_rs_dict = RsDict::from_blocks(blocks);
805        for i in (bits.len() / 64 * 64)..bits.len() {
806            block_rs_dict.push(bits[i]);
807        }
808
809        assert_eq!(rs_dict.len, block_rs_dict.len);
810        assert_eq!(rs_dict.num_ones, block_rs_dict.num_ones);
811        assert_eq!(rs_dict.num_zeros, block_rs_dict.num_zeros);
812        assert_eq!(rs_dict.sb_classes, block_rs_dict.sb_classes);
813        assert_eq!(rs_dict.sb_indices, block_rs_dict.sb_indices);
814        assert_eq!(rs_dict.large_blocks, block_rs_dict.large_blocks);
815        assert_eq!(rs_dict.select_one_inds, block_rs_dict.select_one_inds);
816        assert_eq!(rs_dict.select_zero_inds, block_rs_dict.select_zero_inds);
817        assert_eq!(rs_dict.last_block, block_rs_dict.last_block);
818    }
819
820    // Ask quickcheck to generate blocks of 64 bits so we get test
821    // coverage for ranges spanning multiple small blocks.
822    #[quickcheck]
823    fn qc_rsdict(blocks: Vec<u64>) {
824        check_rsdict(&hash_u64_blocks(&blocks));
825    }
826
827    #[test]
828    fn test_large_rsdicts() {
829        check_rsdict(&[true; 65]);
830        check_rsdict(&[true; 1025]);
831        check_rsdict(&[true; 3121]);
832        check_rsdict(&[true; 3185]);
833        check_rsdict(&[true; 4097]);
834        check_rsdict(&[true; 8193]);
835
836        check_rsdict(&[false; 65]);
837        check_rsdict(&[false; 1025]);
838        check_rsdict(&[false; 3121]);
839        check_rsdict(&[false; 3185]);
840        check_rsdict(&[false; 4097]);
841        check_rsdict(&[false; 8193]);
842
843        let alternating = &mut [false; 8193];
844        for i in 0..8193 {
845            alternating[i] = i % 2 == 0;
846        }
847        check_rsdict(alternating);
848    }
849
850    #[test]
851    fn test_ordering() {
852        let r1 = RsDict::from_blocks([0u64].iter().cloned());
853        let r2 = RsDict::from_blocks([1u64].iter().cloned());
854
855        assert_ne!(r1, r2);
856        assert!(r1 < r2);
857    }
858}