indexed_bitvec_core/
index_raw.rs

1/*
2   Copyright 2018 DarkOtter
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15*/
16//! The raw functions for building and using rank/select indexes.
17//!
18//! The functions here do minimal if any checking on the size
19//! or validity of indexes vs. the bitvectors they are used with,
20//! so you may run into panics from e.g. out of bounds accesses
21//! to slices. They should all be memory-safe though.
22
23use crate::bits_ref::BitsRef;
24use crate::ones_or_zeros::{OneBits, OnesOrZeros, ZeroBits};
25use crate::with_offset::WithOffset;
26use crate::{ceil_div, ceil_div_u64};
27use core::cmp::min;
28
29impl<'a> BitsRef<'a> {
30    /// Split the bits into a sequence of chunks of up to *n* bytes.
31    fn chunks_by_bytes<'s>(&'s self, bytes_per_chunk: usize) -> impl Iterator<Item = BitsRef<'s>> {
32        let bits_per_chunk = (bytes_per_chunk as u64) * 8;
33        self.bytes()
34            .chunks(bytes_per_chunk)
35            .enumerate()
36            .map(move |(i, chunk)| {
37                let len = i as u64 * bits_per_chunk;
38                let bits = min(self.len() - len, bits_per_chunk);
39                BitsRef::from_bytes(chunk, bits).expect("Size invariant violated")
40            })
41    }
42
43    /// Drop the first *n* bytes of bits from the front of the sequence.
44    fn drop_bytes<'s>(&'s self, n_bytes: usize) -> BitsRef<'s> {
45        let bytes = self.bytes();
46        if n_bytes >= bytes.len() {
47            panic!("Index out of bounds: tried to drop all of the bits");
48        }
49        BitsRef::from_bytes(&bytes[n_bytes..], self.len() - (n_bytes as u64 * 8))
50            .expect("Checked sufficient bytes are present")
51    }
52}
53
54mod size {
55    use super::*;
56
57    pub const BITS_PER_L0_BLOCK: u64 = 1 << 32;
58    pub const BITS_PER_L1_BLOCK: u64 = BITS_PER_L2_BLOCK * 4;
59    pub const BITS_PER_L2_BLOCK: u64 = 512;
60
61    pub const BYTES_PER_L0_BLOCK: usize = (BITS_PER_L0_BLOCK / 8) as usize;
62    pub const BYTES_PER_L1_BLOCK: usize = (BITS_PER_L1_BLOCK / 8) as usize;
63    pub const BYTES_PER_L2_BLOCK: usize = (BITS_PER_L2_BLOCK / 8) as usize;
64
65    pub fn l0(total_bits: u64) -> usize {
66        ceil_div_u64(total_bits, BITS_PER_L0_BLOCK) as usize
67    }
68
69    pub fn l1l2(total_bits: u64) -> usize {
70        ceil_div_u64(total_bits, BITS_PER_L1_BLOCK) as usize
71    }
72
73    pub fn blocks(total_bits: u64) -> usize {
74        ceil_div_u64(total_bits, BITS_PER_L2_BLOCK) as usize
75    }
76
77    pub const SAMPLE_LENGTH: u64 = 8192;
78
79    /// If we have *n* one bits (or zero bits), how many samples to cover those bits?
80    pub fn samples_for_bits(matching_bitcount: u64) -> usize {
81        ceil_div_u64(matching_bitcount, SAMPLE_LENGTH) as usize
82    }
83    /// If we have *n* one and zero bits, how many words for all samples together?
84    pub fn sample_words(total_bits: u64) -> usize {
85        ceil_div(samples_for_bits(total_bits) + 1, 2)
86    }
87
88    pub fn total_index_words(total_bits: u64) -> usize {
89        l0(total_bits) + l1l2(total_bits) + sample_words(total_bits)
90    }
91
92    pub const L1_BLOCKS_PER_L0_BLOCK: usize = (BITS_PER_L0_BLOCK / BITS_PER_L1_BLOCK) as usize;
93    pub const L2_BLOCKS_PER_L1_BLOCK: usize = (BITS_PER_L1_BLOCK / BITS_PER_L2_BLOCK) as usize;
94    pub const L2_BLOCKS_PER_L0_BLOCK: usize = L2_BLOCKS_PER_L1_BLOCK * L1_BLOCKS_PER_L0_BLOCK;
95
96    #[cfg(test)]
97    mod tests {
98        use super::*;
99
100        #[test]
101        fn bytes_evenly_divide_block_sizes() {
102            assert_eq!(BITS_PER_L0_BLOCK % 8, 0);
103            assert_eq!(BITS_PER_L1_BLOCK % 8, 0);
104            assert_eq!(BITS_PER_L2_BLOCK % 8, 0);
105        }
106
107        #[test]
108        fn l1l2_evenly_divide_l0() {
109            // This property is needed so that the size of the l1l2
110            // index works out correctly if calculated across separate
111            // l0 blocks.
112            assert_eq!(BITS_PER_L0_BLOCK % BITS_PER_L1_BLOCK, 0);
113            assert_eq!(BITS_PER_L0_BLOCK % BITS_PER_L2_BLOCK, 0);
114        }
115
116        #[test]
117        fn block_sizes_evenly_divide() {
118            assert_eq!(BITS_PER_L0_BLOCK % BITS_PER_L1_BLOCK, 0);
119            assert_eq!(BITS_PER_L1_BLOCK % BITS_PER_L2_BLOCK, 0);
120        }
121
122        #[test]
123        fn sample_size_larger_than_l1() {
124            // This is needed as we assume only one sample can be in each L1 block
125            assert!(SAMPLE_LENGTH >= BITS_PER_L1_BLOCK);
126        }
127
128        #[test]
129        fn size_of_index_for_zero() {
130            assert_eq!(1, total_index_words(0));
131        }
132    }
133}
134
135mod structure {
136    use super::*;
137
138    #[derive(Copy, Clone, Debug)]
139    pub struct L1L2Entry(u64);
140
141    impl L1L2Entry {
142        pub fn pack(base_rank: u32, first_counts: [u16; 3]) -> Self {
143            debug_assert!(first_counts.iter().all(|&x| x < 0x0400));
144            L1L2Entry(
145                ((base_rank as u64) << 32)
146                    | ((first_counts[0] as u64) << 22)
147                    | ((first_counts[1] as u64) << 12)
148                    | ((first_counts[2] as u64) << 2),
149            )
150        }
151
152        pub fn base_rank(self) -> u64 {
153            self.0 >> 32
154        }
155
156        fn fset_base_rank(self, base_rank: u32) -> Self {
157            L1L2Entry(((base_rank as u64) << 32) | self.0 & 0xffffffff)
158        }
159
160        pub fn set_base_rank(&mut self, base_rank: u32) {
161            *self = self.fset_base_rank(base_rank);
162        }
163
164        pub fn l2_count(self, i: usize) -> u64 {
165            let shift = 22 - i * 10;
166            (self.0 >> shift) & 0x3ff
167        }
168    }
169
170    #[derive(Copy, Clone, Debug)]
171    pub struct SampleEntry(u32);
172
173    impl SampleEntry {
174        pub fn pack(block_idx_in_l0_block: usize) -> Self {
175            debug_assert!(block_idx_in_l0_block <= u32::max_value() as usize);
176            SampleEntry(block_idx_in_l0_block as u32)
177        }
178
179        pub fn block_idx_in_l0_block(self) -> usize {
180            self.0 as usize
181        }
182    }
183
184    use core::mem::{align_of, size_of};
185
186    fn cast_to_l1l2<'a>(data: &'a [u64]) -> &'a [L1L2Entry] {
187        debug_assert_eq!(size_of::<u64>(), size_of::<L1L2Entry>());
188        debug_assert_eq!(align_of::<u64>(), align_of::<L1L2Entry>());
189
190        unsafe {
191            use core::slice::from_raw_parts;
192            let n = data.len();
193            let ptr = data.as_ptr() as *mut L1L2Entry;
194            from_raw_parts(ptr, n)
195        }
196    }
197
198    fn cast_to_l1l2_mut<'a>(data: &'a mut [u64]) -> &'a mut [L1L2Entry] {
199        debug_assert_eq!(size_of::<u64>(), size_of::<L1L2Entry>());
200        debug_assert_eq!(align_of::<u64>(), align_of::<L1L2Entry>());
201
202        unsafe {
203            use core::slice::from_raw_parts_mut;
204            let n = data.len();
205            let ptr = data.as_mut_ptr() as *mut L1L2Entry;
206            from_raw_parts_mut(ptr, n)
207        }
208    }
209
210    fn cast_to_samples<'a>(data: &'a [u64]) -> &'a [SampleEntry] {
211        debug_assert_eq!(size_of::<u64>(), 2 * size_of::<SampleEntry>());
212        debug_assert_eq!(align_of::<u64>(), 2 * align_of::<SampleEntry>());
213
214        unsafe {
215            use core::slice::from_raw_parts;
216            let n = data.len() * 2;
217            let ptr = data.as_ptr() as *const SampleEntry;
218            from_raw_parts(ptr, n)
219        }
220    }
221
222    fn cast_to_samples_mut<'a>(data: &'a mut [u64]) -> &'a mut [SampleEntry] {
223        debug_assert_eq!(size_of::<u64>(), 2 * size_of::<SampleEntry>());
224        debug_assert_eq!(align_of::<u64>(), 2 * align_of::<SampleEntry>());
225
226        unsafe {
227            use core::slice::from_raw_parts_mut;
228            let n = data.len() * 2;
229            let ptr = data.as_mut_ptr() as *mut SampleEntry;
230            from_raw_parts_mut(ptr, n)
231        }
232    }
233
234    pub fn split_l0<'a>(index: &'a [u64], data: BitsRef) -> (&'a [u64], &'a [u64]) {
235        index.split_at(size::l0(data.len()))
236    }
237
238    pub fn split_l0_mut<'a>(index: &'a mut [u64], data: BitsRef) -> (&'a mut [u64], &'a mut [u64]) {
239        index.split_at_mut(size::l0(data.len()))
240    }
241
242    #[derive(Copy, Clone, Debug)]
243    pub struct L1L2Indexes<'a>(&'a [L1L2Entry]);
244
245    pub fn split_l1l2<'a>(
246        index_after_l0: &'a [u64],
247        data: BitsRef,
248    ) -> (L1L2Indexes<'a>, &'a [u64]) {
249        let (l1l2, other) = index_after_l0.split_at(size::l1l2(data.len()));
250        (L1L2Indexes(cast_to_l1l2(l1l2)), other)
251    }
252
253    pub fn split_l1l2_mut<'a>(
254        index_after_l0: &'a mut [u64],
255        data: BitsRef,
256    ) -> (&'a mut [L1L2Entry], &'a mut [u64]) {
257        let (l1l2, other) = index_after_l0.split_at_mut(size::l1l2(data.len()));
258        (cast_to_l1l2_mut(l1l2), other)
259    }
260
261    pub fn split_samples<'a>(
262        index_after_l1l2: &'a [u64],
263        data: BitsRef,
264        count_ones: u64,
265    ) -> (&'a [SampleEntry], &'a [SampleEntry]) {
266        let all_samples = cast_to_samples(index_after_l1l2);
267        let n_samples_ones = size::samples_for_bits(count_ones);
268        let n_samples_zeros = size::samples_for_bits(data.len() - count_ones);
269        let (ones_samples, other_samples) = all_samples.split_at(n_samples_ones);
270        let zeros_samples = &other_samples[..n_samples_zeros];
271        (ones_samples, zeros_samples)
272    }
273
274    pub fn split_samples_mut<'a>(
275        index_after_l1l2: &'a mut [u64],
276        data: BitsRef,
277        count_ones: u64,
278    ) -> (&'a mut [SampleEntry], &'a mut [SampleEntry]) {
279        debug_assert!(index_after_l1l2.len() == size::sample_words(data.len()));
280        let all_samples = cast_to_samples_mut(index_after_l1l2);
281        let n_samples_ones = size::samples_for_bits(count_ones);
282        let n_samples_zeros = size::samples_for_bits(data.len() - count_ones);
283        debug_assert!(all_samples.len() >= n_samples_ones + n_samples_zeros);
284        debug_assert!(all_samples.len() <= n_samples_ones + n_samples_zeros + 2);
285        let (ones_samples, other_samples) = all_samples.split_at_mut(n_samples_ones);
286        let zeros_samples = &mut other_samples[..n_samples_zeros];
287        (ones_samples, zeros_samples)
288    }
289
290    #[derive(Copy, Clone, Debug)]
291    pub struct L1L2Index<'a> {
292        block_count: usize,
293        index_data: &'a [L1L2Entry],
294    }
295
296    impl<'a> L1L2Indexes<'a> {
297        pub fn it_is_the_whole_index_honest(index: &'a [L1L2Entry]) -> Self {
298            L1L2Indexes(index)
299        }
300
301        pub fn inner_index(self, all_bits: BitsRef, l0_idx: usize) -> L1L2Index<'a> {
302            let start_idx = l0_idx * size::L1_BLOCKS_PER_L0_BLOCK;
303            let end_idx = min(start_idx + size::L1_BLOCKS_PER_L0_BLOCK, self.0.len());
304            let block_count_to_end =
305                size::blocks(all_bits.len()) - start_idx * size::L2_BLOCKS_PER_L1_BLOCK;
306            L1L2Index {
307                block_count: min(block_count_to_end, size::L2_BLOCKS_PER_L0_BLOCK),
308                index_data: &self.0[start_idx..end_idx],
309            }
310        }
311    }
312
313    impl<'a> L1L2Index<'a> {
314        pub fn len(self) -> usize {
315            self.block_count
316        }
317
318        pub fn rank_of_block<W: OnesOrZeros>(self, block_idx: usize) -> u64 {
319            if block_idx >= self.block_count {
320                panic!("Index out of bounds: not enough blocks");
321            }
322
323            let l1_idx = block_idx / size::L2_BLOCKS_PER_L1_BLOCK;
324            let l2_idx = block_idx % size::L2_BLOCKS_PER_L1_BLOCK;
325            let entry = self.index_data[l1_idx];
326            let l1_rank_ones = entry.base_rank();
327            let l2_rank_ones = {
328                let mut l2_rank = 0;
329                if l2_idx >= 3 {
330                    l2_rank += entry.l2_count(2)
331                }
332                if l2_idx >= 2 {
333                    l2_rank += entry.l2_count(1)
334                }
335                if l2_idx >= 1 {
336                    l2_rank += entry.l2_count(0)
337                }
338                l2_rank
339            };
340
341            W::convert_count(
342                l1_rank_ones + l2_rank_ones,
343                block_idx as u64 * size::BITS_PER_L2_BLOCK,
344            )
345        }
346    }
347}
348use self::structure::{L1L2Entry, L1L2Index, L1L2Indexes, SampleEntry};
349
350/// Calculate the storage size for an index for a given bitvector (*O(1)*).
351///
352/// This just looks at the number of bits in the bitvector and does some
353/// calculations. The number returned is the number of `u64`s needed to
354/// store the index.
355pub fn index_size_for(bits: BitsRef) -> usize {
356    size::total_index_words(bits.len())
357}
358
359/// Indicates the index storage was the wrong size for the bit vector it was used with.
360#[derive(Copy, Clone, Debug)]
361pub struct IndexSizeError;
362
363/// Check an index is the right size for a given bitvector.
364///
365/// This does not in any way guarantee the index was built for
366/// that bitvector, or that neither has been modified.
367pub fn check_index_size(index: &[u64], bits: BitsRef) -> Result<(), IndexSizeError> {
368    if index.len() != index_size_for(bits) {
369        Err(IndexSizeError)
370    } else {
371        Ok(())
372    }
373}
374
375/// Build the index data for a given bitvector (*O(n)*).
376pub fn build_index_for(bits: BitsRef, into: &mut [u64]) -> Result<(), IndexSizeError> {
377    check_index_size(into, bits)?;
378
379    if bits.len() == 0 {
380        return Ok(());
381    }
382
383    let (l0_index, index_after_l0) = structure::split_l0_mut(into, bits);
384    let (l1l2_index, index_after_l1l2) = structure::split_l1l2_mut(index_after_l0, bits);
385
386    // Build the L1L2 index, and get the L0 block bitcounts
387    bits.chunks_by_bytes(size::BYTES_PER_L0_BLOCK)
388        .zip(l1l2_index.chunks_mut(size::L1_BLOCKS_PER_L0_BLOCK))
389        .zip(l0_index.iter_mut())
390        .for_each(|((bits_chunk, l1l2_chunk), l0_entry)| {
391            *l0_entry = build_inner_l1l2(l1l2_chunk, bits_chunk)
392        });
393    let l1l2_index = L1L2Indexes::it_is_the_whole_index_honest(l1l2_index);
394
395    // Convert the L0 block bitcounts into the proper L0 index
396    let mut total_count_ones = 0u64;
397    for l0_entry in l0_index.iter_mut() {
398        total_count_ones += l0_entry.clone();
399        *l0_entry = total_count_ones;
400    }
401    let l0_index: &[u64] = l0_index;
402
403    // Build the select index
404    let (samples_ones, samples_zeros) =
405        structure::split_samples_mut(index_after_l1l2, bits, total_count_ones);
406    build_samples::<OneBits>(l0_index, l1l2_index, bits, samples_ones);
407    build_samples::<ZeroBits>(l0_index, l1l2_index, bits, samples_zeros);
408
409    Ok(())
410}
411
412/// Build the inner l1l2 index and return the total count of set bits.
413fn build_inner_l1l2(l1l2_index: &mut [L1L2Entry], data_chunk: BitsRef) -> u64 {
414    debug_assert!(data_chunk.len() > 0);
415    debug_assert!(data_chunk.len() <= size::BITS_PER_L0_BLOCK);
416    debug_assert!(l1l2_index.len() == size::l1l2(data_chunk.len()));
417
418    data_chunk
419        .chunks_by_bytes(size::BYTES_PER_L1_BLOCK)
420        .zip(l1l2_index.iter_mut())
421        .for_each(|(l1_chunk, write_to)| {
422            let mut counts = [0u16; 3];
423            let mut chunks = l1_chunk.chunks_by_bytes(size::BYTES_PER_L2_BLOCK);
424            let count_or_zero =
425                |opt: Option<BitsRef>| opt.map_or(0, |chunk| chunk.count_ones() as u16);
426
427            counts[0] = count_or_zero(chunks.next());
428            counts[1] = count_or_zero(chunks.next());
429            counts[2] = count_or_zero(chunks.next());
430            let mut total = count_or_zero(chunks.next());
431            total += counts[0];
432            total += counts[1];
433            total += counts[2];
434
435            *write_to = L1L2Entry::pack(total as u32, counts);
436        });
437
438    // Pass through reassigning each entry to hold its rank to finish.
439    let mut running_total = 0u64;
440    for entry in l1l2_index.iter_mut() {
441        let base_rank = running_total.clone() as u32;
442        running_total += entry.base_rank();
443        entry.set_base_rank(base_rank);
444    }
445
446    running_total
447}
448
449fn build_samples<W: OnesOrZeros>(
450    l0_index: &[u64],
451    l1l2_index: L1L2Indexes,
452    all_bits: BitsRef,
453    samples: &mut [SampleEntry],
454) {
455    build_samples_outer::<W>(
456        l0_index,
457        0,
458        l0_index.len(),
459        l1l2_index,
460        all_bits,
461        WithOffset::at_origin(samples),
462    )
463}
464
465fn build_samples_outer<W: OnesOrZeros>(
466    l0_index: &[u64],
467    low_l0_block: usize,
468    high_l0_block: usize,
469    l1l2_index: L1L2Indexes,
470    all_bits: BitsRef,
471    samples: WithOffset<&mut [SampleEntry]>,
472) {
473    if low_l0_block >= high_l0_block || samples.len() == 0 {
474        return;
475    } else if low_l0_block + 1 >= high_l0_block {
476        let l0_idx = low_l0_block;
477        let base_rank = read_l0_rank::<W>(l0_index, all_bits, l0_idx);
478        let inner_l1l2_index = l1l2_index.inner_index(all_bits, l0_idx);
479        return build_samples_inner::<W>(
480            base_rank,
481            inner_l1l2_index,
482            0,
483            inner_l1l2_index.len(),
484            samples,
485        );
486    }
487
488    debug_assert!(low_l0_block + 1 < high_l0_block);
489    let mid_l0_block = (low_l0_block + high_l0_block) / 2;
490    debug_assert!(mid_l0_block > low_l0_block);
491    debug_assert!(mid_l0_block < high_l0_block);
492
493    let samples_before_mid_l0_block =
494        size::samples_for_bits(read_l0_rank::<W>(l0_index, all_bits, mid_l0_block));
495    let (before_mid, after_mid) = samples.split_at_mut_from_origin(samples_before_mid_l0_block);
496
497    build_samples_outer::<W>(
498        l0_index,
499        low_l0_block,
500        mid_l0_block,
501        l1l2_index,
502        all_bits,
503        before_mid,
504    );
505    build_samples_outer::<W>(
506        l0_index,
507        mid_l0_block,
508        high_l0_block,
509        l1l2_index,
510        all_bits,
511        after_mid,
512    );
513}
514
515fn build_samples_inner<W: OnesOrZeros>(
516    base_rank: u64,
517    inner_l1l2_index: L1L2Index,
518    low_block: usize,
519    high_block: usize,
520    samples: WithOffset<&mut [SampleEntry]>,
521) {
522    if samples.len() == 0 {
523        return;
524    } else if samples.len() == 1 {
525        debug_assert!(high_block > low_block);
526        let target_rank = samples.offset_from_origin() as u64 * size::SAMPLE_LENGTH;
527        let target_rank_in_l0 = target_rank - base_rank;
528        let following_block_idx = binary_search(low_block, high_block, |block_idx| {
529            inner_l1l2_index.rank_of_block::<W>(block_idx) > target_rank_in_l0
530        });
531        debug_assert!(following_block_idx > low_block);
532        samples.decompose()[0] = SampleEntry::pack(following_block_idx - 1);
533        return;
534    }
535
536    debug_assert!(samples.len() > 1);
537    debug_assert!(low_block + 1 < high_block);
538    let mid_block = (low_block + high_block) / 2;
539    debug_assert!(mid_block > low_block);
540    debug_assert!(mid_block < high_block);
541
542    let samples_before_mid_block =
543        size::samples_for_bits(inner_l1l2_index.rank_of_block::<W>(mid_block) + base_rank);
544
545    let (before_mid, after_mid) = samples.split_at_mut_from_origin(samples_before_mid_block);
546
547    build_samples_inner::<W>(
548        base_rank,
549        inner_l1l2_index,
550        low_block,
551        mid_block,
552        before_mid,
553    );
554    build_samples_inner::<W>(
555        base_rank,
556        inner_l1l2_index,
557        mid_block,
558        high_block,
559        after_mid,
560    );
561}
562
563/// Count the set bits using the index (fast *O(1)*).
564#[inline]
565pub fn count_ones(index: &[u64], bits: BitsRef) -> u64 {
566    if bits.len() == 0 {
567        return 0;
568    }
569    let l0_size = size::l0(bits.len());
570    debug_assert!(l0_size > 0);
571    index[l0_size - 1]
572}
573
574/// Count the unset bits using the index (fast *O(1)*).
575#[inline]
576pub fn count_zeros(index: &[u64], bits: BitsRef) -> u64 {
577    ZeroBits::convert_count(count_ones(index, bits), bits.len())
578}
579
580fn read_l0_cumulative_count<W: OnesOrZeros>(l0_index: &[u64], bits: BitsRef, idx: usize) -> u64 {
581    let count_ones = l0_index[idx];
582    let total_count = if idx + 1 < l0_index.len() {
583        (idx as u64 + 1) * size::BITS_PER_L0_BLOCK
584    } else {
585        bits.len()
586    };
587    W::convert_count(count_ones, total_count)
588}
589
590fn read_l0_rank<W: OnesOrZeros>(l0_index: &[u64], bits: BitsRef, idx: usize) -> u64 {
591    if idx > 0 {
592        read_l0_cumulative_count::<W>(l0_index, bits, idx - 1)
593    } else {
594        0
595    }
596}
597
598/// Count the set bits before a position in the bits using the index (*O(1)*).
599///
600/// Returns `None` it the index is out of bounds.
601pub fn rank_ones(index: &[u64], all_bits: BitsRef, idx: u64) -> Option<u64> {
602    if idx >= all_bits.len() {
603        return None;
604    } else if idx == 0 {
605        return Some(0);
606    }
607
608    let (l0_index, index_after_l0) = structure::split_l0(index, all_bits);
609
610    let l0_idx = idx / size::BITS_PER_L0_BLOCK;
611    debug_assert!(l0_idx < l0_index.len() as u64);
612    let l0_idx = l0_idx as usize;
613    let l0_offset = idx % size::BITS_PER_L0_BLOCK;
614    let l0_rank = read_l0_rank::<OneBits>(l0_index, all_bits, l0_idx);
615
616    let (l1l2_index, _) = structure::split_l1l2(index_after_l0, all_bits);
617    let inner_l1l2_index = l1l2_index.inner_index(all_bits, l0_idx);
618
619    let block_idx = l0_offset / size::BITS_PER_L2_BLOCK;
620    debug_assert!(
621        block_idx < (inner_l1l2_index.len() as u64) * size::L2_BLOCKS_PER_L1_BLOCK as u64
622    );
623    let block_idx = block_idx as usize;
624    let block_offset = l0_offset % size::BITS_PER_L2_BLOCK;
625    let block_rank = inner_l1l2_index.rank_of_block::<OneBits>(block_idx);
626
627    let scan_skip_bytes = l0_idx * size::BYTES_PER_L0_BLOCK + block_idx * size::BYTES_PER_L2_BLOCK;
628    let scan_bits = all_bits.drop_bytes(scan_skip_bytes);
629    let scanned_rank = scan_bits
630        .rank_ones(block_offset)
631        .expect("Already checked size");
632    Some(l0_rank + block_rank + scanned_rank)
633}
634
635/// Count the unset bits before a position in the bits using the index (*O(1)*).
636///
637/// Returns `None` it the index is out of bounds.
638#[inline]
639pub fn rank_zeros(index: &[u64], bits: BitsRef, idx: u64) -> Option<u64> {
640    rank_ones(index, bits, idx).map(|res_ones| ZeroBits::convert_count(res_ones, idx))
641}
642
643/// Find the index *i* which partitions the input space into values
644/// satisfying the check and those which don't.
645///
646/// This assumes there is some *i* which is at least `from` and less
647/// than `until` such that `check(j) == (j >= i)`.
648fn binary_search<F>(from: usize, until: usize, check: F) -> usize
649where
650    F: Fn(usize) -> bool,
651{
652    const LINEAR_FOR_N: usize = 16;
653
654    let mut false_up_to = from;
655    let mut true_from = until;
656
657    while false_up_to + LINEAR_FOR_N < true_from {
658        let mid_ish = (false_up_to + true_from) / 2;
659        if check(mid_ish) {
660            true_from = mid_ish;
661        } else {
662            false_up_to = mid_ish + 1;
663        }
664    }
665
666    while false_up_to < true_from && !check(false_up_to) {
667        false_up_to += 1;
668    }
669    debug_assert!(false_up_to <= true_from);
670    debug_assert!(false_up_to == true_from || check(false_up_to));
671
672    return false_up_to;
673}
674
675fn select<W: OnesOrZeros>(index: &[u64], all_bits: BitsRef, target_rank: u64) -> Option<u64> {
676    if all_bits.len() == 0 {
677        return None;
678    }
679    let (l0_index, index_after_l0) = structure::split_l0(index, all_bits);
680    debug_assert!(l0_index.len() > 0);
681    let total_count_ones = l0_index[l0_index.len() - 1];
682    let total_count = W::convert_count(total_count_ones, all_bits.len());
683    if target_rank >= total_count {
684        return None;
685    }
686
687    // Find the right l0 block by binary search
688    let l0_idx = binary_search(0, l0_index.len(), |idx| {
689        read_l0_cumulative_count::<W>(l0_index, all_bits, idx) > target_rank
690    });
691    debug_assert!(l0_idx < l0_index.len());
692    let next_l0_block_rank = read_l0_cumulative_count::<W>(l0_index, all_bits, l0_idx);
693    debug_assert!(next_l0_block_rank > target_rank);
694    let l0_block_rank = read_l0_rank::<W>(l0_index, all_bits, l0_idx);
695    debug_assert!(l0_block_rank <= target_rank);
696    let target_rank_in_l0_block = target_rank - l0_block_rank;
697
698    // Unpack the other parts of the index
699    let (l1l2_index, index_after_l1l2) = structure::split_l1l2(index_after_l0, all_bits);
700    let inner_l1l2_index = l1l2_index.inner_index(all_bits, l0_idx);
701    debug_assert!(inner_l1l2_index.len() > 0);
702    let (select_ones_samples, select_zeros_samples) =
703        structure::split_samples(index_after_l1l2, all_bits, total_count_ones);
704    let select_samples = if W::is_ones() {
705        select_ones_samples
706    } else {
707        select_zeros_samples
708    };
709
710    // Use the samples to find bounds on which block can contain our target bit
711    let sample_idx = target_rank / size::SAMPLE_LENGTH;
712    let block_idx_should_be_at_least = {
713        let sample_rank = sample_idx * size::SAMPLE_LENGTH;
714        if sample_rank < l0_block_rank {
715            // Sample is from the previous l0 block
716            0
717        } else {
718            select_samples[sample_idx as usize].block_idx_in_l0_block()
719        }
720    };
721    let block_idx_should_be_less_than = {
722        let next_sample_idx = sample_idx + 1;
723        let next_sample_rank = next_sample_idx * size::SAMPLE_LENGTH;
724        if next_sample_rank >= next_l0_block_rank {
725            // Sample is in the next l0 block
726            inner_l1l2_index.len()
727        } else if next_sample_idx >= select_samples.len() as u64 {
728            // Sample does not exist
729            inner_l1l2_index.len()
730        } else {
731            select_samples[next_sample_idx as usize].block_idx_in_l0_block() + 1
732        }
733    };
734
735    let block_idx = {
736        let following_block_idx = binary_search(
737            block_idx_should_be_at_least,
738            block_idx_should_be_less_than,
739            |idx| inner_l1l2_index.rank_of_block::<W>(idx) > target_rank_in_l0_block,
740        );
741        debug_assert!(following_block_idx > 0);
742        following_block_idx - 1
743    };
744    let block_rank = inner_l1l2_index.rank_of_block::<W>(block_idx);
745    let target_rank_in_block = target_rank_in_l0_block - block_rank;
746
747    let scan_skip_bytes = l0_idx * size::BYTES_PER_L0_BLOCK + block_idx * size::BYTES_PER_L2_BLOCK;
748    let scan_bits = all_bits.drop_bytes(scan_skip_bytes);
749    let scanned_idx = scan_bits
750        .select::<W>(target_rank_in_block)
751        .expect("Already checked against total count");
752
753    Some(scan_skip_bytes as u64 * 8 + scanned_idx)
754}
755
756/// Find the position of a set bit by its rank using the index (*O(log n)*).
757///
758/// Returns `None` if no suitable bit is found. It is
759/// always the case otherwise that `rank_ones(index, result) == Some(target_rank)`
760/// and `get(result) == Some(true)`.
761pub fn select_ones(index: &[u64], all_bits: BitsRef, target_rank: u64) -> Option<u64> {
762    select::<OneBits>(index, all_bits, target_rank)
763}
764
765/// Find the position of an unset bit by its rank using the index (*O(log n)*).
766///
767/// Returns `None` if no suitable bit is found. It is
768/// always the case otherwise that `rank_zeros(index, result) == Some(target_rank)`
769/// and `get(result) == Some(false)`.
770pub fn select_zeros(index: &[u64], all_bits: BitsRef, target_rank: u64) -> Option<u64> {
771    select::<ZeroBits>(index, all_bits, target_rank)
772}
773
774#[cfg(test)]
775mod tests {
776    use super::*;
777    use std::vec::Vec;
778
779    #[test]
780    fn select_bug_issue_15() {
781        // When the bit we are selecting is in the same block as the next index sample
782        let mut data = vec![0xffu8; 8192 / 8 * 2];
783        data[8192 / 8 - 1] = 0;
784        let data = BitsRef::from_bytes(&data[..], 8192 * 2).unwrap();
785        let mut index = vec![0u64; index_size_for(data)];
786        build_index_for(data, &mut index).unwrap();
787        let index = index;
788        assert_eq!(select_ones(&index, data, 8191), Some(8199));
789    }
790
791    #[test]
792    fn small_indexed_tests() {
793        use rand::{Rng, RngCore, SeedableRng};
794        use rand_xorshift::XorShiftRng;
795        let n_bits: u64 = (1 << 19) - 1;
796        let n_bytes: usize = ceil_div_u64(n_bits, 8) as usize;
797        let seed = [
798            42, 73, 197, 231, 255, 43, 87, 05, 50, 13, 74, 107, 195, 231, 5, 1,
799        ];
800        let mut rng = XorShiftRng::from_seed(seed);
801        let data = {
802            let mut data = vec![0u8; n_bytes];
803            rng.fill_bytes(&mut data);
804            data
805        };
806        let data = BitsRef::from_bytes(&data[..], n_bits).expect("Should have enough bytes");
807        let index = {
808            let mut index = vec![0u64; index_size_for(data)];
809            build_index_for(data, &mut index).unwrap();
810            index
811        };
812
813        let expected_count_ones = data.count_ones();
814        let expected_count_zeros = n_bits - expected_count_ones;
815        assert_eq!(expected_count_ones, count_ones(&index, data));
816        assert_eq!(expected_count_zeros, count_zeros(&index, data));
817
818        assert_eq!(None, rank_ones(&index, data, n_bits));
819        assert_eq!(None, rank_zeros(&index, data, n_bits));
820
821        let rank_idxs = {
822            let mut idxs: Vec<u64> = (0..1000).map(|_| rng.gen_range(0, n_bits)).collect();
823            idxs.sort();
824            idxs
825        };
826        for idx in rank_idxs {
827            assert_eq!(data.rank_ones(idx), rank_ones(&index, data, idx));
828            assert_eq!(data.rank_zeros(idx), rank_zeros(&index, data, idx));
829        }
830
831        assert_eq!(None, select_ones(&index, data, expected_count_ones));
832        let one_ranks = {
833            let mut ranks: Vec<u64> = (0..1000)
834                .map(|_| rng.gen_range(0, expected_count_ones))
835                .collect();
836            ranks.sort();
837            ranks
838        };
839        for rank in one_ranks {
840            assert_eq!(data.select_ones(rank), select_ones(&index, data, rank));
841        }
842
843        assert_eq!(None, select_zeros(&index, data, expected_count_zeros));
844        let zero_ranks = {
845            let mut ranks: Vec<u64> = (0..1000)
846                .map(|_| rng.gen_range(0, expected_count_zeros))
847                .collect();
848            ranks.sort();
849            ranks
850        };
851        for rank in zero_ranks {
852            assert_eq!(data.select_zeros(rank), select_zeros(&index, data, rank));
853        }
854    }
855}