Skip to main content

oxicuda_seq/align/
wfa.rs

1//! WFA — the Wavefront Alignment algorithm for gap-affine **global** alignment.
2//!
3//! This is a faithful implementation of the algorithm of Marco-Sola, Moure,
4//! Moreto & Espinosa, *"Fast gap-affine pairwise alignment using the wavefront
5//! algorithm"*, Bioinformatics 37(4):456–463, 2021.
6//!
7//! # Model
8//!
9//! WFA **minimizes** an alignment *penalty* under the gap-affine model
10//!
11//! ```text
12//! match     → 0
13//! mismatch  → x          (x > 0)
14//! gap run   → o + k·e    (gap-open o paid once per run, plus e per gap symbol
15//!                          INCLUDING the first; o ≥ 0, e > 0)
16//! ```
17//!
18//! Instead of filling an `(m+1)·(n+1)` dynamic-programming matrix, WFA tracks,
19//! for each increasing penalty `s`, the *furthest-reaching* point on every
20//! diagonal `k = i − j` (the "wavefront"). On similar sequences only a narrow
21//! band of diagonals is ever touched, which is what makes the algorithm fast.
22//!
23//! Three wavefront components are maintained per penalty `s`:
24//!
25//! * `M` — the match / substitution path (the alignment proper),
26//! * `I` — the *insertion* path (a gap in `a`, consuming a character of `b`),
27//! * `D` — the *deletion* path (a gap in `b`, consuming a character of `a`).
28//!
29//! # Convention
30//!
31//! We align `a` along the rows (index `i`, length `m`) and `b` along the
32//! columns (index `j`, length `n`). A diagonal is `k = i − j`. The *offset*
33//! stored on a diagonal is `i`, the number of characters of `a` consumed, so
34//! the matching column is `j = i − k`. The optimum is reached when the `M`
35//! wavefront on the final diagonal `k_final = m − n` attains offset `m`
36//! (equivalently `i = m`, `j = n`).
37//!
38//! * [`WfaOp::Ins`] is a gap in `a`: it consumes one character of `b` only.
39//! * [`WfaOp::Del`] is a gap in `b`: it consumes one character of `a` only.
40//!
41//! # Cross-check with Gotoh
42//!
43//! [`crate::alignment::gotoh::gotoh_align`] solves the *same* problem but
44//! **maximizes** a score. Given a [`GotohScoring`] `(M, mis, go, ge)` we derive
45//! the (×2-scaled, integral) WFA penalties
46//!
47//! ```text
48//! x = 2·(M − mis)        // mismatch penalty
49//! o = 2·(ge − go)        // gap-open penalty
50//! e = M − 2·ge           // gap-extend penalty
51//! ```
52//!
53//! Run WFA to obtain the minimum penalty `P`; the equivalent Gotoh maximum
54//! score is then
55//!
56//! ```text
57//! gotoh_score = ((m + n)·M − P) / 2
58//! ```
59//!
60//! which is exact (the division is always even). [`WfaAlignment`] reports both
61//! the raw `penalty` and the converted `score`.
62
63use crate::alignment::gotoh::GotohScoring;
64use crate::error::{SeqError, SeqResult};
65
66/// A single edit operation of a WFA alignment, in left-to-right order.
67///
68/// The convention is:
69///
70/// * [`WfaOp::Match`] / [`WfaOp::Mismatch`] consume one character of *both* `a`
71///   and `b`.
72/// * [`WfaOp::Ins`] is a gap in `a`; it consumes one character of `b` only.
73/// * [`WfaOp::Del`] is a gap in `b`; it consumes one character of `a` only.
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum WfaOp {
76    /// Aligned, equal characters (`a[i] == b[j]`).
77    Match,
78    /// Aligned, unequal characters (`a[i] != b[j]`).
79    Mismatch,
80    /// Insertion: gap in `a`, consumes `b[j]`.
81    Ins,
82    /// Deletion: gap in `b`, consumes `a[i]`.
83    Del,
84}
85
86/// Result of a WFA gap-affine global alignment.
87#[derive(Debug, Clone)]
88pub struct WfaAlignment {
89    /// The equivalent Gotoh **maximum** score (`((m+n)·M − penalty) / 2`).
90    pub score: i32,
91    /// The raw, ×2-scaled WFA **minimum** penalty.
92    pub penalty: i32,
93    /// The optimal alignment as a left-to-right list of edit operations.
94    pub cigar: Vec<WfaOp>,
95}
96
97/// The ×2-scaled, integral gap-affine penalties derived from a [`GotohScoring`].
98#[derive(Debug, Clone, Copy)]
99struct WfaPenalties {
100    /// Mismatch penalty `x`.
101    x: i32,
102    /// Gap-open penalty `o` (paid once per gap run).
103    o: i32,
104    /// Gap-extend penalty `e` (paid per gap symbol, including the first).
105    e: i32,
106}
107
108impl WfaPenalties {
109    /// Derive the ×2-scaled WFA penalties from a Gotoh scoring scheme.
110    ///
111    /// Returns [`SeqError::InvalidConfiguration`] when the resulting affine
112    /// model is degenerate (the algorithm needs `x > 0`, `o ≥ 0`, `e > 0`).
113    fn from_gotoh(sc: &GotohScoring) -> SeqResult<Self> {
114        let x = 2 * (sc.match_score - sc.mismatch);
115        let o = 2 * (sc.gap_extend - sc.gap_open);
116        let e = sc.match_score - 2 * sc.gap_extend;
117        if x <= 0 {
118            return Err(SeqError::InvalidConfiguration(format!(
119                "WFA requires a positive mismatch penalty (match_score must exceed mismatch); \
120                 derived x = {x}"
121            )));
122        }
123        if o < 0 {
124            return Err(SeqError::InvalidConfiguration(format!(
125                "WFA requires a non-negative gap-open penalty (gap_extend must be >= gap_open); \
126                 derived o = {o}"
127            )));
128        }
129        if e <= 0 {
130            return Err(SeqError::InvalidConfiguration(format!(
131                "WFA requires a positive gap-extend penalty (match_score must exceed 2*gap_extend); \
132                 derived e = {e}"
133            )));
134        }
135        Ok(Self { x, o, e })
136    }
137}
138
139/// Sentinel marking an unreachable diagonal/offset.
140const NIL: i32 = i32::MIN;
141
142/// A single wavefront component (`M`, `I` or `D`) at one penalty.
143///
144/// Offsets are stored densely over the inclusive diagonal range
145/// `[lo, hi]`; unreachable diagonals hold [`NIL`].
146#[derive(Debug, Clone)]
147struct Wavefront {
148    /// Lowest diagonal index covered (inclusive).
149    lo: i32,
150    /// Highest diagonal index covered (inclusive).
151    hi: i32,
152    /// Dense offsets indexed by `k - lo`.
153    offsets: Vec<i32>,
154}
155
156impl Wavefront {
157    /// An empty (all-[`NIL`]) wavefront covering `[lo, hi]`.
158    fn new(lo: i32, hi: i32) -> Self {
159        let len = if hi >= lo { (hi - lo + 1) as usize } else { 0 };
160        Self {
161            lo,
162            hi,
163            offsets: vec![NIL; len],
164        }
165    }
166
167    /// The offset on diagonal `k`, or [`NIL`] if out of range / unreachable.
168    #[inline]
169    fn get(&self, k: i32) -> i32 {
170        if k < self.lo || k > self.hi {
171            NIL
172        } else {
173            self.offsets[(k - self.lo) as usize]
174        }
175    }
176
177    /// Set the offset on diagonal `k` (no-op if `k` is out of range).
178    #[inline]
179    fn set(&mut self, k: i32, v: i32) {
180        if k >= self.lo && k <= self.hi {
181            self.offsets[(k - self.lo) as usize] = v;
182        }
183    }
184}
185
186/// The three wavefront components recorded at a single penalty `s`.
187#[derive(Debug, Clone)]
188struct WfSet {
189    m: Wavefront,
190    i: Wavefront,
191    d: Wavefront,
192}
193
194/// Run the WFA gap-affine **global** alignment of `a` against `b`.
195///
196/// `sc` is interpreted exactly as in [`crate::alignment::gotoh::gotoh_align`];
197/// the returned [`WfaAlignment::score`] is guaranteed to equal that function's
198/// score on the same inputs.
199///
200/// # Errors
201///
202/// * [`SeqError::EmptyInput`] if either sequence is empty (mirroring Gotoh).
203/// * [`SeqError::InvalidConfiguration`] if the derived affine model is
204///   degenerate (see `WfaPenalties::from_gotoh`).
205pub fn wfa_align(a: &[u8], b: &[u8], sc: &GotohScoring) -> SeqResult<WfaAlignment> {
206    let m = a.len();
207    let n = b.len();
208    if m == 0 || n == 0 {
209        return Err(SeqError::EmptyInput);
210    }
211    let pen = WfaPenalties::from_gotoh(sc)?;
212
213    let m_i = m as i32;
214    let n_i = n as i32;
215    let k_final = m_i - n_i;
216    let a_off_max = m_i; // offset == i, so the final offset is m.
217
218    // History of wavefront sets, indexed by penalty s.
219    let mut history: Vec<WfSet> = Vec::new();
220
221    // s = 0: only M[0] = 0, then extend.
222    {
223        let mut m_wf = Wavefront::new(0, 0);
224        m_wf.set(0, 0);
225        extend(&mut m_wf, a, b);
226        let set = WfSet {
227            m: m_wf,
228            i: Wavefront::new(0, -1),
229            d: Wavefront::new(0, -1),
230        };
231        if reached(&set.m, k_final, a_off_max) {
232            let cigar = traceback(&history, &set, &pen, k_final);
233            return finish(0, m, n, sc, cigar);
234        }
235        history.push(set);
236    }
237
238    // A generous upper bound on the optimal penalty: the cost of aligning
239    // everything as gaps. We grow up to (and including) this value.
240    let max_pen = (m_i + n_i) * (pen.x + pen.o + pen.e) + pen.o + pen.e;
241
242    let mut s = 1i32;
243    loop {
244        if s > max_pen {
245            // Unreachable for valid positive penalties, but keep the loop total.
246            return Err(SeqError::NumericalInstability(
247                "WFA failed to reach the alignment endpoint within the penalty bound".into(),
248            ));
249        }
250        let set = compute_next(&history, s, &pen, k_final, a, b);
251        if reached(&set.m, k_final, a_off_max) {
252            let cigar = traceback(&history, &set, &pen, k_final);
253            return finish(s, m, n, sc, cigar);
254        }
255        history.push(set);
256        s += 1;
257    }
258}
259
260/// Advance the `M` wavefront along every diagonal while characters match.
261fn extend(m_wf: &mut Wavefront, a: &[u8], b: &[u8]) {
262    let m = a.len() as i32;
263    let n = b.len() as i32;
264    for k in m_wf.lo..=m_wf.hi {
265        let mut off = m_wf.get(k);
266        if off == NIL {
267            continue;
268        }
269        // offset == i, j = i - k.
270        loop {
271            let i = off;
272            let j = off - k;
273            if i < m && j >= 0 && j < n && a[i as usize] == b[j as usize] {
274                off += 1;
275            } else {
276                break;
277            }
278        }
279        m_wf.set(k, off);
280    }
281}
282
283/// Has the `M` wavefront reached the bottom-right corner?
284#[inline]
285fn reached(m_wf: &Wavefront, k_final: i32, a_off_max: i32) -> bool {
286    m_wf.get(k_final) >= a_off_max
287}
288
289/// Compute the wavefront set at penalty `s` from the recorded history, then
290/// extend its `M` component.
291fn compute_next(
292    history: &[WfSet],
293    s: i32,
294    pen: &WfaPenalties,
295    k_final: i32,
296    a: &[u8],
297    b: &[u8],
298) -> WfSet {
299    // Predecessor penalties.
300    let s_x = s - pen.x; // mismatch
301    let s_o_e = s - pen.o - pen.e; // gap open (+ first extend)
302    let s_e = s - pen.e; // gap extend
303
304    // Diagonal range of the new wavefront: union of predecessor ranges, grown
305    // by one on each side to allow opening fresh gaps, and always covering the
306    // final diagonal so the endpoint can be detected.
307    let mut lo = k_final;
308    let mut hi = k_final;
309    for &(sp, grow) in &[(s_x, 0i32), (s_o_e, 1), (s_e, 1)] {
310        if sp >= 0 {
311            if let Some(set) = history.get(sp as usize) {
312                lo = lo.min(set.m.lo - grow);
313                hi = hi.max(set.m.hi + grow);
314                lo = lo.min(set.i.lo - grow);
315                hi = hi.max(set.i.hi + grow);
316                lo = lo.min(set.d.lo - grow);
317                hi = hi.max(set.d.hi + grow);
318            }
319        }
320    }
321
322    let mut i_wf = Wavefront::new(lo, hi);
323    let mut d_wf = Wavefront::new(lo, hi);
324    let mut m_wf = Wavefront::new(lo, hi);
325
326    let m_open = history.get_at(s_o_e, |set| &set.m);
327    let i_ext = history.get_at(s_e, |set| &set.i);
328    let d_ext = history.get_at(s_e, |set| &set.d);
329    let m_mis = history.get_at(s_x, |set| &set.m);
330
331    for k in lo..=hi {
332        // I[k]: gap in a (consumes b → j+1, i unchanged). offset == i unchanged.
333        // Predecessor lives on diagonal k+1 (since k = i - j, j+1 ⇒ k-1; thus a
334        // cell on diagonal k is reached from a cell on diagonal k+1).
335        let i_from_open = opt_get(m_open, k + 1);
336        let i_from_ext = opt_get(i_ext, k + 1);
337        let i_val = max2(i_from_open, i_from_ext);
338        i_wf.set(k, i_val);
339
340        // D[k]: gap in b (consumes a → i+1, j unchanged). offset == i, so +1.
341        // Predecessor lives on diagonal k-1.
342        let d_open = opt_get(m_open, k - 1);
343        let d_ext_v = opt_get(d_ext, k - 1);
344        let d_pred = max2(d_open, d_ext_v);
345        let d_val = if d_pred == NIL { NIL } else { d_pred + 1 };
346        d_wf.set(k, d_val);
347
348        // M[k]: mismatch (i+1, j+1, offset+1) OR fold in I[k] / D[k].
349        let m_sub = {
350            let v = opt_get(m_mis, k);
351            if v == NIL { NIL } else { v + 1 }
352        };
353        let m_val = max3(m_sub, i_val, d_val);
354        m_wf.set(k, m_val);
355    }
356
357    extend(&mut m_wf, a, b);
358    WfSet {
359        m: m_wf,
360        i: i_wf,
361        d: d_wf,
362    }
363}
364
365/// `max` of two offsets, [`NIL`]-aware.
366#[inline]
367fn max2(a: i32, b: i32) -> i32 {
368    if a == NIL {
369        b
370    } else if b == NIL {
371        a
372    } else {
373        a.max(b)
374    }
375}
376
377/// `max` of three offsets, [`NIL`]-aware.
378#[inline]
379fn max3(a: i32, b: i32, c: i32) -> i32 {
380    max2(max2(a, b), c)
381}
382
383/// Fetch the offset on diagonal `k` from an optional wavefront reference.
384#[inline]
385fn opt_get(wf: Option<&Wavefront>, k: i32) -> i32 {
386    match wf {
387        Some(w) => w.get(k),
388        None => NIL,
389    }
390}
391
392/// Small helper trait letting us index the history with a (possibly negative)
393/// penalty and project to one wavefront component.
394trait HistoryExt {
395    fn get_at<'a, F>(&'a self, s: i32, f: F) -> Option<&'a Wavefront>
396    where
397        F: Fn(&'a WfSet) -> &'a Wavefront;
398}
399
400impl HistoryExt for [WfSet] {
401    #[inline]
402    fn get_at<'a, F>(&'a self, s: i32, f: F) -> Option<&'a Wavefront>
403    where
404        F: Fn(&'a WfSet) -> &'a Wavefront,
405    {
406        if s < 0 {
407            None
408        } else {
409            self.get(s as usize).map(f)
410        }
411    }
412}
413
414/// Convert a raw penalty to a [`WfaAlignment`].
415fn finish(
416    penalty: i32,
417    m: usize,
418    n: usize,
419    sc: &GotohScoring,
420    cigar: Vec<WfaOp>,
421) -> SeqResult<WfaAlignment> {
422    let score = ((m as i32 + n as i32) * sc.match_score - penalty) / 2;
423    Ok(WfaAlignment {
424        score,
425        penalty,
426        cigar,
427    })
428}
429
430/// Which wavefront component a traceback cursor currently sits in.
431#[derive(Clone, Copy, PartialEq, Eq)]
432enum Comp {
433    M,
434    I,
435    D,
436}
437
438/// What explains an `M` cell `(s, k, off)` once its trailing matches are peeled.
439enum MOrigin {
440    /// The cell sits at the origin `(0, 0)`; emit leading matches and stop.
441    Start,
442    /// The cell is reached by a run of matches down to `target_off`.
443    Match { target_off: i32 },
444    /// The cell is reached by a mismatch from `(prev_s, k, prev_off)`.
445    Mismatch { prev_s: i32, prev_off: i32 },
446    /// The cell coincides with the `I` component at the same `(s, k)`.
447    FromI,
448    /// The cell coincides with the `D` component at the same `(s, k)`.
449    FromD,
450}
451
452/// Read-only context for walking the recorded wavefronts back to the origin.
453struct Tracer<'a> {
454    history: &'a [WfSet],
455    final_set: &'a WfSet,
456    pen: WfaPenalties,
457    /// The optimal penalty, i.e. the index of `final_set`.
458    s_opt: i32,
459}
460
461impl<'a> Tracer<'a> {
462    fn new(history: &'a [WfSet], final_set: &'a WfSet, pen: WfaPenalties) -> Self {
463        Self {
464            history,
465            final_set,
466            pen,
467            s_opt: history.len() as i32,
468        }
469    }
470
471    /// Borrow the wavefront set recorded at penalty `sp`, treating the optimal
472    /// penalty as the (non-recorded) `final_set`.
473    fn get_set(&self, sp: i32) -> Option<&'a WfSet> {
474        if sp == self.s_opt {
475            Some(self.final_set)
476        } else if sp >= 0 {
477            self.history.get(sp as usize)
478        } else {
479            None
480        }
481    }
482
483    /// Determine how the `M` cell at `(s, k, off)` was produced.
484    fn m_origin(&self, s: i32, k: i32, off: i32) -> MOrigin {
485        let s_x = s - self.pen.x;
486        let mis_pred = self.get_set(s_x).map(|set| set.m.get(k)).unwrap_or(NIL);
487        let mis_bare = if mis_pred == NIL { NIL } else { mis_pred + 1 };
488        let i_here = self.get_set(s).map(|set| set.i.get(k)).unwrap_or(NIL);
489        let d_here = self.get_set(s).map(|set| set.d.get(k)).unwrap_or(NIL);
490
491        // Pick the largest bare offset not exceeding `off`; the gap up to `off`
492        // is the matched run that `extend` appended. Any predecessor whose bare
493        // offset equals the true value is a valid traceback choice.
494        let mut best_bare = NIL;
495        let mut kind = 0u8; // 1=mismatch 2=I 3=D
496        for (cand, kd) in [(mis_bare, 1u8), (i_here, 2), (d_here, 3)] {
497            if cand != NIL && cand <= off && cand > best_bare {
498                best_bare = cand;
499                kind = kd;
500            }
501        }
502
503        if best_bare == NIL {
504            return MOrigin::Start;
505        }
506        if best_bare < off {
507            return MOrigin::Match {
508                target_off: best_bare,
509            };
510        }
511        match kind {
512            1 => MOrigin::Mismatch {
513                prev_s: s_x,
514                prev_off: mis_pred,
515            },
516            2 => MOrigin::FromI,
517            _ => MOrigin::FromD,
518        }
519    }
520
521    /// Walk back from the terminal cell on diagonal `k_final`, emitting ops in
522    /// reverse, then reverse to left-to-right order.
523    fn run(&self, k_final: i32) -> Vec<WfaOp> {
524        let mut ops: Vec<WfaOp> = Vec::new();
525        let mut s = self.s_opt;
526        let mut k = k_final;
527        let mut comp = Comp::M;
528        let mut off = self.final_set.m.get(k);
529
530        loop {
531            match comp {
532                Comp::M => match self.m_origin(s, k, off) {
533                    MOrigin::Start => {
534                        // Leading matches down to the origin (0, 0).
535                        for _ in 0..off.max(0) {
536                            ops.push(WfaOp::Match);
537                        }
538                        break;
539                    }
540                    MOrigin::Match { target_off } => {
541                        let mut cur = off;
542                        while cur > target_off {
543                            ops.push(WfaOp::Match);
544                            cur -= 1;
545                        }
546                        off = target_off;
547                    }
548                    MOrigin::Mismatch { prev_s, prev_off } => {
549                        ops.push(WfaOp::Mismatch);
550                        s = prev_s;
551                        off = prev_off;
552                    }
553                    MOrigin::FromI => {
554                        comp = Comp::I;
555                        if let Some(set) = self.get_set(s) {
556                            off = set.i.get(k);
557                        }
558                    }
559                    MOrigin::FromD => {
560                        comp = Comp::D;
561                        if let Some(set) = self.get_set(s) {
562                            off = set.d.get(k);
563                        }
564                    }
565                },
566                Comp::I => {
567                    // I[k] ← M[k+1] @ s-o-e (open) or I[k+1] @ s-e (extend);
568                    // offset is preserved. Emit one Ins, move to diagonal k+1.
569                    ops.push(WfaOp::Ins);
570                    let s_o_e = s - self.pen.o - self.pen.e;
571                    let s_e = s - self.pen.e;
572                    let open = self
573                        .get_set(s_o_e)
574                        .map(|set| set.m.get(k + 1))
575                        .unwrap_or(NIL);
576                    let ext = self.get_set(s_e).map(|set| set.i.get(k + 1)).unwrap_or(NIL);
577                    if ext != NIL && ext == off {
578                        s = s_e;
579                        k += 1;
580                        comp = Comp::I;
581                    } else if open != NIL && open == off {
582                        s = s_o_e;
583                        k += 1;
584                        comp = Comp::M;
585                    } else if open != NIL {
586                        s = s_o_e;
587                        k += 1;
588                        comp = Comp::M;
589                        off = open;
590                    } else if ext != NIL {
591                        s = s_e;
592                        k += 1;
593                        comp = Comp::I;
594                        off = ext;
595                    } else {
596                        break;
597                    }
598                }
599                Comp::D => {
600                    // D[k] ← M[k-1] @ s-o-e (open) or D[k-1] @ s-e (extend),
601                    // offset+1. Emit one Del, move to diagonal k-1, offset−1.
602                    ops.push(WfaOp::Del);
603                    let s_o_e = s - self.pen.o - self.pen.e;
604                    let s_e = s - self.pen.e;
605                    let open = self
606                        .get_set(s_o_e)
607                        .map(|set| set.m.get(k - 1))
608                        .unwrap_or(NIL);
609                    let ext = self.get_set(s_e).map(|set| set.d.get(k - 1)).unwrap_or(NIL);
610                    let pred_off = off - 1;
611                    if ext != NIL && ext == pred_off {
612                        s = s_e;
613                        k -= 1;
614                        off = pred_off;
615                        comp = Comp::D;
616                    } else if open != NIL && open == pred_off {
617                        s = s_o_e;
618                        k -= 1;
619                        off = pred_off;
620                        comp = Comp::M;
621                    } else if open != NIL {
622                        s = s_o_e;
623                        k -= 1;
624                        off = open;
625                        comp = Comp::M;
626                    } else if ext != NIL {
627                        s = s_e;
628                        k -= 1;
629                        off = ext;
630                        comp = Comp::D;
631                    } else {
632                        break;
633                    }
634                }
635            }
636        }
637
638        ops.reverse();
639        ops
640    }
641}
642
643/// Reconstruct the optimal alignment by walking the recorded wavefronts back
644/// from the terminal cell to the origin.
645///
646/// `final_set` is the freshly-computed wavefront set at the optimal penalty
647/// `s = history.len()` (it is *not* part of `history`).
648fn traceback(history: &[WfSet], final_set: &WfSet, pen: &WfaPenalties, k_final: i32) -> Vec<WfaOp> {
649    Tracer::new(history, final_set, *pen).run(k_final)
650}
651
652#[cfg(test)]
653mod tests {
654    use super::*;
655    use crate::alignment::gotoh::gotoh_align;
656
657    fn default_sc() -> GotohScoring {
658        GotohScoring::default()
659    }
660
661    fn custom_sc() -> GotohScoring {
662        GotohScoring {
663            match_score: 3,
664            mismatch: -2,
665            gap_open: -6,
666            gap_extend: -2,
667        }
668    }
669
670    /// Independent re-scorer: walk a CIGAR applying Gotoh's scoring rules.
671    fn score_cigar(a: &[u8], b: &[u8], cigar: &[WfaOp], sc: &GotohScoring) -> i32 {
672        let mut score = 0i32;
673        let mut i = 0usize;
674        let mut j = 0usize;
675        // Track whether the previous op was the same kind of gap (for affine).
676        let mut prev: Option<WfaOp> = None;
677        for &op in cigar {
678            match op {
679                WfaOp::Match => {
680                    score += sc.match_score;
681                    i += 1;
682                    j += 1;
683                }
684                WfaOp::Mismatch => {
685                    score += sc.mismatch;
686                    i += 1;
687                    j += 1;
688                }
689                WfaOp::Ins => {
690                    // gap in a, consumes b.
691                    if prev == Some(WfaOp::Ins) {
692                        score += sc.gap_extend;
693                    } else {
694                        score += sc.gap_open;
695                    }
696                    j += 1;
697                }
698                WfaOp::Del => {
699                    // gap in b, consumes a.
700                    if prev == Some(WfaOp::Del) {
701                        score += sc.gap_extend;
702                    } else {
703                        score += sc.gap_open;
704                    }
705                    i += 1;
706                }
707            }
708            prev = Some(op);
709        }
710        assert_eq!(i, a.len(), "cigar must consume all of a");
711        assert_eq!(j, b.len(), "cigar must consume all of b");
712        score
713    }
714
715    fn check_consumption(a: &[u8], b: &[u8], cigar: &[WfaOp]) {
716        let consumes_a = cigar
717            .iter()
718            .filter(|o| matches!(o, WfaOp::Match | WfaOp::Mismatch | WfaOp::Del))
719            .count();
720        let consumes_b = cigar
721            .iter()
722            .filter(|o| matches!(o, WfaOp::Match | WfaOp::Mismatch | WfaOp::Ins))
723            .count();
724        assert_eq!(consumes_a, a.len(), "Match+Mismatch+Del must consume a");
725        assert_eq!(consumes_b, b.len(), "Match+Mismatch+Ins must consume b");
726    }
727
728    // (a) CENTRAL cross-check: WFA converted score == Gotoh score.
729    #[test]
730    fn central_cross_check_matches_gotoh() {
731        let pairs: &[(&[u8], &[u8])] = &[
732            (b"GATTACA", b"GCATGCU"),
733            (b"ACGTACGT", b"ACGTTCGT"),
734            (b"AAAA", b"AAAAGGGGAAAA"),
735            (b"ACGT", b"TGCA"),
736            (b"AGGGCT", b"AGGCT"),
737            (b"HELLOWORLD", b"HELOWRLD"),
738        ];
739        for sc in [default_sc(), custom_sc()] {
740            for &(a, b) in pairs {
741                let w = wfa_align(a, b, &sc).expect("wfa ok");
742                let g = gotoh_align(a, b, &sc).expect("gotoh ok");
743                assert_eq!(
744                    w.score,
745                    g.score,
746                    "score mismatch on {:?} vs {:?} with {:?}",
747                    std::str::from_utf8(a),
748                    std::str::from_utf8(b),
749                    sc
750                );
751                // The CIGAR must itself reproduce the score.
752                check_consumption(a, b, &w.cigar);
753                assert_eq!(
754                    score_cigar(a, b, &w.cigar, &sc),
755                    w.score,
756                    "cigar re-score mismatch on {:?} vs {:?}",
757                    std::str::from_utf8(a),
758                    std::str::from_utf8(b),
759                );
760            }
761        }
762    }
763
764    // (b) identical sequences.
765    #[test]
766    fn identical_sequences() {
767        let a = b"ACGTACGT";
768        let sc = default_sc();
769        let w = wfa_align(a, a, &sc).expect("ok");
770        assert_eq!(w.penalty, 0);
771        assert_eq!(w.score, sc.match_score * a.len() as i32);
772        assert!(w.cigar.iter().all(|o| *o == WfaOp::Match));
773        assert_eq!(w.cigar.len(), a.len());
774    }
775
776    // (c) traceback validity (consumption + re-score) on a tricky pair.
777    #[test]
778    fn traceback_validity() {
779        let sc = default_sc();
780        let cases: &[(&[u8], &[u8])] = &[
781            (b"GATTACA", b"GCATGCU"),
782            (b"ACGTACGTACGT", b"ACGTTTACGT"),
783            (b"BANANA", b"ANANAS"),
784        ];
785        for &(a, b) in cases {
786            let w = wfa_align(a, b, &sc).expect("ok");
787            check_consumption(a, b, &w.cigar);
788            assert_eq!(score_cigar(a, b, &w.cigar, &sc), w.score);
789            let g = gotoh_align(a, b, &sc).expect("ok");
790            assert_eq!(w.score, g.score);
791        }
792    }
793
794    // (d) affine: one contiguous length-4 gap stays a single run.
795    #[test]
796    fn affine_single_long_gap() {
797        let sc = default_sc();
798        let a = b"ACGTACGT";
799        // Insert 4 contiguous characters into the middle of `a` to make `b`.
800        let b = b"ACGTGGGGACGT";
801        let w = wfa_align(a, b, &sc).expect("ok");
802        let g = gotoh_align(a, b, &sc).expect("ok");
803        assert_eq!(w.score, g.score);
804        // There must be exactly one Ins run of length 4 and no Del.
805        let ins = w.cigar.iter().filter(|o| **o == WfaOp::Ins).count();
806        let del = w.cigar.iter().filter(|o| **o == WfaOp::Del).count();
807        assert_eq!(ins, 4, "expected 4 inserted symbols, cigar = {:?}", w.cigar);
808        assert_eq!(del, 0);
809        // And they must be contiguous (exactly one maximal Ins run).
810        let runs = count_runs(&w.cigar, WfaOp::Ins);
811        assert_eq!(runs, 1, "Ins must form a single run, cigar = {:?}", w.cigar);
812    }
813
814    fn count_runs(cigar: &[WfaOp], op: WfaOp) -> usize {
815        let mut runs = 0;
816        let mut in_run = false;
817        for &c in cigar {
818            if c == op {
819                if !in_run {
820                    runs += 1;
821                    in_run = true;
822                }
823            } else {
824                in_run = false;
825            }
826        }
827        runs
828    }
829
830    // (e) single mismatch.
831    #[test]
832    fn single_mismatch_cost() {
833        let sc = default_sc();
834        let a = b"ACGTACGT";
835        let b = b"ACGTTCGT"; // differs at index 4 (A vs T).
836        let w = wfa_align(a, b, &sc).expect("ok");
837        let g = gotoh_align(a, b, &sc).expect("ok");
838        assert_eq!(w.score, g.score);
839        let len = a.len() as i32;
840        assert_eq!(w.score, sc.match_score * (len - 1) + sc.mismatch);
841        // Penalty is exactly one mismatch unit x = 2*(M - mis).
842        assert_eq!(w.penalty, 2 * (sc.match_score - sc.mismatch));
843    }
844
845    // (f) empty-sequence handling mirrors gotoh.
846    #[test]
847    fn empty_sequence_errors() {
848        let sc = default_sc();
849        assert!(matches!(
850            wfa_align(b"", b"ACGT", &sc),
851            Err(SeqError::EmptyInput)
852        ));
853        assert!(matches!(
854            wfa_align(b"ACGT", b"", &sc),
855            Err(SeqError::EmptyInput)
856        ));
857        // Confirm gotoh errors the same way.
858        assert!(matches!(
859            gotoh_align(b"", b"ACGT", &sc),
860            Err(SeqError::EmptyInput)
861        ));
862        assert!(matches!(
863            gotoh_align(b"ACGT", b"", &sc),
864            Err(SeqError::EmptyInput)
865        ));
866    }
867
868    // (g) match extension across a long identical run.
869    #[test]
870    fn long_match_extension() {
871        let sc = default_sc();
872        let prefix = vec![b'A'; 50];
873        let mut a = prefix.clone();
874        a.extend_from_slice(b"CGTACG");
875        let mut b = prefix.clone();
876        b.extend_from_slice(b"CTTACG"); // diverges within the suffix.
877        let w = wfa_align(&a, &b, &sc).expect("ok");
878        let g = gotoh_align(&a, &b, &sc).expect("ok");
879        assert_eq!(w.score, g.score);
880        assert!(w.penalty > 0);
881        check_consumption(&a, &b, &w.cigar);
882        assert_eq!(score_cigar(&a, &b, &w.cigar, &sc), w.score);
883    }
884
885    // Degenerate scoring → InvalidConfiguration.
886    #[test]
887    fn degenerate_scoring_rejected() {
888        // match_score <= mismatch ⇒ x <= 0.
889        let bad = GotohScoring {
890            match_score: 1,
891            mismatch: 1,
892            gap_open: -5,
893            gap_extend: -1,
894        };
895        assert!(matches!(
896            wfa_align(b"AC", b"AG", &bad),
897            Err(SeqError::InvalidConfiguration(_))
898        ));
899        // gap_extend < gap_open ⇒ o < 0.
900        let bad_open = GotohScoring {
901            match_score: 2,
902            mismatch: -1,
903            gap_open: -1,
904            gap_extend: -5,
905        };
906        assert!(matches!(
907            wfa_align(b"AC", b"AG", &bad_open),
908            Err(SeqError::InvalidConfiguration(_))
909        ));
910    }
911
912    // Heavy randomized cross-check: the converted WFA score must equal Gotoh on
913    // hundreds of random pairs across several valid scoring schemes, and the
914    // reconstructed CIGAR must independently reproduce that score.
915    #[test]
916    fn randomized_cross_check_matches_gotoh() {
917        use crate::handle::LcgRng;
918
919        let alphabet = b"ACGT";
920        // A handful of valid (positive-penalty) affine scoring schemes.
921        let schemes = [
922            GotohScoring::default(),
923            GotohScoring {
924                match_score: 3,
925                mismatch: -2,
926                gap_open: -6,
927                gap_extend: -2,
928            },
929            GotohScoring {
930                match_score: 1,
931                mismatch: -1,
932                gap_open: -2,
933                gap_extend: -1,
934            },
935            GotohScoring {
936                match_score: 4,
937                mismatch: -3,
938                gap_open: -8,
939                gap_extend: -1,
940            },
941            GotohScoring {
942                match_score: 2,
943                mismatch: 0,
944                gap_open: -4,
945                gap_extend: -1,
946            },
947        ];
948
949        let mut rng = LcgRng::new(0x5EED_1234_ABCD);
950        for sc in schemes {
951            // Sanity: every scheme must derive valid positive penalties.
952            assert!(WfaPenalties::from_gotoh(&sc).is_ok());
953            for _ in 0..120 {
954                let la = 1 + rng.next_usize(14);
955                let lb = 1 + rng.next_usize(14);
956                let a: Vec<u8> = (0..la).map(|_| alphabet[rng.next_usize(4)]).collect();
957                let b: Vec<u8> = (0..lb).map(|_| alphabet[rng.next_usize(4)]).collect();
958                let w = wfa_align(&a, &b, &sc).expect("wfa ok");
959                let g = gotoh_align(&a, &b, &sc).expect("gotoh ok");
960                assert_eq!(
961                    w.score,
962                    g.score,
963                    "score mismatch: a={:?} b={:?} sc={:?} (wfa={} gotoh={})",
964                    std::str::from_utf8(&a),
965                    std::str::from_utf8(&b),
966                    sc,
967                    w.score,
968                    g.score,
969                );
970                check_consumption(&a, &b, &w.cigar);
971                assert_eq!(
972                    score_cigar(&a, &b, &w.cigar, &sc),
973                    w.score,
974                    "cigar re-score mismatch: a={:?} b={:?}",
975                    std::str::from_utf8(&a),
976                    std::str::from_utf8(&b),
977                );
978            }
979        }
980    }
981
982    // Asymmetric long gaps in both directions (deletion-heavy and
983    // insertion-heavy) must still match Gotoh exactly.
984    #[test]
985    fn asymmetric_gaps_match_gotoh() {
986        let sc = custom_sc();
987        let cases: &[(&[u8], &[u8])] = &[
988            (b"AAAAGGGGAAAA", b"AAAA"),      // deletion-heavy
989            (b"AAAA", b"AAAAGGGGAAAA"),      // insertion-heavy
990            (b"ACGTACGTACGT", b"ACGT"),      // big deletion
991            (b"ACGT", b"ACGTACGTACGT"),      // big insertion
992            (b"TTTTACGTTTTT", b"ACGT"),      // flanking deletions
993            (b"GATTACAGATTACA", b"GATTACA"), // tandem deletion
994        ];
995        for &(a, b) in cases {
996            let w = wfa_align(a, b, &sc).expect("ok");
997            let g = gotoh_align(a, b, &sc).expect("ok");
998            assert_eq!(
999                w.score,
1000                g.score,
1001                "mismatch on {:?} vs {:?}",
1002                std::str::from_utf8(a),
1003                std::str::from_utf8(b),
1004            );
1005            check_consumption(a, b, &w.cigar);
1006            assert_eq!(score_cigar(a, b, &w.cigar, &sc), w.score);
1007        }
1008    }
1009}