rna_algos/
centroid_fold.rs

1use utils::*;
2
3#[derive(Clone)]
4pub struct CentroidFold<T> {
5  pub basepair_pos_pairs: PosPairs<T>,
6  pub expect_accuracy: Score,
7}
8pub type Poss<T> = Vec<T>;
9
10impl<T> Default for CentroidFold<T> {
11  fn default() -> Self {
12    Self::new()
13  }
14}
15
16impl<T> CentroidFold<T> {
17  pub fn new() -> CentroidFold<T> {
18    CentroidFold {
19      basepair_pos_pairs: PosPairs::<T>::new(),
20      expect_accuracy: 0.,
21    }
22  }
23}
24
25pub fn centroid_fold<T: HashIndex>(
26  basepair_probs: &SparseProbMat<T>,
27  seq_len: usize,
28  centroid_threshold: Prob,
29) -> CentroidFold<T>
30where
31  T: HashIndex,
32{
33  let mut max_expect_accuracies = vec![vec![0.; seq_len]; seq_len];
34  let seq_len = T::from_usize(seq_len).unwrap();
35  for subseq_len in range_inclusive(T::one(), seq_len) {
36    for i in range_inclusive(T::zero(), seq_len - subseq_len) {
37      let j = i + subseq_len - T::one();
38      let (long_i, long_j) = (i.to_usize().unwrap(), j.to_usize().unwrap());
39      if i == j {
40        continue;
41      }
42      let mut max_expect_accuracy = max_expect_accuracies[long_i + 1][long_j];
43      let expect_accuracy = max_expect_accuracies[long_i][long_j - 1];
44      if expect_accuracy > max_expect_accuracy {
45        max_expect_accuracy = expect_accuracy;
46      }
47      let pos_pair = (i, j);
48      if let Some(&x) = basepair_probs.get(&pos_pair) {
49        let expect_accuracy =
50          max_expect_accuracies[long_i + 1][long_j - 1] + centroid_threshold * x - 1.;
51        if expect_accuracy > max_expect_accuracy {
52          max_expect_accuracy = expect_accuracy;
53        }
54      }
55      for k in long_i + 1..long_j {
56        let expect_accuracy =
57          max_expect_accuracies[long_i][k] + max_expect_accuracies[k + 1][long_j];
58        if expect_accuracy > max_expect_accuracy {
59          max_expect_accuracy = expect_accuracy;
60        }
61      }
62      max_expect_accuracies[long_i][long_j] = max_expect_accuracy;
63    }
64  }
65  let mut centroid_fold = CentroidFold::<T>::new();
66  let mut pos_pair_stack = vec![(T::zero(), seq_len - T::one())];
67  while !pos_pair_stack.is_empty() {
68    let pos_pair = pos_pair_stack.pop().unwrap();
69    let (i, j) = pos_pair;
70    if j <= i {
71      continue;
72    }
73    let (long_i, long_j) = (i.to_usize().unwrap(), j.to_usize().unwrap());
74    let max_expect_accuracy = max_expect_accuracies[long_i][long_j];
75    if max_expect_accuracy == 0. {
76      continue;
77    }
78    if max_expect_accuracy == max_expect_accuracies[long_i + 1][long_j] {
79      pos_pair_stack.push((i + T::one(), j));
80    } else if max_expect_accuracy == max_expect_accuracies[long_i][long_j - 1] {
81      pos_pair_stack.push((i, j - T::one()));
82    } else if basepair_probs.contains_key(&pos_pair)
83      && max_expect_accuracy
84        == max_expect_accuracies[long_i + 1][long_j - 1]
85          + centroid_threshold * basepair_probs[&pos_pair]
86          - 1.
87    {
88      pos_pair_stack.push((i + T::one(), j - T::one()));
89      centroid_fold.basepair_pos_pairs.push(pos_pair);
90    } else {
91      for k in range(i + T::one(), j) {
92        let long_k = k.to_usize().unwrap();
93        if max_expect_accuracy
94          == max_expect_accuracies[long_i][long_k] + max_expect_accuracies[long_k + 1][long_j]
95        {
96          pos_pair_stack.push((i, k));
97          pos_pair_stack.push((k + T::one(), j));
98          break;
99        }
100      }
101    }
102  }
103  centroid_fold.expect_accuracy = max_expect_accuracies[0][seq_len.to_usize().unwrap() - 1];
104  centroid_fold
105}