const ALN_DASH: usize = 4;
const ALN_SOFT_CLIP: usize = 5;
const ALN_HARD_CLIP: usize = 6;
const ALN_PAD: usize = 7;
const ALN_REF_SKIP: usize = 8;
const NUM_STATES: usize = 9;
const NUM_ALN_STATES: usize = 82;
const START_STATE: usize = 81;
#[inline]
fn ref_2bit(b: u8) -> usize {
match b {
b'A' | b'a' => 0,
b'C' | b'c' => 1,
b'G' | b'g' => 2,
b'T' | b't' => 3,
_ => 0,
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AlnOp {
Match, SeqMatch, SeqMismatch, Ins, Del, RefSkip, SoftClip, HardClip, Pad, }
impl AlnOp {
#[inline]
fn consume_seq(self) -> bool {
matches!(
self,
AlnOp::Match | AlnOp::SeqMatch | AlnOp::SeqMismatch | AlnOp::Ins | AlnOp::SoftClip
)
}
#[inline]
fn consume_ref(self) -> bool {
matches!(
self,
AlnOp::Match | AlnOp::SeqMatch | AlnOp::SeqMismatch | AlnOp::Del | AlnOp::RefSkip
)
}
#[inline]
fn set_bases(self, cur_ref: &mut usize, cur_read: &mut usize) {
match self {
AlnOp::Ins => *cur_ref = ALN_DASH,
AlnOp::Del => *cur_read = ALN_DASH,
AlnOp::RefSkip => *cur_read = ALN_REF_SKIP,
AlnOp::SoftClip => *cur_ref = ALN_SOFT_CLIP,
AlnOp::HardClip => {
*cur_ref = ALN_HARD_CLIP;
*cur_read = ALN_HARD_CLIP;
}
AlnOp::Pad => {
*cur_ref = ALN_PAD;
*cur_read = ALN_PAD;
}
_ => {}
}
}
}
#[inline]
fn log_add(a: f64, b: f64) -> f64 {
if a == f64::NEG_INFINITY {
return b;
}
if b == f64::NEG_INFINITY {
return a;
}
let (hi, lo) = if a > b { (a, b) } else { (b, a) };
hi + (lo - hi).exp().ln_1p()
}
#[derive(Clone)]
struct TransMatrix {
storage: Vec<f64>, rowsums: Vec<f64>,
}
impl TransMatrix {
#[allow(dead_code)] fn new(alpha: f64) -> Self {
Self {
storage: vec![alpha.ln(); NUM_ALN_STATES * NUM_ALN_STATES],
rowsums: vec![(NUM_ALN_STATES as f64 * alpha).ln(); NUM_ALN_STATES],
}
}
fn empty() -> Self {
Self {
storage: vec![f64::NEG_INFINITY; NUM_ALN_STATES * NUM_ALN_STATES],
rowsums: vec![f64::NEG_INFINITY; NUM_ALN_STATES],
}
}
#[inline]
fn increment(&mut self, prev: usize, cur: usize, amt: f64) {
let k = prev * NUM_ALN_STATES + cur;
self.storage[k] = log_add(self.storage[k], amt);
self.rowsums[prev] = log_add(self.rowsums[prev], amt);
}
fn clear(&mut self) {
self.storage.iter_mut().for_each(|v| *v = f64::NEG_INFINITY);
self.rowsums.iter_mut().for_each(|v| *v = f64::NEG_INFINITY);
}
fn combine(&mut self, other: &TransMatrix) {
for (s, o) in self.storage.iter_mut().zip(&other.storage) {
*s = log_add(*s, *o);
}
for (s, o) in self.rowsums.iter_mut().zip(&other.rowsums) {
*s = log_add(*s, *o);
}
}
#[inline]
#[allow(dead_code)] fn get(&self, prev: usize, cur: usize) -> f64 {
self.storage[prev * NUM_ALN_STATES + cur] - self.rowsums[prev]
}
}
#[derive(Clone)]
pub struct AlignmentModel {
left: Vec<TransMatrix>,
right: Vec<TransMatrix>,
read_bins: usize,
}
impl AlignmentModel {
#[allow(dead_code)] pub fn new(alpha: f64, read_bins: usize) -> Self {
Self {
left: (0..read_bins).map(|_| TransMatrix::new(alpha)).collect(),
right: (0..read_bins).map(|_| TransMatrix::new(alpha)).collect(),
read_bins,
}
}
pub fn clear(&mut self) {
self.left.iter_mut().for_each(|m| m.clear());
self.right.iter_mut().for_each(|m| m.clear());
}
pub fn empty(read_bins: usize) -> Self {
Self {
left: (0..read_bins).map(|_| TransMatrix::empty()).collect(),
right: (0..read_bins).map(|_| TransMatrix::empty()).collect(),
read_bins,
}
}
pub fn combine(&mut self, other: &AlignmentModel) {
for (s, o) in self.left.iter_mut().zip(&other.left) {
s.combine(o);
}
for (s, o) in self.right.iter_mut().zip(&other.right) {
s.combine(o);
}
}
fn walk<F: FnMut(usize, usize, usize)>(
read_bins: usize,
read_2bit: &[u8],
ref_bytes: &[u8],
pos: usize,
ops: &[(AlnOp, usize)],
mut f: F,
) {
let read_len = read_2bit.len();
if read_len == 0 || ref_bytes.is_empty() {
return;
}
let inv_len = read_bins as f64 / read_len as f64;
let mut read_idx = 0usize;
let mut ref_idx = pos;
let mut prev = START_STATE;
for &(op, op_len) in ops {
for _ in 0..op_len {
if op.consume_seq() && read_idx >= read_len {
return; }
if op.consume_ref() && ref_idx >= ref_bytes.len() {
return;
}
let mut cur_read = if op.consume_seq() {
read_2bit[read_idx] as usize
} else {
0
};
let mut cur_ref = if op.consume_ref() {
ref_2bit(ref_bytes[ref_idx])
} else {
0
};
op.set_bases(&mut cur_ref, &mut cur_read);
let bin = ((read_idx as f64 * inv_len) as usize).min(read_bins - 1);
let cur = cur_ref * NUM_STATES + cur_read;
f(bin, prev, cur);
prev = cur;
if op.consume_seq() {
read_idx += 1;
}
if op.consume_ref() {
ref_idx += 1;
}
}
}
}
pub fn update(
&mut self,
read_2bit: &[u8],
ref_bytes: &[u8],
pos: usize,
ops: &[(AlnOp, usize)],
is_left: bool,
log_weight: f64,
) {
let mut trans: Vec<(usize, usize, usize)> = Vec::new();
Self::walk(
self.read_bins,
read_2bit,
ref_bytes,
pos,
ops,
|bin, prev, cur| {
trans.push((bin, prev, cur));
},
);
let mats = if is_left {
&mut self.left
} else {
&mut self.right
};
for (bin, prev, cur) in trans {
mats[bin].increment(prev, cur, log_weight);
}
}
#[allow(dead_code)] pub fn log_likelihood(
&self,
read_2bit: &[u8],
ref_bytes: &[u8],
pos: usize,
ops: &[(AlnOp, usize)],
is_left: bool,
) -> (f64, f64) {
let mats = if is_left { &self.left } else { &self.right };
let mut fg = 0.0; let mut bg = 0.0;
Self::walk(
self.read_bins,
read_2bit,
ref_bytes,
pos,
ops,
|bin, prev, cur| {
fg += mats[bin].get(prev, cur);
bg += mats[bin].get(0, 0);
},
);
(fg, bg)
}
}
struct SharedTransMatrix {
storage: Vec<salmon_core::atomic::AtomicF64>,
rowsums: Vec<salmon_core::atomic::AtomicF64>,
}
impl SharedTransMatrix {
fn new(alpha: f64) -> Self {
Self {
storage: (0..NUM_ALN_STATES * NUM_ALN_STATES)
.map(|_| salmon_core::atomic::AtomicF64::new(alpha.ln()))
.collect(),
rowsums: (0..NUM_ALN_STATES)
.map(|_| salmon_core::atomic::AtomicF64::new((NUM_ALN_STATES as f64 * alpha).ln()))
.collect(),
}
}
#[inline]
fn get(&self, prev: usize, cur: usize) -> f64 {
self.storage[prev * NUM_ALN_STATES + cur].load() - self.rowsums[prev].load()
}
fn flush_from(&self, delta: &TransMatrix) {
for (a, &d) in self.storage.iter().zip(&delta.storage) {
if d != f64::NEG_INFINITY {
a.log_add_assign(d);
}
}
for (a, &d) in self.rowsums.iter().zip(&delta.rowsums) {
if d != f64::NEG_INFINITY {
a.log_add_assign(d);
}
}
}
}
pub struct SharedAlignmentModel {
left: Vec<SharedTransMatrix>,
right: Vec<SharedTransMatrix>,
read_bins: usize,
}
impl SharedAlignmentModel {
pub fn new(alpha: f64, read_bins: usize) -> Self {
Self {
left: (0..read_bins)
.map(|_| SharedTransMatrix::new(alpha))
.collect(),
right: (0..read_bins)
.map(|_| SharedTransMatrix::new(alpha))
.collect(),
read_bins,
}
}
pub fn log_likelihood(
&self,
read_2bit: &[u8],
ref_bytes: &[u8],
pos: usize,
ops: &[(AlnOp, usize)],
is_left: bool,
) -> (f64, f64) {
let mats = if is_left { &self.left } else { &self.right };
let mut fg = 0.0;
let mut bg = 0.0;
AlignmentModel::walk(
self.read_bins,
read_2bit,
ref_bytes,
pos,
ops,
|bin, prev, cur| {
fg += mats[bin].get(prev, cur);
bg += mats[bin].get(0, 0);
},
);
(fg, bg)
}
pub fn flush_from(&self, delta: &AlignmentModel) {
for (s, d) in self.left.iter().zip(&delta.left) {
s.flush_from(d);
}
for (s, d) in self.right.iter().zip(&delta.right) {
s.flush_from(d);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn perfect() -> (Vec<u8>, Vec<u8>, Vec<(AlnOp, usize)>) {
let read = vec![0u8, 1, 2, 3, 0];
let refs = b"ACGTA".to_vec();
(read, refs, vec![(AlnOp::Match, 5)])
}
#[test]
fn matches_score_higher_than_mismatches_after_training() {
let mut m = AlignmentModel::new(1.0, 4);
let (read, refs, ops) = perfect();
for _ in 0..200 {
m.update(&read, &refs, 0, &ops, true, 0.0);
}
let mut bad = read.clone();
bad[2] = 0; let (fg_good, bg_good) = m.log_likelihood(&read, &refs, 0, &ops, true);
let (fg_bad, bg_bad) = m.log_likelihood(&bad, &refs, 0, &ops, true);
assert!(fg_good > fg_bad, "fg_good {fg_good} !> fg_bad {fg_bad}");
assert!(
(fg_good - bg_good) > (fg_bad - bg_bad),
"perfect score {} !> mismatch score {}",
fg_good - bg_good,
fg_bad - bg_bad
);
}
#[test]
fn untrained_model_is_neutral() {
let m = AlignmentModel::new(1.0, 4);
let (read, refs, ops) = perfect();
let (fg, bg) = m.log_likelihood(&read, &refs, 0, &ops, true);
assert!((fg - bg).abs() < 1e-9, "untrained score {} not 0", fg - bg);
}
}