use crate::utils::SelectInWord;
use ambassador::Delegate;
use mem_dbg::{MemDbg, MemSize};
use num_primitive::PrimitiveInteger;
use std::{
cmp::{max, min},
ops::Deref,
};
use crate::{
prelude::{BitCount, BitLength, Select, SelectHinted},
traits::{
Backend, NumBits, Rank, RankHinted, RankUnchecked, RankZero, SelectUnchecked, SelectZero,
SelectZeroHinted, SelectZeroUnchecked, Word,
},
};
use crate::ambassador_impl_Index;
use crate::traits::ambassador_impl_Backend;
use crate::traits::bal_paren::{BalParen, ambassador_impl_BalParen};
use crate::traits::bit_vec_ops::ambassador_impl_BitLength;
use crate::traits::rank_sel::ambassador_impl_BitCount;
use crate::traits::rank_sel::ambassador_impl_NumBits;
use crate::traits::rank_sel::ambassador_impl_Rank;
use crate::traits::rank_sel::ambassador_impl_RankHinted;
use crate::traits::rank_sel::ambassador_impl_RankUnchecked;
use crate::traits::rank_sel::ambassador_impl_RankZero;
use crate::traits::rank_sel::ambassador_impl_SelectHinted;
use crate::traits::rank_sel::ambassador_impl_SelectZero;
use crate::traits::rank_sel::ambassador_impl_SelectZeroHinted;
use crate::traits::rank_sel::ambassador_impl_SelectZeroUnchecked;
use std::ops::Index;
#[derive(Debug, Clone, MemSize, MemDbg, Delegate)]
#[cfg_attr(feature = "epserde", derive(epserde::Epserde))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[delegate(crate::traits::Backend, target = "bits")]
#[delegate(Index<usize>, target = "bits")]
#[delegate(crate::traits::rank_sel::BitCount, target = "bits")]
#[delegate(crate::traits::bit_vec_ops::BitLength, target = "bits")]
#[delegate(crate::traits::rank_sel::NumBits, target = "bits")]
#[delegate(crate::traits::rank_sel::Rank, target = "bits")]
#[delegate(crate::traits::rank_sel::RankHinted, target = "bits")]
#[delegate(crate::traits::rank_sel::RankUnchecked, target = "bits")]
#[delegate(crate::traits::rank_sel::RankZero, target = "bits")]
#[delegate(crate::traits::rank_sel::SelectHinted, target = "bits")]
#[delegate(crate::traits::rank_sel::SelectZero, target = "bits")]
#[delegate(crate::traits::rank_sel::SelectZeroHinted, target = "bits")]
#[delegate(crate::traits::rank_sel::SelectZeroUnchecked, target = "bits")]
#[delegate(crate::bal_paren::BalParen, target = "bits")]
pub struct SelectAdapt<B, I = Box<[usize]>> {
bits: B,
inventory: I,
spill: I,
log2_ones_per_inventory: usize,
log2_ones_per_sub16: usize,
log2_words_per_subinventory: usize,
ones_per_inventory_mask: usize,
ones_per_sub16_mask: usize,
}
impl<B: Backend + AsRef<[B::Word]>, I> AsRef<[B::Word]> for SelectAdapt<B, I> {
#[inline(always)]
fn as_ref(&self) -> &[B::Word] {
self.bits.as_ref()
}
}
impl<B, I> Deref for SelectAdapt<B, I> {
type Target = B;
#[inline(always)]
fn deref(&self) -> &Self::Target {
&self.bits
}
}
pub const DEFAULT_LOG2_WORDS_PER_SUBINVENTORY: usize = 3;
pub const fn default_target_inventory_span(log2_words_per_subinventory: usize) -> usize {
((usize::BITS as usize * usize::BITS as usize) / 4) << log2_words_per_subinventory
}
#[cfg(target_pointer_width = "64")]
pub(super) const MAX_INVENTORY_BITS: usize = usize::MAX >> 2;
#[cfg(target_pointer_width = "32")]
pub(super) const MAX_INVENTORY_BITS: usize = usize::MAX >> 1;
pub(super) const LOG2_U16_PER_USIZE: usize = (usize::BITS / 16).ilog2() as usize;
pub(super) const U32_PER_USIZE: usize = (usize::BITS / 32) as usize;
#[inline]
pub(super) fn assert_inventory_length(len: usize) {
assert!(
len <= MAX_INVENTORY_BITS,
"Bit vector length ({len}) exceeds the maximum representable \
inventory value ({MAX_INVENTORY_BITS})"
);
}
pub(super) trait Inventory {
fn is_u16_span(&self) -> bool;
fn is_u32_span(&self) -> bool;
#[cfg(target_pointer_width = "64")]
fn is_u64_span(&self) -> bool;
fn set_u16_span(&mut self);
fn set_u32_span(&mut self);
#[cfg(target_pointer_width = "64")]
fn set_u64_span(&mut self);
fn get(&self) -> usize;
}
impl Inventory for usize {
#[inline(always)]
fn is_u16_span(&self) -> bool {
*self >> (usize::BITS - 1) == 0
}
#[cfg(target_pointer_width = "64")]
#[inline(always)]
fn is_u32_span(&self) -> bool {
*self >> (usize::BITS - 2) == 2
}
#[cfg(target_pointer_width = "32")]
#[inline(always)]
fn is_u32_span(&self) -> bool {
*self >> (usize::BITS - 1) == 1
}
#[cfg(target_pointer_width = "64")]
#[inline(always)]
fn is_u64_span(&self) -> bool {
*self >> (usize::BITS - 2) == 3
}
#[inline(always)]
fn set_u16_span(&mut self) {}
#[inline(always)]
fn set_u32_span(&mut self) {
*self |= 1 << (usize::BITS - 1);
}
#[cfg(target_pointer_width = "64")]
#[inline(always)]
fn set_u64_span(&mut self) {
*self |= 3 << (usize::BITS - 2);
}
#[cfg(target_pointer_width = "64")]
#[inline(always)]
fn get(&self) -> usize {
*self & (usize::MAX >> 2)
}
#[cfg(target_pointer_width = "32")]
#[inline(always)]
fn get(&self) -> usize {
*self & (usize::MAX >> 1)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(super) enum SpanType {
U16,
U32,
#[cfg(target_pointer_width = "64")]
U64,
}
impl SpanType {
pub fn from_span(x: usize) -> SpanType {
match x {
0..=0x10000 => SpanType::U16,
#[cfg(not(target_pointer_width = "64"))]
_ => SpanType::U32,
#[cfg(target_pointer_width = "64")]
0x10001..=0x100000000 => SpanType::U32,
#[cfg(target_pointer_width = "64")]
_ => SpanType::U64,
}
}
}
impl<B, I> SelectAdapt<B, I> {
pub fn into_inner(self) -> B {
self.bits
}
#[inline(always)]
const fn log2_ones_per_sub32(span: usize, log2_ones_per_sub16: usize) -> usize {
debug_assert!(span > 1 << 16);
log2_ones_per_sub16.saturating_sub((span >> 15).ilog2() as usize + 1)
}
pub unsafe fn map<C: SelectHinted>(self, f: impl FnOnce(B) -> C) -> SelectAdapt<C, I> {
SelectAdapt {
bits: f(self.bits),
inventory: self.inventory,
spill: self.spill,
log2_ones_per_inventory: self.log2_ones_per_inventory,
log2_ones_per_sub16: self.log2_ones_per_sub16,
log2_words_per_subinventory: self.log2_words_per_subinventory,
ones_per_inventory_mask: self.ones_per_inventory_mask,
ones_per_sub16_mask: self.ones_per_sub16_mask,
}
}
}
impl<B: BitLength, C> SelectAdapt<B, C> {
#[inline(always)]
pub fn len(&self) -> usize {
BitLength::len(self)
}
}
impl<B: Backend<Word: Word + SelectInWord> + AsRef<[B::Word]> + BitCount>
SelectAdapt<B, Box<[usize]>>
{
#[must_use]
pub fn new(bits: B) -> Self {
Self::with_span(
bits,
default_target_inventory_span(DEFAULT_LOG2_WORDS_PER_SUBINVENTORY),
DEFAULT_LOG2_WORDS_PER_SUBINVENTORY,
)
}
#[must_use]
pub fn with_span(
bits: B,
target_inventory_span: usize,
max_log2_words_per_subinventory: usize,
) -> Self {
assert_inventory_length(bits.len());
let num_bits = max(1usize, bits.len());
let num_ones = bits.count_ones();
let log2_ones_per_inventory = (num_ones as u128 * target_inventory_span as u128)
.div_ceil(num_bits as u128)
.max(1)
.ilog2() as usize;
Self::_new(
bits,
num_ones,
log2_ones_per_inventory,
max_log2_words_per_subinventory,
)
}
pub fn with_inv(
bits: B,
log2_ones_per_inventory: usize,
max_log2_words_per_subinventory: usize,
) -> Self {
assert_inventory_length(bits.len());
let num_ones = bits.count_ones();
Self::_new(
bits,
num_ones,
log2_ones_per_inventory,
max_log2_words_per_subinventory,
)
}
pub fn with_overhead(
bits: B,
overhead_percentage: f64,
max_log2_words_per_subinv: usize,
) -> Self {
assert!(
overhead_percentage > 0.0,
"overhead_percentage must be positive"
);
let m = 1usize << max_log2_words_per_subinv;
let target_span =
((1 + m) as f64 * usize::BITS as f64 * 100.0 / overhead_percentage) as usize;
let min_span = m * (usize::BITS as usize * usize::BITS as usize) / 16;
Self::with_span(bits, target_span.max(min_span), max_log2_words_per_subinv)
}
fn _new(
bits: B,
num_ones: usize,
log2_ones_per_inventory: usize,
max_log2_words_per_subinventory: usize,
) -> Self {
assert_inventory_length(bits.len());
let num_bits = max(1, bits.len());
let ones_per_inventory = 1 << log2_ones_per_inventory;
let ones_per_inventory_mask = ones_per_inventory - 1;
let inventory_size = num_ones.div_ceil(ones_per_inventory);
let log2_words_per_subinventory =
max_log2_words_per_subinventory.min(log2_ones_per_inventory.saturating_sub(2));
let words_per_subinventory = 1 << log2_words_per_subinventory;
let words_per_inventory = words_per_subinventory + 1;
let log2_ones_per_sub16 = log2_ones_per_inventory
.saturating_sub(log2_words_per_subinventory + LOG2_U16_PER_USIZE);
let ones_per_sub16 = 1 << log2_ones_per_sub16;
let ones_per_sub16_mask = ones_per_sub16 - 1;
let inventory_words = inventory_size * words_per_inventory + 1;
let mut inventory: Vec<usize> = Vec::with_capacity(inventory_words);
let mut past_ones = 0;
let mut next_quantum = 0;
let mut spilled = 0;
let bits_per_word = B::Word::BITS as usize;
for (i, word) in bits.as_ref().iter().copied().enumerate() {
let ones_in_word = word.count_ones() as usize;
while past_ones + ones_in_word > next_quantum {
let in_word_index = word.select_in_word(next_quantum - past_ones);
let index = (i * bits_per_word) + in_word_index;
inventory.push(index);
inventory.resize(inventory.len() + words_per_subinventory, 0);
next_quantum += ones_per_inventory;
}
past_ones += ones_in_word;
}
assert_eq!(past_ones, num_ones);
inventory.push(num_bits);
assert_eq!(inventory.len(), inventory_words);
for (i, inv) in inventory[..inventory_size * words_per_inventory]
.iter()
.copied()
.step_by(words_per_inventory)
.enumerate()
{
let start = inv;
let span = inventory[i * words_per_inventory + words_per_inventory] - start;
past_ones = i * ones_per_inventory;
let ones = min(num_ones - past_ones, ones_per_inventory);
debug_assert!(start + span == num_bits || ones == ones_per_inventory);
match SpanType::from_span(span) {
SpanType::U32 => {
let log2_ones_per_sub32 = Self::log2_ones_per_sub32(span, log2_ones_per_sub16);
let num_u32s = ones.div_ceil(1 << log2_ones_per_sub32);
let num_words = num_u32s.div_ceil(U32_PER_USIZE);
let spilled_u64s = num_words.saturating_sub(words_per_subinventory - 1);
spilled += spilled_u64s;
}
#[cfg(target_pointer_width = "64")]
SpanType::U64 => {
spilled += (ones - 1).saturating_sub(words_per_subinventory - 1);
}
_ => {}
}
}
let spill_size = spilled;
let mut inventory: Box<[usize]> = inventory.into();
let mut spill: Box<[usize]> = vec![0; spill_size].into();
spilled = 0;
let locally_stored_u32s = U32_PER_USIZE * (words_per_subinventory - 1);
for inventory_idx in 0..inventory_size {
let start_inv_idx = inventory_idx * words_per_inventory;
let end_inv_idx = start_inv_idx + words_per_inventory;
let start_bit_idx = inventory[start_inv_idx];
let end_bit_idx = inventory[end_inv_idx];
let span = end_bit_idx - start_bit_idx;
let span_type = SpanType::from_span(span);
let mut past_ones = inventory_idx * ones_per_inventory;
let mut next_quantum = past_ones;
let log2_quantum;
match span_type {
SpanType::U16 => {
log2_quantum = log2_ones_per_sub16;
inventory[start_inv_idx].set_u16_span();
}
SpanType::U32 => {
log2_quantum = Self::log2_ones_per_sub32(span, log2_ones_per_sub16);
inventory[start_inv_idx].set_u32_span();
inventory[start_inv_idx + 1] = spilled;
}
#[cfg(target_pointer_width = "64")]
SpanType::U64 => {
log2_quantum = 0;
inventory[start_inv_idx].set_u64_span();
inventory[start_inv_idx + 1] = spilled;
}
}
let quantum = 1 << log2_quantum;
let mut subinventory_idx = 1;
next_quantum += quantum;
let mut word_idx = start_bit_idx / bits_per_word;
let end_word_idx = end_bit_idx.div_ceil(bits_per_word);
let bit_idx = start_bit_idx % bits_per_word;
let mut word = (bits.as_ref()[word_idx] >> bit_idx) << bit_idx;
'outer: loop {
let ones_in_word = word.count_ones() as usize;
while past_ones + ones_in_word > next_quantum {
debug_assert!(next_quantum <= end_bit_idx);
let in_word_index = word.select_in_word(next_quantum - past_ones);
let bit_index = (word_idx * bits_per_word) + in_word_index;
if bit_index >= end_bit_idx {
break 'outer;
}
let sub_offset = bit_index - start_bit_idx;
match span_type {
SpanType::U16 => {
let subinventory: &mut [u16] = unsafe {
inventory[start_inv_idx + 1..end_inv_idx].align_to_mut().1
};
subinventory[subinventory_idx] = sub_offset as u16;
subinventory_idx += 1;
if subinventory_idx << log2_quantum == ones_per_inventory {
break 'outer;
}
}
SpanType::U32 => {
if subinventory_idx < locally_stored_u32s {
let subinventory: &mut [u32] = unsafe {
inventory[start_inv_idx + 2..end_inv_idx].align_to_mut().1
};
debug_assert_eq!(subinventory[subinventory_idx], 0);
subinventory[subinventory_idx] = sub_offset as u32;
} else {
let u32_spill: &mut [u32] =
unsafe { spill[spilled..].align_to_mut().1 };
debug_assert_eq!(
u32_spill[subinventory_idx - locally_stored_u32s],
0
);
u32_spill[subinventory_idx - locally_stored_u32s] =
sub_offset as u32;
}
subinventory_idx += 1;
if subinventory_idx << log2_quantum == ones_per_inventory {
break 'outer;
}
}
#[cfg(target_pointer_width = "64")]
SpanType::U64 => {
if subinventory_idx < words_per_subinventory {
inventory[start_inv_idx + 1 + subinventory_idx] = bit_index;
subinventory_idx += 1;
} else {
assert!(spilled < spill_size);
spill[spilled] = bit_index;
spilled += 1;
}
if subinventory_idx == ones_per_inventory {
break 'outer;
}
}
}
next_quantum += quantum;
}
past_ones += ones_in_word;
word_idx += 1;
if word_idx == end_word_idx {
break;
}
word = bits.as_ref()[word_idx];
}
if span_type == SpanType::U32 {
spilled += subinventory_idx
.saturating_sub(locally_stored_u32s)
.div_ceil(U32_PER_USIZE);
}
}
assert_eq!(spilled, spill_size);
Self {
bits,
inventory,
spill,
log2_ones_per_inventory,
log2_ones_per_sub16,
log2_words_per_subinventory,
ones_per_inventory_mask,
ones_per_sub16_mask,
}
}
}
impl<
B: Backend<Word: Word + SelectInWord> + AsRef<[B::Word]> + BitLength + SelectHinted,
I: AsRef<[usize]>,
> SelectUnchecked for SelectAdapt<B, I>
{
unsafe fn select_unchecked(&self, rank: usize) -> usize {
unsafe {
let inventory = self.inventory.as_ref();
let inventory_index = rank >> self.log2_ones_per_inventory;
let inventory_start_pos =
(inventory_index << self.log2_words_per_subinventory) + inventory_index;
let inventory_rank = { *inventory.get_unchecked(inventory_start_pos) };
let subrank = rank & self.ones_per_inventory_mask;
if inventory_rank.is_u16_span() {
let subinventory = inventory
.get_unchecked(inventory_start_pos + 1..)
.align_to::<u16>()
.1;
debug_assert!(subrank >> self.log2_ones_per_sub16 < subinventory.len());
let hint_pos = inventory_rank
+ *subinventory.get_unchecked(subrank >> self.log2_ones_per_sub16) as usize;
let residual = subrank & self.ones_per_sub16_mask;
return self
.bits
.select_hinted::<{ usize::MAX }>(rank, hint_pos, rank - residual);
}
let words_per_subinventory = 1 << self.log2_words_per_subinventory;
if inventory_rank.is_u32_span() {
let inventory_rank = inventory_rank.get();
let span = (*inventory
.get_unchecked(inventory_start_pos + words_per_subinventory + 1))
.get()
- inventory_rank;
let log2_ones_per_sub32 = Self::log2_ones_per_sub32(span, self.log2_ones_per_sub16);
let hint_pos = if subrank >> log2_ones_per_sub32
< (words_per_subinventory - 1) * U32_PER_USIZE
{
let u32s = inventory
.get_unchecked(inventory_start_pos + 2..)
.align_to::<u32>()
.1;
inventory_rank + *u32s.get_unchecked(subrank >> log2_ones_per_sub32) as usize
} else {
let start_spill_idx = *inventory.get_unchecked(inventory_start_pos + 1);
let spilled_u32s = self
.spill
.as_ref()
.get_unchecked(start_spill_idx..)
.align_to::<u32>()
.1;
inventory_rank
+ *spilled_u32s.get_unchecked(
(subrank >> log2_ones_per_sub32)
- (words_per_subinventory - 1) * U32_PER_USIZE,
) as usize
};
let residual = subrank & ((1 << log2_ones_per_sub32) - 1);
return self
.bits
.select_hinted::<{ usize::MAX }>(rank, hint_pos, rank - residual);
}
#[cfg(target_pointer_width = "64")]
debug_assert!(inventory_rank.is_u64_span());
let inventory_rank = inventory_rank.get();
if subrank < words_per_subinventory {
if subrank == 0 {
return inventory_rank;
}
return *inventory.get_unchecked(inventory_start_pos + 1 + subrank);
}
let spill_idx = { *inventory.get_unchecked(inventory_start_pos + 1) } + subrank
- words_per_subinventory;
debug_assert!(spill_idx < self.spill.as_ref().len());
*self.spill.as_ref().get_unchecked(spill_idx)
}
}
}
impl<
B: Backend<Word: Word + SelectInWord> + SelectHinted + AsRef<[B::Word]> + NumBits,
I: AsRef<[usize]>,
> Select for SelectAdapt<B, I>
{
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bits::BitVec;
#[test]
#[should_panic(expected = "exceeds the maximum representable")]
fn test_max_length_panic() {
let too_long = MAX_INVENTORY_BITS + 1;
let bits = unsafe { BitVec::from_raw_parts(vec![0usize; 1], too_long) };
let _select = SelectAdapt::new(bits);
}
}
#[cfg(test)]
#[cfg(target_pointer_width = "64")]
mod tests_64 {
use std::collections::BTreeSet;
use super::*;
use crate::bits::BitVec;
use crate::traits::AddNumBits;
use crate::traits::BitVecOpsMut;
use rand::rngs::SmallRng;
use rand::{RngExt, SeedableRng};
#[test]
fn test_sub64s() {
let len = 5_000_000_000;
let mut rng = SmallRng::seed_from_u64(0);
let mut bits = BitVec::new(len);
let mut pos = BTreeSet::new();
for _ in 0..(1 << 13) / 4 * 3 {
let p = rng.random_range(0..len);
if pos.insert(p) {
bits.set(p, true);
}
}
let bits: AddNumBits<BitVec> = bits.into();
for m in [0, 3, 16] {
let simple = SelectAdapt::with_inv(&bits, 13, m);
assert!(simple.inventory[0].is_u64_span());
for (i, &p) in pos.iter().enumerate() {
assert_eq!(simple.select(i), Some(p));
}
assert_eq!(simple.select(pos.len()), None);
}
}
}