rna_algos/
centroid_fold.rs1use utils::*;
2
3#[derive(Clone)]
4pub struct CentroidFold<T> {
5 pub basepair_pos_pairs: PosPairs<T>,
6 pub expect_accuracy: Score,
7}
8pub type Poss<T> = Vec<T>;
9
10impl<T> Default for CentroidFold<T> {
11 fn default() -> Self {
12 Self::new()
13 }
14}
15
16impl<T> CentroidFold<T> {
17 pub fn new() -> CentroidFold<T> {
18 CentroidFold {
19 basepair_pos_pairs: PosPairs::<T>::new(),
20 expect_accuracy: 0.,
21 }
22 }
23}
24
25pub fn centroid_fold<T: HashIndex>(
26 basepair_probs: &SparseProbMat<T>,
27 seq_len: usize,
28 centroid_threshold: Prob,
29) -> CentroidFold<T>
30where
31 T: HashIndex,
32{
33 let mut max_expect_accuracies = vec![vec![0.; seq_len]; seq_len];
34 let seq_len = T::from_usize(seq_len).unwrap();
35 for subseq_len in range_inclusive(T::one(), seq_len) {
36 for i in range_inclusive(T::zero(), seq_len - subseq_len) {
37 let j = i + subseq_len - T::one();
38 let (long_i, long_j) = (i.to_usize().unwrap(), j.to_usize().unwrap());
39 if i == j {
40 continue;
41 }
42 let mut max_expect_accuracy = max_expect_accuracies[long_i + 1][long_j];
43 let expect_accuracy = max_expect_accuracies[long_i][long_j - 1];
44 if expect_accuracy > max_expect_accuracy {
45 max_expect_accuracy = expect_accuracy;
46 }
47 let pos_pair = (i, j);
48 if let Some(&x) = basepair_probs.get(&pos_pair) {
49 let expect_accuracy =
50 max_expect_accuracies[long_i + 1][long_j - 1] + centroid_threshold * x - 1.;
51 if expect_accuracy > max_expect_accuracy {
52 max_expect_accuracy = expect_accuracy;
53 }
54 }
55 for k in long_i + 1..long_j {
56 let expect_accuracy =
57 max_expect_accuracies[long_i][k] + max_expect_accuracies[k + 1][long_j];
58 if expect_accuracy > max_expect_accuracy {
59 max_expect_accuracy = expect_accuracy;
60 }
61 }
62 max_expect_accuracies[long_i][long_j] = max_expect_accuracy;
63 }
64 }
65 let mut centroid_fold = CentroidFold::<T>::new();
66 let mut pos_pair_stack = vec![(T::zero(), seq_len - T::one())];
67 while !pos_pair_stack.is_empty() {
68 let pos_pair = pos_pair_stack.pop().unwrap();
69 let (i, j) = pos_pair;
70 if j <= i {
71 continue;
72 }
73 let (long_i, long_j) = (i.to_usize().unwrap(), j.to_usize().unwrap());
74 let max_expect_accuracy = max_expect_accuracies[long_i][long_j];
75 if max_expect_accuracy == 0. {
76 continue;
77 }
78 if max_expect_accuracy == max_expect_accuracies[long_i + 1][long_j] {
79 pos_pair_stack.push((i + T::one(), j));
80 } else if max_expect_accuracy == max_expect_accuracies[long_i][long_j - 1] {
81 pos_pair_stack.push((i, j - T::one()));
82 } else if basepair_probs.contains_key(&pos_pair)
83 && max_expect_accuracy
84 == max_expect_accuracies[long_i + 1][long_j - 1]
85 + centroid_threshold * basepair_probs[&pos_pair]
86 - 1.
87 {
88 pos_pair_stack.push((i + T::one(), j - T::one()));
89 centroid_fold.basepair_pos_pairs.push(pos_pair);
90 } else {
91 for k in range(i + T::one(), j) {
92 let long_k = k.to_usize().unwrap();
93 if max_expect_accuracy
94 == max_expect_accuracies[long_i][long_k] + max_expect_accuracies[long_k + 1][long_j]
95 {
96 pos_pair_stack.push((i, k));
97 pos_pair_stack.push((k + T::one(), j));
98 break;
99 }
100 }
101 }
102 }
103 centroid_fold.expect_accuracy = max_expect_accuracies[0][seq_len.to_usize().unwrap() - 1];
104 centroid_fold
105}