use crate::alignment::gotoh::GotohScoring;
use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WfaOp {
Match,
Mismatch,
Ins,
Del,
}
#[derive(Debug, Clone)]
pub struct WfaAlignment {
pub score: i32,
pub penalty: i32,
pub cigar: Vec<WfaOp>,
}
#[derive(Debug, Clone, Copy)]
struct WfaPenalties {
x: i32,
o: i32,
e: i32,
}
impl WfaPenalties {
fn from_gotoh(sc: &GotohScoring) -> SeqResult<Self> {
let x = 2 * (sc.match_score - sc.mismatch);
let o = 2 * (sc.gap_extend - sc.gap_open);
let e = sc.match_score - 2 * sc.gap_extend;
if x <= 0 {
return Err(SeqError::InvalidConfiguration(format!(
"WFA requires a positive mismatch penalty (match_score must exceed mismatch); \
derived x = {x}"
)));
}
if o < 0 {
return Err(SeqError::InvalidConfiguration(format!(
"WFA requires a non-negative gap-open penalty (gap_extend must be >= gap_open); \
derived o = {o}"
)));
}
if e <= 0 {
return Err(SeqError::InvalidConfiguration(format!(
"WFA requires a positive gap-extend penalty (match_score must exceed 2*gap_extend); \
derived e = {e}"
)));
}
Ok(Self { x, o, e })
}
}
const NIL: i32 = i32::MIN;
#[derive(Debug, Clone)]
struct Wavefront {
lo: i32,
hi: i32,
offsets: Vec<i32>,
}
impl Wavefront {
fn new(lo: i32, hi: i32) -> Self {
let len = if hi >= lo { (hi - lo + 1) as usize } else { 0 };
Self {
lo,
hi,
offsets: vec![NIL; len],
}
}
#[inline]
fn get(&self, k: i32) -> i32 {
if k < self.lo || k > self.hi {
NIL
} else {
self.offsets[(k - self.lo) as usize]
}
}
#[inline]
fn set(&mut self, k: i32, v: i32) {
if k >= self.lo && k <= self.hi {
self.offsets[(k - self.lo) as usize] = v;
}
}
}
#[derive(Debug, Clone)]
struct WfSet {
m: Wavefront,
i: Wavefront,
d: Wavefront,
}
pub fn wfa_align(a: &[u8], b: &[u8], sc: &GotohScoring) -> SeqResult<WfaAlignment> {
let m = a.len();
let n = b.len();
if m == 0 || n == 0 {
return Err(SeqError::EmptyInput);
}
let pen = WfaPenalties::from_gotoh(sc)?;
let m_i = m as i32;
let n_i = n as i32;
let k_final = m_i - n_i;
let a_off_max = m_i;
let mut history: Vec<WfSet> = Vec::new();
{
let mut m_wf = Wavefront::new(0, 0);
m_wf.set(0, 0);
extend(&mut m_wf, a, b);
let set = WfSet {
m: m_wf,
i: Wavefront::new(0, -1),
d: Wavefront::new(0, -1),
};
if reached(&set.m, k_final, a_off_max) {
let cigar = traceback(&history, &set, &pen, k_final);
return finish(0, m, n, sc, cigar);
}
history.push(set);
}
let max_pen = (m_i + n_i) * (pen.x + pen.o + pen.e) + pen.o + pen.e;
let mut s = 1i32;
loop {
if s > max_pen {
return Err(SeqError::NumericalInstability(
"WFA failed to reach the alignment endpoint within the penalty bound".into(),
));
}
let set = compute_next(&history, s, &pen, k_final, a, b);
if reached(&set.m, k_final, a_off_max) {
let cigar = traceback(&history, &set, &pen, k_final);
return finish(s, m, n, sc, cigar);
}
history.push(set);
s += 1;
}
}
fn extend(m_wf: &mut Wavefront, a: &[u8], b: &[u8]) {
let m = a.len() as i32;
let n = b.len() as i32;
for k in m_wf.lo..=m_wf.hi {
let mut off = m_wf.get(k);
if off == NIL {
continue;
}
loop {
let i = off;
let j = off - k;
if i < m && j >= 0 && j < n && a[i as usize] == b[j as usize] {
off += 1;
} else {
break;
}
}
m_wf.set(k, off);
}
}
#[inline]
fn reached(m_wf: &Wavefront, k_final: i32, a_off_max: i32) -> bool {
m_wf.get(k_final) >= a_off_max
}
fn compute_next(
history: &[WfSet],
s: i32,
pen: &WfaPenalties,
k_final: i32,
a: &[u8],
b: &[u8],
) -> WfSet {
let s_x = s - pen.x; let s_o_e = s - pen.o - pen.e; let s_e = s - pen.e;
let mut lo = k_final;
let mut hi = k_final;
for &(sp, grow) in &[(s_x, 0i32), (s_o_e, 1), (s_e, 1)] {
if sp >= 0 {
if let Some(set) = history.get(sp as usize) {
lo = lo.min(set.m.lo - grow);
hi = hi.max(set.m.hi + grow);
lo = lo.min(set.i.lo - grow);
hi = hi.max(set.i.hi + grow);
lo = lo.min(set.d.lo - grow);
hi = hi.max(set.d.hi + grow);
}
}
}
let mut i_wf = Wavefront::new(lo, hi);
let mut d_wf = Wavefront::new(lo, hi);
let mut m_wf = Wavefront::new(lo, hi);
let m_open = history.get_at(s_o_e, |set| &set.m);
let i_ext = history.get_at(s_e, |set| &set.i);
let d_ext = history.get_at(s_e, |set| &set.d);
let m_mis = history.get_at(s_x, |set| &set.m);
for k in lo..=hi {
let i_from_open = opt_get(m_open, k + 1);
let i_from_ext = opt_get(i_ext, k + 1);
let i_val = max2(i_from_open, i_from_ext);
i_wf.set(k, i_val);
let d_open = opt_get(m_open, k - 1);
let d_ext_v = opt_get(d_ext, k - 1);
let d_pred = max2(d_open, d_ext_v);
let d_val = if d_pred == NIL { NIL } else { d_pred + 1 };
d_wf.set(k, d_val);
let m_sub = {
let v = opt_get(m_mis, k);
if v == NIL { NIL } else { v + 1 }
};
let m_val = max3(m_sub, i_val, d_val);
m_wf.set(k, m_val);
}
extend(&mut m_wf, a, b);
WfSet {
m: m_wf,
i: i_wf,
d: d_wf,
}
}
#[inline]
fn max2(a: i32, b: i32) -> i32 {
if a == NIL {
b
} else if b == NIL {
a
} else {
a.max(b)
}
}
#[inline]
fn max3(a: i32, b: i32, c: i32) -> i32 {
max2(max2(a, b), c)
}
#[inline]
fn opt_get(wf: Option<&Wavefront>, k: i32) -> i32 {
match wf {
Some(w) => w.get(k),
None => NIL,
}
}
trait HistoryExt {
fn get_at<'a, F>(&'a self, s: i32, f: F) -> Option<&'a Wavefront>
where
F: Fn(&'a WfSet) -> &'a Wavefront;
}
impl HistoryExt for [WfSet] {
#[inline]
fn get_at<'a, F>(&'a self, s: i32, f: F) -> Option<&'a Wavefront>
where
F: Fn(&'a WfSet) -> &'a Wavefront,
{
if s < 0 {
None
} else {
self.get(s as usize).map(f)
}
}
}
fn finish(
penalty: i32,
m: usize,
n: usize,
sc: &GotohScoring,
cigar: Vec<WfaOp>,
) -> SeqResult<WfaAlignment> {
let score = ((m as i32 + n as i32) * sc.match_score - penalty) / 2;
Ok(WfaAlignment {
score,
penalty,
cigar,
})
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum Comp {
M,
I,
D,
}
enum MOrigin {
Start,
Match { target_off: i32 },
Mismatch { prev_s: i32, prev_off: i32 },
FromI,
FromD,
}
struct Tracer<'a> {
history: &'a [WfSet],
final_set: &'a WfSet,
pen: WfaPenalties,
s_opt: i32,
}
impl<'a> Tracer<'a> {
fn new(history: &'a [WfSet], final_set: &'a WfSet, pen: WfaPenalties) -> Self {
Self {
history,
final_set,
pen,
s_opt: history.len() as i32,
}
}
fn get_set(&self, sp: i32) -> Option<&'a WfSet> {
if sp == self.s_opt {
Some(self.final_set)
} else if sp >= 0 {
self.history.get(sp as usize)
} else {
None
}
}
fn m_origin(&self, s: i32, k: i32, off: i32) -> MOrigin {
let s_x = s - self.pen.x;
let mis_pred = self.get_set(s_x).map(|set| set.m.get(k)).unwrap_or(NIL);
let mis_bare = if mis_pred == NIL { NIL } else { mis_pred + 1 };
let i_here = self.get_set(s).map(|set| set.i.get(k)).unwrap_or(NIL);
let d_here = self.get_set(s).map(|set| set.d.get(k)).unwrap_or(NIL);
let mut best_bare = NIL;
let mut kind = 0u8; for (cand, kd) in [(mis_bare, 1u8), (i_here, 2), (d_here, 3)] {
if cand != NIL && cand <= off && cand > best_bare {
best_bare = cand;
kind = kd;
}
}
if best_bare == NIL {
return MOrigin::Start;
}
if best_bare < off {
return MOrigin::Match {
target_off: best_bare,
};
}
match kind {
1 => MOrigin::Mismatch {
prev_s: s_x,
prev_off: mis_pred,
},
2 => MOrigin::FromI,
_ => MOrigin::FromD,
}
}
fn run(&self, k_final: i32) -> Vec<WfaOp> {
let mut ops: Vec<WfaOp> = Vec::new();
let mut s = self.s_opt;
let mut k = k_final;
let mut comp = Comp::M;
let mut off = self.final_set.m.get(k);
loop {
match comp {
Comp::M => match self.m_origin(s, k, off) {
MOrigin::Start => {
for _ in 0..off.max(0) {
ops.push(WfaOp::Match);
}
break;
}
MOrigin::Match { target_off } => {
let mut cur = off;
while cur > target_off {
ops.push(WfaOp::Match);
cur -= 1;
}
off = target_off;
}
MOrigin::Mismatch { prev_s, prev_off } => {
ops.push(WfaOp::Mismatch);
s = prev_s;
off = prev_off;
}
MOrigin::FromI => {
comp = Comp::I;
if let Some(set) = self.get_set(s) {
off = set.i.get(k);
}
}
MOrigin::FromD => {
comp = Comp::D;
if let Some(set) = self.get_set(s) {
off = set.d.get(k);
}
}
},
Comp::I => {
ops.push(WfaOp::Ins);
let s_o_e = s - self.pen.o - self.pen.e;
let s_e = s - self.pen.e;
let open = self
.get_set(s_o_e)
.map(|set| set.m.get(k + 1))
.unwrap_or(NIL);
let ext = self.get_set(s_e).map(|set| set.i.get(k + 1)).unwrap_or(NIL);
if ext != NIL && ext == off {
s = s_e;
k += 1;
comp = Comp::I;
} else if open != NIL && open == off {
s = s_o_e;
k += 1;
comp = Comp::M;
} else if open != NIL {
s = s_o_e;
k += 1;
comp = Comp::M;
off = open;
} else if ext != NIL {
s = s_e;
k += 1;
comp = Comp::I;
off = ext;
} else {
break;
}
}
Comp::D => {
ops.push(WfaOp::Del);
let s_o_e = s - self.pen.o - self.pen.e;
let s_e = s - self.pen.e;
let open = self
.get_set(s_o_e)
.map(|set| set.m.get(k - 1))
.unwrap_or(NIL);
let ext = self.get_set(s_e).map(|set| set.d.get(k - 1)).unwrap_or(NIL);
let pred_off = off - 1;
if ext != NIL && ext == pred_off {
s = s_e;
k -= 1;
off = pred_off;
comp = Comp::D;
} else if open != NIL && open == pred_off {
s = s_o_e;
k -= 1;
off = pred_off;
comp = Comp::M;
} else if open != NIL {
s = s_o_e;
k -= 1;
off = open;
comp = Comp::M;
} else if ext != NIL {
s = s_e;
k -= 1;
off = ext;
comp = Comp::D;
} else {
break;
}
}
}
}
ops.reverse();
ops
}
}
fn traceback(history: &[WfSet], final_set: &WfSet, pen: &WfaPenalties, k_final: i32) -> Vec<WfaOp> {
Tracer::new(history, final_set, *pen).run(k_final)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alignment::gotoh::gotoh_align;
fn default_sc() -> GotohScoring {
GotohScoring::default()
}
fn custom_sc() -> GotohScoring {
GotohScoring {
match_score: 3,
mismatch: -2,
gap_open: -6,
gap_extend: -2,
}
}
fn score_cigar(a: &[u8], b: &[u8], cigar: &[WfaOp], sc: &GotohScoring) -> i32 {
let mut score = 0i32;
let mut i = 0usize;
let mut j = 0usize;
let mut prev: Option<WfaOp> = None;
for &op in cigar {
match op {
WfaOp::Match => {
score += sc.match_score;
i += 1;
j += 1;
}
WfaOp::Mismatch => {
score += sc.mismatch;
i += 1;
j += 1;
}
WfaOp::Ins => {
if prev == Some(WfaOp::Ins) {
score += sc.gap_extend;
} else {
score += sc.gap_open;
}
j += 1;
}
WfaOp::Del => {
if prev == Some(WfaOp::Del) {
score += sc.gap_extend;
} else {
score += sc.gap_open;
}
i += 1;
}
}
prev = Some(op);
}
assert_eq!(i, a.len(), "cigar must consume all of a");
assert_eq!(j, b.len(), "cigar must consume all of b");
score
}
fn check_consumption(a: &[u8], b: &[u8], cigar: &[WfaOp]) {
let consumes_a = cigar
.iter()
.filter(|o| matches!(o, WfaOp::Match | WfaOp::Mismatch | WfaOp::Del))
.count();
let consumes_b = cigar
.iter()
.filter(|o| matches!(o, WfaOp::Match | WfaOp::Mismatch | WfaOp::Ins))
.count();
assert_eq!(consumes_a, a.len(), "Match+Mismatch+Del must consume a");
assert_eq!(consumes_b, b.len(), "Match+Mismatch+Ins must consume b");
}
#[test]
fn central_cross_check_matches_gotoh() {
let pairs: &[(&[u8], &[u8])] = &[
(b"GATTACA", b"GCATGCU"),
(b"ACGTACGT", b"ACGTTCGT"),
(b"AAAA", b"AAAAGGGGAAAA"),
(b"ACGT", b"TGCA"),
(b"AGGGCT", b"AGGCT"),
(b"HELLOWORLD", b"HELOWRLD"),
];
for sc in [default_sc(), custom_sc()] {
for &(a, b) in pairs {
let w = wfa_align(a, b, &sc).expect("wfa ok");
let g = gotoh_align(a, b, &sc).expect("gotoh ok");
assert_eq!(
w.score,
g.score,
"score mismatch on {:?} vs {:?} with {:?}",
std::str::from_utf8(a),
std::str::from_utf8(b),
sc
);
check_consumption(a, b, &w.cigar);
assert_eq!(
score_cigar(a, b, &w.cigar, &sc),
w.score,
"cigar re-score mismatch on {:?} vs {:?}",
std::str::from_utf8(a),
std::str::from_utf8(b),
);
}
}
}
#[test]
fn identical_sequences() {
let a = b"ACGTACGT";
let sc = default_sc();
let w = wfa_align(a, a, &sc).expect("ok");
assert_eq!(w.penalty, 0);
assert_eq!(w.score, sc.match_score * a.len() as i32);
assert!(w.cigar.iter().all(|o| *o == WfaOp::Match));
assert_eq!(w.cigar.len(), a.len());
}
#[test]
fn traceback_validity() {
let sc = default_sc();
let cases: &[(&[u8], &[u8])] = &[
(b"GATTACA", b"GCATGCU"),
(b"ACGTACGTACGT", b"ACGTTTACGT"),
(b"BANANA", b"ANANAS"),
];
for &(a, b) in cases {
let w = wfa_align(a, b, &sc).expect("ok");
check_consumption(a, b, &w.cigar);
assert_eq!(score_cigar(a, b, &w.cigar, &sc), w.score);
let g = gotoh_align(a, b, &sc).expect("ok");
assert_eq!(w.score, g.score);
}
}
#[test]
fn affine_single_long_gap() {
let sc = default_sc();
let a = b"ACGTACGT";
let b = b"ACGTGGGGACGT";
let w = wfa_align(a, b, &sc).expect("ok");
let g = gotoh_align(a, b, &sc).expect("ok");
assert_eq!(w.score, g.score);
let ins = w.cigar.iter().filter(|o| **o == WfaOp::Ins).count();
let del = w.cigar.iter().filter(|o| **o == WfaOp::Del).count();
assert_eq!(ins, 4, "expected 4 inserted symbols, cigar = {:?}", w.cigar);
assert_eq!(del, 0);
let runs = count_runs(&w.cigar, WfaOp::Ins);
assert_eq!(runs, 1, "Ins must form a single run, cigar = {:?}", w.cigar);
}
fn count_runs(cigar: &[WfaOp], op: WfaOp) -> usize {
let mut runs = 0;
let mut in_run = false;
for &c in cigar {
if c == op {
if !in_run {
runs += 1;
in_run = true;
}
} else {
in_run = false;
}
}
runs
}
#[test]
fn single_mismatch_cost() {
let sc = default_sc();
let a = b"ACGTACGT";
let b = b"ACGTTCGT"; let w = wfa_align(a, b, &sc).expect("ok");
let g = gotoh_align(a, b, &sc).expect("ok");
assert_eq!(w.score, g.score);
let len = a.len() as i32;
assert_eq!(w.score, sc.match_score * (len - 1) + sc.mismatch);
assert_eq!(w.penalty, 2 * (sc.match_score - sc.mismatch));
}
#[test]
fn empty_sequence_errors() {
let sc = default_sc();
assert!(matches!(
wfa_align(b"", b"ACGT", &sc),
Err(SeqError::EmptyInput)
));
assert!(matches!(
wfa_align(b"ACGT", b"", &sc),
Err(SeqError::EmptyInput)
));
assert!(matches!(
gotoh_align(b"", b"ACGT", &sc),
Err(SeqError::EmptyInput)
));
assert!(matches!(
gotoh_align(b"ACGT", b"", &sc),
Err(SeqError::EmptyInput)
));
}
#[test]
fn long_match_extension() {
let sc = default_sc();
let prefix = vec![b'A'; 50];
let mut a = prefix.clone();
a.extend_from_slice(b"CGTACG");
let mut b = prefix.clone();
b.extend_from_slice(b"CTTACG"); let w = wfa_align(&a, &b, &sc).expect("ok");
let g = gotoh_align(&a, &b, &sc).expect("ok");
assert_eq!(w.score, g.score);
assert!(w.penalty > 0);
check_consumption(&a, &b, &w.cigar);
assert_eq!(score_cigar(&a, &b, &w.cigar, &sc), w.score);
}
#[test]
fn degenerate_scoring_rejected() {
let bad = GotohScoring {
match_score: 1,
mismatch: 1,
gap_open: -5,
gap_extend: -1,
};
assert!(matches!(
wfa_align(b"AC", b"AG", &bad),
Err(SeqError::InvalidConfiguration(_))
));
let bad_open = GotohScoring {
match_score: 2,
mismatch: -1,
gap_open: -1,
gap_extend: -5,
};
assert!(matches!(
wfa_align(b"AC", b"AG", &bad_open),
Err(SeqError::InvalidConfiguration(_))
));
}
#[test]
fn randomized_cross_check_matches_gotoh() {
use crate::handle::LcgRng;
let alphabet = b"ACGT";
let schemes = [
GotohScoring::default(),
GotohScoring {
match_score: 3,
mismatch: -2,
gap_open: -6,
gap_extend: -2,
},
GotohScoring {
match_score: 1,
mismatch: -1,
gap_open: -2,
gap_extend: -1,
},
GotohScoring {
match_score: 4,
mismatch: -3,
gap_open: -8,
gap_extend: -1,
},
GotohScoring {
match_score: 2,
mismatch: 0,
gap_open: -4,
gap_extend: -1,
},
];
let mut rng = LcgRng::new(0x5EED_1234_ABCD);
for sc in schemes {
assert!(WfaPenalties::from_gotoh(&sc).is_ok());
for _ in 0..120 {
let la = 1 + rng.next_usize(14);
let lb = 1 + rng.next_usize(14);
let a: Vec<u8> = (0..la).map(|_| alphabet[rng.next_usize(4)]).collect();
let b: Vec<u8> = (0..lb).map(|_| alphabet[rng.next_usize(4)]).collect();
let w = wfa_align(&a, &b, &sc).expect("wfa ok");
let g = gotoh_align(&a, &b, &sc).expect("gotoh ok");
assert_eq!(
w.score,
g.score,
"score mismatch: a={:?} b={:?} sc={:?} (wfa={} gotoh={})",
std::str::from_utf8(&a),
std::str::from_utf8(&b),
sc,
w.score,
g.score,
);
check_consumption(&a, &b, &w.cigar);
assert_eq!(
score_cigar(&a, &b, &w.cigar, &sc),
w.score,
"cigar re-score mismatch: a={:?} b={:?}",
std::str::from_utf8(&a),
std::str::from_utf8(&b),
);
}
}
}
#[test]
fn asymmetric_gaps_match_gotoh() {
let sc = custom_sc();
let cases: &[(&[u8], &[u8])] = &[
(b"AAAAGGGGAAAA", b"AAAA"), (b"AAAA", b"AAAAGGGGAAAA"), (b"ACGTACGTACGT", b"ACGT"), (b"ACGT", b"ACGTACGTACGT"), (b"TTTTACGTTTTT", b"ACGT"), (b"GATTACAGATTACA", b"GATTACA"), ];
for &(a, b) in cases {
let w = wfa_align(a, b, &sc).expect("ok");
let g = gotoh_align(a, b, &sc).expect("ok");
assert_eq!(
w.score,
g.score,
"mismatch on {:?} vs {:?}",
std::str::from_utf8(a),
std::str::from_utf8(b),
);
check_consumption(a, b, &w.cigar);
assert_eq!(score_cigar(a, b, &w.cigar, &sc), w.score);
}
}
}