quadrank 0.2.0

Fast rank over binary and size-4 DNA alphabets.
Documentation
use std::{iter::zip, mem::MaybeUninit};

use prefetch_index::prefetch_index;

use super::{BasicBlock, BiRanker, SuperBlock};
use rayon::prelude::*;

/// Rank queries over binary alphabet.
///
/// Supports various block and superblock implementations.
/// Has 3.28% space overhead by default.
pub struct BiRank<
    BB: BasicBlock = super::blocks::BinaryBlock16,
    SB: SuperBlock<BB> = super::super_blocks::ShiftSB,
> {
    basic_blocks: Vec<BB>,
    super_blocks: Vec<SB>,
}

impl<BB: BasicBlock, SB: SuperBlock<BB>> BiRanker for BiRank<BB, SB> {
    fn new_packed(seq: &[u64]) -> Self {
        let (head, seq, tail) = unsafe { seq.align_to::<u8>() };
        assert!(head.is_empty());
        assert!(tail.is_empty());

        let add_block = seq.len() % BB::B == 0;
        let add_superblock = seq.len() % (SB::BYTES_PER_SUPERBLOCK) == 0;

        let n_blocks = seq.len().div_ceil(BB::B) + (add_block as usize);

        // 1. Count ones in each superblock.
        let mut sb_offsets: Vec<u64> = seq
            .par_chunks(SB::BYTES_PER_SUPERBLOCK)
            .map(|slice| slice.iter().map(|&b| b.count_ones() as u64).sum())
            .collect();

        if add_superblock {
            sb_offsets.push(0);
        }

        // 2. Accumulate to get superblock offsets.
        {
            let mut sum = 0;
            for i in 0..sb_offsets.len() {
                let cnt = sb_offsets[i];
                sb_offsets[i] = sum;
                sum += cnt;
            }
        }

        // 3. Allocate space for blocks.
        let mut blocks = vec![];
        blocks.resize_with(n_blocks, MaybeUninit::<BB>::uninit);

        let sb_chunks = seq.par_chunks(SB::BYTES_PER_SUPERBLOCK);
        let mut super_blocks = sb_chunks
            .zip(&sb_offsets)
            .zip(blocks.par_chunks_mut(SB::BLOCKS_PER_SUPERBLOCK))
            .map(|((sb_chunk, &sb_offset), blocks)| {
                let sb = SB::new(sb_offset, sb_chunk);

                let bb_chunks = sb_chunk.chunks(BB::B);
                let num_chunks = bb_chunks.len();
                let mut delta = 0u64;

                for (i, (block, bb_chunk)) in zip(blocks.iter_mut(), bb_chunks).enumerate() {
                    // This must be wrapping since `get_for_block` can return negative values.
                    let remaining_delta = (sb_offset + delta).wrapping_sub(sb.get(i));

                    let mut bb_chunk_buffer = vec![];
                    let bb_chunk = if bb_chunk.len() == BB::B {
                        bb_chunk
                    } else {
                        bb_chunk_buffer.resize(BB::B, 0u8);
                        bb_chunk_buffer[..bb_chunk.len()].copy_from_slice(bb_chunk);
                        bb_chunk_buffer[bb_chunk.len()..].fill(0);
                        &bb_chunk_buffer
                    };

                    block.write(BB::new(remaining_delta, bb_chunk));

                    let count = bb_chunk.iter().map(|&b| b.count_ones() as u64).sum::<u64>();
                    delta += count;
                }

                // If this is the last (incomplete) superblock, and it spans
                // exactly a full number of blocks, add one extra block.
                if blocks.len() > num_chunks {
                    assert_eq!(blocks.len(), num_chunks + 1);
                    let i = num_chunks;
                    let remaining_delta = (sb_offset + delta).wrapping_sub(sb.get(i));
                    blocks[i].write(BB::new(remaining_delta, &vec![0u8; BB::B]));
                }

                sb
            })
            .collect::<Vec<_>>();

        // Handle edge case where we need to push an additional superblock and block to handle queries exactly at the end.
        if add_superblock {
            let sb_offset = *sb_offsets.last().unwrap();
            super_blocks.push(SB::new(sb_offset, &[]));
            let sb = super_blocks.last().unwrap();
            let remaining_delta = sb_offset.wrapping_sub(sb.get(0));
            blocks
                .last_mut()
                .unwrap()
                .write(BB::new(remaining_delta, &vec![0u8; BB::B]));
        }

        Self {
            basic_blocks: unsafe { std::mem::transmute::<Vec<MaybeUninit<BB>>, Vec<BB>>(blocks) },
            super_blocks,
        }
    }

    const HAS_PREFETCH: bool = true;

    #[inline(always)]
    fn prefetch(&self, pos: usize) {
        let block_idx = pos / BB::N;
        prefetch_index(&self.basic_blocks, block_idx);
        // Prefetch superblocks if they (potentially) do not fit in L1.
        if BB::W < 64 {
            let long_pos = block_idx / SB::BLOCKS_PER_SUPERBLOCK;
            prefetch_index(&self.super_blocks, long_pos);
        }
    }

    fn size(&self) -> usize {
        self.basic_blocks.len() * size_of::<BB>() + self.super_blocks.len() * size_of::<SB>()
    }

    #[inline(always)]
    unsafe fn rank_unchecked(&self, mut pos: usize) -> u64 {
        unsafe {
            if BB::INCLUSIVE {
                if pos == 0 {
                    return 0;
                }
                pos -= 1;
            }

            let block_idx = pos / BB::N;
            let block_pos = pos % BB::N;
            debug_assert!(block_idx < self.basic_blocks.len());
            let mut rank = self.basic_blocks.get_unchecked(block_idx).rank(block_pos);
            if BB::W < 64 {
                let long_pos = block_idx / SB::BLOCKS_PER_SUPERBLOCK;
                let long_rank = self
                    .super_blocks
                    .get_unchecked(long_pos)
                    .get(block_idx % SB::BLOCKS_PER_SUPERBLOCK);
                rank = rank.wrapping_add(long_rank);
            }
            rank
        }
    }
}