1use std::borrow::Borrow;
28use std::cmp::min;
29use std::iter;
30use std::iter::repeat;
31
32use crate::utils::TextSlice;
33
34pub fn unit_cost(a: u8, b: u8) -> u32 {
36 (a != b) as u32
37}
38
39#[allow(non_snake_case)]
41#[derive(Default, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Serialize, Deserialize)]
42pub struct Ukkonen<F>
43where
44 F: Fn(u8, u8) -> u32,
45{
46 D: [Vec<usize>; 2],
47 cost: F,
48}
49
50impl<F> Ukkonen<F>
51where
52 F: Fn(u8, u8) -> u32,
53{
54 pub fn with_capacity(m: usize, cost: F) -> Self {
56 let get_vec = || Vec::with_capacity(m + 1);
57 Ukkonen {
58 D: [get_vec(), get_vec()],
59 cost,
60 }
61 }
62
63 pub fn find_all_end<'a, C, T>(
66 &'a mut self,
67 pattern: TextSlice<'a>,
68 text: T,
69 k: usize,
70 ) -> Matches<'a, F, C, T::IntoIter>
71 where
72 C: Borrow<u8>,
73 T: IntoIterator<Item = C>,
74 {
75 let m = pattern.len();
76 self.D[0].clear();
77 self.D[0].extend(repeat(k + 1).take(m + 1));
78 self.D[1].clear();
79 self.D[1].extend(0..=m);
80 Matches {
81 ukkonen: self,
82 pattern,
83 text: text.into_iter().enumerate(),
84 lastk: min(k, m),
85 m,
86 k,
87 }
88 }
89}
90
91#[derive(Debug)]
93pub struct Matches<'a, F, C, T>
94where
95 F: Fn(u8, u8) -> u32,
96 C: Borrow<u8>,
97 T: Iterator<Item = C>,
98{
99 ukkonen: &'a mut Ukkonen<F>,
100 pattern: TextSlice<'a>,
101 text: iter::Enumerate<T>,
102 lastk: usize,
103 m: usize,
104 k: usize,
105}
106
107impl<'a, F, C, T> Iterator for Matches<'a, F, C, T>
108where
109 F: 'a + Fn(u8, u8) -> u32,
110 C: Borrow<u8>,
111 T: Iterator<Item = C>,
112{
113 type Item = (usize, usize);
114
115 fn next(&mut self) -> Option<(usize, usize)> {
116 let cost = &self.ukkonen.cost;
117 for (i, c) in &mut self.text {
118 let col = i % 2;
119 let prev = 1 - col;
120
121 self.ukkonen.D[col][0] = 0;
123 self.lastk = min(self.lastk + 1, self.m);
124 for j in 1..=self.lastk {
127 self.ukkonen.D[col][j] = min(
128 min(self.ukkonen.D[prev][j] + 1, self.ukkonen.D[col][j - 1] + 1),
129 self.ukkonen.D[prev][j - 1] + (cost)(self.pattern[j - 1], *c.borrow()) as usize,
130 );
131 }
132
133 while self.ukkonen.D[col][self.lastk] > self.k {
136 self.lastk -= 1;
137 }
138
139 if self.lastk == self.m {
140 return Some((i, self.ukkonen.D[col][self.m]));
141 }
142 }
143
144 None
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn test_find_all_end() {
154 let mut ukkonen = Ukkonen::with_capacity(10, unit_cost);
155 let text = b"ACCGTGGATGAGCGCCATAG";
156 let pattern = b"TGAGCGT";
157 let occ: Vec<(usize, usize)> = ukkonen.find_all_end(pattern, text, 1).collect();
158 assert_eq!(occ, [(13, 1), (14, 1)]);
159 }
160
161 #[test]
162 fn test_find_start() {
163 let mut u = Ukkonen::with_capacity(10, unit_cost);
164
165 let pattern = b"ACCGT";
166 let text1 = b"ACCGTGGATGAGCGCCATAG";
168 let text2 = b"AACCGTGGATGAGCGCCATAG";
170
171 let occ: Vec<(usize, usize)> = u.find_all_end(pattern, text1, 1).collect();
172 assert_eq!(occ, [(3, 1), (4, 0), (5, 1)]);
173 let occ: Vec<(usize, usize)> = u.find_all_end(pattern, text2, 1).collect();
174 assert_eq!(occ, [(4, 1), (5, 0), (6, 1)]);
175 }
176}