1use compiled_align_scores::*;
2use utils::*;
3
4#[derive(Clone, Debug)]
5pub struct AlignScores {
6 pub match2match_score: Score,
7 pub match2insert_score: Score,
8 pub insert_extend_score: Score,
9 pub insert_switch_score: Score,
10 pub init_match_score: Score,
11 pub init_insert_score: Score,
12 pub insert_scores: InsertScores,
13 pub match_scores: MatchScores,
14}
15
16pub struct AlignSums {
17 pub forward_sums_match: SumMat,
18 pub forward_sums_insert: SumMat,
19 pub forward_sums_del: SumMat,
20 pub backward_sums_match: SumMat,
21 pub backward_sums_insert: SumMat,
22 pub backward_sums_del: SumMat,
23}
24
25impl AlignScores {
26 pub fn new(init_val: Score) -> AlignScores {
27 let init_vals = [init_val; NUM_BASES];
28 let mat = [[init_val; NUM_BASES]; NUM_BASES];
29 AlignScores {
30 match2match_score: init_val,
31 match2insert_score: init_val,
32 init_match_score: init_val,
33 insert_extend_score: init_val,
34 insert_switch_score: init_val,
35 init_insert_score: init_val,
36 insert_scores: init_vals,
37 match_scores: mat,
38 }
39 }
40
41 pub fn transfer(&mut self) {
42 self.match2match_score = MATCH2MATCH_SCORE;
43 self.match2insert_score = MATCH2INSERT_SCORE;
44 self.insert_extend_score = INSERT_EXTEND_SCORE;
45 self.insert_switch_score = INSERT_SWITCH_SCORE;
46 self.init_match_score = INIT_MATCH_SCORE;
47 self.init_insert_score = INIT_INSERT_SCORE;
48 for (i, &v) in INSERT_SCORES.iter().enumerate() {
49 self.insert_scores[i] = v;
50 }
51 for (i, x) in MATCH_SCORES.iter().enumerate() {
52 for (j, &x) in x.iter().enumerate() {
53 self.match_scores[i][j] = x;
54 }
55 }
56 }
57}
58
59impl AlignSums {
60 pub fn new(seq_len_pair: &(usize, usize)) -> AlignSums {
61 let neg_infs = vec![vec![NEG_INFINITY; seq_len_pair.1]; seq_len_pair.0];
62 AlignSums {
63 forward_sums_match: neg_infs.clone(),
64 forward_sums_insert: neg_infs.clone(),
65 forward_sums_del: neg_infs.clone(),
66 backward_sums_match: neg_infs.clone(),
67 backward_sums_insert: neg_infs.clone(),
68 backward_sums_del: neg_infs,
69 }
70 }
71}
72
73pub fn durbin_algo(seq_pair: &SeqPair, align_scores: &AlignScores) -> ProbMat {
74 let seq_len_pair = (seq_pair.0.len(), seq_pair.1.len());
75 let align_sums = get_align_sums(seq_pair, align_scores);
76 get_match_probs(&align_sums, &seq_len_pair, align_scores)
77}
78
79pub fn get_align_sums(seq_pair: &SeqPair, align_scores: &AlignScores) -> AlignSums {
80 let seq_len_pair = (seq_pair.0.len(), seq_pair.1.len());
81 let mut align_sums = AlignSums::new(&seq_len_pair);
82 for i in 0..seq_len_pair.0 - 1 {
83 for j in 0..seq_len_pair.1 - 1 {
84 if i == 0 && j == 0 {
85 align_sums.forward_sums_match[i][j] = 0.;
86 continue;
87 }
88 if i > 0 && j > 0 {
89 let mut sum = NEG_INFINITY;
90 let basepair = (seq_pair.0[i], seq_pair.1[j]);
91 let match_score = align_scores.match_scores[basepair.0][basepair.1];
92 let begins_sum = (i - 1, j - 1) == (0, 0);
93 let term = align_sums.forward_sums_match[i - 1][j - 1]
94 + if begins_sum {
95 align_scores.init_match_score
96 } else {
97 align_scores.match2match_score
98 };
99 logsumexp(&mut sum, term);
100 let term = align_sums.forward_sums_insert[i - 1][j - 1] + align_scores.match2insert_score;
101 logsumexp(&mut sum, term);
102 let term = align_sums.forward_sums_del[i - 1][j - 1] + align_scores.match2insert_score;
103 logsumexp(&mut sum, term);
104 align_sums.forward_sums_match[i][j] = sum + match_score;
105 }
106 if i > 0 {
107 let base = seq_pair.0[i];
108 let insert_score = align_scores.insert_scores[base];
109 let begins_sum = (i - 1, j) == (0, 0);
110 let mut sum = NEG_INFINITY;
111 let term = align_sums.forward_sums_match[i - 1][j]
112 + if begins_sum {
113 align_scores.init_insert_score
114 } else {
115 align_scores.match2insert_score
116 };
117 logsumexp(&mut sum, term);
118 let term = align_sums.forward_sums_insert[i - 1][j] + align_scores.insert_extend_score;
119 logsumexp(&mut sum, term);
120 align_sums.forward_sums_insert[i][j] = sum + insert_score;
121 }
122 if j > 0 {
123 let base = seq_pair.1[j];
124 let insert_score = align_scores.insert_scores[base];
125 let begins_sum = (i, j - 1) == (0, 0);
126 let mut sum = NEG_INFINITY;
127 let term = align_sums.forward_sums_match[i][j - 1]
128 + if begins_sum {
129 align_scores.init_insert_score
130 } else {
131 align_scores.match2insert_score
132 };
133 logsumexp(&mut sum, term);
134 let term = align_sums.forward_sums_del[i][j - 1] + align_scores.insert_extend_score;
135 logsumexp(&mut sum, term);
136 align_sums.forward_sums_del[i][j] = sum + insert_score;
137 }
138 }
139 }
140 for i in (1..seq_len_pair.0).rev() {
141 for j in (1..seq_len_pair.1).rev() {
142 if i == seq_len_pair.0 - 1 && j == seq_len_pair.1 - 1 {
143 align_sums.backward_sums_match[i][j] = 0.;
144 continue;
145 }
146 if i < seq_len_pair.0 - 1 && j < seq_len_pair.1 - 1 {
147 let mut sum = NEG_INFINITY;
148 let base_pair = (seq_pair.0[i], seq_pair.1[j]);
149 let match_score = align_scores.match_scores[base_pair.0][base_pair.1];
150 let ends_sum = (i + 1, j + 1) == (seq_len_pair.0 - 1, seq_len_pair.1 - 1);
151 let term = align_sums.backward_sums_match[i + 1][j + 1]
152 + if ends_sum {
153 0.
154 } else {
155 align_scores.match2match_score
156 };
157 logsumexp(&mut sum, term);
158 let term = align_sums.backward_sums_insert[i + 1][j + 1] + align_scores.match2insert_score;
159 logsumexp(&mut sum, term);
160 let term = align_sums.backward_sums_del[i + 1][j + 1] + align_scores.match2insert_score;
161 logsumexp(&mut sum, term);
162 align_sums.backward_sums_match[i][j] = sum + match_score;
163 }
164 if i < seq_len_pair.0 - 1 {
165 let base = seq_pair.0[i];
166 let insert_score = align_scores.insert_scores[base];
167 let ends_sum = (i + 1, j) == (seq_len_pair.0 - 1, seq_len_pair.1 - 1);
168 let mut sum = NEG_INFINITY;
169 let term = align_sums.backward_sums_match[i + 1][j]
170 + if ends_sum {
171 0.
172 } else {
173 align_scores.match2insert_score
174 };
175 logsumexp(&mut sum, term);
176 let term = align_sums.backward_sums_insert[i + 1][j] + align_scores.insert_extend_score;
177 logsumexp(&mut sum, term);
178 align_sums.backward_sums_insert[i][j] = sum + insert_score;
179 }
180 if j < seq_len_pair.1 - 1 {
181 let base = seq_pair.1[j];
182 let insert_score = align_scores.insert_scores[base];
183 let ends_sum = (i, j + 1) == (seq_len_pair.0 - 1, seq_len_pair.1 - 1);
184 let mut sum = NEG_INFINITY;
185 let term = align_sums.backward_sums_match[i][j + 1]
186 + if ends_sum {
187 0.
188 } else {
189 align_scores.match2insert_score
190 };
191 logsumexp(&mut sum, term);
192 let term = align_sums.backward_sums_del[i][j + 1] + align_scores.insert_extend_score;
193 logsumexp(&mut sum, term);
194 align_sums.backward_sums_del[i][j] = sum + insert_score;
195 }
196 }
197 }
198 align_sums
199}
200
201fn get_match_probs(
202 align_sums: &AlignSums,
203 seq_len_pair: &(usize, usize),
204 align_scores: &AlignScores,
205) -> ProbMat {
206 let mut match_probs = vec![vec![0.; seq_len_pair.1]; seq_len_pair.0];
207 let mut global_sum = align_sums.forward_sums_match[seq_len_pair.0 - 2][seq_len_pair.1 - 2];
208 logsumexp(
209 &mut global_sum,
210 align_sums.forward_sums_insert[seq_len_pair.0 - 2][seq_len_pair.1 - 2],
211 );
212 logsumexp(
213 &mut global_sum,
214 align_sums.forward_sums_del[seq_len_pair.0 - 2][seq_len_pair.1 - 2],
215 );
216 for (i, x) in match_probs.iter_mut().enumerate() {
217 if i == 0 || i == seq_len_pair.0 - 1 {
218 continue;
219 }
220 for (j, x) in x.iter_mut().enumerate() {
221 if j == 0 || j == seq_len_pair.1 - 1 {
222 continue;
223 }
224 let mut sum = NEG_INFINITY;
225 let forward_sum = align_sums.forward_sums_match[i][j];
226 let ends_sum = (i + 1, j + 1) == (seq_len_pair.0 - 1, seq_len_pair.1 - 1);
227 let term = if ends_sum {
228 0.
229 } else {
230 align_scores.match2match_score
231 } + align_sums.backward_sums_match[i + 1][j + 1];
232 logsumexp(&mut sum, term);
233 let term = align_scores.match2insert_score + align_sums.backward_sums_insert[i + 1][j + 1];
234 logsumexp(&mut sum, term);
235 let term = align_scores.match2insert_score + align_sums.backward_sums_del[i + 1][j + 1];
236 logsumexp(&mut sum, term);
237 let match_prob = expf(forward_sum + sum - global_sum);
238 *x = match_prob;
239 }
240 }
241 match_probs
242}