use crate::genomics::CompressedDNA;
pub const ALPHABET_SIZE: usize = 5;
pub const CHECKPOINT_STRIDE: usize = 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BaseCode {
A = 0,
C = 1,
G = 2,
T = 3,
N = 4,
}
impl BaseCode {
pub fn from_ascii(base: u8) -> Option<Self> {
match base {
b'A' | b'a' => Some(BaseCode::A),
b'C' | b'c' => Some(BaseCode::C),
b'G' | b'g' => Some(BaseCode::G),
b'T' | b't' | b'U' | b'u' => Some(BaseCode::T),
b'N' | b'n' => Some(BaseCode::N),
_ => None,
}
}
#[inline]
pub fn index(self) -> usize {
self as usize
}
}
#[derive(Debug, Clone)]
pub struct RankSelectCheckpoint {
pub position: usize,
pub counts: [u32; ALPHABET_SIZE],
}
impl RankSelectCheckpoint {
#[allow(dead_code)]
fn new(position: usize, counts: [u32; ALPHABET_SIZE]) -> Self {
Self { position, counts }
}
}
#[derive(Debug, Clone)]
pub struct RankSelectIndex {
stride: usize, bitvectors: [Vec<u64>; ALPHABET_SIZE],
superblocks: [Vec<u32>; ALPHABET_SIZE], totals: [u32; ALPHABET_SIZE],
len: usize,
}
impl RankSelectIndex {
pub fn build(sequence: &CompressedDNA) -> Self {
Self::build_with_stride(sequence, CHECKPOINT_STRIDE)
}
pub fn build_with_stride(sequence: &CompressedDNA, stride: usize) -> Self {
let stride = stride.max(1);
let len = sequence.len();
let word_len = words_for_bits(len);
let mut bitvectors: [Vec<u64>; ALPHABET_SIZE] =
std::array::from_fn(|_| vec![0u64; word_len]);
let mut totals = [0u32; ALPHABET_SIZE];
for idx in 0..len {
let base = sequence.base_at(idx).unwrap_or(b'N');
let symbol = BaseCode::from_ascii(base).unwrap_or(BaseCode::N);
totals[symbol.index()] += 1;
let (word_idx, bit_mask) = bit_position(idx);
bitvectors[symbol.index()][word_idx] |= bit_mask;
}
let num_superblocks = len.div_ceil(stride);
let mut superblocks: [Vec<u32>; ALPHABET_SIZE] =
std::array::from_fn(|_| Vec::with_capacity(num_superblocks + 1));
for sb_prefix in superblocks.iter_mut() {
sb_prefix.push(0);
}
for sb in 0..num_superblocks {
let start = sb * stride;
let end = (start + stride).min(len);
for b in 0..ALPHABET_SIZE {
let prev = *superblocks[b].last().unwrap();
let add = popcount_range(&bitvectors[b], start, end);
superblocks[b].push(prev + add);
}
}
Self {
stride,
bitvectors,
superblocks,
totals,
len,
}
}
pub fn stride(&self) -> usize {
self.stride
}
pub fn totals(&self) -> [u32; ALPHABET_SIZE] {
self.totals
}
pub(crate) fn bitvectors(&self) -> &[Vec<u64>; ALPHABET_SIZE] {
&self.bitvectors
}
pub(crate) fn superblocks(&self) -> &[Vec<u32>; ALPHABET_SIZE] {
&self.superblocks
}
pub fn rank(&self, sequence: &CompressedDNA, base: BaseCode, position: usize) -> u32 {
debug_assert_eq!(
sequence.len(),
self.len,
"RankSelectIndex must be queried with the same sequence it was built for"
);
let bounded = position.min(self.len);
let sb = bounded / self.stride;
let within_start = sb * self.stride;
let base_idx = base.index();
let prefix = self.superblocks[base_idx][sb];
let within = popcount_range(&self.bitvectors[base_idx], within_start, bounded);
prefix + within
}
pub fn rank_all(&self, sequence: &CompressedDNA, position: usize) -> [u32; ALPHABET_SIZE] {
debug_assert_eq!(
sequence.len(),
self.len,
"RankSelectIndex must be queried with the same sequence it was built for"
);
let bounded = position.min(self.len);
let sb = bounded / self.stride;
let within_start = sb * self.stride;
let mut out = [0u32; ALPHABET_SIZE];
for (b, slot) in out.iter_mut().enumerate() {
let prefix = self.superblocks[b][sb];
let within = popcount_range(&self.bitvectors[b], within_start, bounded);
*slot = prefix + within;
}
out
}
}
#[inline]
fn words_for_bits(bits: usize) -> usize {
if bits == 0 {
0
} else {
bits.div_ceil(64)
}
}
#[inline]
fn bit_position(idx: usize) -> (usize, u64) {
let word_idx = idx / 64;
let bit_idx = idx % 64;
(word_idx, 1u64 << bit_idx)
}
#[inline]
pub(crate) fn popcount_range(words: &[u64], start: usize, end: usize) -> u32 {
if end <= start {
return 0;
}
let start_word = start / 64;
let end_word = (end - 1) / 64;
let start_bit = start % 64;
let end_bit = end % 64;
let mut count = 0u32;
if start_word == end_word {
let mut w = words[start_word];
w &= !((1u64 << start_bit) - 1);
if end_bit != 0 {
w &= (1u64 << end_bit) - 1;
}
count += w.count_ones();
return count;
}
let first = words[start_word] & !((1u64 << start_bit) - 1);
count += first.count_ones();
for w in &words[start_word + 1..end_word] {
count += w.count_ones();
}
let mut last = words[end_word];
if end_bit != 0 {
last &= (1u64 << end_bit) - 1;
}
count += last.count_ones();
count
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rank_queries_match_naive_counts() {
let seq = b"AAACCCGGGTTTNNNAAGT";
let compressed = CompressedDNA::compress(seq).unwrap();
let index = RankSelectIndex::build_with_stride(&compressed, 4);
for pos in 0..=seq.len() {
for &base in &[
BaseCode::A,
BaseCode::C,
BaseCode::G,
BaseCode::T,
BaseCode::N,
] {
let naive = seq[..pos]
.iter()
.filter(|&&b| BaseCode::from_ascii(b).unwrap_or(BaseCode::N) == base)
.count() as u32;
assert_eq!(index.rank(&compressed, base, pos), naive);
}
}
}
#[test]
fn rank_all_returns_expected_counts() {
let seq = b"ATCGNNATCG";
let compressed = CompressedDNA::compress(seq).unwrap();
let index = RankSelectIndex::build(&compressed);
for pos in 0..=seq.len() {
let counts = index.rank_all(&compressed, pos);
let naive = [
seq[..pos]
.iter()
.filter(|&&b| b == b'A' || b == b'a')
.count() as u32,
seq[..pos]
.iter()
.filter(|&&b| b == b'C' || b == b'c')
.count() as u32,
seq[..pos]
.iter()
.filter(|&&b| b == b'G' || b == b'g')
.count() as u32,
seq[..pos]
.iter()
.filter(|&&b| b == b'T' || b == b't' || b == b'U' || b == b'u')
.count() as u32,
seq[..pos]
.iter()
.filter(|&&b| b == b'N' || b == b'n')
.count() as u32,
];
assert_eq!(counts, naive);
}
}
#[test]
fn totals_match_full_sequence() {
let seq = b"AACCGGTTNN";
let compressed = CompressedDNA::compress(seq).unwrap();
let index = RankSelectIndex::build(&compressed);
assert_eq!(
index.totals(),
[
2, 2, 2, 2, 2, ]
);
}
}