use std::collections::HashMap;
use std::io::{Seek, SeekFrom};
use anyhow::{Context as _, Result, anyhow, bail};
use noodles_core::Region;
use noodles_fasta as fasta;
use noodles_sam::Header;
pub(crate) const BASE_C: u8 = 2;
pub(crate) const BASE_G: u8 = 4;
pub(crate) const BASE_A: u8 = 1;
pub(crate) const BASE_T: u8 = 8;
pub(crate) const BASE_N: u8 = 15;
const REF_ASCII_TO_CODE: [u8; 256] = build_ref_codes();
const fn build_ref_codes() -> [u8; 256] {
let mut t = [BASE_N; 256];
t[b'A' as usize] = BASE_A;
t[b'a' as usize] = BASE_A;
t[b'C' as usize] = BASE_C;
t[b'c' as usize] = BASE_C;
t[b'G' as usize] = BASE_G;
t[b'g' as usize] = BASE_G;
t[b'T' as usize] = BASE_T;
t[b't' as usize] = BASE_T;
t
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum Context {
CpA,
CpC,
CpG,
CpT,
}
impl Context {
pub(crate) const ALL: [Context; 4] = [Context::CpA, Context::CpC, Context::CpG, Context::CpT];
#[inline]
#[must_use]
pub(crate) fn index(self) -> usize {
match self {
Context::CpA => 0,
Context::CpC => 1,
Context::CpG => 2,
Context::CpT => 3,
}
}
}
#[inline]
#[must_use]
pub(crate) fn top_context(next_code: u8) -> Option<Context> {
match next_code {
BASE_A => Some(Context::CpA),
BASE_C => Some(Context::CpC),
BASE_G => Some(Context::CpG),
BASE_T => Some(Context::CpT),
_ => None,
}
}
#[inline]
#[must_use]
pub(crate) fn bottom_context(prev_code: u8) -> Option<Context> {
match prev_code {
BASE_C => Some(Context::CpG),
BASE_T => Some(Context::CpA),
BASE_A => Some(Context::CpT),
BASE_G => Some(Context::CpC),
_ => None,
}
}
pub(crate) trait RefCodes {
fn len(&self) -> usize;
fn code(&self, pos: usize) -> u8;
#[inline]
fn monitors(&self, pos: usize, code: u8) -> bool {
self.code(pos) == code
}
#[inline]
fn ctx_top(&self, pos: usize) -> Option<Context> {
top_context(self.code(pos))
}
#[inline]
fn ctx_bottom(&self, pos: usize) -> Option<Context> {
bottom_context(self.code(pos))
}
}
#[derive(Clone, Copy)]
pub(crate) struct ByteCodes<'a>(pub(crate) &'a [u8]);
impl RefCodes for ByteCodes<'_> {
#[inline]
fn len(&self) -> usize {
self.0.len()
}
#[inline]
fn code(&self, pos: usize) -> u8 {
self.0[pos]
}
}
#[derive(Clone, Copy)]
pub(crate) struct NibbleCodes<'a> {
data: &'a [u8],
len: usize,
}
impl RefCodes for NibbleCodes<'_> {
#[inline]
fn len(&self) -> usize {
self.len
}
#[inline]
fn code(&self, pos: usize) -> u8 {
let byte = self.data[pos >> 1];
if pos & 1 == 0 { byte >> 4 } else { byte & 0x0F }
}
}
#[derive(Clone, Copy)]
pub(crate) struct TwoBitCodes<'a> {
data: &'a [u8],
len: usize,
}
impl RefCodes for TwoBitCodes<'_> {
#[inline]
fn len(&self) -> usize {
self.len
}
#[inline]
fn code(&self, pos: usize) -> u8 {
let val = (self.data[pos >> 2] >> ((pos & 3) * 2)) & 0x3;
1u8 << val
}
#[inline]
fn monitors(&self, pos: usize, code: u8) -> bool {
let val = (self.data[pos >> 2] >> ((pos & 3) * 2)) & 0x3;
val == code.trailing_zeros() as u8
}
#[inline]
fn ctx_top(&self, pos: usize) -> Option<Context> {
let val = (self.data[pos >> 2] >> ((pos & 3) * 2)) & 0x3;
Some(Context::ALL[val as usize])
}
#[inline]
fn ctx_bottom(&self, pos: usize) -> Option<Context> {
let val = (self.data[pos >> 2] >> ((pos & 3) * 2)) & 0x3;
Some(Context::ALL[3 - val as usize])
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RefEncoding {
Bytes,
Nibble,
TwoBit,
}
pub(crate) struct PackedContig {
data: Vec<u8>,
len: usize,
}
#[inline]
fn nibble_to_2bit(code: u8) -> u8 {
match code {
BASE_C => 1,
BASE_G => 2,
BASE_T => 3,
_ => 0, }
}
fn pack_nibble(codes: &[u8]) -> PackedContig {
let mut data = vec![0u8; codes.len().div_ceil(2)];
for (i, &c) in codes.iter().enumerate() {
if i & 1 == 0 {
data[i >> 1] = c << 4;
} else {
data[i >> 1] |= c & 0x0F;
}
}
PackedContig { data, len: codes.len() }
}
fn pack_twobit(codes: &[u8]) -> PackedContig {
let mut data = vec![0u8; codes.len().div_ceil(4)];
for (i, &c) in codes.iter().enumerate() {
data[i >> 2] |= nibble_to_2bit(c) << ((i & 3) * 2);
}
PackedContig { data, len: codes.len() }
}
pub(crate) enum Reference {
Bytes(Vec<Vec<u8>>),
Nibble(Vec<PackedContig>),
TwoBit(Vec<PackedContig>),
}
impl Reference {
pub(crate) fn load(
path: &std::path::Path,
header: &Header,
encoding: RefEncoding,
) -> Result<Self> {
let mut reader =
fasta::io::indexed_reader::Builder::default().build_from_path(path).with_context(
|| format!("opening indexed FASTA {} (is there a .fai?)", path.display()),
)?;
let fai_lengths: HashMap<String, u64> = reader
.index()
.as_ref()
.iter()
.map(|rec| (String::from_utf8_lossy(rec.name().as_ref()).into_owned(), rec.length()))
.collect();
let n = header.reference_sequences().len();
let mut bytes_v: Vec<Vec<u8>> =
if encoding == RefEncoding::Bytes { Vec::with_capacity(n) } else { Vec::new() };
let mut packed_v: Vec<PackedContig> =
if encoding == RefEncoding::Bytes { Vec::new() } else { Vec::with_capacity(n) };
for (name, map) in header.reference_sequences() {
let name_str = std::str::from_utf8(name.as_ref())
.map_err(|_| anyhow!("BAM @SQ name is not valid UTF-8"))?;
let bam_len = usize::from(map.length()) as u64;
match fai_lengths.get(name_str) {
None => bail!(
"BAM contig '{name_str}' is not present in the reference FASTA index \
({}). Every @SQ contig must exist in the reference.",
path.display()
),
Some(&fai_len) if fai_len != bam_len => bail!(
"Length mismatch for contig '{name_str}': BAM @SQ says {bam_len} bp but the \
reference FASTA says {fai_len} bp. The BAM was aligned against a different \
reference."
),
Some(_) => {}
}
let region = Region::new(name_str, ..);
let offset = reader.index().query(®ion).with_context(|| {
format!("looking up contig '{name_str}' in {}.fai", path.display())
})?;
reader
.get_mut()
.seek(SeekFrom::Start(offset))
.with_context(|| format!("seeking to contig '{name_str}' in {}", path.display()))?;
let mut raw = Vec::with_capacity(bam_len as usize);
reader
.read_sequence(&mut raw)
.with_context(|| format!("reading contig '{name_str}' from {}", path.display()))?;
let codes: Vec<u8> = raw.iter().map(|&b| REF_ASCII_TO_CODE[b as usize]).collect();
match encoding {
RefEncoding::Bytes => bytes_v.push(codes),
RefEncoding::Nibble => packed_v.push(pack_nibble(&codes)),
RefEncoding::TwoBit => packed_v.push(pack_twobit(&codes)),
}
}
Ok(match encoding {
RefEncoding::Bytes => Reference::Bytes(bytes_v),
RefEncoding::Nibble => Reference::Nibble(packed_v),
RefEncoding::TwoBit => Reference::TwoBit(packed_v),
})
}
#[must_use]
pub(crate) fn encoding(&self) -> RefEncoding {
match self {
Reference::Bytes(_) => RefEncoding::Bytes,
Reference::Nibble(_) => RefEncoding::Nibble,
Reference::TwoBit(_) => RefEncoding::TwoBit,
}
}
#[inline]
#[must_use]
pub(crate) fn byte_codes(&self, tid: i32) -> Option<ByteCodes<'_>> {
match self {
Reference::Bytes(v) if tid >= 0 => v.get(tid as usize).map(|c| ByteCodes(c)),
_ => None,
}
}
#[inline]
#[must_use]
pub(crate) fn nibble_codes(&self, tid: i32) -> Option<NibbleCodes<'_>> {
match self {
Reference::Nibble(v) if tid >= 0 => {
v.get(tid as usize).map(|c| NibbleCodes { data: &c.data, len: c.len })
}
_ => None,
}
}
#[inline]
#[must_use]
pub(crate) fn twobit_codes(&self, tid: i32) -> Option<TwoBitCodes<'_>> {
match self {
Reference::TwoBit(v) if tid >= 0 => {
v.get(tid as usize).map(|c| TwoBitCodes { data: &c.data, len: c.len })
}
_ => None,
}
}
#[cfg(test)]
#[must_use]
pub(crate) fn from_encoded_contigs(contigs: Vec<Vec<u8>>) -> Self {
Reference::Bytes(contigs)
}
}
#[cfg(test)]
#[must_use]
pub(crate) fn encode_ref_base(ascii: u8) -> u8 {
REF_ASCII_TO_CODE[ascii as usize]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ascii_to_4bit_codes() {
assert_eq!(encode_ref_base(b'C'), BASE_C);
assert_eq!(encode_ref_base(b'c'), BASE_C);
assert_eq!(encode_ref_base(b'G'), BASE_G);
assert_eq!(encode_ref_base(b'A'), BASE_A);
assert_eq!(encode_ref_base(b'T'), BASE_T);
assert_eq!(encode_ref_base(b'N'), BASE_N);
assert_eq!(encode_ref_base(b'R'), BASE_N); }
#[test]
fn top_context_buckets_by_next_base() {
assert_eq!(top_context(BASE_A), Some(Context::CpA));
assert_eq!(top_context(BASE_C), Some(Context::CpC));
assert_eq!(top_context(BASE_G), Some(Context::CpG));
assert_eq!(top_context(BASE_T), Some(Context::CpT));
assert_eq!(top_context(BASE_N), None);
}
#[test]
fn bottom_context_uses_complement_of_prev_base() {
assert_eq!(bottom_context(BASE_C), Some(Context::CpG));
assert_eq!(bottom_context(BASE_T), Some(Context::CpA));
assert_eq!(bottom_context(BASE_A), Some(Context::CpT));
assert_eq!(bottom_context(BASE_G), Some(Context::CpC));
assert_eq!(bottom_context(BASE_N), None);
}
#[test]
fn nibble_packing_round_trips_all_codes() {
let codes: Vec<u8> = "CAGTNCCGTA".bytes().map(encode_ref_base).collect();
let packed = pack_nibble(&codes);
let view = NibbleCodes { data: &packed.data, len: packed.len };
assert_eq!(view.len(), codes.len());
for (i, &c) in codes.iter().enumerate() {
assert_eq!(view.code(i), c, "nibble mismatch at {i}");
}
assert_eq!(packed.data.len(), codes.len().div_ceil(2));
}
#[test]
fn twobit_packing_round_trips_acgt_and_folds_n_to_a() {
let codes: Vec<u8> = "CAGTNCCGTA".bytes().map(encode_ref_base).collect();
let packed = pack_twobit(&codes);
let view = TwoBitCodes { data: &packed.data, len: packed.len };
assert_eq!(view.len(), codes.len());
for (i, &c) in codes.iter().enumerate() {
let expected = if c == BASE_N { BASE_A } else { c }; assert_eq!(view.code(i), expected, "2-bit mismatch at {i}");
}
assert_eq!(packed.data.len(), codes.len().div_ceil(4));
assert_ne!(view.code(4), BASE_C);
assert_ne!(view.code(4), BASE_G);
}
#[test]
fn twobit_monitors_matches_code_compare() {
let codes: Vec<u8> = "CACGCATTGCGNCAGTACG".bytes().map(encode_ref_base).collect();
let packed = pack_twobit(&codes);
let view = TwoBitCodes { data: &packed.data, len: packed.len };
for i in 0..view.len() {
for &code in &[BASE_C, BASE_G] {
assert_eq!(view.monitors(i, code), view.code(i) == code, "pos {i} code {code}");
}
}
}
#[test]
fn context_index_is_stable() {
assert_eq!(Context::CpA.index(), 0);
assert_eq!(Context::CpC.index(), 1);
assert_eq!(Context::CpG.index(), 2);
assert_eq!(Context::CpT.index(), 3);
}
}