Skip to main content

frizbee/smith_waterman/simd/
alignment_iter.rs

1use super::matrix::Matrix;
2use crate::simd::Vector256;
3
4pub enum Alignment {
5    Left((usize, usize)),
6    Up((usize, usize)),
7    Match((usize, usize)),
8    Mismatch((usize, usize)),
9}
10
11impl Alignment {
12    pub fn pos(&self) -> (usize, usize) {
13        match self {
14            Alignment::Left(pos) | Alignment::Up(pos) => *pos,
15            Alignment::Match(pos) | Alignment::Mismatch(pos) => *pos,
16        }
17    }
18
19    pub fn col(&self) -> usize {
20        match self {
21            Alignment::Left((_, col)) | Alignment::Up((_, col)) => *col,
22            Alignment::Match((_, col)) | Alignment::Mismatch((_, col)) => *col,
23        }
24    }
25
26    pub fn row(&self) -> usize {
27        match self {
28            Alignment::Left((row, _)) | Alignment::Up((row, _)) => *row,
29            Alignment::Match((row, _)) | Alignment::Mismatch((row, _)) => *row,
30        }
31    }
32}
33
34/// Iterator over alignment path positions with support for max typos.
35///
36/// Yields `Some((needle_idx, haystack_idx))` for each position in the path,
37/// or `None` to signal that max_typos was exceeded.
38pub struct AlignmentPathIter<'a> {
39    score_matrix: &'a [[u16; 16]],
40    match_masks: &'a [[u16; 16]],
41    haystack_chunks: usize,
42    row_idx: usize,
43    col_idx: usize,
44    skipped_chunks: usize,
45    max_typos: Option<u16>,
46    typo_count: u16,
47    score: u16,
48    finished: bool,
49}
50
51impl<'a> AlignmentPathIter<'a> {
52    #[inline(always)]
53    pub fn new<Simd256: Vector256>(
54        score_matrix: &'a Matrix<Simd256>,
55        match_masks: &'a Matrix<Simd256>,
56        needle_len: usize,
57        haystack_chunks: usize,
58        skipped_chunks: usize,
59        score: u16,
60        max_typos: Option<u16>,
61    ) -> Self {
62        let col_idx = Self::get_col_idx(
63            score_matrix,
64            needle_len,
65            haystack_chunks,
66            score,
67        );
68
69        Self {
70            score_matrix: score_matrix.as_slice(),
71            match_masks: match_masks.as_slice(),
72            haystack_chunks: score_matrix.haystack_chunks,
73            row_idx: needle_len,
74            col_idx,
75            skipped_chunks,
76            max_typos,
77            typo_count: 0,
78            score,
79            finished: false,
80        }
81    }
82
83    #[inline(always)]
84    fn get_col_idx<Simd256: Vector256>(
85        score_matrix: &Matrix<Simd256>,
86        needle_len: usize,
87        haystack_chunks: usize,
88        score: u16,
89    ) -> usize {
90        for chunk_idx in 1..haystack_chunks {
91            let chunk = &score_matrix.get(needle_len, chunk_idx);
92            let idx = unsafe { chunk.idx_u16(score) };
93            if idx != 16 {
94                return chunk_idx * 16 + idx;
95            }
96        }
97        panic!("could not find max score in score matrix final row");
98    }
99
100    #[inline(always)]
101    fn get_score(&self, row: usize, col: usize) -> u16 {
102        self.score_matrix[row * self.haystack_chunks + col / 16][col % 16]
103    }
104
105    #[inline(always)]
106    fn get_is_match(&self, row: usize, col: usize) -> bool {
107        self.match_masks[row * self.haystack_chunks + col / 16][col % 16] != 0
108    }
109}
110
111impl<'a> Iterator for AlignmentPathIter<'a> {
112    type Item = Option<Alignment>;
113
114    #[inline(always)]
115    fn next(&mut self) -> Option<Self::Item> {
116        if self.row_idx == 0 || self.finished {
117            return None;
118        }
119
120        if let Some(max_typos) = self.max_typos
121            && self.typo_count > max_typos
122        {
123            self.finished = true;
124            return Some(None);
125        }
126
127        // Must be moving up only (at left edge), or lost alignment
128        if self.col_idx < 16 || self.score == 0 {
129            if let Some(max_typos) = self.max_typos
130                && (self.typo_count + self.row_idx as u16) > max_typos
131            {
132                self.finished = true;
133                return Some(None);
134            }
135            return None;
136        }
137
138        // Capture current position to yield (adjusted to 0-indexed)
139        let current_pos = (
140            self.row_idx - 1,
141            self.col_idx - 16 + self.skipped_chunks * 16,
142        );
143
144        if self.get_is_match(self.row_idx, self.col_idx) {
145            self.row_idx -= 1;
146            self.col_idx -= 1;
147            self.score = self.get_score(self.row_idx, self.col_idx);
148            return Some(Some(Alignment::Match(current_pos)));
149        }
150
151        // Gather scores for all possible paths
152        let diag = self.get_score(self.row_idx - 1, self.col_idx - 1);
153        let left = self.get_score(self.row_idx, self.col_idx - 1);
154        let up = self.get_score(self.row_idx - 1, self.col_idx);
155
156        // Match or mismatch (diagonal)
157        if diag >= left && diag >= up {
158            self.row_idx -= 1;
159            self.col_idx -= 1;
160            // Must be a mismatch if score didn't increase
161            if diag >= self.score {
162                self.typo_count += 1;
163                self.score = diag;
164                return Some(Some(Alignment::Mismatch(current_pos)));
165            }
166            self.score = diag;
167            Some(Some(Alignment::Match(current_pos)))
168        // Skipped character in haystack (left)
169        } else if left >= up {
170            self.col_idx -= 1;
171            self.score = left;
172            Some(Some(Alignment::Left(current_pos)))
173        // Skipped character in needle (up)
174        } else {
175            self.typo_count += 1;
176            self.row_idx -= 1;
177            self.score = up;
178            Some(Some(Alignment::Up(current_pos)))
179        }
180    }
181}