oxicuda_seq/alignment/
needleman_wunsch.rs1use crate::error::{SeqError, SeqResult};
4
5#[derive(Debug, Clone, Copy)]
7pub struct ScoringMatrix {
8 pub match_score: i32,
9 pub mismatch: i32,
10 pub gap: i32,
11}
12
13impl Default for ScoringMatrix {
14 fn default() -> Self {
15 Self {
16 match_score: 1,
17 mismatch: -1,
18 gap: -2,
19 }
20 }
21}
22
23#[derive(Debug, Clone)]
25pub struct Alignment {
26 pub a_aligned: Vec<Option<usize>>,
27 pub b_aligned: Vec<Option<usize>>,
28 pub score: i32,
29}
30
31pub fn needleman_wunsch(a: &[u8], b: &[u8], score: &ScoringMatrix) -> SeqResult<Alignment> {
33 let m = a.len();
34 let n = b.len();
35 if m == 0 || n == 0 {
36 return Err(SeqError::EmptyInput);
37 }
38 let cols = n + 1;
39 let mut dp = vec![0i32; (m + 1) * cols];
40 let mut trace = vec![0u8; (m + 1) * cols];
41 for i in 0..=m {
44 dp[i * cols] = score.gap * i as i32;
45 trace[i * cols] = 1;
46 }
47 for j in 0..=n {
48 dp[j] = score.gap * j as i32;
49 trace[j] = 2;
50 }
51 trace[0] = 0;
52
53 for i in 1..=m {
54 for j in 1..=n {
55 let s = if a[i - 1] == b[j - 1] {
56 score.match_score
57 } else {
58 score.mismatch
59 };
60 let diag = dp[(i - 1) * cols + (j - 1)] + s;
61 let up = dp[(i - 1) * cols + j] + score.gap;
62 let left = dp[i * cols + (j - 1)] + score.gap;
63 let (best, dir) = if diag >= up && diag >= left {
64 (diag, 0u8)
65 } else if up >= left {
66 (up, 1u8)
67 } else {
68 (left, 2u8)
69 };
70 dp[i * cols + j] = best;
71 trace[i * cols + j] = dir;
72 }
73 }
74 let final_score = dp[m * cols + n];
75
76 let mut a_align = Vec::new();
78 let mut b_align = Vec::new();
79 let mut i = m;
80 let mut j = n;
81 while i > 0 || j > 0 {
82 let dir = trace[i * cols + j];
83 match dir {
84 0 if i > 0 && j > 0 => {
85 a_align.push(Some(i - 1));
86 b_align.push(Some(j - 1));
87 i -= 1;
88 j -= 1;
89 }
90 1 if i > 0 => {
91 a_align.push(Some(i - 1));
92 b_align.push(None);
93 i -= 1;
94 }
95 2 if j > 0 => {
96 a_align.push(None);
97 b_align.push(Some(j - 1));
98 j -= 1;
99 }
100 _ => {
101 if i > 0 {
103 a_align.push(Some(i - 1));
104 b_align.push(None);
105 i -= 1;
106 } else if j > 0 {
107 a_align.push(None);
108 b_align.push(Some(j - 1));
109 j -= 1;
110 } else {
111 break;
112 }
113 }
114 }
115 }
116 a_align.reverse();
117 b_align.reverse();
118 Ok(Alignment {
119 a_aligned: a_align,
120 b_aligned: b_align,
121 score: final_score,
122 })
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn nw_simple() {
131 let a = b"GATTACA";
132 let b = b"GCATGCU";
133 let s = ScoringMatrix::default();
134 let r = needleman_wunsch(a, b, &s).expect("ok");
135 assert!(
138 r.score >= -3 && r.score <= 3,
139 "score in unexpected range: {}",
140 r.score
141 );
142 assert_eq!(r.a_aligned.len(), r.b_aligned.len());
143 }
144
145 #[test]
146 fn nw_identical() {
147 let a = b"ACGT";
148 let s = ScoringMatrix::default();
149 let r = needleman_wunsch(a, a, &s).expect("ok");
150 assert_eq!(r.score, 4);
151 }
152}