use std::{iter::zip, mem::MaybeUninit};
use prefetch_index::prefetch_index;
use super::{BasicBlock, BiRanker, SuperBlock};
use rayon::prelude::*;
pub struct BiRank<
BB: BasicBlock = super::blocks::BinaryBlock16,
SB: SuperBlock<BB> = super::super_blocks::ShiftSB,
> {
basic_blocks: Vec<BB>,
super_blocks: Vec<SB>,
}
impl<BB: BasicBlock, SB: SuperBlock<BB>> BiRanker for BiRank<BB, SB> {
fn new_packed(seq: &[u64]) -> Self {
let (head, seq, tail) = unsafe { seq.align_to::<u8>() };
assert!(head.is_empty());
assert!(tail.is_empty());
let add_block = seq.len() % BB::B == 0;
let add_superblock = seq.len() % (SB::BYTES_PER_SUPERBLOCK) == 0;
let n_blocks = seq.len().div_ceil(BB::B) + (add_block as usize);
let mut sb_offsets: Vec<u64> = seq
.par_chunks(SB::BYTES_PER_SUPERBLOCK)
.map(|slice| slice.iter().map(|&b| b.count_ones() as u64).sum())
.collect();
if add_superblock {
sb_offsets.push(0);
}
{
let mut sum = 0;
for i in 0..sb_offsets.len() {
let cnt = sb_offsets[i];
sb_offsets[i] = sum;
sum += cnt;
}
}
let mut blocks = vec![];
blocks.resize_with(n_blocks, MaybeUninit::<BB>::uninit);
let sb_chunks = seq.par_chunks(SB::BYTES_PER_SUPERBLOCK);
let mut super_blocks = sb_chunks
.zip(&sb_offsets)
.zip(blocks.par_chunks_mut(SB::BLOCKS_PER_SUPERBLOCK))
.map(|((sb_chunk, &sb_offset), blocks)| {
let sb = SB::new(sb_offset, sb_chunk);
let bb_chunks = sb_chunk.chunks(BB::B);
let num_chunks = bb_chunks.len();
let mut delta = 0u64;
for (i, (block, bb_chunk)) in zip(blocks.iter_mut(), bb_chunks).enumerate() {
let remaining_delta = (sb_offset + delta).wrapping_sub(sb.get(i));
let mut bb_chunk_buffer = vec![];
let bb_chunk = if bb_chunk.len() == BB::B {
bb_chunk
} else {
bb_chunk_buffer.resize(BB::B, 0u8);
bb_chunk_buffer[..bb_chunk.len()].copy_from_slice(bb_chunk);
bb_chunk_buffer[bb_chunk.len()..].fill(0);
&bb_chunk_buffer
};
block.write(BB::new(remaining_delta, bb_chunk));
let count = bb_chunk.iter().map(|&b| b.count_ones() as u64).sum::<u64>();
delta += count;
}
if blocks.len() > num_chunks {
assert_eq!(blocks.len(), num_chunks + 1);
let i = num_chunks;
let remaining_delta = (sb_offset + delta).wrapping_sub(sb.get(i));
blocks[i].write(BB::new(remaining_delta, &vec![0u8; BB::B]));
}
sb
})
.collect::<Vec<_>>();
if add_superblock {
let sb_offset = *sb_offsets.last().unwrap();
super_blocks.push(SB::new(sb_offset, &[]));
let sb = super_blocks.last().unwrap();
let remaining_delta = sb_offset.wrapping_sub(sb.get(0));
blocks
.last_mut()
.unwrap()
.write(BB::new(remaining_delta, &vec![0u8; BB::B]));
}
Self {
basic_blocks: unsafe { std::mem::transmute::<Vec<MaybeUninit<BB>>, Vec<BB>>(blocks) },
super_blocks,
}
}
const HAS_PREFETCH: bool = true;
#[inline(always)]
fn prefetch(&self, pos: usize) {
let block_idx = pos / BB::N;
prefetch_index(&self.basic_blocks, block_idx);
if BB::W < 64 {
let long_pos = block_idx / SB::BLOCKS_PER_SUPERBLOCK;
prefetch_index(&self.super_blocks, long_pos);
}
}
fn size(&self) -> usize {
self.basic_blocks.len() * size_of::<BB>() + self.super_blocks.len() * size_of::<SB>()
}
#[inline(always)]
unsafe fn rank_unchecked(&self, mut pos: usize) -> u64 {
unsafe {
if BB::INCLUSIVE {
if pos == 0 {
return 0;
}
pos -= 1;
}
let block_idx = pos / BB::N;
let block_pos = pos % BB::N;
debug_assert!(block_idx < self.basic_blocks.len());
let mut rank = self.basic_blocks.get_unchecked(block_idx).rank(block_pos);
if BB::W < 64 {
let long_pos = block_idx / SB::BLOCKS_PER_SUPERBLOCK;
let long_rank = self
.super_blocks
.get_unchecked(long_pos)
.get(block_idx % SB::BLOCKS_PER_SUPERBLOCK);
rank = rank.wrapping_add(long_rank);
}
rank
}
}
}