use crate::genomics::{BaseCode, FMInterval, FmSymbol};
pub(crate) trait BwtBacking {
fn bwt_len(&self) -> usize;
fn block_size(&self) -> usize;
fn num_blocks(&self) -> usize;
fn sentinel_pos(&self) -> usize;
fn sample_rate(&self) -> usize;
fn c_table(&self) -> [u32; 6];
fn boundary_base(&self, block_idx: usize, base_index: usize) -> u32;
fn boundary_sentinel(&self, block_idx: usize) -> u32;
fn block_rank(&self, block_idx: usize, symbol: FmSymbol, within: usize) -> u32;
fn block_symbol(&self, block_idx: usize, within: usize) -> FmSymbol;
fn sampled_at(&self, index: usize) -> Option<u32>;
}
pub(crate) fn rank<B: BwtBacking>(b: &B, symbol: FmSymbol, position: usize) -> u32 {
let bounded = position.min(b.bwt_len());
let block_idx = bounded / b.block_size();
let mut count = match symbol {
FmSymbol::Sentinel => b.boundary_sentinel(block_idx),
FmSymbol::Base(code) => b.boundary_base(block_idx, code.index()),
};
if block_idx < b.num_blocks() {
let within = bounded - block_idx * b.block_size();
count += b.block_rank(block_idx, symbol, within);
}
count
}
pub(crate) fn total<B: BwtBacking>(b: &B, symbol: FmSymbol) -> u32 {
match symbol {
FmSymbol::Sentinel => 1,
FmSymbol::Base(code) => b.boundary_base(b.num_blocks(), code.index()),
}
}
pub(crate) fn symbol_at<B: BwtBacking>(b: &B, index: usize) -> FmSymbol {
assert!(index < b.bwt_len(), "BWT index out of range");
if index == b.sentinel_pos() {
return FmSymbol::Sentinel;
}
let block_idx = index / b.block_size();
let within = index - block_idx * b.block_size();
b.block_symbol(block_idx, within)
}
pub(crate) fn lf_index<B: BwtBacking>(b: &B, index: usize) -> usize {
let symbol = symbol_at(b, index);
let occ_inclusive = rank(b, symbol, index + 1);
let c_row = b.c_table()[symbol.order()] as usize;
c_row + occ_inclusive as usize - 1
}
pub(crate) fn backward_search<B: BwtBacking>(b: &B, pattern: &[u8]) -> FMInterval {
let mut interval = FMInterval::full(b.bwt_len());
for &ch in pattern.iter().rev() {
let base_code = match BaseCode::from_ascii(ch) {
Some(code) => code,
None => return FMInterval { lower: 0, upper: 0 },
};
let symbol = FmSymbol::Base(base_code);
let c_row = b.c_table()[symbol.order()];
let new_lower = c_row + rank(b, symbol, interval.lower as usize);
let new_upper = c_row + rank(b, symbol, interval.upper as usize);
interval = FMInterval {
lower: new_lower,
upper: new_upper,
};
if interval.is_empty() {
break;
}
}
interval
}
pub(crate) fn sa_at<B: BwtBacking>(b: &B, index: usize) -> usize {
assert!(index < b.bwt_len(), "BWT index out of range");
let mut current = index;
let mut lf_steps = 0usize;
loop {
if let Some(sampled) = b.sampled_at(current) {
return sampled as usize + lf_steps;
}
current = lf_index(b, current);
lf_steps += 1;
debug_assert!(
lf_steps <= b.sample_rate() + 1,
"LF steps exceeded sample rate; sampling invariant violated"
);
}
}
pub(crate) fn locate_interval<B: BwtBacking>(
b: &B,
interval: FMInterval,
max_hits: usize,
) -> Vec<u32> {
let max_hits = max_hits.max(1);
let mut out = Vec::new();
let reference_len = b.bwt_len().saturating_sub(1);
let lower = interval.lower as usize;
let upper = interval.upper as usize;
for bwt_idx in lower..upper {
if out.len() >= max_hits {
break;
}
let sa = sa_at(b, bwt_idx);
if sa < reference_len {
out.push(sa as u32);
}
}
out
}