1use crate::bits_ref::BitsRef;
24use crate::ones_or_zeros::{OneBits, OnesOrZeros, ZeroBits};
25use crate::with_offset::WithOffset;
26use crate::{ceil_div, ceil_div_u64};
27use core::cmp::min;
28
29impl<'a> BitsRef<'a> {
30 fn chunks_by_bytes<'s>(&'s self, bytes_per_chunk: usize) -> impl Iterator<Item = BitsRef<'s>> {
32 let bits_per_chunk = (bytes_per_chunk as u64) * 8;
33 self.bytes()
34 .chunks(bytes_per_chunk)
35 .enumerate()
36 .map(move |(i, chunk)| {
37 let len = i as u64 * bits_per_chunk;
38 let bits = min(self.len() - len, bits_per_chunk);
39 BitsRef::from_bytes(chunk, bits).expect("Size invariant violated")
40 })
41 }
42
43 fn drop_bytes<'s>(&'s self, n_bytes: usize) -> BitsRef<'s> {
45 let bytes = self.bytes();
46 if n_bytes >= bytes.len() {
47 panic!("Index out of bounds: tried to drop all of the bits");
48 }
49 BitsRef::from_bytes(&bytes[n_bytes..], self.len() - (n_bytes as u64 * 8))
50 .expect("Checked sufficient bytes are present")
51 }
52}
53
54mod size {
55 use super::*;
56
57 pub const BITS_PER_L0_BLOCK: u64 = 1 << 32;
58 pub const BITS_PER_L1_BLOCK: u64 = BITS_PER_L2_BLOCK * 4;
59 pub const BITS_PER_L2_BLOCK: u64 = 512;
60
61 pub const BYTES_PER_L0_BLOCK: usize = (BITS_PER_L0_BLOCK / 8) as usize;
62 pub const BYTES_PER_L1_BLOCK: usize = (BITS_PER_L1_BLOCK / 8) as usize;
63 pub const BYTES_PER_L2_BLOCK: usize = (BITS_PER_L2_BLOCK / 8) as usize;
64
65 pub fn l0(total_bits: u64) -> usize {
66 ceil_div_u64(total_bits, BITS_PER_L0_BLOCK) as usize
67 }
68
69 pub fn l1l2(total_bits: u64) -> usize {
70 ceil_div_u64(total_bits, BITS_PER_L1_BLOCK) as usize
71 }
72
73 pub fn blocks(total_bits: u64) -> usize {
74 ceil_div_u64(total_bits, BITS_PER_L2_BLOCK) as usize
75 }
76
77 pub const SAMPLE_LENGTH: u64 = 8192;
78
79 pub fn samples_for_bits(matching_bitcount: u64) -> usize {
81 ceil_div_u64(matching_bitcount, SAMPLE_LENGTH) as usize
82 }
83 pub fn sample_words(total_bits: u64) -> usize {
85 ceil_div(samples_for_bits(total_bits) + 1, 2)
86 }
87
88 pub fn total_index_words(total_bits: u64) -> usize {
89 l0(total_bits) + l1l2(total_bits) + sample_words(total_bits)
90 }
91
92 pub const L1_BLOCKS_PER_L0_BLOCK: usize = (BITS_PER_L0_BLOCK / BITS_PER_L1_BLOCK) as usize;
93 pub const L2_BLOCKS_PER_L1_BLOCK: usize = (BITS_PER_L1_BLOCK / BITS_PER_L2_BLOCK) as usize;
94 pub const L2_BLOCKS_PER_L0_BLOCK: usize = L2_BLOCKS_PER_L1_BLOCK * L1_BLOCKS_PER_L0_BLOCK;
95
96 #[cfg(test)]
97 mod tests {
98 use super::*;
99
100 #[test]
101 fn bytes_evenly_divide_block_sizes() {
102 assert_eq!(BITS_PER_L0_BLOCK % 8, 0);
103 assert_eq!(BITS_PER_L1_BLOCK % 8, 0);
104 assert_eq!(BITS_PER_L2_BLOCK % 8, 0);
105 }
106
107 #[test]
108 fn l1l2_evenly_divide_l0() {
109 assert_eq!(BITS_PER_L0_BLOCK % BITS_PER_L1_BLOCK, 0);
113 assert_eq!(BITS_PER_L0_BLOCK % BITS_PER_L2_BLOCK, 0);
114 }
115
116 #[test]
117 fn block_sizes_evenly_divide() {
118 assert_eq!(BITS_PER_L0_BLOCK % BITS_PER_L1_BLOCK, 0);
119 assert_eq!(BITS_PER_L1_BLOCK % BITS_PER_L2_BLOCK, 0);
120 }
121
122 #[test]
123 fn sample_size_larger_than_l1() {
124 assert!(SAMPLE_LENGTH >= BITS_PER_L1_BLOCK);
126 }
127
128 #[test]
129 fn size_of_index_for_zero() {
130 assert_eq!(1, total_index_words(0));
131 }
132 }
133}
134
135mod structure {
136 use super::*;
137
138 #[derive(Copy, Clone, Debug)]
139 pub struct L1L2Entry(u64);
140
141 impl L1L2Entry {
142 pub fn pack(base_rank: u32, first_counts: [u16; 3]) -> Self {
143 debug_assert!(first_counts.iter().all(|&x| x < 0x0400));
144 L1L2Entry(
145 ((base_rank as u64) << 32)
146 | ((first_counts[0] as u64) << 22)
147 | ((first_counts[1] as u64) << 12)
148 | ((first_counts[2] as u64) << 2),
149 )
150 }
151
152 pub fn base_rank(self) -> u64 {
153 self.0 >> 32
154 }
155
156 fn fset_base_rank(self, base_rank: u32) -> Self {
157 L1L2Entry(((base_rank as u64) << 32) | self.0 & 0xffffffff)
158 }
159
160 pub fn set_base_rank(&mut self, base_rank: u32) {
161 *self = self.fset_base_rank(base_rank);
162 }
163
164 pub fn l2_count(self, i: usize) -> u64 {
165 let shift = 22 - i * 10;
166 (self.0 >> shift) & 0x3ff
167 }
168 }
169
170 #[derive(Copy, Clone, Debug)]
171 pub struct SampleEntry(u32);
172
173 impl SampleEntry {
174 pub fn pack(block_idx_in_l0_block: usize) -> Self {
175 debug_assert!(block_idx_in_l0_block <= u32::max_value() as usize);
176 SampleEntry(block_idx_in_l0_block as u32)
177 }
178
179 pub fn block_idx_in_l0_block(self) -> usize {
180 self.0 as usize
181 }
182 }
183
184 use core::mem::{align_of, size_of};
185
186 fn cast_to_l1l2<'a>(data: &'a [u64]) -> &'a [L1L2Entry] {
187 debug_assert_eq!(size_of::<u64>(), size_of::<L1L2Entry>());
188 debug_assert_eq!(align_of::<u64>(), align_of::<L1L2Entry>());
189
190 unsafe {
191 use core::slice::from_raw_parts;
192 let n = data.len();
193 let ptr = data.as_ptr() as *mut L1L2Entry;
194 from_raw_parts(ptr, n)
195 }
196 }
197
198 fn cast_to_l1l2_mut<'a>(data: &'a mut [u64]) -> &'a mut [L1L2Entry] {
199 debug_assert_eq!(size_of::<u64>(), size_of::<L1L2Entry>());
200 debug_assert_eq!(align_of::<u64>(), align_of::<L1L2Entry>());
201
202 unsafe {
203 use core::slice::from_raw_parts_mut;
204 let n = data.len();
205 let ptr = data.as_mut_ptr() as *mut L1L2Entry;
206 from_raw_parts_mut(ptr, n)
207 }
208 }
209
210 fn cast_to_samples<'a>(data: &'a [u64]) -> &'a [SampleEntry] {
211 debug_assert_eq!(size_of::<u64>(), 2 * size_of::<SampleEntry>());
212 debug_assert_eq!(align_of::<u64>(), 2 * align_of::<SampleEntry>());
213
214 unsafe {
215 use core::slice::from_raw_parts;
216 let n = data.len() * 2;
217 let ptr = data.as_ptr() as *const SampleEntry;
218 from_raw_parts(ptr, n)
219 }
220 }
221
222 fn cast_to_samples_mut<'a>(data: &'a mut [u64]) -> &'a mut [SampleEntry] {
223 debug_assert_eq!(size_of::<u64>(), 2 * size_of::<SampleEntry>());
224 debug_assert_eq!(align_of::<u64>(), 2 * align_of::<SampleEntry>());
225
226 unsafe {
227 use core::slice::from_raw_parts_mut;
228 let n = data.len() * 2;
229 let ptr = data.as_mut_ptr() as *mut SampleEntry;
230 from_raw_parts_mut(ptr, n)
231 }
232 }
233
234 pub fn split_l0<'a>(index: &'a [u64], data: BitsRef) -> (&'a [u64], &'a [u64]) {
235 index.split_at(size::l0(data.len()))
236 }
237
238 pub fn split_l0_mut<'a>(index: &'a mut [u64], data: BitsRef) -> (&'a mut [u64], &'a mut [u64]) {
239 index.split_at_mut(size::l0(data.len()))
240 }
241
242 #[derive(Copy, Clone, Debug)]
243 pub struct L1L2Indexes<'a>(&'a [L1L2Entry]);
244
245 pub fn split_l1l2<'a>(
246 index_after_l0: &'a [u64],
247 data: BitsRef,
248 ) -> (L1L2Indexes<'a>, &'a [u64]) {
249 let (l1l2, other) = index_after_l0.split_at(size::l1l2(data.len()));
250 (L1L2Indexes(cast_to_l1l2(l1l2)), other)
251 }
252
253 pub fn split_l1l2_mut<'a>(
254 index_after_l0: &'a mut [u64],
255 data: BitsRef,
256 ) -> (&'a mut [L1L2Entry], &'a mut [u64]) {
257 let (l1l2, other) = index_after_l0.split_at_mut(size::l1l2(data.len()));
258 (cast_to_l1l2_mut(l1l2), other)
259 }
260
261 pub fn split_samples<'a>(
262 index_after_l1l2: &'a [u64],
263 data: BitsRef,
264 count_ones: u64,
265 ) -> (&'a [SampleEntry], &'a [SampleEntry]) {
266 let all_samples = cast_to_samples(index_after_l1l2);
267 let n_samples_ones = size::samples_for_bits(count_ones);
268 let n_samples_zeros = size::samples_for_bits(data.len() - count_ones);
269 let (ones_samples, other_samples) = all_samples.split_at(n_samples_ones);
270 let zeros_samples = &other_samples[..n_samples_zeros];
271 (ones_samples, zeros_samples)
272 }
273
274 pub fn split_samples_mut<'a>(
275 index_after_l1l2: &'a mut [u64],
276 data: BitsRef,
277 count_ones: u64,
278 ) -> (&'a mut [SampleEntry], &'a mut [SampleEntry]) {
279 debug_assert!(index_after_l1l2.len() == size::sample_words(data.len()));
280 let all_samples = cast_to_samples_mut(index_after_l1l2);
281 let n_samples_ones = size::samples_for_bits(count_ones);
282 let n_samples_zeros = size::samples_for_bits(data.len() - count_ones);
283 debug_assert!(all_samples.len() >= n_samples_ones + n_samples_zeros);
284 debug_assert!(all_samples.len() <= n_samples_ones + n_samples_zeros + 2);
285 let (ones_samples, other_samples) = all_samples.split_at_mut(n_samples_ones);
286 let zeros_samples = &mut other_samples[..n_samples_zeros];
287 (ones_samples, zeros_samples)
288 }
289
290 #[derive(Copy, Clone, Debug)]
291 pub struct L1L2Index<'a> {
292 block_count: usize,
293 index_data: &'a [L1L2Entry],
294 }
295
296 impl<'a> L1L2Indexes<'a> {
297 pub fn it_is_the_whole_index_honest(index: &'a [L1L2Entry]) -> Self {
298 L1L2Indexes(index)
299 }
300
301 pub fn inner_index(self, all_bits: BitsRef, l0_idx: usize) -> L1L2Index<'a> {
302 let start_idx = l0_idx * size::L1_BLOCKS_PER_L0_BLOCK;
303 let end_idx = min(start_idx + size::L1_BLOCKS_PER_L0_BLOCK, self.0.len());
304 let block_count_to_end =
305 size::blocks(all_bits.len()) - start_idx * size::L2_BLOCKS_PER_L1_BLOCK;
306 L1L2Index {
307 block_count: min(block_count_to_end, size::L2_BLOCKS_PER_L0_BLOCK),
308 index_data: &self.0[start_idx..end_idx],
309 }
310 }
311 }
312
313 impl<'a> L1L2Index<'a> {
314 pub fn len(self) -> usize {
315 self.block_count
316 }
317
318 pub fn rank_of_block<W: OnesOrZeros>(self, block_idx: usize) -> u64 {
319 if block_idx >= self.block_count {
320 panic!("Index out of bounds: not enough blocks");
321 }
322
323 let l1_idx = block_idx / size::L2_BLOCKS_PER_L1_BLOCK;
324 let l2_idx = block_idx % size::L2_BLOCKS_PER_L1_BLOCK;
325 let entry = self.index_data[l1_idx];
326 let l1_rank_ones = entry.base_rank();
327 let l2_rank_ones = {
328 let mut l2_rank = 0;
329 if l2_idx >= 3 {
330 l2_rank += entry.l2_count(2)
331 }
332 if l2_idx >= 2 {
333 l2_rank += entry.l2_count(1)
334 }
335 if l2_idx >= 1 {
336 l2_rank += entry.l2_count(0)
337 }
338 l2_rank
339 };
340
341 W::convert_count(
342 l1_rank_ones + l2_rank_ones,
343 block_idx as u64 * size::BITS_PER_L2_BLOCK,
344 )
345 }
346 }
347}
348use self::structure::{L1L2Entry, L1L2Index, L1L2Indexes, SampleEntry};
349
350pub fn index_size_for(bits: BitsRef) -> usize {
356 size::total_index_words(bits.len())
357}
358
359#[derive(Copy, Clone, Debug)]
361pub struct IndexSizeError;
362
363pub fn check_index_size(index: &[u64], bits: BitsRef) -> Result<(), IndexSizeError> {
368 if index.len() != index_size_for(bits) {
369 Err(IndexSizeError)
370 } else {
371 Ok(())
372 }
373}
374
375pub fn build_index_for(bits: BitsRef, into: &mut [u64]) -> Result<(), IndexSizeError> {
377 check_index_size(into, bits)?;
378
379 if bits.len() == 0 {
380 return Ok(());
381 }
382
383 let (l0_index, index_after_l0) = structure::split_l0_mut(into, bits);
384 let (l1l2_index, index_after_l1l2) = structure::split_l1l2_mut(index_after_l0, bits);
385
386 bits.chunks_by_bytes(size::BYTES_PER_L0_BLOCK)
388 .zip(l1l2_index.chunks_mut(size::L1_BLOCKS_PER_L0_BLOCK))
389 .zip(l0_index.iter_mut())
390 .for_each(|((bits_chunk, l1l2_chunk), l0_entry)| {
391 *l0_entry = build_inner_l1l2(l1l2_chunk, bits_chunk)
392 });
393 let l1l2_index = L1L2Indexes::it_is_the_whole_index_honest(l1l2_index);
394
395 let mut total_count_ones = 0u64;
397 for l0_entry in l0_index.iter_mut() {
398 total_count_ones += l0_entry.clone();
399 *l0_entry = total_count_ones;
400 }
401 let l0_index: &[u64] = l0_index;
402
403 let (samples_ones, samples_zeros) =
405 structure::split_samples_mut(index_after_l1l2, bits, total_count_ones);
406 build_samples::<OneBits>(l0_index, l1l2_index, bits, samples_ones);
407 build_samples::<ZeroBits>(l0_index, l1l2_index, bits, samples_zeros);
408
409 Ok(())
410}
411
412fn build_inner_l1l2(l1l2_index: &mut [L1L2Entry], data_chunk: BitsRef) -> u64 {
414 debug_assert!(data_chunk.len() > 0);
415 debug_assert!(data_chunk.len() <= size::BITS_PER_L0_BLOCK);
416 debug_assert!(l1l2_index.len() == size::l1l2(data_chunk.len()));
417
418 data_chunk
419 .chunks_by_bytes(size::BYTES_PER_L1_BLOCK)
420 .zip(l1l2_index.iter_mut())
421 .for_each(|(l1_chunk, write_to)| {
422 let mut counts = [0u16; 3];
423 let mut chunks = l1_chunk.chunks_by_bytes(size::BYTES_PER_L2_BLOCK);
424 let count_or_zero =
425 |opt: Option<BitsRef>| opt.map_or(0, |chunk| chunk.count_ones() as u16);
426
427 counts[0] = count_or_zero(chunks.next());
428 counts[1] = count_or_zero(chunks.next());
429 counts[2] = count_or_zero(chunks.next());
430 let mut total = count_or_zero(chunks.next());
431 total += counts[0];
432 total += counts[1];
433 total += counts[2];
434
435 *write_to = L1L2Entry::pack(total as u32, counts);
436 });
437
438 let mut running_total = 0u64;
440 for entry in l1l2_index.iter_mut() {
441 let base_rank = running_total.clone() as u32;
442 running_total += entry.base_rank();
443 entry.set_base_rank(base_rank);
444 }
445
446 running_total
447}
448
449fn build_samples<W: OnesOrZeros>(
450 l0_index: &[u64],
451 l1l2_index: L1L2Indexes,
452 all_bits: BitsRef,
453 samples: &mut [SampleEntry],
454) {
455 build_samples_outer::<W>(
456 l0_index,
457 0,
458 l0_index.len(),
459 l1l2_index,
460 all_bits,
461 WithOffset::at_origin(samples),
462 )
463}
464
465fn build_samples_outer<W: OnesOrZeros>(
466 l0_index: &[u64],
467 low_l0_block: usize,
468 high_l0_block: usize,
469 l1l2_index: L1L2Indexes,
470 all_bits: BitsRef,
471 samples: WithOffset<&mut [SampleEntry]>,
472) {
473 if low_l0_block >= high_l0_block || samples.len() == 0 {
474 return;
475 } else if low_l0_block + 1 >= high_l0_block {
476 let l0_idx = low_l0_block;
477 let base_rank = read_l0_rank::<W>(l0_index, all_bits, l0_idx);
478 let inner_l1l2_index = l1l2_index.inner_index(all_bits, l0_idx);
479 return build_samples_inner::<W>(
480 base_rank,
481 inner_l1l2_index,
482 0,
483 inner_l1l2_index.len(),
484 samples,
485 );
486 }
487
488 debug_assert!(low_l0_block + 1 < high_l0_block);
489 let mid_l0_block = (low_l0_block + high_l0_block) / 2;
490 debug_assert!(mid_l0_block > low_l0_block);
491 debug_assert!(mid_l0_block < high_l0_block);
492
493 let samples_before_mid_l0_block =
494 size::samples_for_bits(read_l0_rank::<W>(l0_index, all_bits, mid_l0_block));
495 let (before_mid, after_mid) = samples.split_at_mut_from_origin(samples_before_mid_l0_block);
496
497 build_samples_outer::<W>(
498 l0_index,
499 low_l0_block,
500 mid_l0_block,
501 l1l2_index,
502 all_bits,
503 before_mid,
504 );
505 build_samples_outer::<W>(
506 l0_index,
507 mid_l0_block,
508 high_l0_block,
509 l1l2_index,
510 all_bits,
511 after_mid,
512 );
513}
514
515fn build_samples_inner<W: OnesOrZeros>(
516 base_rank: u64,
517 inner_l1l2_index: L1L2Index,
518 low_block: usize,
519 high_block: usize,
520 samples: WithOffset<&mut [SampleEntry]>,
521) {
522 if samples.len() == 0 {
523 return;
524 } else if samples.len() == 1 {
525 debug_assert!(high_block > low_block);
526 let target_rank = samples.offset_from_origin() as u64 * size::SAMPLE_LENGTH;
527 let target_rank_in_l0 = target_rank - base_rank;
528 let following_block_idx = binary_search(low_block, high_block, |block_idx| {
529 inner_l1l2_index.rank_of_block::<W>(block_idx) > target_rank_in_l0
530 });
531 debug_assert!(following_block_idx > low_block);
532 samples.decompose()[0] = SampleEntry::pack(following_block_idx - 1);
533 return;
534 }
535
536 debug_assert!(samples.len() > 1);
537 debug_assert!(low_block + 1 < high_block);
538 let mid_block = (low_block + high_block) / 2;
539 debug_assert!(mid_block > low_block);
540 debug_assert!(mid_block < high_block);
541
542 let samples_before_mid_block =
543 size::samples_for_bits(inner_l1l2_index.rank_of_block::<W>(mid_block) + base_rank);
544
545 let (before_mid, after_mid) = samples.split_at_mut_from_origin(samples_before_mid_block);
546
547 build_samples_inner::<W>(
548 base_rank,
549 inner_l1l2_index,
550 low_block,
551 mid_block,
552 before_mid,
553 );
554 build_samples_inner::<W>(
555 base_rank,
556 inner_l1l2_index,
557 mid_block,
558 high_block,
559 after_mid,
560 );
561}
562
563#[inline]
565pub fn count_ones(index: &[u64], bits: BitsRef) -> u64 {
566 if bits.len() == 0 {
567 return 0;
568 }
569 let l0_size = size::l0(bits.len());
570 debug_assert!(l0_size > 0);
571 index[l0_size - 1]
572}
573
574#[inline]
576pub fn count_zeros(index: &[u64], bits: BitsRef) -> u64 {
577 ZeroBits::convert_count(count_ones(index, bits), bits.len())
578}
579
580fn read_l0_cumulative_count<W: OnesOrZeros>(l0_index: &[u64], bits: BitsRef, idx: usize) -> u64 {
581 let count_ones = l0_index[idx];
582 let total_count = if idx + 1 < l0_index.len() {
583 (idx as u64 + 1) * size::BITS_PER_L0_BLOCK
584 } else {
585 bits.len()
586 };
587 W::convert_count(count_ones, total_count)
588}
589
590fn read_l0_rank<W: OnesOrZeros>(l0_index: &[u64], bits: BitsRef, idx: usize) -> u64 {
591 if idx > 0 {
592 read_l0_cumulative_count::<W>(l0_index, bits, idx - 1)
593 } else {
594 0
595 }
596}
597
598pub fn rank_ones(index: &[u64], all_bits: BitsRef, idx: u64) -> Option<u64> {
602 if idx >= all_bits.len() {
603 return None;
604 } else if idx == 0 {
605 return Some(0);
606 }
607
608 let (l0_index, index_after_l0) = structure::split_l0(index, all_bits);
609
610 let l0_idx = idx / size::BITS_PER_L0_BLOCK;
611 debug_assert!(l0_idx < l0_index.len() as u64);
612 let l0_idx = l0_idx as usize;
613 let l0_offset = idx % size::BITS_PER_L0_BLOCK;
614 let l0_rank = read_l0_rank::<OneBits>(l0_index, all_bits, l0_idx);
615
616 let (l1l2_index, _) = structure::split_l1l2(index_after_l0, all_bits);
617 let inner_l1l2_index = l1l2_index.inner_index(all_bits, l0_idx);
618
619 let block_idx = l0_offset / size::BITS_PER_L2_BLOCK;
620 debug_assert!(
621 block_idx < (inner_l1l2_index.len() as u64) * size::L2_BLOCKS_PER_L1_BLOCK as u64
622 );
623 let block_idx = block_idx as usize;
624 let block_offset = l0_offset % size::BITS_PER_L2_BLOCK;
625 let block_rank = inner_l1l2_index.rank_of_block::<OneBits>(block_idx);
626
627 let scan_skip_bytes = l0_idx * size::BYTES_PER_L0_BLOCK + block_idx * size::BYTES_PER_L2_BLOCK;
628 let scan_bits = all_bits.drop_bytes(scan_skip_bytes);
629 let scanned_rank = scan_bits
630 .rank_ones(block_offset)
631 .expect("Already checked size");
632 Some(l0_rank + block_rank + scanned_rank)
633}
634
635#[inline]
639pub fn rank_zeros(index: &[u64], bits: BitsRef, idx: u64) -> Option<u64> {
640 rank_ones(index, bits, idx).map(|res_ones| ZeroBits::convert_count(res_ones, idx))
641}
642
643fn binary_search<F>(from: usize, until: usize, check: F) -> usize
649where
650 F: Fn(usize) -> bool,
651{
652 const LINEAR_FOR_N: usize = 16;
653
654 let mut false_up_to = from;
655 let mut true_from = until;
656
657 while false_up_to + LINEAR_FOR_N < true_from {
658 let mid_ish = (false_up_to + true_from) / 2;
659 if check(mid_ish) {
660 true_from = mid_ish;
661 } else {
662 false_up_to = mid_ish + 1;
663 }
664 }
665
666 while false_up_to < true_from && !check(false_up_to) {
667 false_up_to += 1;
668 }
669 debug_assert!(false_up_to <= true_from);
670 debug_assert!(false_up_to == true_from || check(false_up_to));
671
672 return false_up_to;
673}
674
675fn select<W: OnesOrZeros>(index: &[u64], all_bits: BitsRef, target_rank: u64) -> Option<u64> {
676 if all_bits.len() == 0 {
677 return None;
678 }
679 let (l0_index, index_after_l0) = structure::split_l0(index, all_bits);
680 debug_assert!(l0_index.len() > 0);
681 let total_count_ones = l0_index[l0_index.len() - 1];
682 let total_count = W::convert_count(total_count_ones, all_bits.len());
683 if target_rank >= total_count {
684 return None;
685 }
686
687 let l0_idx = binary_search(0, l0_index.len(), |idx| {
689 read_l0_cumulative_count::<W>(l0_index, all_bits, idx) > target_rank
690 });
691 debug_assert!(l0_idx < l0_index.len());
692 let next_l0_block_rank = read_l0_cumulative_count::<W>(l0_index, all_bits, l0_idx);
693 debug_assert!(next_l0_block_rank > target_rank);
694 let l0_block_rank = read_l0_rank::<W>(l0_index, all_bits, l0_idx);
695 debug_assert!(l0_block_rank <= target_rank);
696 let target_rank_in_l0_block = target_rank - l0_block_rank;
697
698 let (l1l2_index, index_after_l1l2) = structure::split_l1l2(index_after_l0, all_bits);
700 let inner_l1l2_index = l1l2_index.inner_index(all_bits, l0_idx);
701 debug_assert!(inner_l1l2_index.len() > 0);
702 let (select_ones_samples, select_zeros_samples) =
703 structure::split_samples(index_after_l1l2, all_bits, total_count_ones);
704 let select_samples = if W::is_ones() {
705 select_ones_samples
706 } else {
707 select_zeros_samples
708 };
709
710 let sample_idx = target_rank / size::SAMPLE_LENGTH;
712 let block_idx_should_be_at_least = {
713 let sample_rank = sample_idx * size::SAMPLE_LENGTH;
714 if sample_rank < l0_block_rank {
715 0
717 } else {
718 select_samples[sample_idx as usize].block_idx_in_l0_block()
719 }
720 };
721 let block_idx_should_be_less_than = {
722 let next_sample_idx = sample_idx + 1;
723 let next_sample_rank = next_sample_idx * size::SAMPLE_LENGTH;
724 if next_sample_rank >= next_l0_block_rank {
725 inner_l1l2_index.len()
727 } else if next_sample_idx >= select_samples.len() as u64 {
728 inner_l1l2_index.len()
730 } else {
731 select_samples[next_sample_idx as usize].block_idx_in_l0_block() + 1
732 }
733 };
734
735 let block_idx = {
736 let following_block_idx = binary_search(
737 block_idx_should_be_at_least,
738 block_idx_should_be_less_than,
739 |idx| inner_l1l2_index.rank_of_block::<W>(idx) > target_rank_in_l0_block,
740 );
741 debug_assert!(following_block_idx > 0);
742 following_block_idx - 1
743 };
744 let block_rank = inner_l1l2_index.rank_of_block::<W>(block_idx);
745 let target_rank_in_block = target_rank_in_l0_block - block_rank;
746
747 let scan_skip_bytes = l0_idx * size::BYTES_PER_L0_BLOCK + block_idx * size::BYTES_PER_L2_BLOCK;
748 let scan_bits = all_bits.drop_bytes(scan_skip_bytes);
749 let scanned_idx = scan_bits
750 .select::<W>(target_rank_in_block)
751 .expect("Already checked against total count");
752
753 Some(scan_skip_bytes as u64 * 8 + scanned_idx)
754}
755
756pub fn select_ones(index: &[u64], all_bits: BitsRef, target_rank: u64) -> Option<u64> {
762 select::<OneBits>(index, all_bits, target_rank)
763}
764
765pub fn select_zeros(index: &[u64], all_bits: BitsRef, target_rank: u64) -> Option<u64> {
771 select::<ZeroBits>(index, all_bits, target_rank)
772}
773
774#[cfg(test)]
775mod tests {
776 use super::*;
777 use std::vec::Vec;
778
779 #[test]
780 fn select_bug_issue_15() {
781 let mut data = vec![0xffu8; 8192 / 8 * 2];
783 data[8192 / 8 - 1] = 0;
784 let data = BitsRef::from_bytes(&data[..], 8192 * 2).unwrap();
785 let mut index = vec![0u64; index_size_for(data)];
786 build_index_for(data, &mut index).unwrap();
787 let index = index;
788 assert_eq!(select_ones(&index, data, 8191), Some(8199));
789 }
790
791 #[test]
792 fn small_indexed_tests() {
793 use rand::{Rng, RngCore, SeedableRng};
794 use rand_xorshift::XorShiftRng;
795 let n_bits: u64 = (1 << 19) - 1;
796 let n_bytes: usize = ceil_div_u64(n_bits, 8) as usize;
797 let seed = [
798 42, 73, 197, 231, 255, 43, 87, 05, 50, 13, 74, 107, 195, 231, 5, 1,
799 ];
800 let mut rng = XorShiftRng::from_seed(seed);
801 let data = {
802 let mut data = vec![0u8; n_bytes];
803 rng.fill_bytes(&mut data);
804 data
805 };
806 let data = BitsRef::from_bytes(&data[..], n_bits).expect("Should have enough bytes");
807 let index = {
808 let mut index = vec![0u64; index_size_for(data)];
809 build_index_for(data, &mut index).unwrap();
810 index
811 };
812
813 let expected_count_ones = data.count_ones();
814 let expected_count_zeros = n_bits - expected_count_ones;
815 assert_eq!(expected_count_ones, count_ones(&index, data));
816 assert_eq!(expected_count_zeros, count_zeros(&index, data));
817
818 assert_eq!(None, rank_ones(&index, data, n_bits));
819 assert_eq!(None, rank_zeros(&index, data, n_bits));
820
821 let rank_idxs = {
822 let mut idxs: Vec<u64> = (0..1000).map(|_| rng.gen_range(0, n_bits)).collect();
823 idxs.sort();
824 idxs
825 };
826 for idx in rank_idxs {
827 assert_eq!(data.rank_ones(idx), rank_ones(&index, data, idx));
828 assert_eq!(data.rank_zeros(idx), rank_zeros(&index, data, idx));
829 }
830
831 assert_eq!(None, select_ones(&index, data, expected_count_ones));
832 let one_ranks = {
833 let mut ranks: Vec<u64> = (0..1000)
834 .map(|_| rng.gen_range(0, expected_count_ones))
835 .collect();
836 ranks.sort();
837 ranks
838 };
839 for rank in one_ranks {
840 assert_eq!(data.select_ones(rank), select_ones(&index, data, rank));
841 }
842
843 assert_eq!(None, select_zeros(&index, data, expected_count_zeros));
844 let zero_ranks = {
845 let mut ranks: Vec<u64> = (0..1000)
846 .map(|_| rng.gen_range(0, expected_count_zeros))
847 .collect();
848 ranks.sort();
849 ranks
850 };
851 for rank in zero_ranks {
852 assert_eq!(data.select_zeros(rank), select_zeros(&index, data, rank));
853 }
854 }
855}