rna_algos/
durbin_algo.rs

1use compiled_align_scores::*;
2use utils::*;
3
4#[derive(Clone, Debug)]
5pub struct AlignScores {
6  pub match2match_score: Score,
7  pub match2insert_score: Score,
8  pub insert_extend_score: Score,
9  pub insert_switch_score: Score,
10  pub init_match_score: Score,
11  pub init_insert_score: Score,
12  pub insert_scores: InsertScores,
13  pub match_scores: MatchScores,
14}
15
16pub struct AlignSums {
17  pub forward_sums_match: SumMat,
18  pub forward_sums_insert: SumMat,
19  pub forward_sums_del: SumMat,
20  pub backward_sums_match: SumMat,
21  pub backward_sums_insert: SumMat,
22  pub backward_sums_del: SumMat,
23}
24
25impl AlignScores {
26  pub fn new(init_val: Score) -> AlignScores {
27    let init_vals = [init_val; NUM_BASES];
28    let mat = [[init_val; NUM_BASES]; NUM_BASES];
29    AlignScores {
30      match2match_score: init_val,
31      match2insert_score: init_val,
32      init_match_score: init_val,
33      insert_extend_score: init_val,
34      insert_switch_score: init_val,
35      init_insert_score: init_val,
36      insert_scores: init_vals,
37      match_scores: mat,
38    }
39  }
40
41  pub fn transfer(&mut self) {
42    self.match2match_score = MATCH2MATCH_SCORE;
43    self.match2insert_score = MATCH2INSERT_SCORE;
44    self.insert_extend_score = INSERT_EXTEND_SCORE;
45    self.insert_switch_score = INSERT_SWITCH_SCORE;
46    self.init_match_score = INIT_MATCH_SCORE;
47    self.init_insert_score = INIT_INSERT_SCORE;
48    for (i, &v) in INSERT_SCORES.iter().enumerate() {
49      self.insert_scores[i] = v;
50    }
51    for (i, x) in MATCH_SCORES.iter().enumerate() {
52      for (j, &x) in x.iter().enumerate() {
53        self.match_scores[i][j] = x;
54      }
55    }
56  }
57}
58
59impl AlignSums {
60  pub fn new(seq_len_pair: &(usize, usize)) -> AlignSums {
61    let neg_infs = vec![vec![NEG_INFINITY; seq_len_pair.1]; seq_len_pair.0];
62    AlignSums {
63      forward_sums_match: neg_infs.clone(),
64      forward_sums_insert: neg_infs.clone(),
65      forward_sums_del: neg_infs.clone(),
66      backward_sums_match: neg_infs.clone(),
67      backward_sums_insert: neg_infs.clone(),
68      backward_sums_del: neg_infs,
69    }
70  }
71}
72
73pub fn durbin_algo(seq_pair: &SeqPair, align_scores: &AlignScores) -> ProbMat {
74  let seq_len_pair = (seq_pair.0.len(), seq_pair.1.len());
75  let align_sums = get_align_sums(seq_pair, align_scores);
76  get_match_probs(&align_sums, &seq_len_pair, align_scores)
77}
78
79pub fn get_align_sums(seq_pair: &SeqPair, align_scores: &AlignScores) -> AlignSums {
80  let seq_len_pair = (seq_pair.0.len(), seq_pair.1.len());
81  let mut align_sums = AlignSums::new(&seq_len_pair);
82  for i in 0..seq_len_pair.0 - 1 {
83    for j in 0..seq_len_pair.1 - 1 {
84      if i == 0 && j == 0 {
85        align_sums.forward_sums_match[i][j] = 0.;
86        continue;
87      }
88      if i > 0 && j > 0 {
89        let mut sum = NEG_INFINITY;
90        let basepair = (seq_pair.0[i], seq_pair.1[j]);
91        let match_score = align_scores.match_scores[basepair.0][basepair.1];
92        let begins_sum = (i - 1, j - 1) == (0, 0);
93        let term = align_sums.forward_sums_match[i - 1][j - 1]
94          + if begins_sum {
95            align_scores.init_match_score
96          } else {
97            align_scores.match2match_score
98          };
99        logsumexp(&mut sum, term);
100        let term = align_sums.forward_sums_insert[i - 1][j - 1] + align_scores.match2insert_score;
101        logsumexp(&mut sum, term);
102        let term = align_sums.forward_sums_del[i - 1][j - 1] + align_scores.match2insert_score;
103        logsumexp(&mut sum, term);
104        align_sums.forward_sums_match[i][j] = sum + match_score;
105      }
106      if i > 0 {
107        let base = seq_pair.0[i];
108        let insert_score = align_scores.insert_scores[base];
109        let begins_sum = (i - 1, j) == (0, 0);
110        let mut sum = NEG_INFINITY;
111        let term = align_sums.forward_sums_match[i - 1][j]
112          + if begins_sum {
113            align_scores.init_insert_score
114          } else {
115            align_scores.match2insert_score
116          };
117        logsumexp(&mut sum, term);
118        let term = align_sums.forward_sums_insert[i - 1][j] + align_scores.insert_extend_score;
119        logsumexp(&mut sum, term);
120        align_sums.forward_sums_insert[i][j] = sum + insert_score;
121      }
122      if j > 0 {
123        let base = seq_pair.1[j];
124        let insert_score = align_scores.insert_scores[base];
125        let begins_sum = (i, j - 1) == (0, 0);
126        let mut sum = NEG_INFINITY;
127        let term = align_sums.forward_sums_match[i][j - 1]
128          + if begins_sum {
129            align_scores.init_insert_score
130          } else {
131            align_scores.match2insert_score
132          };
133        logsumexp(&mut sum, term);
134        let term = align_sums.forward_sums_del[i][j - 1] + align_scores.insert_extend_score;
135        logsumexp(&mut sum, term);
136        align_sums.forward_sums_del[i][j] = sum + insert_score;
137      }
138    }
139  }
140  for i in (1..seq_len_pair.0).rev() {
141    for j in (1..seq_len_pair.1).rev() {
142      if i == seq_len_pair.0 - 1 && j == seq_len_pair.1 - 1 {
143        align_sums.backward_sums_match[i][j] = 0.;
144        continue;
145      }
146      if i < seq_len_pair.0 - 1 && j < seq_len_pair.1 - 1 {
147        let mut sum = NEG_INFINITY;
148        let base_pair = (seq_pair.0[i], seq_pair.1[j]);
149        let match_score = align_scores.match_scores[base_pair.0][base_pair.1];
150        let ends_sum = (i + 1, j + 1) == (seq_len_pair.0 - 1, seq_len_pair.1 - 1);
151        let term = align_sums.backward_sums_match[i + 1][j + 1]
152          + if ends_sum {
153            0.
154          } else {
155            align_scores.match2match_score
156          };
157        logsumexp(&mut sum, term);
158        let term = align_sums.backward_sums_insert[i + 1][j + 1] + align_scores.match2insert_score;
159        logsumexp(&mut sum, term);
160        let term = align_sums.backward_sums_del[i + 1][j + 1] + align_scores.match2insert_score;
161        logsumexp(&mut sum, term);
162        align_sums.backward_sums_match[i][j] = sum + match_score;
163      }
164      if i < seq_len_pair.0 - 1 {
165        let base = seq_pair.0[i];
166        let insert_score = align_scores.insert_scores[base];
167        let ends_sum = (i + 1, j) == (seq_len_pair.0 - 1, seq_len_pair.1 - 1);
168        let mut sum = NEG_INFINITY;
169        let term = align_sums.backward_sums_match[i + 1][j]
170          + if ends_sum {
171            0.
172          } else {
173            align_scores.match2insert_score
174          };
175        logsumexp(&mut sum, term);
176        let term = align_sums.backward_sums_insert[i + 1][j] + align_scores.insert_extend_score;
177        logsumexp(&mut sum, term);
178        align_sums.backward_sums_insert[i][j] = sum + insert_score;
179      }
180      if j < seq_len_pair.1 - 1 {
181        let base = seq_pair.1[j];
182        let insert_score = align_scores.insert_scores[base];
183        let ends_sum = (i, j + 1) == (seq_len_pair.0 - 1, seq_len_pair.1 - 1);
184        let mut sum = NEG_INFINITY;
185        let term = align_sums.backward_sums_match[i][j + 1]
186          + if ends_sum {
187            0.
188          } else {
189            align_scores.match2insert_score
190          };
191        logsumexp(&mut sum, term);
192        let term = align_sums.backward_sums_del[i][j + 1] + align_scores.insert_extend_score;
193        logsumexp(&mut sum, term);
194        align_sums.backward_sums_del[i][j] = sum + insert_score;
195      }
196    }
197  }
198  align_sums
199}
200
201fn get_match_probs(
202  align_sums: &AlignSums,
203  seq_len_pair: &(usize, usize),
204  align_scores: &AlignScores,
205) -> ProbMat {
206  let mut match_probs = vec![vec![0.; seq_len_pair.1]; seq_len_pair.0];
207  let mut global_sum = align_sums.forward_sums_match[seq_len_pair.0 - 2][seq_len_pair.1 - 2];
208  logsumexp(
209    &mut global_sum,
210    align_sums.forward_sums_insert[seq_len_pair.0 - 2][seq_len_pair.1 - 2],
211  );
212  logsumexp(
213    &mut global_sum,
214    align_sums.forward_sums_del[seq_len_pair.0 - 2][seq_len_pair.1 - 2],
215  );
216  for (i, x) in match_probs.iter_mut().enumerate() {
217    if i == 0 || i == seq_len_pair.0 - 1 {
218      continue;
219    }
220    for (j, x) in x.iter_mut().enumerate() {
221      if j == 0 || j == seq_len_pair.1 - 1 {
222        continue;
223      }
224      let mut sum = NEG_INFINITY;
225      let forward_sum = align_sums.forward_sums_match[i][j];
226      let ends_sum = (i + 1, j + 1) == (seq_len_pair.0 - 1, seq_len_pair.1 - 1);
227      let term = if ends_sum {
228        0.
229      } else {
230        align_scores.match2match_score
231      } + align_sums.backward_sums_match[i + 1][j + 1];
232      logsumexp(&mut sum, term);
233      let term = align_scores.match2insert_score + align_sums.backward_sums_insert[i + 1][j + 1];
234      logsumexp(&mut sum, term);
235      let term = align_scores.match2insert_score + align_sums.backward_sums_del[i + 1][j + 1];
236      logsumexp(&mut sum, term);
237      let match_prob = expf(forward_sum + sum - global_sum);
238      *x = match_prob;
239    }
240  }
241  match_probs
242}