nale/align/bounded/
traceback_bounded.rs

1use crate::structs::dp_matrix::DpMatrix;
2use crate::structs::trace::constants::{
3    TRACE_B, TRACE_C, TRACE_D, TRACE_E, TRACE_I, TRACE_J, TRACE_M, TRACE_N, TRACE_S, TRACE_T,
4};
5use crate::structs::{Profile, Trace};
6
7pub fn traceback_bounded(
8    profile: &Profile,
9    posterior_matrix: &impl DpMatrix,
10    optimal_matrix: &impl DpMatrix,
11    trace: &mut Trace,
12    target_end: usize,
13) {
14    let mut target_idx = target_end;
15    let mut profile_idx = 0;
16
17    let mut current_state_posterior_probability: f32;
18    // we trace back starting from the last C state
19    let mut previous_state: usize = TRACE_C;
20    let mut current_state: usize;
21
22    trace.append_with_posterior_probability(TRACE_T, target_idx, profile_idx, 0.0);
23    trace.append_with_posterior_probability(TRACE_C, target_idx, profile_idx, 0.0);
24
25    while previous_state != TRACE_S {
26        current_state = match previous_state {
27            TRACE_C => {
28                let c_to_c_path = profile.special_transition_score_delta(
29                    Profile::SPECIAL_C_IDX,
30                    Profile::SPECIAL_LOOP_IDX,
31                ) * (optimal_matrix
32                    .get_special(target_idx - 1, Profile::SPECIAL_C_IDX)
33                    // TODO: why does this specific path involve a posterior probability?
34                    + posterior_matrix.get_special(target_idx, Profile::SPECIAL_C_IDX));
35
36                let c_to_e_path = profile
37                    .transition_score_delta(Profile::SPECIAL_E_IDX, Profile::SPECIAL_MOVE_IDX)
38                    * optimal_matrix.get_special(target_idx, Profile::SPECIAL_E_IDX);
39
40                if c_to_c_path > c_to_e_path {
41                    TRACE_C
42                } else {
43                    TRACE_E
44                }
45            }
46            TRACE_E => {
47                let mut max_score = -f32::INFINITY;
48                let mut state_of_max_score = 0;
49                let mut profile_idx_of_max_score = 0;
50
51                // TODO: why do we do this instead of just taking
52                //       the state that follows the E state?
53                for profile_idx in 1..=profile.length {
54                    if optimal_matrix.get_match(target_idx, profile_idx) >= max_score {
55                        max_score = optimal_matrix.get_match(target_idx, profile_idx);
56                        state_of_max_score = TRACE_M;
57                        profile_idx_of_max_score = profile_idx;
58                    }
59                    if optimal_matrix.get_delete(target_idx, profile_idx) > max_score {
60                        max_score = optimal_matrix.get_delete(target_idx, profile_idx);
61                        state_of_max_score = TRACE_D;
62                        profile_idx_of_max_score = profile_idx;
63                    }
64                }
65                profile_idx = profile_idx_of_max_score;
66                state_of_max_score
67            }
68            TRACE_M => {
69                let possible_states: [usize; 4] = [TRACE_M, TRACE_I, TRACE_D, TRACE_B];
70
71                let possible_paths: [f32; 4] = [
72                    profile.transition_score_delta(Profile::MATCH_TO_MATCH_IDX, profile_idx - 1)
73                        * optimal_matrix.get_match(target_idx - 1, profile_idx - 1),
74                    profile.transition_score_delta(Profile::INSERT_TO_MATCH_IDX, profile_idx - 1)
75                        * optimal_matrix.get_insert(target_idx - 1, profile_idx - 1),
76                    profile.transition_score_delta(Profile::DELETE_TO_MATCH_IDX, profile_idx - 1)
77                        * optimal_matrix.get_delete(target_idx - 1, profile_idx - 1),
78                    profile.transition_score_delta(Profile::BEGIN_TO_MATCH_IDX, profile_idx - 1)
79                        * optimal_matrix.get_special(target_idx - 1, Profile::SPECIAL_B_IDX),
80                ];
81
82                let mut argmax: usize = 0;
83                for i in 1..4 {
84                    if possible_paths[i] > possible_paths[argmax] {
85                        argmax = i;
86                    }
87                }
88
89                // a match means we have moved forward in the both the profile and the target
90                profile_idx -= 1;
91                target_idx -= 1;
92
93                possible_states[argmax]
94            }
95            TRACE_I => {
96                let match_to_insert_path = profile
97                    .transition_score_delta(Profile::MATCH_TO_INSERT_IDX, profile_idx)
98                    * optimal_matrix.get_match(target_idx - 1, profile_idx);
99
100                let insert_to_insert_path: f32 = profile
101                    .transition_score_delta(Profile::INSERT_TO_INSERT_IDX, profile_idx)
102                    * optimal_matrix.get_insert(target_idx - 1, profile_idx);
103
104                // an insert means we moved forward only in the profile
105                target_idx -= 1;
106
107                if match_to_insert_path >= insert_to_insert_path {
108                    TRACE_M
109                } else {
110                    TRACE_I
111                }
112            }
113            TRACE_D => {
114                let match_to_delete_path = profile
115                    .transition_score_delta(Profile::MATCH_TO_DELETE_IDX, profile_idx - 1)
116                    * optimal_matrix.get_match(target_idx, profile_idx - 1);
117
118                let delete_to_delete_path = profile
119                    .transition_score_delta(Profile::DELETE_TO_DELETE_IDX, profile_idx - 1)
120                    * optimal_matrix.get_delete(target_idx, profile_idx - 1);
121
122                // a delete means we moved forward only in the profile
123                profile_idx -= 1;
124
125                if match_to_delete_path >= delete_to_delete_path {
126                    TRACE_M
127                } else {
128                    TRACE_D
129                }
130            }
131            TRACE_B => {
132                let n_to_b_path = profile.special_transition_score_delta(
133                    Profile::SPECIAL_N_IDX,
134                    Profile::SPECIAL_MOVE_IDX,
135                ) * optimal_matrix
136                    .get_special(target_idx, Profile::SPECIAL_N_IDX);
137
138                let j_to_b_path = profile.special_transition_score_delta(
139                    Profile::SPECIAL_J_IDX,
140                    Profile::SPECIAL_MOVE_IDX,
141                ) * optimal_matrix
142                    .get_special(target_idx, Profile::SPECIAL_J_IDX);
143
144                if n_to_b_path >= j_to_b_path {
145                    TRACE_N
146                } else {
147                    TRACE_J
148                }
149            }
150            TRACE_N => {
151                if target_idx == 0 {
152                    TRACE_S
153                } else {
154                    TRACE_N
155                }
156            }
157            TRACE_J => {
158                let j_to_j_path = profile.special_transition_score_delta(
159                    Profile::SPECIAL_J_IDX,
160                    Profile::SPECIAL_LOOP_IDX,
161                ) * (optimal_matrix
162                    .get_special(target_idx - 1, Profile::SPECIAL_J_IDX)
163                    // TODO: why does this specific path involve a posterior probability?
164                    + posterior_matrix.get_special(target_idx, Profile::SPECIAL_J_IDX));
165
166                let e_to_j_path = profile.special_transition_score_delta(
167                    Profile::SPECIAL_E_IDX,
168                    Profile::SPECIAL_LOOP_IDX,
169                ) * optimal_matrix
170                    .get_special(target_idx, Profile::SPECIAL_E_IDX);
171
172                if j_to_j_path > e_to_j_path {
173                    TRACE_J
174                } else {
175                    TRACE_E
176                }
177            }
178            _ => {
179                panic!("bad state in traceback")
180            }
181        };
182
183        current_state_posterior_probability = get_posterior_probability(
184            posterior_matrix,
185            current_state,
186            previous_state,
187            profile_idx,
188            target_idx,
189        );
190
191        trace.append_with_posterior_probability(
192            current_state,
193            target_idx,
194            profile_idx,
195            current_state_posterior_probability,
196        );
197
198        if (current_state == TRACE_N || current_state == TRACE_J || current_state == TRACE_C)
199            && current_state == previous_state
200        {
201            target_idx -= 1;
202        }
203        previous_state = current_state;
204    }
205    trace.reverse();
206}
207
208pub fn get_posterior_probability(
209    optimal_matrix: &impl DpMatrix,
210    current_state: usize,
211    previous_state: usize,
212    profile_idx: usize,
213    target_idx: usize,
214) -> f32 {
215    match current_state {
216        TRACE_M => optimal_matrix.get_match(target_idx, profile_idx),
217        TRACE_I => optimal_matrix.get_insert(target_idx, profile_idx),
218        TRACE_N => {
219            if current_state == previous_state {
220                optimal_matrix.get_special(target_idx, Profile::SPECIAL_N_IDX)
221            } else {
222                0.0
223            }
224        }
225        TRACE_C => {
226            if current_state == previous_state {
227                optimal_matrix.get_special(target_idx, Profile::SPECIAL_C_IDX)
228            } else {
229                0.0
230            }
231        }
232        TRACE_J => {
233            if current_state == previous_state {
234                optimal_matrix.get_special(target_idx, Profile::SPECIAL_J_IDX)
235            } else {
236                0.0
237            }
238        }
239        _ => 0.0,
240    }
241}