1extern crate bio;
2extern crate consprob;
3extern crate my_bfgs as bfgs;
4extern crate ndarray_rand;
5extern crate rand;
6pub mod trained_alignfold_scores;
7pub mod trained_alignfold_scores_randinit;
8
9pub use bfgs::bfgs;
10pub use bio::io::fasta::*;
11pub use bio::utils::*;
12pub use consprob::*;
13pub use ndarray_rand::rand_distr::{Distribution, Normal};
14pub use ndarray_rand::RandomExt;
15pub use rand::thread_rng;
16pub use std::f32::INFINITY;
17pub use std::fs::{read_dir, DirEntry};
18pub use std::io::stdout;
19
20pub type SparsePosMat<T> = HashSet<PosPair<T>>;
21pub type FoldScoresPairTrained<T> = (FoldScoresTrained<T>, FoldScoresTrained<T>);
22pub type RefFoldScoresPair<'a, T> = (&'a FoldScoresTrained<T>, &'a FoldScoresTrained<T>);
23pub type Regularizers = Array1<Regularizer>;
24pub type Regularizer = Score;
25pub type BfgsScores = Array1<BfgsScore>;
26pub type BfgsScore = f64;
27pub type Scores = Array1<Score>;
28pub type TrainData<T> = Vec<TrainDatum<T>>;
29pub type RealSeqPair = (Seq, Seq);
30pub type LoopStruct = HashMap<(usize, usize), Vec<(usize, usize)>>;
31
32#[derive(Clone)]
33pub struct FoldScoresTrained<T> {
34 pub hairpin_scores: SparseScoreMat<T>,
35 pub twoloop_scores: ScoreMat4d<T>,
36 pub multibranch_close_scores: SparseScoreMat<T>,
37 pub multibranch_accessible_scores: SparseScoreMat<T>,
38 pub external_accessible_scores: SparseScoreMat<T>,
39}
40
41#[derive(Clone)]
42pub struct TrainDatum<T> {
43 pub seq_pair: RealSeqPair,
44 pub seq_pair_gapped: RealSeqPair,
45 pub alignfold_counts_observed: AlignfoldScores,
46 pub alignfold_counts_expected: AlignfoldScores,
47 pub basepair_probs_pair: SparseProbMatPair<T>,
48 pub max_basepair_span_pair: (T, T),
49 pub global_sum: Prob,
50 pub forward_pos_pairs: PosPairMatSet<T>,
51 pub backward_pos_pairs: PosPairMatSet<T>,
52 pub pos_quads_hashed_lens: PosQuadsHashedLens<T>,
53 pub matchable_poss: SparsePosSets<T>,
54 pub matchable_poss2: SparsePosSets<T>,
55 pub fold_scores_pair: FoldScoresPairTrained<T>,
56 pub match_probs: SparseProbMat<T>,
57 pub alignfold: PairAlignfold<T>,
58 pub accuracy: Score,
59}
60
61#[derive(Clone)]
62pub struct PairAlignfold<T> {
63 pub matched_pos_pairs: SparsePosMat<T>,
64 pub inserted_poss: SparsePoss<T>,
65 pub deleted_poss: SparsePoss<T>,
66}
67
68#[derive(Clone, Debug)]
69pub struct AlignfoldScores {
70 pub hairpin_scores_len: HairpinScoresLen,
72 pub bulge_scores_len: BulgeScoresLen,
73 pub interior_scores_len: InteriorScoresLen,
74 pub interior_scores_symmetric: InteriorScoresSymmetric,
75 pub interior_scores_asymmetric: InteriorScoresAsymmetric,
76 pub stack_scores: StackScores,
77 pub terminal_mismatch_scores: TerminalMismatchScores,
78 pub dangling_scores_left: DanglingScores,
79 pub dangling_scores_right: DanglingScores,
80 pub helix_close_scores: HelixCloseScores,
81 pub basepair_scores: BasepairScores,
82 pub interior_scores_explicit: InteriorScoresExplicit,
83 pub bulge_scores_0x1: BulgeScores0x1,
84 pub interior_scores_1x1: InteriorScores1x1Contra,
85 pub multibranch_score_base: Score,
86 pub multibranch_score_basepair: Score,
87 pub multibranch_score_unpair: Score,
88 pub external_score_basepair: Score,
89 pub external_score_unpair: Score,
90 pub match2match_score: Score,
92 pub match2insert_score: Score,
93 pub insert_extend_score: Score,
94 pub init_match_score: Score,
95 pub init_insert_score: Score,
96 pub insert_scores: InsertScores,
97 pub match_scores: MatchScores,
98 pub hairpin_scores_len_cumulative: HairpinScoresLen,
100 pub bulge_scores_len_cumulative: BulgeScoresLen,
101 pub interior_scores_len_cumulative: InteriorScoresLen,
102 pub interior_scores_symmetric_cumulative: InteriorScoresSymmetric,
103 pub interior_scores_asymmetric_cumulative: InteriorScoresAsymmetric,
104}
105pub type ScoreMat = Vec<Vec<Score>>;
106pub struct RangeInsertScores {
107 pub insert_scores: ScoreMat,
108 pub insert_scores_external: ScoreMat,
109 pub insert_scores_multibranch: ScoreMat,
110 pub insert_scores2: ScoreMat,
111 pub insert_scores_external2: ScoreMat,
112 pub insert_scores_multibranch2: ScoreMat,
113}
114
115pub type InputsAlignfoldProbsGetter<'a, T> = (
116 &'a SeqPair<'a>,
117 &'a AlignfoldScores,
118 &'a PosPair<T>,
119 &'a SparseProbMat<T>,
120 &'a AlignfoldSums<T>,
121 bool,
122 Sum,
123 bool,
124 &'a mut AlignfoldScores,
125 &'a PosQuadsHashedLens<T>,
126 &'a RefFoldScoresPair<'a, T>,
127 bool,
128 &'a PosPairMatSet<T>,
129 &'a PosPairMatSet<T>,
130 &'a RangeInsertScores,
131 &'a SparsePosSets<T>,
132 &'a SparsePosSets<T>,
133);
134
135pub type InputsConsprobCore<'a, T> = (
136 &'a SeqPair<'a>,
137 &'a AlignfoldScores,
138 &'a PosPair<T>,
139 &'a SparseProbMat<T>,
140 bool,
141 bool,
142 &'a mut AlignfoldScores,
143 &'a PosPairMatSet<T>,
144 &'a PosPairMatSet<T>,
145 &'a PosQuadsHashedLens<T>,
146 &'a RefFoldScoresPair<'a, T>,
147 bool,
148 &'a SparsePosSets<T>,
149 &'a SparsePosSets<T>,
150);
151
152pub type Inputs2loopSumsGetter<'a, T> = (
153 &'a SeqPair<'a>,
154 &'a AlignfoldScores,
155 &'a SparseProbMat<T>,
156 &'a PosQuad<T>,
157 &'a AlignfoldSums<T>,
158 bool,
159 &'a PosPairMatSet<T>,
160 &'a RefFoldScoresPair<'a, T>,
161 &'a RangeInsertScores,
162 &'a SparsePosSets<T>,
163 &'a SparsePosSets<T>,
164);
165
166pub type InputsLoopSumsGetter<'a, T> = (
167 &'a SeqPair<'a>,
168 &'a AlignfoldScores,
169 &'a SparseProbMat<T>,
170 &'a PosQuad<T>,
171 &'a mut AlignfoldSums<T>,
172 bool,
173 &'a PosPairMatSet<T>,
174 &'a RangeInsertScores,
175 &'a SparsePosSets<T>,
176 &'a SparsePosSets<T>,
177);
178
179pub type InputsInsideSumsGetter<'a, T> = (
180 &'a SeqPair<'a>,
181 &'a AlignfoldScores,
182 &'a PosPair<T>,
183 &'a SparseProbMat<T>,
184 bool,
185 &'a PosPairMatSet<T>,
186 &'a PosPairMatSet<T>,
187 &'a PosQuadsHashedLens<T>,
188 &'a RefFoldScoresPair<'a, T>,
189 &'a RangeInsertScores,
190 &'a SparsePosSets<T>,
191 &'a SparsePosSets<T>,
192);
193
194impl<T: HashIndex> Default for FoldScoresTrained<T> {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl<T: HashIndex> FoldScoresTrained<T> {
201 pub fn new() -> FoldScoresTrained<T> {
202 FoldScoresTrained {
203 hairpin_scores: SparseScoreMat::<T>::default(),
204 twoloop_scores: ScoreMat4d::<T>::default(),
205 multibranch_close_scores: SparseScoreMat::<T>::default(),
206 multibranch_accessible_scores: SparseScoreMat::<T>::default(),
207 external_accessible_scores: SparseScoreMat::<T>::default(),
208 }
209 }
210
211 pub fn set_curr_scores(
212 alignfold_scores: &AlignfoldScores,
213 seq: SeqSlice,
214 basepair_probs: &SparseProbMat<T>,
215 ) -> FoldScoresTrained<T> {
216 let mut fold_scores = FoldScoresTrained::<T>::new();
217 for pos_pair in basepair_probs.keys() {
218 let long_pos_pair = (
219 pos_pair.0.to_usize().unwrap(),
220 pos_pair.1.to_usize().unwrap(),
221 );
222 if long_pos_pair.1 - long_pos_pair.0 - 1 <= MAX_LOOP_LEN {
223 let x = get_hairpin_score(alignfold_scores, seq, &long_pos_pair);
224 fold_scores.hairpin_scores.insert(*pos_pair, x);
225 }
226 let multibranch_close_score = alignfold_scores.multibranch_score_base
227 + alignfold_scores.multibranch_score_basepair
228 + get_junction_score(alignfold_scores, seq, &long_pos_pair);
229 fold_scores
230 .multibranch_close_scores
231 .insert(*pos_pair, multibranch_close_score);
232 let basepair = (seq[long_pos_pair.0], seq[long_pos_pair.1]);
233 let junction_score =
234 get_junction_score(alignfold_scores, seq, &(long_pos_pair.1, long_pos_pair.0))
235 + alignfold_scores.basepair_scores[basepair.0][basepair.1];
236 let multibranch_accessible_score =
237 junction_score + alignfold_scores.multibranch_score_basepair;
238 fold_scores
239 .multibranch_accessible_scores
240 .insert(*pos_pair, multibranch_accessible_score);
241 let external_accessible_score = junction_score + alignfold_scores.external_score_basepair;
242 fold_scores
243 .external_accessible_scores
244 .insert(*pos_pair, external_accessible_score);
245 for x in basepair_probs.keys() {
246 if !(x.0 < pos_pair.0 && pos_pair.1 < x.1) {
247 continue;
248 }
249 let y = (x.0.to_usize().unwrap(), x.1.to_usize().unwrap());
250 if long_pos_pair.0 - y.0 - 1 + y.1 - long_pos_pair.1 - 1 > MAX_LOOP_LEN {
251 continue;
252 }
253 let y = get_twoloop_score(alignfold_scores, seq, &y, &long_pos_pair);
254 fold_scores
255 .twoloop_scores
256 .insert((x.0, x.1, pos_pair.0, pos_pair.1), y);
257 }
258 }
259 fold_scores
260 }
261}
262
263impl AlignfoldScores {
264 pub fn new(init_val: Score) -> AlignfoldScores {
265 let init_vals = [init_val; NUM_BASES];
266 let mat_2d = [init_vals; NUM_BASES];
267 let mat_3d = [mat_2d; NUM_BASES];
268 let mat_4d = [mat_3d; NUM_BASES];
269 AlignfoldScores {
270 hairpin_scores_len: [init_val; MAX_LOOP_LEN + 1],
272 bulge_scores_len: [init_val; MAX_LOOP_LEN],
273 interior_scores_len: [init_val; MAX_LOOP_LEN - 1],
274 interior_scores_symmetric: [init_val; MAX_INTERIOR_SYMMETRIC],
275 interior_scores_asymmetric: [init_val; MAX_INTERIOR_ASYMMETRIC],
276 stack_scores: mat_4d,
277 terminal_mismatch_scores: mat_4d,
278 dangling_scores_left: mat_3d,
279 dangling_scores_right: mat_3d,
280 helix_close_scores: mat_2d,
281 basepair_scores: mat_2d,
282 interior_scores_explicit: [[init_val; MAX_INTERIOR_EXPLICIT]; MAX_INTERIOR_EXPLICIT],
283 bulge_scores_0x1: init_vals,
284 interior_scores_1x1: mat_2d,
285 multibranch_score_base: init_val,
286 multibranch_score_basepair: init_val,
287 multibranch_score_unpair: init_val,
288 external_score_basepair: init_val,
289 external_score_unpair: init_val,
290 match2match_score: init_val,
292 match2insert_score: init_val,
293 init_match_score: init_val,
294 insert_extend_score: init_val,
295 init_insert_score: init_val,
296 insert_scores: init_vals,
297 match_scores: mat_2d,
298 hairpin_scores_len_cumulative: [init_val; MAX_LOOP_LEN + 1],
300 bulge_scores_len_cumulative: [init_val; MAX_LOOP_LEN],
301 interior_scores_len_cumulative: [init_val; MAX_LOOP_LEN - 1],
302 interior_scores_symmetric_cumulative: [init_val; MAX_INTERIOR_SYMMETRIC],
303 interior_scores_asymmetric_cumulative: [init_val; MAX_INTERIOR_ASYMMETRIC],
304 }
305 }
306
307 pub fn len(&self) -> usize {
308 let mut sum = 0;
309 sum += self.hairpin_scores_len.len();
310 sum += self.bulge_scores_len.len();
311 sum += self.interior_scores_len.len();
312 sum += self.interior_scores_symmetric.len();
313 sum += self.interior_scores_asymmetric.len();
314 let len = self.stack_scores.len();
315 for i in 0..len {
316 for j in 0..len {
317 if !has_canonical_basepair(&(i, j)) {
318 continue;
319 }
320 for k in 0..len {
321 for l in 0..len {
322 if !has_canonical_basepair(&(k, l)) {
323 continue;
324 }
325 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
326 if ((i, j), (k, l)) != dict_min_stack {
327 continue;
328 }
329 sum += 1;
330 }
331 }
332 }
333 }
334 let len = self.terminal_mismatch_scores.len();
335 for i in 0..len {
336 for j in 0..len {
337 if !has_canonical_basepair(&(i, j)) {
338 continue;
339 }
340 for _ in 0..len {
341 for _ in 0..len {
342 sum += 1;
343 }
344 }
345 }
346 }
347 let len = self.dangling_scores_left.len();
348 for i in 0..len {
349 for j in 0..len {
350 if !has_canonical_basepair(&(i, j)) {
351 continue;
352 }
353 for _ in 0..len {
354 sum += 1;
355 }
356 }
357 }
358 let len = self.dangling_scores_right.len();
359 for i in 0..len {
360 for j in 0..len {
361 if !has_canonical_basepair(&(i, j)) {
362 continue;
363 }
364 for _ in 0..len {
365 sum += 1;
366 }
367 }
368 }
369 let len = self.helix_close_scores.len();
370 for i in 0..len {
371 for j in 0..len {
372 if !has_canonical_basepair(&(i, j)) {
373 continue;
374 }
375 sum += 1;
376 }
377 }
378 let len = self.basepair_scores.len();
379 for i in 0..len {
380 for j in 0..len {
381 if !has_canonical_basepair(&(i, j)) {
382 continue;
383 }
384 let dict_min_basepair = get_dict_min_pair(&(i, j));
385 if (i, j) != dict_min_basepair {
386 continue;
387 }
388 sum += 1;
389 }
390 }
391 let len = self.interior_scores_explicit.len();
392 for i in 0..len {
393 for j in 0..len {
394 let dict_min_len_pair = get_dict_min_pair(&(i, j));
395 if (i, j) != dict_min_len_pair {
396 continue;
397 }
398 sum += 1;
399 }
400 }
401 sum += self.bulge_scores_0x1.len();
402 let len = self.interior_scores_1x1.len();
403 for i in 0..len {
404 for j in 0..len {
405 let dict_min_basepair = get_dict_min_pair(&(i, j));
406 if (i, j) != dict_min_basepair {
407 continue;
408 }
409 sum += 1;
410 }
411 }
412 sum += GROUP_SIZE_MULTIBRANCH;
413 sum += GROUP_SIZE_EXTERNAL;
414 sum += GROUP_SIZE_MATCH_TRANSITION;
415 sum += GROUP_SIZE_INSERT_TRANSITION;
416 sum += self.insert_scores.len();
417 let len = self.match_scores.len();
418 for i in 0..len {
419 for j in 0..len {
420 let dict_min_match = get_dict_min_pair(&(i, j));
421 if (i, j) != dict_min_match {
422 continue;
423 }
424 sum += 1;
425 }
426 }
427 sum
428 }
429
430 pub fn is_empty(&self) -> bool {
431 self.len() == 0
432 }
433
434 pub fn update_regularizers(&self, regularizers: &mut Regularizers) {
435 let mut regularizers2 = vec![0.; regularizers.len()];
436 let mut offset = 0;
437 let len = self.hairpin_scores_len.len();
438 let group_size = len;
439 let mut squared_sum = 0.;
440 for i in 0..len {
441 let x = self.hairpin_scores_len[i];
442 squared_sum += x * x;
443 }
444 let regularizer = get_regularizer(group_size, squared_sum);
445 for i in 0..len {
446 regularizers2[offset + i] = regularizer;
447 }
448 offset += group_size;
449 let len = self.bulge_scores_len.len();
450 let group_size = len;
451 let mut squared_sum = 0.;
452 for i in 0..len {
453 let x = self.bulge_scores_len[i];
454 squared_sum += x * x;
455 }
456 let regularizer = get_regularizer(group_size, squared_sum);
457 for i in 0..len {
458 regularizers2[offset + i] = regularizer;
459 }
460 offset += group_size;
461 let len = self.interior_scores_len.len();
462 let group_size = len;
463 let mut squared_sum = 0.;
464 for i in 0..len {
465 let x = self.interior_scores_len[i];
466 squared_sum += x * x;
467 }
468 let regularizer = get_regularizer(group_size, squared_sum);
469 for i in 0..len {
470 regularizers2[offset + i] = regularizer;
471 }
472 offset += group_size;
473 let len = self.interior_scores_symmetric.len();
474 let group_size = len;
475 let mut squared_sum = 0.;
476 for i in 0..len {
477 let x = self.interior_scores_symmetric[i];
478 squared_sum += x * x;
479 }
480 let regularizer = get_regularizer(group_size, squared_sum);
481 for i in 0..len {
482 regularizers2[offset + i] = regularizer;
483 }
484 offset += group_size;
485 let len = self.interior_scores_asymmetric.len();
486 let group_size = len;
487 let mut squared_sum = 0.;
488 for i in 0..len {
489 let x = self.interior_scores_asymmetric[i];
490 squared_sum += x * x;
491 }
492 let regularizer = get_regularizer(group_size, squared_sum);
493 for i in 0..len {
494 regularizers2[offset + i] = regularizer;
495 }
496 offset += group_size;
497 let len = self.stack_scores.len();
498 let mut effective_group_size = 0;
499 let mut squared_sum = 0.;
500 for i in 0..len {
501 for j in 0..len {
502 if !has_canonical_basepair(&(i, j)) {
503 continue;
504 }
505 for k in 0..len {
506 for l in 0..len {
507 if !has_canonical_basepair(&(k, l)) {
508 continue;
509 }
510 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
511 if ((i, j), (k, l)) != dict_min_stack {
512 continue;
513 }
514 let x = self.stack_scores[i][j][k][l];
515 squared_sum += x * x;
516 effective_group_size += 1;
517 }
518 }
519 }
520 }
521 let regularizer = get_regularizer(effective_group_size, squared_sum);
522 for i in 0..len {
523 for j in 0..len {
524 if !has_canonical_basepair(&(i, j)) {
525 continue;
526 }
527 for k in 0..len {
528 for l in 0..len {
529 if !has_canonical_basepair(&(k, l)) {
530 continue;
531 }
532 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
533 if ((i, j), (k, l)) != dict_min_stack {
534 continue;
535 }
536 regularizers2[offset] = regularizer;
537 offset += 1;
538 }
539 }
540 }
541 }
542 let len = self.terminal_mismatch_scores.len();
543 let mut effective_group_size = 0;
544 let mut squared_sum = 0.;
545 for i in 0..len {
546 for j in 0..len {
547 if !has_canonical_basepair(&(i, j)) {
548 continue;
549 }
550 for k in 0..len {
551 for l in 0..len {
552 let x = self.terminal_mismatch_scores[i][j][k][l];
553 squared_sum += x * x;
554 effective_group_size += 1;
555 }
556 }
557 }
558 }
559 let regularizer = get_regularizer(effective_group_size, squared_sum);
560 for i in 0..len {
561 for j in 0..len {
562 if !has_canonical_basepair(&(i, j)) {
563 continue;
564 }
565 for _ in 0..len {
566 for _ in 0..len {
567 regularizers2[offset] = regularizer;
568 offset += 1;
569 }
570 }
571 }
572 }
573 let len = self.dangling_scores_left.len();
574 let mut effective_group_size = 0;
575 let mut squared_sum = 0.;
576 for i in 0..len {
577 for j in 0..len {
578 if !has_canonical_basepair(&(i, j)) {
579 continue;
580 }
581 for k in 0..len {
582 let x = self.dangling_scores_left[i][j][k];
583 squared_sum += x * x;
584 effective_group_size += 1;
585 }
586 }
587 }
588 for i in 0..len {
589 for j in 0..len {
590 if !has_canonical_basepair(&(i, j)) {
591 continue;
592 }
593 for k in 0..len {
594 let x = self.dangling_scores_right[i][j][k];
595 squared_sum += x * x;
596 effective_group_size += 1;
597 }
598 }
599 }
600 let regularizer = get_regularizer(effective_group_size, squared_sum);
601 for i in 0..len {
602 for j in 0..len {
603 if !has_canonical_basepair(&(i, j)) {
604 continue;
605 }
606 for _ in 0..len {
607 regularizers2[offset] = regularizer;
608 offset += 1;
609 }
610 }
611 }
612 for i in 0..len {
613 for j in 0..len {
614 if !has_canonical_basepair(&(i, j)) {
615 continue;
616 }
617 for _ in 0..len {
618 regularizers2[offset] = regularizer;
619 offset += 1;
620 }
621 }
622 }
623 let len = self.helix_close_scores.len();
624 let mut effective_group_size = 0;
625 let mut squared_sum = 0.;
626 for i in 0..len {
627 for j in 0..len {
628 if !has_canonical_basepair(&(i, j)) {
629 continue;
630 }
631 let x = self.helix_close_scores[i][j];
632 squared_sum += x * x;
633 effective_group_size += 1;
634 }
635 }
636 let regularizer = get_regularizer(effective_group_size, squared_sum);
637 for i in 0..len {
638 for j in 0..len {
639 if !has_canonical_basepair(&(i, j)) {
640 continue;
641 }
642 regularizers2[offset] = regularizer;
643 offset += 1;
644 }
645 }
646 let len = self.basepair_scores.len();
647 let mut effective_group_size = 0;
648 let mut squared_sum = 0.;
649 for i in 0..len {
650 for j in 0..len {
651 if !has_canonical_basepair(&(i, j)) {
652 continue;
653 }
654 let dict_min_basepair = get_dict_min_pair(&(i, j));
655 if (i, j) != dict_min_basepair {
656 continue;
657 }
658 let x = self.basepair_scores[i][j];
659 squared_sum += x * x;
660 effective_group_size += 1;
661 }
662 }
663 let regularizer = get_regularizer(effective_group_size, squared_sum);
664 for i in 0..len {
665 for j in 0..len {
666 if !has_canonical_basepair(&(i, j)) {
667 continue;
668 }
669 let dict_min_basepair = get_dict_min_pair(&(i, j));
670 if (i, j) != dict_min_basepair {
671 continue;
672 }
673 regularizers2[offset] = regularizer;
674 offset += 1;
675 }
676 }
677 let len = self.interior_scores_explicit.len();
678 let mut effective_group_size = 0;
679 let mut squared_sum = 0.;
680 for i in 0..len {
681 for j in 0..len {
682 let dict_min_len_pair = get_dict_min_pair(&(i, j));
683 if (i, j) != dict_min_len_pair {
684 continue;
685 }
686 let x = self.interior_scores_explicit[i][j];
687 squared_sum += x * x;
688 effective_group_size += 1;
689 }
690 }
691 let regularizer = get_regularizer(effective_group_size, squared_sum);
692 for i in 0..len {
693 for j in 0..len {
694 let dict_min_len_pair = get_dict_min_pair(&(i, j));
695 if (i, j) != dict_min_len_pair {
696 continue;
697 }
698 regularizers2[offset] = regularizer;
699 offset += 1;
700 }
701 }
702 let len = self.bulge_scores_0x1.len();
703 let group_size = len;
704 let mut squared_sum = 0.;
705 for i in 0..len {
706 let x = self.bulge_scores_0x1[i];
707 squared_sum += x * x;
708 }
709 let regularizer = get_regularizer(group_size, squared_sum);
710 for i in 0..len {
711 regularizers2[offset + i] = regularizer;
712 }
713 offset += group_size;
714 let len = self.interior_scores_1x1.len();
715 let mut effective_group_size = 0;
716 let mut squared_sum = 0.;
717 for i in 0..len {
718 for j in 0..len {
719 let dict_min_basepair = get_dict_min_pair(&(i, j));
720 if (i, j) != dict_min_basepair {
721 continue;
722 }
723 let x = self.interior_scores_1x1[i][j];
724 squared_sum += x * x;
725 effective_group_size += 1;
726 }
727 }
728 let regularizer = get_regularizer(effective_group_size, squared_sum);
729 for i in 0..len {
730 for j in 0..len {
731 let dict_min_basepair = get_dict_min_pair(&(i, j));
732 if (i, j) != dict_min_basepair {
733 continue;
734 }
735 regularizers2[offset] = regularizer;
736 offset += 1;
737 }
738 }
739 let mut squared_sum = 0.;
740 squared_sum += self.multibranch_score_base * self.multibranch_score_base;
741 squared_sum += self.multibranch_score_basepair * self.multibranch_score_basepair;
742 squared_sum += self.multibranch_score_unpair * self.multibranch_score_unpair;
743 let regularizer = get_regularizer(GROUP_SIZE_MULTIBRANCH, squared_sum);
744 regularizers2[offset] = regularizer;
745 offset += 1;
746 regularizers2[offset] = regularizer;
747 offset += 1;
748 regularizers2[offset] = regularizer;
749 offset += 1;
750 let mut squared_sum = 0.;
751 squared_sum += self.external_score_basepair * self.external_score_basepair;
752 squared_sum += self.external_score_unpair * self.external_score_unpair;
753 let regularizer = get_regularizer(GROUP_SIZE_EXTERNAL, squared_sum);
754 regularizers2[offset] = regularizer;
755 offset += 1;
756 regularizers2[offset] = regularizer;
757 offset += 1;
758 let mut squared_sum = 0.;
759 squared_sum += self.match2match_score * self.match2match_score;
760 squared_sum += self.match2insert_score * self.match2insert_score;
761 squared_sum += self.init_match_score * self.init_match_score;
762 let regularizer = get_regularizer(GROUP_SIZE_MATCH_TRANSITION, squared_sum);
763 regularizers2[offset] = regularizer;
764 offset += 1;
765 regularizers2[offset] = regularizer;
766 offset += 1;
767 regularizers2[offset] = regularizer;
768 offset += 1;
769 let mut squared_sum = 0.;
770 squared_sum += self.insert_extend_score * self.insert_extend_score;
771 squared_sum += self.init_insert_score * self.init_insert_score;
772 let regularizer = get_regularizer(GROUP_SIZE_INSERT_TRANSITION, squared_sum);
773 regularizers2[offset] = regularizer;
774 offset += 1;
775 regularizers2[offset] = regularizer;
776 offset += 1;
777 let len = self.insert_scores.len();
778 let group_size = len;
779 let mut squared_sum = 0.;
780 for i in 0..len {
781 let x = self.insert_scores[i];
782 squared_sum += x * x;
783 }
784 let regularizer = get_regularizer(group_size, squared_sum);
785 for i in 0..len {
786 regularizers2[offset + i] = regularizer;
787 }
788 offset += group_size;
789 let len = self.match_scores.len();
790 let mut effective_group_size = 0;
791 let mut squared_sum = 0.;
792 for i in 0..len {
793 for j in 0..len {
794 let dict_min_match = get_dict_min_pair(&(i, j));
795 if (i, j) != dict_min_match {
796 continue;
797 }
798 let x = self.match_scores[i][j];
799 squared_sum += x * x;
800 effective_group_size += 1;
801 }
802 }
803 let regularizer = get_regularizer(effective_group_size, squared_sum);
804 for i in 0..len {
805 for j in 0..len {
806 let dict_min_match = get_dict_min_pair(&(i, j));
807 if (i, j) != dict_min_match {
808 continue;
809 }
810 regularizers2[offset] = regularizer;
811 offset += 1;
812 }
813 }
814 assert!(offset == self.len());
815 *regularizers = Array1::from(regularizers2);
816 }
817
818 pub fn update<T: HashIndex>(
819 &mut self,
820 train_data: &[TrainDatum<T>],
821 regularizers: &mut Regularizers,
822 ) {
823 let f = |_: &BfgsScores| self.get_cost(train_data, regularizers) as BfgsScore;
824 let g = |_: &BfgsScores| scores2bfgs_scores(&self.get_grad(train_data, regularizers));
825 let uses_cumulative_scores = false;
826 match bfgs(
827 scores2bfgs_scores(&struct2vec(self, uses_cumulative_scores)),
828 f,
829 g,
830 ) {
831 Ok(solution) => {
832 *self = vec2struct(&bfgs_scores2scores(&solution), uses_cumulative_scores);
833 }
834 Err(_) => {
835 println!("BFGS failed");
836 }
837 };
838 self.update_regularizers(regularizers);
839 self.mirror();
840 self.accumulate();
841 }
842
843 pub fn mirror(&mut self) {
844 for i in 0..NUM_BASES {
845 for j in 0..NUM_BASES {
846 if !has_canonical_basepair(&(i, j)) {
847 continue;
848 }
849 for k in 0..NUM_BASES {
850 for l in 0..NUM_BASES {
851 if !has_canonical_basepair(&(k, l)) {
852 continue;
853 }
854 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
855 if ((i, j), (k, l)) == dict_min_stack {
856 continue;
857 }
858 self.stack_scores[i][j][k][l] = self.stack_scores[dict_min_stack.0 .0]
859 [dict_min_stack.0 .1][dict_min_stack.1 .0][dict_min_stack.1 .1];
860 }
861 }
862 }
863 }
864 for i in 0..NUM_BASES {
865 for j in 0..NUM_BASES {
866 if !has_canonical_basepair(&(i, j)) {
867 continue;
868 }
869 let dict_min_basepair = get_dict_min_pair(&(i, j));
870 if (i, j) == dict_min_basepair {
871 continue;
872 }
873 self.basepair_scores[i][j] = self.basepair_scores[dict_min_basepair.0][dict_min_basepair.1];
874 }
875 }
876 let len = self.interior_scores_explicit.len();
877 for i in 0..len {
878 for j in 0..len {
879 let dict_min_len_pair = get_dict_min_pair(&(i, j));
880 if (i, j) == dict_min_len_pair {
881 continue;
882 }
883 self.interior_scores_explicit[i][j] =
884 self.interior_scores_explicit[dict_min_len_pair.0][dict_min_len_pair.1];
885 }
886 }
887 for i in 0..NUM_BASES {
888 for j in 0..NUM_BASES {
889 let dict_min_basepair = get_dict_min_pair(&(i, j));
890 if (i, j) == dict_min_basepair {
891 continue;
892 }
893 self.interior_scores_1x1[i][j] =
894 self.interior_scores_1x1[dict_min_basepair.0][dict_min_basepair.1];
895 }
896 }
897 for i in 0..NUM_BASES {
898 for j in 0..NUM_BASES {
899 let dict_min_match = get_dict_min_pair(&(i, j));
900 if (i, j) == dict_min_match {
901 continue;
902 }
903 self.match_scores[i][j] = self.match_scores[dict_min_match.0][dict_min_match.1];
904 }
905 }
906 }
907
908 pub fn accumulate(&mut self) {
909 let mut sum = 0.;
910 for i in 0..self.hairpin_scores_len_cumulative.len() {
911 sum += self.hairpin_scores_len[i];
912 self.hairpin_scores_len_cumulative[i] = sum;
913 }
914 let mut sum = 0.;
915 for i in 0..self.bulge_scores_len_cumulative.len() {
916 sum += self.bulge_scores_len[i];
917 self.bulge_scores_len_cumulative[i] = sum;
918 }
919 let mut sum = 0.;
920 for i in 0..self.interior_scores_len_cumulative.len() {
921 sum += self.interior_scores_len[i];
922 self.interior_scores_len_cumulative[i] = sum;
923 }
924 let mut sum = 0.;
925 for i in 0..self.interior_scores_symmetric_cumulative.len() {
926 sum += self.interior_scores_symmetric[i];
927 self.interior_scores_symmetric_cumulative[i] = sum;
928 }
929 let mut sum = 0.;
930 for i in 0..self.interior_scores_asymmetric_cumulative.len() {
931 sum += self.interior_scores_asymmetric[i];
932 self.interior_scores_asymmetric_cumulative[i] = sum;
933 }
934 }
935
936 pub fn get_grad<T: HashIndex>(
937 &self,
938 train_data: &[TrainDatum<T>],
939 regularizers: &Regularizers,
940 ) -> Scores {
941 let uses_cumulative_scores = false;
942 let alignfold_scores = struct2vec(self, uses_cumulative_scores);
943 let mut grad = AlignfoldScores::new(0.);
944 for train_datum in train_data {
945 let obs = &train_datum.alignfold_counts_observed;
946 let expect = &train_datum.alignfold_counts_expected;
947 let mut sum = 0.;
948 let len = obs.hairpin_scores_len.len();
949 for i in (0..len).rev() {
950 let x = obs.hairpin_scores_len[i];
951 let y = expect.hairpin_scores_len[i];
952 sum -= x - y;
953 grad.hairpin_scores_len[i] += sum;
954 }
955 let len = obs.bulge_scores_len.len();
956 let mut sum = 0.;
957 for i in (0..len).rev() {
958 let x = obs.bulge_scores_len[i];
959 let y = expect.bulge_scores_len[i];
960 sum -= x - y;
961 grad.bulge_scores_len[i] += sum;
962 }
963 let len = obs.interior_scores_len.len();
964 let mut sum = 0.;
965 for i in (0..len).rev() {
966 let x = obs.interior_scores_len[i];
967 let y = expect.interior_scores_len[i];
968 sum -= x - y;
969 grad.interior_scores_len[i] += sum;
970 }
971 let len = obs.interior_scores_symmetric.len();
972 let mut sum = 0.;
973 for i in (0..len).rev() {
974 let x = obs.interior_scores_symmetric[i];
975 let y = expect.interior_scores_symmetric[i];
976 sum -= x - y;
977 grad.interior_scores_symmetric[i] += sum;
978 }
979 let len = obs.interior_scores_asymmetric.len();
980 let mut sum = 0.;
981 for i in (0..len).rev() {
982 let x = obs.interior_scores_asymmetric[i];
983 let y = expect.interior_scores_asymmetric[i];
984 sum -= x - y;
985 grad.interior_scores_asymmetric[i] += sum;
986 }
987 for i in 0..NUM_BASES {
988 for j in 0..NUM_BASES {
989 if !has_canonical_basepair(&(i, j)) {
990 continue;
991 }
992 for k in 0..NUM_BASES {
993 for l in 0..NUM_BASES {
994 if !has_canonical_basepair(&(k, l)) {
995 continue;
996 }
997 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
998 if ((i, j), (k, l)) != dict_min_stack {
999 continue;
1000 }
1001 let x = obs.stack_scores[i][j][k][l];
1002 let y = expect.stack_scores[i][j][k][l];
1003 grad.stack_scores[i][j][k][l] -= x - y;
1004 }
1005 }
1006 }
1007 }
1008 for i in 0..NUM_BASES {
1009 for j in 0..NUM_BASES {
1010 if !has_canonical_basepair(&(i, j)) {
1011 continue;
1012 }
1013 for k in 0..NUM_BASES {
1014 for l in 0..NUM_BASES {
1015 let x = obs.terminal_mismatch_scores[i][j][k][l];
1016 let y = expect.terminal_mismatch_scores[i][j][k][l];
1017 grad.terminal_mismatch_scores[i][j][k][l] -= x - y;
1018 }
1019 }
1020 }
1021 }
1022 for i in 0..NUM_BASES {
1023 for j in 0..NUM_BASES {
1024 if !has_canonical_basepair(&(i, j)) {
1025 continue;
1026 }
1027 for k in 0..NUM_BASES {
1028 let x = obs.dangling_scores_left[i][j][k];
1029 let y = expect.dangling_scores_left[i][j][k];
1030 grad.dangling_scores_left[i][j][k] -= x - y;
1031 }
1032 }
1033 }
1034 for i in 0..NUM_BASES {
1035 for j in 0..NUM_BASES {
1036 if !has_canonical_basepair(&(i, j)) {
1037 continue;
1038 }
1039 for k in 0..NUM_BASES {
1040 let x = obs.dangling_scores_right[i][j][k];
1041 let y = expect.dangling_scores_right[i][j][k];
1042 grad.dangling_scores_right[i][j][k] -= x - y;
1043 }
1044 }
1045 }
1046 for i in 0..NUM_BASES {
1047 for j in 0..NUM_BASES {
1048 if !has_canonical_basepair(&(i, j)) {
1049 continue;
1050 }
1051 let x = obs.helix_close_scores[i][j];
1052 let y = expect.helix_close_scores[i][j];
1053 grad.helix_close_scores[i][j] -= x - y;
1054 }
1055 }
1056 for i in 0..NUM_BASES {
1057 for j in 0..NUM_BASES {
1058 if !has_canonical_basepair(&(i, j)) {
1059 continue;
1060 }
1061 let dict_min_basepair = get_dict_min_pair(&(i, j));
1062 if (i, j) != dict_min_basepair {
1063 continue;
1064 }
1065 let x = obs.basepair_scores[i][j];
1066 let y = expect.basepair_scores[i][j];
1067 grad.basepair_scores[i][j] -= x - y;
1068 }
1069 }
1070 let len = obs.interior_scores_explicit.len();
1071 for i in 0..len {
1072 for j in 0..len {
1073 let dict_min_len_pair = get_dict_min_pair(&(i, j));
1074 if (i, j) != dict_min_len_pair {
1075 continue;
1076 }
1077 let x = obs.interior_scores_explicit[i][j];
1078 let y = expect.interior_scores_explicit[i][j];
1079 grad.interior_scores_explicit[i][j] -= x - y;
1080 }
1081 }
1082 for i in 0..NUM_BASES {
1083 let x = obs.bulge_scores_0x1[i];
1084 let y = expect.bulge_scores_0x1[i];
1085 grad.bulge_scores_0x1[i] -= x - y;
1086 }
1087 for i in 0..NUM_BASES {
1088 for j in 0..NUM_BASES {
1089 let dict_min_basepair = get_dict_min_pair(&(i, j));
1090 if (i, j) != dict_min_basepair {
1091 continue;
1092 }
1093 let x = obs.interior_scores_1x1[i][j];
1094 let y = expect.interior_scores_1x1[i][j];
1095 grad.interior_scores_1x1[i][j] -= x - y;
1096 }
1097 }
1098 let obs_score = obs.multibranch_score_base;
1099 let expect_score = expect.multibranch_score_base;
1100 grad.multibranch_score_base -= obs_score - expect_score;
1101 let obs_score = obs.multibranch_score_basepair;
1102 let expect_score = expect.multibranch_score_basepair;
1103 grad.multibranch_score_basepair -= obs_score - expect_score;
1104 let obs_score = obs.multibranch_score_unpair;
1105 let expect_score = expect.multibranch_score_unpair;
1106 grad.multibranch_score_unpair -= obs_score - expect_score;
1107 let obs_score = obs.external_score_basepair;
1108 let expect_score = expect.external_score_basepair;
1109 grad.external_score_basepair -= obs_score - expect_score;
1110 let obs_score = obs.external_score_unpair;
1111 let expect_score = expect.external_score_unpair;
1112 grad.external_score_unpair -= obs_score - expect_score;
1113 let obs_score = obs.match2match_score;
1114 let expect_score = expect.match2match_score;
1115 grad.match2match_score -= obs_score - expect_score;
1116 let obs_score = obs.match2insert_score;
1117 let expect_score = expect.match2insert_score;
1118 grad.match2insert_score -= obs_score - expect_score;
1119 let obs_score = obs.insert_extend_score;
1120 let expect_score = expect.insert_extend_score;
1121 grad.insert_extend_score -= obs_score - expect_score;
1122 let obs_score = obs.init_match_score;
1123 let expect_score = expect.init_match_score;
1124 grad.init_match_score -= obs_score - expect_score;
1125 let obs_score = obs.init_insert_score;
1126 let expect_score = expect.init_insert_score;
1127 grad.init_insert_score -= obs_score - expect_score;
1128 for i in 0..NUM_BASES {
1129 let x = obs.insert_scores[i];
1130 let y = expect.insert_scores[i];
1131 grad.insert_scores[i] -= x - y;
1132 }
1133 for i in 0..NUM_BASES {
1134 for j in 0..NUM_BASES {
1135 let dict_min_match = get_dict_min_pair(&(i, j));
1136 if (i, j) != dict_min_match {
1137 continue;
1138 }
1139 let x = obs.match_scores[i][j];
1140 let y = expect.match_scores[i][j];
1141 grad.match_scores[i][j] -= x - y;
1142 }
1143 }
1144 }
1145 struct2vec(&grad, uses_cumulative_scores) + regularizers * &alignfold_scores
1146 }
1147
1148 pub fn get_cost<T: HashIndex>(
1149 &self,
1150 train_data: &[TrainDatum<T>],
1151 regularizers: &Regularizers,
1152 ) -> Score {
1153 let uses_cumulative_scores = true;
1154 let mut log_likelihood = 0.;
1155 let alignfold_scores_cumulative = struct2vec(self, uses_cumulative_scores);
1156 let uses_cumulative_scores = false;
1157 for train_datum in train_data {
1158 let obs = &train_datum.alignfold_counts_observed;
1159 log_likelihood += alignfold_scores_cumulative.dot(&struct2vec(obs, uses_cumulative_scores));
1160 log_likelihood -= train_datum.global_sum;
1161 }
1162 let alignfold_scores = struct2vec(self, uses_cumulative_scores);
1163 let product = regularizers * &alignfold_scores;
1164 -log_likelihood + product.dot(&alignfold_scores) / 2.
1165 }
1166
1167 pub fn rand_init(&mut self) {
1168 let len = self.len();
1169 let std_deviation = 1. / (len as Score).sqrt();
1170 let normal = Normal::new(0., std_deviation).unwrap();
1171 let mut thread_rng = thread_rng();
1172 for x in self.hairpin_scores_len.iter_mut() {
1173 *x = normal.sample(&mut thread_rng);
1174 }
1175 for x in self.bulge_scores_len.iter_mut() {
1176 *x = normal.sample(&mut thread_rng);
1177 }
1178 for x in self.interior_scores_len.iter_mut() {
1179 *x = normal.sample(&mut thread_rng);
1180 }
1181 for x in self.interior_scores_symmetric.iter_mut() {
1182 *x = normal.sample(&mut thread_rng);
1183 }
1184 for x in self.interior_scores_asymmetric.iter_mut() {
1185 *x = normal.sample(&mut thread_rng);
1186 }
1187 let len = self.stack_scores.len();
1188 for i in 0..len {
1189 for j in 0..len {
1190 if !has_canonical_basepair(&(i, j)) {
1191 continue;
1192 }
1193 for k in 0..len {
1194 for l in 0..len {
1195 if !has_canonical_basepair(&(k, l)) {
1196 continue;
1197 }
1198 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
1199 if ((i, j), (k, l)) != dict_min_stack {
1200 continue;
1201 }
1202 let x = normal.sample(&mut thread_rng);
1203 self.stack_scores[i][j][k][l] = x;
1204 }
1205 }
1206 }
1207 }
1208 let len = self.terminal_mismatch_scores.len();
1209 for i in 0..len {
1210 for j in 0..len {
1211 if !has_canonical_basepair(&(i, j)) {
1212 continue;
1213 }
1214 for k in 0..len {
1215 for l in 0..len {
1216 let x = normal.sample(&mut thread_rng);
1217 self.terminal_mismatch_scores[i][j][k][l] = x;
1218 }
1219 }
1220 }
1221 }
1222 let len = self.dangling_scores_left.len();
1223 for i in 0..len {
1224 for j in 0..len {
1225 if !has_canonical_basepair(&(i, j)) {
1226 continue;
1227 }
1228 for k in 0..len {
1229 let x = normal.sample(&mut thread_rng);
1230 self.dangling_scores_left[i][j][k] = x;
1231 }
1232 }
1233 }
1234 let len = self.dangling_scores_right.len();
1235 for i in 0..len {
1236 for j in 0..len {
1237 if !has_canonical_basepair(&(i, j)) {
1238 continue;
1239 }
1240 for k in 0..len {
1241 let x = normal.sample(&mut thread_rng);
1242 self.dangling_scores_right[i][j][k] = x;
1243 }
1244 }
1245 }
1246 let len = self.helix_close_scores.len();
1247 for i in 0..len {
1248 for j in 0..len {
1249 if !has_canonical_basepair(&(i, j)) {
1250 continue;
1251 }
1252 let x = normal.sample(&mut thread_rng);
1253 self.helix_close_scores[i][j] = x;
1254 }
1255 }
1256 let len = self.basepair_scores.len();
1257 for i in 0..len {
1258 for j in 0..len {
1259 if !has_canonical_basepair(&(i, j)) {
1260 continue;
1261 }
1262 let dict_min_basepair = get_dict_min_pair(&(i, j));
1263 if (i, j) != dict_min_basepair {
1264 continue;
1265 }
1266 let x = normal.sample(&mut thread_rng);
1267 self.basepair_scores[i][j] = x;
1268 }
1269 }
1270 let len = self.interior_scores_explicit.len();
1271 for i in 0..len {
1272 for j in 0..len {
1273 let dict_min_len_pair = get_dict_min_pair(&(i, j));
1274 if (i, j) != dict_min_len_pair {
1275 continue;
1276 }
1277 let x = normal.sample(&mut thread_rng);
1278 self.interior_scores_explicit[i][j] = x;
1279 }
1280 }
1281 for x in &mut self.bulge_scores_0x1 {
1282 *x = normal.sample(&mut thread_rng);
1283 }
1284 let len = self.interior_scores_1x1.len();
1285 for i in 0..len {
1286 for j in 0..len {
1287 let dict_min_basepair = get_dict_min_pair(&(i, j));
1288 if (i, j) != dict_min_basepair {
1289 continue;
1290 }
1291 let x = normal.sample(&mut thread_rng);
1292 self.interior_scores_1x1[i][j] = x;
1293 }
1294 }
1295 self.multibranch_score_base = normal.sample(&mut thread_rng);
1296 self.multibranch_score_basepair = normal.sample(&mut thread_rng);
1297 self.multibranch_score_unpair = normal.sample(&mut thread_rng);
1298 self.external_score_basepair = normal.sample(&mut thread_rng);
1299 self.external_score_unpair = normal.sample(&mut thread_rng);
1300 self.match2match_score = normal.sample(&mut thread_rng);
1301 self.match2insert_score = normal.sample(&mut thread_rng);
1302 self.insert_extend_score = normal.sample(&mut thread_rng);
1303 self.init_match_score = normal.sample(&mut thread_rng);
1304 self.init_insert_score = normal.sample(&mut thread_rng);
1305 let len = self.insert_scores.len();
1306 for i in 0..len {
1307 let x = normal.sample(&mut thread_rng);
1308 self.insert_scores[i] = x;
1309 }
1310 let len = self.match_scores.len();
1311 for i in 0..len {
1312 for j in 0..len {
1313 let dict_min_match = get_dict_min_pair(&(i, j));
1314 if (i, j) != dict_min_match {
1315 continue;
1316 }
1317 let x = normal.sample(&mut thread_rng);
1318 self.match_scores[i][j] = x;
1319 }
1320 }
1321 self.mirror();
1322 self.accumulate();
1323 }
1324
1325 pub fn transfer(&mut self) {
1326 for (x, &y) in self
1327 .hairpin_scores_len
1328 .iter_mut()
1329 .zip(HAIRPIN_SCORES_LEN_ATLEAST.iter())
1330 {
1331 *x = y;
1332 }
1333 for (x, &y) in self
1334 .bulge_scores_len
1335 .iter_mut()
1336 .zip(BULGE_SCORES_LEN_ATLEAST.iter())
1337 {
1338 *x = y;
1339 }
1340 for (x, &y) in self
1341 .interior_scores_len
1342 .iter_mut()
1343 .zip(INTERIOR_SCORES_LEN_ATLEAST.iter())
1344 {
1345 *x = y;
1346 }
1347 for (x, &y) in self
1348 .interior_scores_symmetric
1349 .iter_mut()
1350 .zip(INTERIOR_SCORES_SYMMETRIC_ATLEAST.iter())
1351 {
1352 *x = y;
1353 }
1354 for (x, &y) in self
1355 .interior_scores_asymmetric
1356 .iter_mut()
1357 .zip(INTERIOR_SCORES_ASYMMETRIC_ATLEAST.iter())
1358 {
1359 *x = y;
1360 }
1361 for (i, x) in STACK_SCORES_CONTRA.iter().enumerate() {
1362 for (j, x) in x.iter().enumerate() {
1363 if !has_canonical_basepair(&(i, j)) {
1364 continue;
1365 }
1366 for (k, x) in x.iter().enumerate() {
1367 for (l, &x) in x.iter().enumerate() {
1368 if !has_canonical_basepair(&(k, l)) {
1369 continue;
1370 }
1371 self.stack_scores[i][j][k][l] = x;
1372 }
1373 }
1374 }
1375 }
1376 for (i, x) in TERMINAL_MISMATCH_SCORES_CONTRA.iter().enumerate() {
1377 for (j, x) in x.iter().enumerate() {
1378 if !has_canonical_basepair(&(i, j)) {
1379 continue;
1380 }
1381 for (k, x) in x.iter().enumerate() {
1382 for (l, &x) in x.iter().enumerate() {
1383 self.terminal_mismatch_scores[i][j][k][l] = x;
1384 }
1385 }
1386 }
1387 }
1388 for (i, x) in DANGLING_SCORES_LEFT.iter().enumerate() {
1389 for (j, x) in x.iter().enumerate() {
1390 if !has_canonical_basepair(&(i, j)) {
1391 continue;
1392 }
1393 for (k, &x) in x.iter().enumerate() {
1394 self.dangling_scores_left[i][j][k] = x;
1395 }
1396 }
1397 }
1398 for (i, x) in DANGLING_SCORES_RIGHT.iter().enumerate() {
1399 for (j, x) in x.iter().enumerate() {
1400 if !has_canonical_basepair(&(i, j)) {
1401 continue;
1402 }
1403 for (k, &x) in x.iter().enumerate() {
1404 self.dangling_scores_right[i][j][k] = x;
1405 }
1406 }
1407 }
1408 for (i, x) in HELIX_CLOSE_SCORES.iter().enumerate() {
1409 for (j, &x) in x.iter().enumerate() {
1410 if !has_canonical_basepair(&(i, j)) {
1411 continue;
1412 }
1413 self.helix_close_scores[i][j] = x;
1414 }
1415 }
1416 for (i, x) in BASEPAIR_SCORES.iter().enumerate() {
1417 for (j, &x) in x.iter().enumerate() {
1418 if !has_canonical_basepair(&(i, j)) {
1419 continue;
1420 }
1421 self.basepair_scores[i][j] = x;
1422 }
1423 }
1424 for (i, x) in INTERIOR_SCORES_EXPLICIT.iter().enumerate() {
1425 for (j, &x) in x.iter().enumerate() {
1426 self.interior_scores_explicit[i][j] = x;
1427 }
1428 }
1429 for (x, &y) in self
1430 .bulge_scores_0x1
1431 .iter_mut()
1432 .zip(BULGE_SCORES_0X1.iter())
1433 {
1434 *x = y;
1435 }
1436 for (i, x) in INTERIOR_SCORES_1X1_CONTRA.iter().enumerate() {
1437 for (j, &x) in x.iter().enumerate() {
1438 self.interior_scores_1x1[i][j] = x;
1439 }
1440 }
1441 self.multibranch_score_base = MULTIBRANCH_SCORE_BASE;
1442 self.multibranch_score_basepair = MULTIBRANCH_SCORE_BASEPAIR;
1443 self.multibranch_score_unpair = MULTIBRANCH_SCORE_UNPAIR;
1444 self.external_score_basepair = EXTERNAL_SCORE_BASEPAIR;
1445 self.external_score_unpair = EXTERNAL_SCORE_UNPAIR;
1446 self.match2match_score = MATCH2MATCH_SCORE;
1447 self.match2insert_score = MATCH2INSERT_SCORE;
1448 self.insert_extend_score = INSERT_EXTEND_SCORE;
1449 self.init_match_score = INIT_MATCH_SCORE;
1450 self.init_insert_score = INIT_INSERT_SCORE;
1451 for (i, &x) in INSERT_SCORES.iter().enumerate() {
1452 self.insert_scores[i] = x;
1453 }
1454 for (i, x) in MATCH_SCORES.iter().enumerate() {
1455 for (j, &x) in x.iter().enumerate() {
1456 self.match_scores[i][j] = x;
1457 }
1458 }
1459 self.accumulate();
1460 }
1461}
1462
1463impl<T: HashIndex> TrainDatum<T> {
1464 pub fn origin() -> TrainDatum<T> {
1465 TrainDatum {
1466 seq_pair: (Seq::new(), Seq::new()),
1467 seq_pair_gapped: (Seq::new(), Seq::new()),
1468 alignfold_counts_observed: AlignfoldScores::new(0.),
1469 alignfold_counts_expected: AlignfoldScores::new(NEG_INFINITY),
1470 basepair_probs_pair: (SparseProbMat::<T>::default(), SparseProbMat::<T>::default()),
1471 max_basepair_span_pair: (T::zero(), T::zero()),
1472 global_sum: NEG_INFINITY,
1473 forward_pos_pairs: PosPairMatSet::<T>::default(),
1474 backward_pos_pairs: PosPairMatSet::<T>::default(),
1475 pos_quads_hashed_lens: PosQuadsHashedLens::<T>::default(),
1476 matchable_poss: SparsePosSets::<T>::default(),
1477 matchable_poss2: SparsePosSets::<T>::default(),
1478 fold_scores_pair: (FoldScoresTrained::<T>::new(), FoldScoresTrained::<T>::new()),
1479 match_probs: SparseProbMat::<T>::default(),
1480 alignfold: PairAlignfold::<T>::new(),
1481 accuracy: NEG_INFINITY,
1482 }
1483 }
1484
1485 pub fn new(
1486 input_file_path: &Path,
1487 min_basepair_prob: Prob,
1488 min_match_prob: Prob,
1489 align_scores: &AlignScores,
1490 ) -> TrainDatum<T> {
1491 let fasta_file_reader = Reader::from_file(Path::new(input_file_path)).unwrap();
1492 let fasta_records: Vec<Record> = fasta_file_reader
1493 .records()
1494 .map(|rec| rec.unwrap())
1495 .collect();
1496 let consensus_fold = fasta_records[2].seq();
1497 let seq_pair = (
1498 bytes2seq_gapped(fasta_records[0].seq()),
1499 bytes2seq_gapped(fasta_records[1].seq()),
1500 );
1501 let mut seq_pair_ungapped = (gapped2ungapped(&seq_pair.0), gapped2ungapped(&seq_pair.1));
1502 let uses_contra_model = false;
1503 let allows_short_hairpins = true;
1504 let basepair_probs_pair = (
1505 filter_basepair_probs::<T>(
1506 &mccaskill_algo(
1507 &seq_pair_ungapped.0[..],
1508 uses_contra_model,
1509 allows_short_hairpins,
1510 &FoldScoreSets::new(0.),
1511 )
1512 .0,
1513 min_basepair_prob,
1514 ),
1515 filter_basepair_probs::<T>(
1516 &mccaskill_algo(
1517 &seq_pair_ungapped.1[..],
1518 uses_contra_model,
1519 allows_short_hairpins,
1520 &FoldScoreSets::new(0.),
1521 )
1522 .0,
1523 min_basepair_prob,
1524 ),
1525 );
1526 seq_pair_ungapped.0.insert(0, PSEUDO_BASE);
1527 seq_pair_ungapped.0.push(PSEUDO_BASE);
1528 seq_pair_ungapped.1.insert(0, PSEUDO_BASE);
1529 seq_pair_ungapped.1.push(PSEUDO_BASE);
1530 let seq_len_pair = (
1531 T::from_usize(seq_pair_ungapped.0.len()).unwrap(),
1532 T::from_usize(seq_pair_ungapped.1.len()).unwrap(),
1533 );
1534 let match_probs = filter_match_probs(
1535 &durbin_algo(
1536 &(&seq_pair_ungapped.0[..], &seq_pair_ungapped.1[..]),
1537 align_scores,
1538 ),
1539 min_match_prob,
1540 );
1541 let (
1542 forward_pos_pairs,
1543 backward_pos_pairs,
1544 _,
1545 pos_quads_hashed_lens,
1546 matchable_poss,
1547 matchable_poss2,
1548 ) = get_sparse_poss(
1549 &(&basepair_probs_pair.0, &basepair_probs_pair.1),
1550 &match_probs,
1551 &seq_len_pair,
1552 );
1553 let max_basepair_span_pair = (
1554 get_max_basepair_span::<T>(&basepair_probs_pair.0),
1555 get_max_basepair_span::<T>(&basepair_probs_pair.1),
1556 );
1557 let mut train_datum = TrainDatum {
1558 seq_pair: seq_pair_ungapped,
1559 seq_pair_gapped: seq_pair,
1560 alignfold_counts_observed: AlignfoldScores::new(0.),
1561 alignfold_counts_expected: AlignfoldScores::new(NEG_INFINITY),
1562 basepair_probs_pair,
1563 max_basepair_span_pair,
1564 global_sum: NEG_INFINITY,
1565 forward_pos_pairs,
1566 backward_pos_pairs,
1567 pos_quads_hashed_lens,
1568 matchable_poss,
1569 matchable_poss2,
1570 fold_scores_pair: (FoldScoresTrained::<T>::new(), FoldScoresTrained::<T>::new()),
1571 match_probs,
1572 alignfold: PairAlignfold::<T>::new(),
1573 accuracy: NEG_INFINITY,
1574 };
1575 train_datum.obs2counts(consensus_fold);
1576 train_datum
1577 }
1578
1579 pub fn obs2counts(&mut self, consensus_fold: TextSlice) {
1580 let align_len = consensus_fold.len();
1581 let mut inserted = false;
1582 let mut inserted2 = inserted;
1583 let seq_pair = &self.seq_pair_gapped;
1584 let mut pos_pair = (T::one(), T::one());
1585 for (i, ¬ation) in consensus_fold.iter().enumerate() {
1586 let basepair = (seq_pair.0[i], seq_pair.1[i]);
1587 if notation != UNPAIR {
1588 let dict_min_match = get_dict_min_pair(&basepair);
1589 self.alignfold_counts_observed.match_scores[dict_min_match.0][dict_min_match.1] += 1.;
1590 if i == 0 {
1591 self.alignfold_counts_observed.init_match_score += 1.;
1592 } else if inserted || inserted2 {
1593 self.alignfold_counts_observed.match2insert_score += 1.;
1594 } else {
1595 self.alignfold_counts_observed.match2match_score += 1.;
1596 }
1597 inserted = false;
1598 inserted2 = inserted;
1599 if basepair.1 == PSEUDO_BASE {
1600 self.alignfold.inserted_poss.insert(pos_pair.0);
1601 pos_pair.0 = pos_pair.0 + T::one();
1602 } else if basepair.0 == PSEUDO_BASE {
1603 self.alignfold.deleted_poss.insert(pos_pair.1);
1604 pos_pair.1 = pos_pair.1 + T::one();
1605 } else {
1606 self.alignfold.matched_pos_pairs.insert(pos_pair);
1607 pos_pair.0 = pos_pair.0 + T::one();
1608 pos_pair.1 = pos_pair.1 + T::one();
1609 }
1610 continue;
1611 }
1612 if basepair.1 == PSEUDO_BASE {
1613 if i == 0 {
1614 self.alignfold_counts_observed.init_insert_score += 1.;
1615 } else if inserted {
1616 self.alignfold_counts_observed.insert_extend_score += 1.;
1617 } else if inserted2 {
1618 inserted = true;
1619 inserted2 = false;
1620 } else {
1621 self.alignfold_counts_observed.match2insert_score += 1.;
1622 inserted = true;
1623 }
1624 self.alignfold_counts_observed.insert_scores[basepair.0] += 1.;
1625 self.alignfold.inserted_poss.insert(pos_pair.0);
1626 pos_pair.0 = pos_pair.0 + T::one();
1627 } else if basepair.0 == PSEUDO_BASE {
1628 if i == 0 {
1629 self.alignfold_counts_observed.init_insert_score += 1.;
1630 inserted2 = true;
1631 } else if inserted2 {
1632 self.alignfold_counts_observed.insert_extend_score += 1.;
1633 } else if inserted {
1634 inserted2 = true;
1635 inserted = false;
1636 } else {
1637 self.alignfold_counts_observed.match2insert_score += 1.;
1638 inserted2 = true;
1639 }
1640 self.alignfold_counts_observed.insert_scores[basepair.1] += 1.;
1641 self.alignfold.deleted_poss.insert(pos_pair.1);
1642 pos_pair.1 = pos_pair.1 + T::one();
1643 } else {
1644 let dict_min_match = get_dict_min_pair(&basepair);
1645 self.alignfold_counts_observed.match_scores[dict_min_match.0][dict_min_match.1] += 1.;
1646 if i == 0 {
1647 self.alignfold_counts_observed.init_match_score += 1.;
1648 } else if inserted || inserted2 {
1649 self.alignfold_counts_observed.match2insert_score += 1.;
1650 } else {
1651 self.alignfold_counts_observed.match2match_score += 1.;
1652 }
1653 inserted = false;
1654 inserted2 = inserted;
1655 self.alignfold.matched_pos_pairs.insert(pos_pair);
1656 pos_pair.0 = pos_pair.0 + T::one();
1657 pos_pair.1 = pos_pair.1 + T::one();
1658 }
1659 }
1660 let mut stack = Vec::new();
1661 let mut consensus_basepairs = HashSet::<(usize, usize)>::default();
1662 for (i, &x) in consensus_fold.iter().enumerate() {
1663 if x == BASEPAIR_LEFT {
1664 stack.push(i);
1665 } else if x == BASEPAIR_RIGHT {
1666 let x = stack.pop().unwrap();
1667 consensus_basepairs.insert((x, i));
1668 let y = (seq_pair.0[x], seq_pair.0[i]);
1669 if has_canonical_basepair(&y) {
1670 let y = get_dict_min_pair(&y);
1671 self.alignfold_counts_observed.basepair_scores[y.0][y.1] += 1.;
1672 }
1673 let y = (seq_pair.1[x], seq_pair.1[i]);
1674 if has_canonical_basepair(&y) {
1675 let y = get_dict_min_pair(&y);
1676 self.alignfold_counts_observed.basepair_scores[y.0][y.1] += 1.;
1677 }
1678 }
1679 }
1680 let mut loop_struct = LoopStruct::default();
1681 let mut stored_basepairs = HashSet::<(usize, usize)>::default();
1682 for substr_len in 2..align_len + 1 {
1683 for i in 0..align_len - substr_len + 1 {
1684 let mut found_basepair = false;
1685 let j = i + substr_len - 1;
1686 if consensus_basepairs.contains(&(i, j)) {
1687 found_basepair = true;
1688 consensus_basepairs.remove(&(i, j));
1689 stored_basepairs.insert((i, j));
1690 }
1691 if found_basepair {
1692 let mut loop_basepairs = Vec::new();
1693 for x in stored_basepairs.iter() {
1694 if i < x.0 && x.1 < j {
1695 loop_basepairs.push(*x);
1696 }
1697 }
1698 for x in loop_basepairs.iter() {
1699 stored_basepairs.remove(x);
1700 }
1701 loop_basepairs.sort();
1702 loop_struct.insert((i, j), loop_basepairs);
1703 }
1704 }
1705 }
1706 for (basepair_close, basepairs_loop) in loop_struct.iter() {
1707 let num_basepairs_loop = basepairs_loop.len();
1708 let basepair = (seq_pair.0[basepair_close.0], seq_pair.0[basepair_close.1]);
1709 let basepair2 = (seq_pair.1[basepair_close.0], seq_pair.1[basepair_close.1]);
1710 let closes = true;
1711 let mismatch_pair = get_mismatch_pair(&seq_pair.0[..], basepair_close, closes);
1712 let mismatch_pair2 = get_mismatch_pair(&seq_pair.1[..], basepair_close, closes);
1713 if num_basepairs_loop == 0 {
1714 let hairpin_len_pair = (
1715 get_hairpin_len(&seq_pair.0[..], basepair_close),
1716 get_hairpin_len(&seq_pair.1[..], basepair_close),
1717 );
1718 if has_canonical_basepair(&basepair) {
1719 self.alignfold_counts_observed.terminal_mismatch_scores[basepair.0][basepair.1]
1720 [mismatch_pair.0][mismatch_pair.1] += 1.;
1721 self.alignfold_counts_observed.helix_close_scores[basepair.0][basepair.1] += 1.;
1722 }
1723 if has_canonical_basepair(&basepair2) {
1724 self.alignfold_counts_observed.terminal_mismatch_scores[basepair2.0][basepair2.1]
1725 [mismatch_pair2.0][mismatch_pair2.1] += 1.;
1726 self.alignfold_counts_observed.helix_close_scores[basepair2.0][basepair2.1] += 1.;
1727 }
1728 if hairpin_len_pair.0 <= MAX_LOOP_LEN {
1729 self.alignfold_counts_observed.hairpin_scores_len[hairpin_len_pair.0] += 1.;
1730 } else {
1731 self.alignfold_counts_observed.hairpin_scores_len[MAX_LOOP_LEN] += 1.;
1732 }
1733 if hairpin_len_pair.1 <= MAX_LOOP_LEN {
1734 self.alignfold_counts_observed.hairpin_scores_len[hairpin_len_pair.1] += 1.;
1735 } else {
1736 self.alignfold_counts_observed.hairpin_scores_len[MAX_LOOP_LEN] += 1.;
1737 }
1738 } else if num_basepairs_loop == 1 {
1739 let basepair_loop = &basepairs_loop[0];
1740 let basepair3 = (seq_pair.0[basepair_loop.0], seq_pair.0[basepair_loop.1]);
1741 let basepair4 = (seq_pair.1[basepair_loop.0], seq_pair.1[basepair_loop.1]);
1742 let twoloop_len_pair = get_2loop_len_pair(&seq_pair.0[..], basepair_close, basepair_loop);
1743 let sum = twoloop_len_pair.0 + twoloop_len_pair.1;
1744 let twoloop_len_pair2 = get_2loop_len_pair(&seq_pair.1[..], basepair_close, basepair_loop);
1745 let sum2 = twoloop_len_pair2.0 + twoloop_len_pair2.1;
1746 let closes = false;
1747 let mismatch_pair3 = get_mismatch_pair(&seq_pair.0[..], basepair_loop, closes);
1748 let mismatch_pair4 = get_mismatch_pair(&seq_pair.1[..], basepair_loop, closes);
1749 if sum == 0 {
1750 if has_canonical_basepair(&basepair) && has_canonical_basepair(&basepair3) {
1751 let dict_min_stack = get_dict_min_stack(&basepair, &basepair3);
1752 self.alignfold_counts_observed.stack_scores[dict_min_stack.0 .0]
1753 [dict_min_stack.0 .1][dict_min_stack.1 .0][dict_min_stack.1 .1] += 1.;
1754 }
1755 } else {
1756 if twoloop_len_pair.0 == 0 || twoloop_len_pair.1 == 0 {
1757 if sum <= MAX_LOOP_LEN {
1758 self.alignfold_counts_observed.bulge_scores_len[sum - 1] += 1.;
1759 if sum == 1 {
1760 let mismatch = if twoloop_len_pair.0 == 0 {
1761 mismatch_pair.1
1762 } else {
1763 mismatch_pair.0
1764 };
1765 self.alignfold_counts_observed.bulge_scores_0x1[mismatch] += 1.;
1766 }
1767 } else {
1768 self.alignfold_counts_observed.bulge_scores_len[MAX_LOOP_LEN - 1] += 1.;
1769 }
1770 } else {
1771 let diff = get_diff(twoloop_len_pair.0, twoloop_len_pair.1);
1772 if sum <= MAX_LOOP_LEN {
1773 self.alignfold_counts_observed.interior_scores_len[sum - 2] += 1.;
1774 if diff == 0 {
1775 self.alignfold_counts_observed.interior_scores_symmetric[twoloop_len_pair.0 - 1] +=
1776 1.;
1777 } else {
1778 self.alignfold_counts_observed.interior_scores_asymmetric[diff - 1] += 1.;
1779 }
1780 if twoloop_len_pair.0 == 1 && twoloop_len_pair.1 == 1 {
1781 let dict_min_mismatch_pair = get_dict_min_pair(&mismatch_pair);
1782 self.alignfold_counts_observed.interior_scores_1x1[dict_min_mismatch_pair.0]
1783 [dict_min_mismatch_pair.1] += 1.;
1784 }
1785 if twoloop_len_pair.0 <= MAX_INTERIOR_EXPLICIT
1786 && twoloop_len_pair.1 <= MAX_INTERIOR_EXPLICIT
1787 {
1788 let dict_min_len_pair = get_dict_min_pair(&twoloop_len_pair);
1789 self.alignfold_counts_observed.interior_scores_explicit[dict_min_len_pair.0 - 1]
1790 [dict_min_len_pair.1 - 1] += 1.;
1791 }
1792 } else {
1793 self.alignfold_counts_observed.interior_scores_len[MAX_LOOP_LEN - 2] += 1.;
1794 if diff == 0 {
1795 if twoloop_len_pair.0 <= MAX_INTERIOR_SYMMETRIC {
1796 self.alignfold_counts_observed.interior_scores_symmetric
1797 [twoloop_len_pair.0 - 1] += 1.;
1798 } else {
1799 self.alignfold_counts_observed.interior_scores_symmetric
1800 [MAX_INTERIOR_SYMMETRIC - 1] += 1.;
1801 }
1802 } else if diff <= MAX_INTERIOR_ASYMMETRIC {
1803 self.alignfold_counts_observed.interior_scores_asymmetric[diff - 1] += 1.;
1804 } else {
1805 self.alignfold_counts_observed.interior_scores_asymmetric
1806 [MAX_INTERIOR_ASYMMETRIC - 1] += 1.;
1807 }
1808 }
1809 }
1810 if has_canonical_basepair(&basepair) {
1811 self.alignfold_counts_observed.terminal_mismatch_scores[basepair.0][basepair.1]
1812 [mismatch_pair.0][mismatch_pair.1] += 1.;
1813 self.alignfold_counts_observed.helix_close_scores[basepair.0][basepair.1] += 1.;
1814 }
1815 if has_canonical_basepair(&basepair3) {
1816 self.alignfold_counts_observed.terminal_mismatch_scores[basepair3.1][basepair3.0]
1817 [mismatch_pair3.1][mismatch_pair3.0] += 1.;
1818 self.alignfold_counts_observed.helix_close_scores[basepair3.1][basepair3.0] += 1.;
1819 }
1820 }
1821 if sum2 == 0 {
1822 if has_canonical_basepair(&basepair2) && has_canonical_basepair(&basepair4) {
1823 let dict_min_stack2 = get_dict_min_stack(&basepair2, &basepair4);
1824 self.alignfold_counts_observed.stack_scores[dict_min_stack2.0 .0]
1825 [dict_min_stack2.0 .1][dict_min_stack2.1 .0][dict_min_stack2.1 .1] += 1.;
1826 }
1827 } else {
1828 if twoloop_len_pair2.0 == 0 || twoloop_len_pair2.1 == 0 {
1829 if sum2 <= MAX_LOOP_LEN {
1830 self.alignfold_counts_observed.bulge_scores_len[sum2 - 1] += 1.;
1831 if sum2 == 1 {
1832 let mismatch2 = if twoloop_len_pair2.0 == 0 {
1833 mismatch_pair2.1
1834 } else {
1835 mismatch_pair2.0
1836 };
1837 self.alignfold_counts_observed.bulge_scores_0x1[mismatch2] += 1.;
1838 }
1839 } else {
1840 self.alignfold_counts_observed.bulge_scores_len[MAX_LOOP_LEN - 1] += 1.;
1841 }
1842 } else {
1843 let diff2 = get_diff(twoloop_len_pair2.0, twoloop_len_pair2.1);
1844 if sum2 <= MAX_LOOP_LEN {
1845 self.alignfold_counts_observed.interior_scores_len[sum2 - 2] += 1.;
1846 if diff2 == 0 {
1847 self.alignfold_counts_observed.interior_scores_symmetric
1848 [twoloop_len_pair2.0 - 1] += 1.;
1849 } else {
1850 self.alignfold_counts_observed.interior_scores_asymmetric[diff2 - 1] += 1.;
1851 }
1852 if twoloop_len_pair2.0 == 1 && twoloop_len_pair2.1 == 1 {
1853 let dict_min_mismatch_pair2 = get_dict_min_pair(&mismatch_pair2);
1854 self.alignfold_counts_observed.interior_scores_1x1[dict_min_mismatch_pair2.0]
1855 [dict_min_mismatch_pair2.1] += 1.;
1856 }
1857 if twoloop_len_pair2.0 <= MAX_INTERIOR_EXPLICIT
1858 && twoloop_len_pair2.1 <= MAX_INTERIOR_EXPLICIT
1859 {
1860 let dict_min_len_pair2 = get_dict_min_pair(&twoloop_len_pair2);
1861 self.alignfold_counts_observed.interior_scores_explicit
1862 [dict_min_len_pair2.0 - 1][dict_min_len_pair2.1 - 1] += 1.;
1863 }
1864 } else {
1865 self.alignfold_counts_observed.interior_scores_len[MAX_LOOP_LEN - 2] += 1.;
1866 if diff2 == 0 {
1867 if twoloop_len_pair2.0 <= MAX_INTERIOR_SYMMETRIC {
1868 self.alignfold_counts_observed.interior_scores_symmetric
1869 [twoloop_len_pair2.0 - 1] += 1.;
1870 } else {
1871 self.alignfold_counts_observed.interior_scores_symmetric
1872 [MAX_INTERIOR_SYMMETRIC - 1] += 1.;
1873 }
1874 } else if diff2 <= MAX_INTERIOR_ASYMMETRIC {
1875 self.alignfold_counts_observed.interior_scores_asymmetric[diff2 - 1] += 1.;
1876 } else {
1877 self.alignfold_counts_observed.interior_scores_asymmetric
1878 [MAX_INTERIOR_ASYMMETRIC - 1] += 1.;
1879 }
1880 }
1881 }
1882 if has_canonical_basepair(&basepair2) {
1883 self.alignfold_counts_observed.terminal_mismatch_scores[basepair2.0][basepair2.1]
1884 [mismatch_pair2.0][mismatch_pair2.1] += 1.;
1885 self.alignfold_counts_observed.helix_close_scores[basepair2.0][basepair2.1] += 1.;
1886 }
1887 if has_canonical_basepair(&basepair4) {
1888 self.alignfold_counts_observed.terminal_mismatch_scores[basepair4.1][basepair4.0]
1889 [mismatch_pair4.1][mismatch_pair4.0] += 1.;
1890 self.alignfold_counts_observed.helix_close_scores[basepair4.1][basepair4.0] += 1.;
1891 }
1892 }
1893 } else {
1894 if has_canonical_basepair(&basepair) {
1895 self.alignfold_counts_observed.dangling_scores_left[basepair.0][basepair.1]
1896 [mismatch_pair.0] += 1.;
1897 self.alignfold_counts_observed.dangling_scores_right[basepair.0][basepair.1]
1898 [mismatch_pair.1] += 1.;
1899 self.alignfold_counts_observed.helix_close_scores[basepair.0][basepair.1] += 1.;
1900 }
1901 if has_canonical_basepair(&basepair2) {
1902 self.alignfold_counts_observed.dangling_scores_left[basepair2.0][basepair2.1]
1903 [mismatch_pair2.0] += 1.;
1904 self.alignfold_counts_observed.dangling_scores_right[basepair2.0][basepair2.1]
1905 [mismatch_pair2.1] += 1.;
1906 self.alignfold_counts_observed.helix_close_scores[basepair2.0][basepair2.1] += 1.;
1907 }
1908 self.alignfold_counts_observed.multibranch_score_base += 2.;
1909 self.alignfold_counts_observed.multibranch_score_basepair += 2.;
1910 self.alignfold_counts_observed.multibranch_score_basepair +=
1911 2. * num_basepairs_loop as Prob;
1912 let num_unpairs_multibranch =
1913 get_num_unpairs_multibranch(basepair_close, basepairs_loop, &seq_pair.0[..]);
1914 self.alignfold_counts_observed.multibranch_score_unpair += num_unpairs_multibranch as Prob;
1915 let num_unpairs_multibranch2 =
1916 get_num_unpairs_multibranch(basepair_close, basepairs_loop, &seq_pair.1[..]);
1917 self.alignfold_counts_observed.multibranch_score_unpair += num_unpairs_multibranch2 as Prob;
1918 for basepair_loop in basepairs_loop.iter() {
1919 let basepair3 = (seq_pair.0[basepair_loop.0], seq_pair.0[basepair_loop.1]);
1920 let closes = false;
1921 let mismatch_pair3 = get_mismatch_pair(&seq_pair.0[..], basepair_loop, closes);
1922 let basepair4 = (seq_pair.1[basepair_loop.0], seq_pair.1[basepair_loop.1]);
1923 let mismatch_pair4 = get_mismatch_pair(&seq_pair.1[..], basepair_loop, closes);
1924 if has_canonical_basepair(&basepair3) {
1925 self.alignfold_counts_observed.dangling_scores_left[basepair3.1][basepair3.0]
1926 [mismatch_pair3.1] += 1.;
1927 self.alignfold_counts_observed.dangling_scores_right[basepair3.1][basepair3.0]
1928 [mismatch_pair3.0] += 1.;
1929 self.alignfold_counts_observed.helix_close_scores[basepair3.1][basepair3.0] += 1.;
1930 }
1931 if has_canonical_basepair(&basepair4) {
1932 self.alignfold_counts_observed.dangling_scores_left[basepair4.1][basepair4.0]
1933 [mismatch_pair4.1] += 1.;
1934 self.alignfold_counts_observed.dangling_scores_right[basepair4.1][basepair4.0]
1935 [mismatch_pair4.0] += 1.;
1936 self.alignfold_counts_observed.helix_close_scores[basepair4.1][basepair4.0] += 1.;
1937 }
1938 }
1939 }
1940 }
1941 self.alignfold_counts_observed.external_score_basepair += 2. * stored_basepairs.len() as Prob;
1942 let mut stored_basepairs_sorted = stored_basepairs
1943 .iter()
1944 .copied()
1945 .collect::<Vec<(usize, usize)>>();
1946 stored_basepairs_sorted.sort();
1947 let num_unpairs_external = get_num_unpairs_external(&stored_basepairs_sorted, &seq_pair.0[..]);
1948 self.alignfold_counts_observed.external_score_unpair += num_unpairs_external as Prob;
1949 let num_unpairs_external2 = get_num_unpairs_external(&stored_basepairs_sorted, &seq_pair.1[..]);
1950 self.alignfold_counts_observed.external_score_unpair += num_unpairs_external2 as Prob;
1951 for basepair_loop in stored_basepairs.iter() {
1952 let basepair = (seq_pair.0[basepair_loop.0], seq_pair.0[basepair_loop.1]);
1953 let basepair2 = (seq_pair.1[basepair_loop.0], seq_pair.1[basepair_loop.1]);
1954 let closes = false;
1955 let mismatch_pair = get_mismatch_pair(&seq_pair.0[..], basepair_loop, closes);
1956 let mismatch_pair2 = get_mismatch_pair(&seq_pair.1[..], basepair_loop, closes);
1957 if has_canonical_basepair(&basepair) {
1958 if mismatch_pair.1 != PSEUDO_BASE {
1959 self.alignfold_counts_observed.dangling_scores_left[basepair.1][basepair.0]
1960 [mismatch_pair.1] += 1.;
1961 }
1962 if mismatch_pair.0 != PSEUDO_BASE {
1963 self.alignfold_counts_observed.dangling_scores_right[basepair.1][basepair.0]
1964 [mismatch_pair.0] += 1.;
1965 }
1966 self.alignfold_counts_observed.helix_close_scores[basepair.1][basepair.0] += 1.;
1967 }
1968 if has_canonical_basepair(&basepair2) {
1969 if mismatch_pair2.1 != PSEUDO_BASE {
1970 self.alignfold_counts_observed.dangling_scores_left[basepair2.1][basepair2.0]
1971 [mismatch_pair2.1] += 1.;
1972 }
1973 if mismatch_pair2.0 != PSEUDO_BASE {
1974 self.alignfold_counts_observed.dangling_scores_right[basepair2.1][basepair2.0]
1975 [mismatch_pair2.0] += 1.;
1976 }
1977 self.alignfold_counts_observed.helix_close_scores[basepair2.1][basepair2.0] += 1.;
1978 }
1979 }
1980 }
1981
1982 pub fn set_curr_scores(&mut self, alignfold_scores: &AlignfoldScores) {
1983 self.fold_scores_pair.0 = FoldScoresTrained::<T>::set_curr_scores(
1984 alignfold_scores,
1985 &self.seq_pair.0,
1986 &self.basepair_probs_pair.0,
1987 );
1988 self.fold_scores_pair.1 = FoldScoresTrained::<T>::set_curr_scores(
1989 alignfold_scores,
1990 &self.seq_pair.1,
1991 &self.basepair_probs_pair.1,
1992 );
1993 }
1994}
1995
1996impl RangeInsertScores {
1997 pub fn origin() -> RangeInsertScores {
1998 let x = Vec::new();
1999 RangeInsertScores {
2000 insert_scores: x.clone(),
2001 insert_scores_external: x.clone(),
2002 insert_scores_multibranch: x.clone(),
2003 insert_scores2: x.clone(),
2004 insert_scores_external2: x.clone(),
2005 insert_scores_multibranch2: x,
2006 }
2007 }
2008
2009 pub fn new(seq_pair: &SeqPair, alignfold_scores: &AlignfoldScores) -> RangeInsertScores {
2010 let seq_len_pair = (seq_pair.0.len(), seq_pair.1.len());
2011 let mut range_insert_scores = RangeInsertScores::origin();
2012 let neg_infs = vec![
2013 vec![NEG_INFINITY; seq_len_pair.0.to_usize().unwrap()];
2014 seq_len_pair.0.to_usize().unwrap()
2015 ];
2016 range_insert_scores.insert_scores = neg_infs.clone();
2017 range_insert_scores.insert_scores_external = neg_infs.clone();
2018 range_insert_scores.insert_scores_multibranch = neg_infs;
2019 let neg_infs = vec![
2020 vec![NEG_INFINITY; seq_len_pair.1.to_usize().unwrap()];
2021 seq_len_pair.1.to_usize().unwrap()
2022 ];
2023 range_insert_scores.insert_scores2 = neg_infs.clone();
2024 range_insert_scores.insert_scores_external2 = neg_infs.clone();
2025 range_insert_scores.insert_scores_multibranch2 = neg_infs;
2026 for i in 1..seq_len_pair.1 - 1 {
2027 let base = seq_pair.1[i];
2028 let mut sum = alignfold_scores.insert_scores[base];
2029 let mut sum_external = sum + alignfold_scores.external_score_unpair;
2030 let mut sum_multibranch = sum + alignfold_scores.multibranch_score_unpair;
2031 range_insert_scores.insert_scores2[i][i] = sum;
2032 range_insert_scores.insert_scores_external2[i][i] = sum_external;
2033 range_insert_scores.insert_scores_multibranch2[i][i] = sum_multibranch;
2034 for j in i + 1..seq_len_pair.1 - 1 {
2035 let x = seq_pair.1[j];
2036 let x = alignfold_scores.insert_scores[x] + alignfold_scores.insert_extend_score;
2037 sum += x;
2038 sum_external += x + alignfold_scores.external_score_unpair;
2039 sum_multibranch += x + alignfold_scores.multibranch_score_unpair;
2040 range_insert_scores.insert_scores2[i][j] = sum;
2041 range_insert_scores.insert_scores_external2[i][j] = sum_external;
2042 range_insert_scores.insert_scores_multibranch2[i][j] = sum_multibranch;
2043 }
2044 }
2045 for i in 1..seq_len_pair.0 - 1 {
2046 let base = seq_pair.0[i];
2047 let term = alignfold_scores.insert_scores[base];
2048 let mut sum = term;
2049 let mut sum_external = sum + alignfold_scores.external_score_unpair;
2050 let mut sum_multibranch = sum + alignfold_scores.multibranch_score_unpair;
2051 range_insert_scores.insert_scores[i][i] = sum;
2052 range_insert_scores.insert_scores_external[i][i] = sum_external;
2053 range_insert_scores.insert_scores_multibranch[i][i] = sum_multibranch;
2054 for j in i + 1..seq_len_pair.0 - 1 {
2055 let x = seq_pair.0[j];
2056 let x = alignfold_scores.insert_scores[x] + alignfold_scores.insert_extend_score;
2057 sum += x;
2058 sum_external += x + alignfold_scores.external_score_unpair;
2059 sum_multibranch += x + alignfold_scores.multibranch_score_unpair;
2060 range_insert_scores.insert_scores[i][j] = sum;
2061 range_insert_scores.insert_scores_external[i][j] = sum_external;
2062 range_insert_scores.insert_scores_multibranch[i][j] = sum_multibranch;
2063 }
2064 }
2065 range_insert_scores
2066 }
2067}
2068
2069impl<T: HashIndex> Default for PairAlignfold<T> {
2070 fn default() -> Self {
2071 Self::new()
2072 }
2073}
2074
2075impl<T: HashIndex> PairAlignfold<T> {
2076 pub fn new() -> PairAlignfold<T> {
2077 PairAlignfold {
2078 matched_pos_pairs: SparsePosMat::<T>::default(),
2079 inserted_poss: SparsePoss::<T>::default(),
2080 deleted_poss: SparsePoss::<T>::default(),
2081 }
2082 }
2083}
2084
2085pub const DEFAULT_BASEPAIR_PROB_TRAIN: Prob = DEFAULT_MIN_BASEPAIR_PROB;
2086pub const DEFAULT_MATCH_PROB_TRAIN: Prob = DEFAULT_MIN_MATCH_PROB;
2087pub const NUM_BASEPAIRS: usize = 6;
2088pub const GROUP_SIZE_MULTIBRANCH: usize = 3;
2089pub const GROUP_SIZE_EXTERNAL: usize = 2;
2090pub const GROUP_SIZE_MATCH_TRANSITION: usize = 3;
2091pub const GROUP_SIZE_INSERT_TRANSITION: usize = 2;
2092pub const GAMMA_DISTRO_ALPHA: Score = 0.;
2093pub const GAMMA_DISTRO_BETA: Score = 1.;
2094pub const DEFAULT_LEARNING_TOLERANCE: Score = 0.000_1;
2095pub const TRAINED_SCORES_FILE: &str = "../src/trained_alignfold_scores.rs";
2096pub const TRAINED_SCORES_FILE_RANDINIT: &str = "../src/trained_alignfold_scores_randinit.rs";
2097#[derive(Clone, Copy)]
2098pub enum TrainType {
2099 TrainedTransfer,
2100 TrainedRandinit,
2101 TransferredOnly,
2102}
2103pub const DEFAULT_TRAIN_TYPE: &str = "trained_transfer";
2104
2105pub fn get_accuracy_expected<T>(
2106 seq_pair: &SeqPair,
2107 alignfold: &PairAlignfold<T>,
2108 match_probs: &SparseProbMat<T>,
2109) -> Score
2110where
2111 T: HashIndex,
2112{
2113 let seq_len_pair = (seq_pair.0.len(), seq_pair.1.len());
2114 let mut insert_probs_pair = (vec![1.; seq_len_pair.0], vec![1.; seq_len_pair.1]);
2115 for (x, &y) in match_probs {
2116 let x = (x.0.to_usize().unwrap(), x.1.to_usize().unwrap());
2117 insert_probs_pair.0[x.0] -= y;
2118 insert_probs_pair.1[x.1] -= y;
2119 }
2120 let total = alignfold.matched_pos_pairs.len()
2121 + alignfold.inserted_poss.len()
2122 + alignfold.deleted_poss.len();
2123 let mut total_expected = alignfold
2124 .matched_pos_pairs
2125 .iter()
2126 .map(|x| match match_probs.get(x) {
2127 Some(&x) => x,
2128 None => 0.,
2129 })
2130 .sum::<Score>();
2131 total_expected += alignfold
2132 .inserted_poss
2133 .iter()
2134 .map(|&x| insert_probs_pair.0[x.to_usize().unwrap()])
2135 .sum::<Score>();
2136 total_expected += alignfold
2137 .deleted_poss
2138 .iter()
2139 .map(|&x| insert_probs_pair.1[x.to_usize().unwrap()])
2140 .sum::<Score>();
2141 total_expected / total as Score
2142}
2143
2144pub fn consprob_core<T>(inputs: InputsConsprobCore<T>) -> (AlignfoldProbMats<T>, Sum)
2145where
2146 T: HashIndex,
2147{
2148 let (
2149 seq_pair,
2150 alignfold_scores,
2151 max_basepair_span_pair,
2152 match_probs,
2153 produces_struct_profs,
2154 trains_alignfold_scores,
2155 alignfold_counts_expected,
2156 forward_pos_pairs,
2157 backward_pos_pairs,
2158 pos_quads_hashed_lens,
2159 fold_scores_pair,
2160 produces_match_probs,
2161 matchable_poss,
2162 matchable_poss2,
2163 ) = inputs;
2164 let range_insert_scores = RangeInsertScores::new(seq_pair, alignfold_scores);
2165 let (alignfold_sums, global_sum) = get_alignfold_sums::<T>((
2166 seq_pair,
2167 alignfold_scores,
2168 max_basepair_span_pair,
2169 match_probs,
2170 trains_alignfold_scores,
2171 forward_pos_pairs,
2172 backward_pos_pairs,
2173 pos_quads_hashed_lens,
2174 fold_scores_pair,
2175 &range_insert_scores,
2176 matchable_poss,
2177 matchable_poss2,
2178 ));
2179 (
2180 get_alignfold_probs::<T>((
2181 seq_pair,
2182 alignfold_scores,
2183 max_basepair_span_pair,
2184 match_probs,
2185 &alignfold_sums,
2186 produces_struct_profs,
2187 global_sum,
2188 trains_alignfold_scores,
2189 alignfold_counts_expected,
2190 pos_quads_hashed_lens,
2191 fold_scores_pair,
2192 produces_match_probs,
2193 forward_pos_pairs,
2194 backward_pos_pairs,
2195 &range_insert_scores,
2196 matchable_poss,
2197 matchable_poss2,
2198 )),
2199 global_sum,
2200 )
2201}
2202
2203pub fn get_alignfold_sums<T>(inputs: InputsInsideSumsGetter<T>) -> (AlignfoldSums<T>, Sum)
2204where
2205 T: HashIndex,
2206{
2207 let (
2208 seq_pair,
2209 alignfold_scores,
2210 max_basepair_span_pair,
2211 match_probs,
2212 trains_alignfold_scores,
2213 forward_pos_pairs,
2214 backward_pos_pairs,
2215 pos_quads_hashed_lens,
2216 fold_scores_pair,
2217 range_insert_scores,
2218 matchable_poss,
2219 matchable_poss2,
2220 ) = inputs;
2221 let seq_len_pair = (
2222 T::from_usize(seq_pair.0.len()).unwrap(),
2223 T::from_usize(seq_pair.1.len()).unwrap(),
2224 );
2225 let mut alignfold_sums = AlignfoldSums::<T>::new();
2226 for substr_len in range_inclusive(
2227 T::from_usize(if trains_alignfold_scores {
2228 2
2229 } else {
2230 MIN_SPAN_HAIRPIN_CLOSE
2231 })
2232 .unwrap(),
2233 max_basepair_span_pair.0,
2234 ) {
2235 for substr_len2 in range_inclusive(
2236 T::from_usize(if trains_alignfold_scores {
2237 2
2238 } else {
2239 MIN_SPAN_HAIRPIN_CLOSE
2240 })
2241 .unwrap(),
2242 max_basepair_span_pair.1,
2243 ) {
2244 if let Some(pos_pairs) = pos_quads_hashed_lens.get(&(substr_len, substr_len2)) {
2245 for &(i, k) in pos_pairs {
2246 let (j, l) = (i + substr_len - T::one(), k + substr_len2 - T::one());
2247 let (long_i, long_j, long_k, long_l) = (
2248 i.to_usize().unwrap(),
2249 j.to_usize().unwrap(),
2250 k.to_usize().unwrap(),
2251 l.to_usize().unwrap(),
2252 );
2253 let basepair = (seq_pair.0[long_i], seq_pair.0[long_j]);
2254 let basepair2 = (seq_pair.1[long_k], seq_pair.1[long_l]);
2255 let pos_quad = (i, j, k, l);
2256 let computes_forward_sums = true;
2257 let (sum_seqalign, sum_multibranch) = get_loop_sums::<T>((
2258 seq_pair,
2259 alignfold_scores,
2260 match_probs,
2261 &pos_quad,
2262 &mut alignfold_sums,
2263 computes_forward_sums,
2264 forward_pos_pairs,
2265 range_insert_scores,
2266 matchable_poss,
2267 matchable_poss2,
2268 ));
2269 let computes_forward_sums = false;
2270 let _ = get_loop_sums::<T>((
2271 seq_pair,
2272 alignfold_scores,
2273 match_probs,
2274 &pos_quad,
2275 &mut alignfold_sums,
2276 computes_forward_sums,
2277 backward_pos_pairs,
2278 range_insert_scores,
2279 matchable_poss,
2280 matchable_poss2,
2281 ));
2282 let mut sum = NEG_INFINITY;
2283 let pairmatch_score = alignfold_scores.match_scores[basepair.0][basepair2.0]
2284 + alignfold_scores.match_scores[basepair.1][basepair2.1];
2285 if substr_len.to_usize().unwrap() - 2 <= MAX_LOOP_LEN
2286 && substr_len2.to_usize().unwrap() - 2 <= MAX_LOOP_LEN
2287 {
2288 let hairpin_score = fold_scores_pair.0.hairpin_scores[&(i, j)];
2289 let hairpin_score2 = fold_scores_pair.1.hairpin_scores[&(k, l)];
2290 let score = hairpin_score + hairpin_score2 + sum_seqalign;
2291 logsumexp(&mut sum, score);
2292 }
2293 let forward_sums = &alignfold_sums.forward_sums_hashed_poss2[&(i, k)];
2294 let backward_sums = &alignfold_sums.backward_sums_hashed_poss2[&(j, l)];
2295 let min = T::from_usize(if trains_alignfold_scores {
2296 2
2297 } else {
2298 MIN_HAIRPIN_LEN
2299 })
2300 .unwrap();
2301 let min_len_pair = (
2302 if substr_len <= min + T::from_usize(MAX_LOOP_LEN + 2).unwrap() {
2303 min
2304 } else {
2305 substr_len - T::from_usize(MAX_LOOP_LEN + 2).unwrap()
2306 },
2307 if substr_len2 <= min + T::from_usize(MAX_LOOP_LEN + 2).unwrap() {
2308 min
2309 } else {
2310 substr_len2 - T::from_usize(MAX_LOOP_LEN + 2).unwrap()
2311 },
2312 );
2313 for substr_len3 in range(min_len_pair.0, substr_len - T::one()) {
2314 for substr_len4 in range(min_len_pair.1, substr_len2 - T::one()) {
2315 if let Some(pos_pairs2) = pos_quads_hashed_lens.get(&(substr_len3, substr_len4)) {
2316 for &(m, o) in pos_pairs2 {
2317 let (n, p) = (m + substr_len3 - T::one(), o + substr_len4 - T::one());
2318 if !(i < m && n < j && k < o && p < l) {
2319 continue;
2320 }
2321 if m - i - T::one() + j - n - T::one() > T::from_usize(MAX_LOOP_LEN).unwrap() {
2322 continue;
2323 }
2324 if o - k - T::one() + l - p - T::one() > T::from_usize(MAX_LOOP_LEN).unwrap() {
2325 continue;
2326 }
2327 let pos_quad2 = (m, n, o, p);
2328 if let Some(&x) = alignfold_sums.sums_close.get(&pos_quad2) {
2329 let mut forward_term = NEG_INFINITY;
2330 let mut backward_term = forward_term;
2331 let pos_pair2 = (m - T::one(), o - T::one());
2332 if let Some(x) = forward_sums.get(&pos_pair2) {
2333 logsumexp(&mut forward_term, x.sum_seqalign);
2334 }
2335 let pos_pair2 = (n + T::one(), p + T::one());
2336 if let Some(x) = backward_sums.get(&pos_pair2) {
2337 logsumexp(&mut backward_term, x.sum_seqalign);
2338 }
2339 let twoloop_score = fold_scores_pair.0.twoloop_scores[&(i, j, m, n)];
2340 let twoloop_score2 = fold_scores_pair.1.twoloop_scores[&(k, l, o, p)];
2341 let x = twoloop_score + twoloop_score2 + x + forward_term + backward_term;
2342 logsumexp(&mut sum, x);
2343 }
2344 }
2345 }
2346 }
2347 }
2348 let multibranch_close_score = fold_scores_pair.0.multibranch_close_scores[&(i, j)];
2349 let multibranch_close_score2 = fold_scores_pair.1.multibranch_close_scores[&(k, l)];
2350 let score = multibranch_close_score + multibranch_close_score2 + sum_multibranch;
2351 logsumexp(&mut sum, score);
2352 if sum > NEG_INFINITY {
2353 let sum = sum + pairmatch_score;
2354 alignfold_sums.sums_close.insert(pos_quad, sum);
2355 let external_accessible_score = fold_scores_pair.0.external_accessible_scores[&(i, j)];
2356 let external_accessible_score2 = fold_scores_pair.1.external_accessible_scores[&(k, l)];
2357 alignfold_sums.sums_accessible_external.insert(
2358 pos_quad,
2359 sum + external_accessible_score + external_accessible_score2,
2360 );
2361 let multibranch_accessible_score =
2362 fold_scores_pair.0.multibranch_accessible_scores[&(i, j)];
2363 let multibranch_accessible_score2 =
2364 fold_scores_pair.1.multibranch_accessible_scores[&(k, l)];
2365 alignfold_sums.sums_accessible_multibranch.insert(
2366 pos_quad,
2367 sum + multibranch_accessible_score + multibranch_accessible_score2,
2368 );
2369 }
2370 }
2371 }
2372 }
2373 }
2374 let leftmost_pos_pair = (T::zero(), T::zero());
2375 let rightmost_pos_pair = (seq_len_pair.0 - T::one(), seq_len_pair.1 - T::one());
2376 alignfold_sums
2377 .forward_sums_external
2378 .insert(leftmost_pos_pair, 0.);
2379 alignfold_sums
2380 .backward_sums_external
2381 .insert(rightmost_pos_pair, 0.);
2382 for i in range(T::zero(), seq_len_pair.0 - T::one()) {
2383 let long_i = i.to_usize().unwrap();
2384 let base = seq_pair.0[long_i];
2385 for j in range(T::zero(), seq_len_pair.1 - T::one()) {
2386 let pos_pair = (i, j);
2387 if pos_pair == (T::zero(), T::zero()) {
2388 continue;
2389 }
2390 let long_j = j.to_usize().unwrap();
2391 let mut sum = NEG_INFINITY;
2392 if let Some(x) = forward_pos_pairs.get(&pos_pair) {
2393 for &(k, l) in x {
2394 let pos_pair2 = (k - T::one(), l - T::one());
2395 let pos_quad = (k, i, l, j);
2396 if let Some(&x) = alignfold_sums.sums_accessible_external.get(&pos_quad) {
2397 if let Some(&y) = alignfold_sums.forward_sums_external2.get(&pos_pair2) {
2398 let y = x + y;
2399 logsumexp(&mut sum, y);
2400 }
2401 }
2402 }
2403 }
2404 let base2 = seq_pair.1[long_j];
2405 if i > T::zero() && j > T::zero() && match_probs.contains_key(&pos_pair) {
2406 let mut sum2 = NEG_INFINITY;
2407 let loopmatch_score =
2408 alignfold_scores.match_scores[base][base2] + 2. * alignfold_scores.external_score_unpair;
2409 let pos_pair2 = (i - T::one(), j - T::one());
2410 let long_pos_pair2 = (
2411 pos_pair2.0.to_usize().unwrap(),
2412 pos_pair2.1.to_usize().unwrap(),
2413 );
2414 let begins_sum = pos_pair2 == leftmost_pos_pair;
2415 if let Some(&x) = alignfold_sums.forward_sums_external.get(&pos_pair2) {
2416 let x = x
2417 + if begins_sum {
2418 alignfold_scores.init_match_score
2419 } else {
2420 alignfold_scores.match2match_score
2421 };
2422 logsumexp(&mut sum2, x);
2423 }
2424 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2425 for &x in x {
2426 if x >= pos_pair2.1 {
2427 continue;
2428 }
2429 let pos_pair3 = (pos_pair2.0, x);
2430 if let Some(&y) = alignfold_sums.forward_sums_external.get(&pos_pair3) {
2431 let long_x = x.to_usize().unwrap();
2432 let begins_sum = pos_pair3 == leftmost_pos_pair;
2433 let z = range_insert_scores.insert_scores_external2[long_x + 1][long_pos_pair2.1]
2434 + if begins_sum {
2435 alignfold_scores.init_insert_score
2436 } else {
2437 alignfold_scores.match2insert_score
2438 };
2439 let z = y + z + alignfold_scores.match2insert_score;
2440 logsumexp(&mut sum2, z);
2441 }
2442 }
2443 }
2444 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2445 for &x in x {
2446 if x >= pos_pair2.0 {
2447 continue;
2448 }
2449 let pos_pair3 = (x, pos_pair2.1);
2450 if let Some(&y) = alignfold_sums.forward_sums_external.get(&pos_pair3) {
2451 let long_x = x.to_usize().unwrap();
2452 let begins_sum = pos_pair3 == leftmost_pos_pair;
2453 let z = range_insert_scores.insert_scores_external[long_x + 1][long_pos_pair2.0]
2454 + if begins_sum {
2455 alignfold_scores.init_insert_score
2456 } else {
2457 alignfold_scores.match2insert_score
2458 };
2459 let z = y + z + alignfold_scores.match2insert_score;
2460 logsumexp(&mut sum2, z);
2461 }
2462 }
2463 }
2464 if sum2 > NEG_INFINITY {
2465 alignfold_sums
2466 .forward_sums_external2
2467 .insert(pos_pair2, sum2);
2468 }
2469 let term = sum2 + loopmatch_score;
2470 logsumexp(&mut sum, term);
2471 if sum > NEG_INFINITY {
2472 alignfold_sums.forward_sums_external.insert(pos_pair, sum);
2473 }
2474 }
2475 }
2476 }
2477 let mut global_sum = NEG_INFINITY;
2478 let pos_pair2 = (
2479 rightmost_pos_pair.0 - T::one(),
2480 rightmost_pos_pair.1 - T::one(),
2481 );
2482 let long_pos_pair2 = (
2483 pos_pair2.0.to_usize().unwrap(),
2484 pos_pair2.1.to_usize().unwrap(),
2485 );
2486 if let Some(&x) = alignfold_sums.forward_sums_external.get(&pos_pair2) {
2487 logsumexp(&mut global_sum, x);
2488 }
2489 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2490 for &x in x {
2491 if x >= pos_pair2.1 {
2492 continue;
2493 }
2494 if let Some(&y) = alignfold_sums.forward_sums_external.get(&(pos_pair2.0, x)) {
2495 let long_x = x.to_usize().unwrap();
2496 let z = range_insert_scores.insert_scores_external2[long_x + 1][long_pos_pair2.1]
2497 + alignfold_scores.match2insert_score;
2498 let z = y + z;
2499 logsumexp(&mut global_sum, z);
2500 }
2501 }
2502 }
2503 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2504 for &x in x {
2505 if x >= pos_pair2.0 {
2506 continue;
2507 }
2508 if let Some(&y) = alignfold_sums.forward_sums_external.get(&(x, pos_pair2.1)) {
2509 let long_x = x.to_usize().unwrap();
2510 let z = range_insert_scores.insert_scores_external[long_x + 1][long_pos_pair2.0]
2511 + alignfold_scores.match2insert_score;
2512 let z = y + z;
2513 logsumexp(&mut global_sum, z);
2514 }
2515 }
2516 }
2517 for i in range(T::one(), seq_len_pair.0).rev() {
2518 let long_i = i.to_usize().unwrap();
2519 let base = seq_pair.0[long_i];
2520 for j in range(T::one(), seq_len_pair.1).rev() {
2521 let pos_pair = (i, j);
2522 if pos_pair == (seq_len_pair.0 - T::one(), seq_len_pair.1 - T::one()) {
2523 continue;
2524 }
2525 let long_j = j.to_usize().unwrap();
2526 let mut sum = NEG_INFINITY;
2527 if let Some(x) = backward_pos_pairs.get(&pos_pair) {
2528 for &(k, l) in x {
2529 let pos_pair2 = (k + T::one(), l + T::one());
2530 let pos_quad = (i, k, j, l);
2531 if let Some(&x) = alignfold_sums.sums_accessible_external.get(&pos_quad) {
2532 if let Some(&y) = alignfold_sums.backward_sums_external2.get(&pos_pair2) {
2533 let y = x + y;
2534 logsumexp(&mut sum, y);
2535 }
2536 }
2537 }
2538 }
2539 let base2 = seq_pair.1[long_j];
2540 if i < seq_len_pair.0 - T::one() && j < seq_len_pair.1 - T::one() {
2541 let pos_pair2 = (i + T::one(), j + T::one());
2542 let long_pos_pair2 = (
2543 pos_pair2.0.to_usize().unwrap(),
2544 pos_pair2.1.to_usize().unwrap(),
2545 );
2546 let ends_sum = pos_pair2 == rightmost_pos_pair;
2547 if match_probs.contains_key(&pos_pair) {
2548 let mut sum2 = NEG_INFINITY;
2549 let loopmatch_score = alignfold_scores.match_scores[base][base2]
2550 + 2. * alignfold_scores.external_score_unpair;
2551 if let Some(&x) = alignfold_sums.backward_sums_external.get(&pos_pair2) {
2552 let x = x
2553 + if ends_sum {
2554 0.
2555 } else {
2556 alignfold_scores.match2match_score
2557 };
2558 logsumexp(&mut sum2, x);
2559 }
2560 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2561 for &x in x {
2562 if x <= pos_pair2.1 {
2563 continue;
2564 }
2565 let pos_pair3 = (pos_pair2.0, x);
2566 if let Some(&y) = alignfold_sums.backward_sums_external.get(&pos_pair3) {
2567 let long_x = x.to_usize().unwrap();
2568 let ends_sum = pos_pair3 == rightmost_pos_pair;
2569 let z = range_insert_scores.insert_scores_external2[long_pos_pair2.1][long_x - 1]
2570 + if ends_sum {
2571 0.
2572 } else {
2573 alignfold_scores.match2insert_score
2574 };
2575 let z = y + z + alignfold_scores.match2insert_score;
2576 logsumexp(&mut sum2, z);
2577 }
2578 }
2579 }
2580 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2581 for &x in x {
2582 if x <= pos_pair2.0 {
2583 continue;
2584 }
2585 let pos_pair3 = (x, pos_pair2.1);
2586 if let Some(&y) = alignfold_sums.backward_sums_external.get(&pos_pair3) {
2587 let long_x = x.to_usize().unwrap();
2588 let ends_sum = pos_pair3 == rightmost_pos_pair;
2589 let z = range_insert_scores.insert_scores_external[long_pos_pair2.0][long_x - 1]
2590 + if ends_sum {
2591 0.
2592 } else {
2593 alignfold_scores.match2insert_score
2594 };
2595 let z = y + z + alignfold_scores.match2insert_score;
2596 logsumexp(&mut sum2, z);
2597 }
2598 }
2599 }
2600 if sum2 > NEG_INFINITY {
2601 alignfold_sums
2602 .backward_sums_external2
2603 .insert(pos_pair2, sum2);
2604 }
2605 let term = sum2 + loopmatch_score;
2606 logsumexp(&mut sum, term);
2607 if sum > NEG_INFINITY {
2608 alignfold_sums.backward_sums_external.insert(pos_pair, sum);
2609 }
2610 }
2611 }
2612 }
2613 }
2614 (alignfold_sums, global_sum)
2615}
2616
2617pub fn get_loop_sums<T>(inputs: InputsLoopSumsGetter<T>) -> (Sum, Sum)
2618where
2619 T: HashIndex,
2620{
2621 let (
2622 seq_pair,
2623 alignfold_scores,
2624 match_probs,
2625 pos_quad,
2626 alignfold_sums,
2627 computes_forward_sums,
2628 pos_pairs,
2629 range_insert_scores,
2630 matchable_poss,
2631 matchable_poss2,
2632 ) = inputs;
2633 let &(i, j, k, l) = pos_quad;
2634 let leftmost_pos_pair = if computes_forward_sums {
2635 (i, k)
2636 } else {
2637 (i + T::one(), k + T::one())
2638 };
2639 let rightmost_pos_pair = if computes_forward_sums {
2640 (j - T::one(), l - T::one())
2641 } else {
2642 (j, l)
2643 };
2644 let sums_hashed_poss = if computes_forward_sums {
2645 &mut alignfold_sums.forward_sums_hashed_poss
2646 } else {
2647 &mut alignfold_sums.backward_sums_hashed_poss
2648 };
2649 let sums_hashed_poss2 = if computes_forward_sums {
2650 &mut alignfold_sums.forward_sums_hashed_poss2
2651 } else {
2652 &mut alignfold_sums.backward_sums_hashed_poss2
2653 };
2654 if !sums_hashed_poss.contains_key(&if computes_forward_sums {
2655 leftmost_pos_pair
2656 } else {
2657 rightmost_pos_pair
2658 }) {
2659 sums_hashed_poss.insert(
2660 if computes_forward_sums {
2661 leftmost_pos_pair
2662 } else {
2663 rightmost_pos_pair
2664 },
2665 LoopSumsMat::<T>::new(),
2666 );
2667 }
2668 if !sums_hashed_poss2.contains_key(&if computes_forward_sums {
2669 leftmost_pos_pair
2670 } else {
2671 rightmost_pos_pair
2672 }) {
2673 sums_hashed_poss2.insert(
2674 if computes_forward_sums {
2675 leftmost_pos_pair
2676 } else {
2677 rightmost_pos_pair
2678 },
2679 LoopSumsMat::<T>::new(),
2680 );
2681 }
2682 let sums_mat = &mut sums_hashed_poss
2683 .get_mut(&if computes_forward_sums {
2684 leftmost_pos_pair
2685 } else {
2686 rightmost_pos_pair
2687 })
2688 .unwrap();
2689 let sums_mat2 = &mut sums_hashed_poss2
2690 .get_mut(&if computes_forward_sums {
2691 leftmost_pos_pair
2692 } else {
2693 rightmost_pos_pair
2694 })
2695 .unwrap();
2696 let iter: Poss<T> = if computes_forward_sums {
2697 range(i, j).collect()
2698 } else {
2699 range_inclusive(i + T::one(), j).rev().collect()
2700 };
2701 let iter2: Poss<T> = if computes_forward_sums {
2702 range(k, l).collect()
2703 } else {
2704 range_inclusive(k + T::one(), l).rev().collect()
2705 };
2706 for &u in iter.iter() {
2707 let long_u = u.to_usize().unwrap();
2708 let base = seq_pair.0[long_u];
2709 for &v in iter2.iter() {
2710 let pos_pair = (u, v);
2711 if sums_mat.contains_key(&pos_pair) {
2712 continue;
2713 }
2714 let mut sums = LoopSums::new();
2715 if (computes_forward_sums && u == i && v == k) || (!computes_forward_sums && u == j && v == l)
2716 {
2717 sums.sum_seqalign = 0.;
2718 sums.sum_seqalign_multibranch = 0.;
2719 sums.sum_0ormore_pairmatches = 0.;
2720 sums_mat.insert(pos_pair, sums);
2721 continue;
2722 }
2723 let long_v = v.to_usize().unwrap();
2724 let mut sum_multibranch = NEG_INFINITY;
2725 let mut sum_1st_pairmatches = sum_multibranch;
2726 let mut sum = sum_multibranch;
2727 if let Some(pos_pairs) = pos_pairs.get(&pos_pair) {
2728 for &(m, n) in pos_pairs {
2729 if computes_forward_sums {
2730 if !(i < m && k < n) {
2731 continue;
2732 }
2733 } else if !(m < j && n < l) {
2734 continue;
2735 }
2736 let pos_pair2 = if computes_forward_sums {
2737 (m - T::one(), n - T::one())
2738 } else {
2739 (m + T::one(), n + T::one())
2740 };
2741 let pos_quad2 = if computes_forward_sums {
2742 (m, u, n, v)
2743 } else {
2744 (u, m, v, n)
2745 };
2746 if let Some(&x) = alignfold_sums.sums_accessible_multibranch.get(&pos_quad2) {
2747 if let Some(y) = sums_mat2.get(&pos_pair2) {
2748 let z = x + y.sum_1ormore_pairmatches;
2749 logsumexp(&mut sum_multibranch, z);
2750 let z = x + y.sum_seqalign_multibranch;
2751 logsumexp(&mut sum_1st_pairmatches, z);
2752 }
2753 }
2754 }
2755 }
2756 let pos_pair2 = if computes_forward_sums {
2757 (u - T::one(), v - T::one())
2758 } else {
2759 (u + T::one(), v + T::one())
2760 };
2761 let long_pos_pair2 = (
2762 pos_pair2.0.to_usize().unwrap(),
2763 pos_pair2.1.to_usize().unwrap(),
2764 );
2765 let base2 = seq_pair.1[long_v];
2766 if match_probs.contains_key(&pos_pair) {
2767 let mut sums2 = LoopSums::new();
2768 let mut sum_seqalign2 = NEG_INFINITY;
2769 let mut sum_seqalign_multibranch2 = sum_seqalign2;
2770 let mut sum_multibranch2 = sum_seqalign2;
2771 let mut sum_1st_pairmatches2 = sum_seqalign2;
2772 let mut sum2 = sum_seqalign2;
2773 let loopmatch_score = alignfold_scores.match_scores[base][base2];
2774 let loopmatch_score_multibranch =
2775 loopmatch_score + 2. * alignfold_scores.multibranch_score_unpair;
2776 if let Some(x) = sums_mat.get(&pos_pair2) {
2777 let y = x.sum_multibranch + alignfold_scores.match2match_score;
2778 logsumexp(&mut sum_multibranch2, y);
2779 let y = x.sum_1st_pairmatches + alignfold_scores.match2match_score;
2780 logsumexp(&mut sum_1st_pairmatches2, y);
2781 let y = x.sum_seqalign_multibranch + alignfold_scores.match2match_score;
2782 logsumexp(&mut sum_seqalign_multibranch2, y);
2783 let y = x.sum_seqalign + alignfold_scores.match2match_score;
2784 logsumexp(&mut sum_seqalign2, y);
2785 }
2786 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2787 for &x in x {
2788 if computes_forward_sums && x >= pos_pair2.1
2789 || (!computes_forward_sums && x <= pos_pair2.1)
2790 {
2791 continue;
2792 }
2793 if let Some(y) = sums_mat.get(&(pos_pair2.0, x)) {
2794 let long_x = x.to_usize().unwrap();
2795 let z = if computes_forward_sums {
2796 range_insert_scores.insert_scores2[long_x + 1][long_pos_pair2.1]
2797 } else {
2798 range_insert_scores.insert_scores2[long_pos_pair2.1][long_x - 1]
2799 } + alignfold_scores.match2insert_score;
2800 let a = if computes_forward_sums {
2801 range_insert_scores.insert_scores_multibranch2[long_x + 1][long_pos_pair2.1]
2802 } else {
2803 range_insert_scores.insert_scores_multibranch2[long_pos_pair2.1][long_x - 1]
2804 } + alignfold_scores.match2insert_score;
2805 let x = y.sum_multibranch + alignfold_scores.match2insert_score + a;
2806 logsumexp(&mut sum_multibranch2, x);
2807 let x = y.sum_1st_pairmatches + alignfold_scores.match2insert_score + a;
2808 logsumexp(&mut sum_1st_pairmatches2, x);
2809 let x = y.sum_seqalign + alignfold_scores.match2insert_score + z;
2810 logsumexp(&mut sum_seqalign2, x);
2811 let x = y.sum_seqalign_multibranch + alignfold_scores.match2insert_score + a;
2812 logsumexp(&mut sum_seqalign_multibranch2, x);
2813 }
2814 }
2815 }
2816 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2817 for &x in x {
2818 if computes_forward_sums && x >= pos_pair2.0
2819 || (!computes_forward_sums && x <= pos_pair2.0)
2820 {
2821 continue;
2822 }
2823 if let Some(y) = sums_mat.get(&(x, pos_pair2.1)) {
2824 let long_x = x.to_usize().unwrap();
2825 let z = if computes_forward_sums {
2826 range_insert_scores.insert_scores[long_x + 1][long_pos_pair2.0]
2827 } else {
2828 range_insert_scores.insert_scores[long_pos_pair2.0][long_x - 1]
2829 } + alignfold_scores.match2insert_score;
2830 let a = if computes_forward_sums {
2831 range_insert_scores.insert_scores_multibranch[long_x + 1][long_pos_pair2.0]
2832 } else {
2833 range_insert_scores.insert_scores_multibranch[long_pos_pair2.0][long_x - 1]
2834 } + alignfold_scores.match2insert_score;
2835 let x = y.sum_multibranch + alignfold_scores.match2insert_score + a;
2836 logsumexp(&mut sum_multibranch2, x);
2837 let x = y.sum_1st_pairmatches + alignfold_scores.match2insert_score + a;
2838 logsumexp(&mut sum_1st_pairmatches2, x);
2839 let x = y.sum_seqalign + alignfold_scores.match2insert_score + z;
2840 logsumexp(&mut sum_seqalign2, x);
2841 let x = y.sum_seqalign_multibranch + alignfold_scores.match2insert_score + a;
2842 logsumexp(&mut sum_seqalign_multibranch2, x);
2843 }
2844 }
2845 }
2846 sums2.sum_multibranch = sum_multibranch2;
2847 logsumexp(&mut sum2, sum_multibranch2);
2848 sums2.sum_1st_pairmatches = sum_1st_pairmatches2;
2849 logsumexp(&mut sum2, sum_1st_pairmatches2);
2850 sums2.sum_1ormore_pairmatches = sum2;
2851 sums2.sum_seqalign = sum_seqalign2;
2852 sums2.sum_seqalign_multibranch = sum_seqalign_multibranch2;
2853 logsumexp(&mut sum2, sum_seqalign_multibranch2);
2854 sums2.sum_0ormore_pairmatches = sum2;
2855 if has_valid_sums(&sums2) {
2856 sums_mat2.insert(pos_pair2, sums2);
2857 }
2858 let term = sum_multibranch2 + loopmatch_score_multibranch;
2859 logsumexp(&mut sum_multibranch, term);
2860 sums.sum_multibranch = sum_multibranch;
2861 logsumexp(&mut sum, sum_multibranch);
2862 let term = sum_1st_pairmatches2 + loopmatch_score_multibranch;
2863 logsumexp(&mut sum_1st_pairmatches, term);
2864 sums.sum_1st_pairmatches = sum_1st_pairmatches;
2865 logsumexp(&mut sum, sum_1st_pairmatches);
2866 sums.sum_1ormore_pairmatches = sum;
2867 let sum_seqalign_multibranch = sum_seqalign_multibranch2 + loopmatch_score_multibranch;
2868 sums.sum_seqalign_multibranch = sum_seqalign_multibranch;
2869 logsumexp(&mut sum, sum_seqalign_multibranch);
2870 sums.sum_0ormore_pairmatches = sum;
2871 let sum_seqalign = sum_seqalign2 + loopmatch_score;
2872 sums.sum_seqalign = sum_seqalign;
2873 if has_valid_sums(&sums) {
2874 sums_mat.insert(pos_pair, sums);
2875 }
2876 }
2877 }
2878 }
2879 let mut final_sum_seqalign = NEG_INFINITY;
2880 let mut final_sum_multibranch = final_sum_seqalign;
2881 if computes_forward_sums {
2882 let pos_pair2 = rightmost_pos_pair;
2883 let long_pos_pair2 = (
2884 pos_pair2.0.to_usize().unwrap(),
2885 pos_pair2.1.to_usize().unwrap(),
2886 );
2887 if let Some(x) = sums_mat.get(&pos_pair2) {
2888 let y = x.sum_multibranch + alignfold_scores.match2match_score;
2889 logsumexp(&mut final_sum_multibranch, y);
2890 let y = x.sum_seqalign + alignfold_scores.match2match_score;
2891 logsumexp(&mut final_sum_seqalign, y);
2892 }
2893 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2894 for &x in x {
2895 if x >= pos_pair2.1 {
2896 continue;
2897 }
2898 if let Some(y) = sums_mat.get(&(pos_pair2.0, x)) {
2899 let long_x = x.to_usize().unwrap();
2900 let z = range_insert_scores.insert_scores2[long_x + 1][long_pos_pair2.1]
2901 + alignfold_scores.match2insert_score;
2902 let a = range_insert_scores.insert_scores_multibranch2[long_x + 1][long_pos_pair2.1]
2903 + alignfold_scores.match2insert_score;
2904 let x = y.sum_multibranch + alignfold_scores.match2insert_score + a;
2905 logsumexp(&mut final_sum_multibranch, x);
2906 let x = y.sum_seqalign + alignfold_scores.match2insert_score + z;
2907 logsumexp(&mut final_sum_seqalign, x);
2908 }
2909 }
2910 }
2911 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2912 for &x in x {
2913 if x >= pos_pair2.0 {
2914 continue;
2915 }
2916 if let Some(y) = sums_mat.get(&(x, pos_pair2.1)) {
2917 let long_x = x.to_usize().unwrap();
2918 let z = range_insert_scores.insert_scores[long_x + 1][long_pos_pair2.0]
2919 + alignfold_scores.match2insert_score;
2920 let a = range_insert_scores.insert_scores_multibranch[long_x + 1][long_pos_pair2.0]
2921 + alignfold_scores.match2insert_score;
2922 let x = y.sum_multibranch + alignfold_scores.match2insert_score + a;
2923 logsumexp(&mut final_sum_multibranch, x);
2924 let x = y.sum_seqalign + alignfold_scores.match2insert_score + z;
2925 logsumexp(&mut final_sum_seqalign, x);
2926 }
2927 }
2928 }
2929 let mut sums2 = LoopSums::new();
2930 sums2.sum_multibranch = final_sum_multibranch;
2931 sums2.sum_seqalign = final_sum_seqalign;
2932 sums_mat2.insert(pos_pair2, sums2);
2933 }
2934 (final_sum_seqalign, final_sum_multibranch)
2935}
2936
2937pub fn get_2loop_sums<T>(inputs: Inputs2loopSumsGetter<T>) -> (SparseSumMat<T>, SparseSumMat<T>)
2938where
2939 T: HashIndex,
2940{
2941 let (
2942 seq_pair,
2943 alignfold_scores,
2944 match_probs,
2945 pos_quad,
2946 alignfold_sums,
2947 computes_forward_sums,
2948 pos_pairs,
2949 fold_scores_pair,
2950 range_insert_scores,
2951 matchable_poss,
2952 matchable_poss2,
2953 ) = inputs;
2954 let &(i, j, k, l) = pos_quad;
2955 let leftmost_pos_pair = if computes_forward_sums {
2956 (i, k)
2957 } else {
2958 (i + T::one(), k + T::one())
2959 };
2960 let rightmost_pos_pair = if computes_forward_sums {
2961 (j - T::one(), l - T::one())
2962 } else {
2963 (j, l)
2964 };
2965 let sums_hashed_poss = if computes_forward_sums {
2966 &alignfold_sums.forward_sums_hashed_poss2
2967 } else {
2968 &alignfold_sums.backward_sums_hashed_poss2
2969 };
2970 let sums = &sums_hashed_poss[&if computes_forward_sums {
2971 leftmost_pos_pair
2972 } else {
2973 rightmost_pos_pair
2974 }];
2975 let iter: Poss<T> = if computes_forward_sums {
2976 range(i, j).collect()
2977 } else {
2978 range_inclusive(i + T::one(), j).rev().collect()
2979 };
2980 let iter2: Poss<T> = if computes_forward_sums {
2981 range(k, l).collect()
2982 } else {
2983 range_inclusive(k + T::one(), l).rev().collect()
2984 };
2985 let mut sum_mat = SparseSumMat::<T>::default();
2986 let mut sum_mat2 = sum_mat.clone();
2987 for &u in iter.iter() {
2988 let long_u = u.to_usize().unwrap();
2989 let base = seq_pair.0[long_u];
2990 for &v in iter2.iter() {
2991 let pos_pair = (u, v);
2992 if (computes_forward_sums && u == i && v == k) || (!computes_forward_sums && u == j && v == l)
2993 {
2994 continue;
2995 }
2996 let long_v = v.to_usize().unwrap();
2997 let mut sum = NEG_INFINITY;
2998 if let Some(x) = pos_pairs.get(&pos_pair) {
2999 for &(m, n) in x {
3000 if computes_forward_sums {
3001 if !(i < m && k < n) {
3002 continue;
3003 }
3004 } else if !(m < j && n < l) {
3005 continue;
3006 }
3007 let pos_pair2 = if computes_forward_sums {
3008 (m - T::one(), n - T::one())
3009 } else {
3010 (m + T::one(), n + T::one())
3011 };
3012 let pos_quad2 = if computes_forward_sums {
3013 (m, u, n, v)
3014 } else {
3015 (u, m, v, n)
3016 };
3017 if pos_quad2.0 - i - T::one() + j - pos_quad2.1 - T::one()
3018 > T::from_usize(MAX_LOOP_LEN).unwrap()
3019 {
3020 continue;
3021 }
3022 if pos_quad2.2 - k - T::one() + l - pos_quad2.3 - T::one()
3023 > T::from_usize(MAX_LOOP_LEN).unwrap()
3024 {
3025 continue;
3026 }
3027 if let Some(&x) = alignfold_sums.sums_close.get(&pos_quad2) {
3028 if let Some(y) = sums.get(&pos_pair2) {
3029 let twoloop_score =
3030 fold_scores_pair.0.twoloop_scores[&(i, j, pos_quad2.0, pos_quad2.1)];
3031 let twoloop_score2 =
3032 fold_scores_pair.1.twoloop_scores[&(k, l, pos_quad2.2, pos_quad2.3)];
3033 let y = x + y.sum_seqalign + twoloop_score + twoloop_score2;
3034 logsumexp(&mut sum, y);
3035 }
3036 }
3037 }
3038 }
3039 let pos_pair2 = if computes_forward_sums {
3040 (u - T::one(), v - T::one())
3041 } else {
3042 (u + T::one(), v + T::one())
3043 };
3044 let long_pos_pair2 = (
3045 pos_pair2.0.to_usize().unwrap(),
3046 pos_pair2.1.to_usize().unwrap(),
3047 );
3048 let base2 = seq_pair.1[long_v];
3049 if match_probs.contains_key(&pos_pair) {
3050 let mut sum2 = NEG_INFINITY;
3051 let loopmatch_score = alignfold_scores.match_scores[base][base2];
3052 if let Some(&x) = sum_mat.get(&pos_pair2) {
3053 let x = x + alignfold_scores.match2match_score;
3054 logsumexp(&mut sum2, x);
3055 }
3056 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
3057 for &x in x {
3058 if computes_forward_sums && x >= pos_pair2.1
3059 || (!computes_forward_sums && x <= pos_pair2.1)
3060 {
3061 continue;
3062 }
3063 if let Some(&y) = sum_mat.get(&(pos_pair2.0, x)) {
3064 let long_x = x.to_usize().unwrap();
3065 let z = if computes_forward_sums {
3066 range_insert_scores.insert_scores2[long_x + 1][long_pos_pair2.1]
3067 } else {
3068 range_insert_scores.insert_scores2[long_pos_pair2.1][long_x - 1]
3069 } + alignfold_scores.match2insert_score;
3070 let z = y + alignfold_scores.match2insert_score + z;
3071 logsumexp(&mut sum2, z);
3072 }
3073 }
3074 }
3075 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
3076 for &x in x {
3077 if computes_forward_sums && x >= pos_pair2.0
3078 || (!computes_forward_sums && x <= pos_pair2.0)
3079 {
3080 continue;
3081 }
3082 if let Some(&y) = sum_mat.get(&(x, pos_pair2.1)) {
3083 let long_x = x.to_usize().unwrap();
3084 let z = if computes_forward_sums {
3085 range_insert_scores.insert_scores[long_x + 1][long_pos_pair2.0]
3086 } else {
3087 range_insert_scores.insert_scores[long_pos_pair2.0][long_x - 1]
3088 } + alignfold_scores.match2insert_score;
3089 let z = y + alignfold_scores.match2insert_score + z;
3090 logsumexp(&mut sum2, z);
3091 }
3092 }
3093 }
3094 if sum2 > NEG_INFINITY {
3095 sum_mat2.insert(pos_pair2, sum2);
3096 }
3097 let term = sum2 + loopmatch_score;
3098 logsumexp(&mut sum, term);
3099 if sum > NEG_INFINITY {
3100 sum_mat.insert(pos_pair, sum);
3101 }
3102 }
3103 }
3104 }
3105 (sum_mat, sum_mat2)
3106}
3107
3108pub fn get_alignfold_probs<T>(inputs: InputsAlignfoldProbsGetter<T>) -> AlignfoldProbMats<T>
3109where
3110 T: HashIndex,
3111{
3112 let (
3113 seq_pair,
3114 alignfold_scores,
3115 max_basepair_span_pair,
3116 match_probs,
3117 alignfold_sums,
3118 produces_struct_profs,
3119 global_sum,
3120 trains_alignfold_scores,
3121 alignfold_counts_expected,
3122 pos_quads_hashed_lens,
3123 fold_scores_pair,
3124 produces_match_probs,
3125 forward_pos_pairs,
3126 backward_pos_pairs,
3127 range_insert_scores,
3128 matchable_poss,
3129 matchable_poss2,
3130 ) = inputs;
3131 let seq_len_pair = (
3132 T::from_usize(seq_pair.0.len()).unwrap(),
3133 T::from_usize(seq_pair.1.len()).unwrap(),
3134 );
3135 let mut alignfold_outside_sums = SumMat4d::<T>::default();
3136 let mut alignfold_probs = AlignfoldProbMats::<T>::new(&(
3137 seq_len_pair.0.to_usize().unwrap(),
3138 seq_len_pair.1.to_usize().unwrap(),
3139 ));
3140 let leftmost_pos_pair = (T::zero(), T::zero());
3141 let rightmost_pos_pair = (seq_len_pair.0 - T::one(), seq_len_pair.1 - T::one());
3142 let mut prob_coeffs_multibranch = SumMat4d::<T>::default();
3143 let mut prob_coeffs_multibranch2 = prob_coeffs_multibranch.clone();
3144 for substr_len in range_inclusive(
3145 T::from_usize(if trains_alignfold_scores {
3146 2
3147 } else {
3148 MIN_SPAN_HAIRPIN_CLOSE
3149 })
3150 .unwrap(),
3151 max_basepair_span_pair.0,
3152 )
3153 .rev()
3154 {
3155 for substr_len2 in range_inclusive(
3156 T::from_usize(if trains_alignfold_scores {
3157 2
3158 } else {
3159 MIN_SPAN_HAIRPIN_CLOSE
3160 })
3161 .unwrap(),
3162 max_basepair_span_pair.1,
3163 )
3164 .rev()
3165 {
3166 if let Some(pos_pairs) = pos_quads_hashed_lens.get(&(substr_len, substr_len2)) {
3167 for &(i, k) in pos_pairs {
3168 let (j, l) = (i + substr_len - T::one(), k + substr_len2 - T::one());
3169 let pos_quad = (i, j, k, l);
3170 if let Some(&sum_close) = alignfold_sums.sums_close.get(&pos_quad) {
3171 let (long_i, long_j, long_k, long_l) = (
3172 i.to_usize().unwrap(),
3173 j.to_usize().unwrap(),
3174 k.to_usize().unwrap(),
3175 l.to_usize().unwrap(),
3176 );
3177 let basepair = (seq_pair.0[long_i], seq_pair.0[long_j]);
3178 let basepair2 = (seq_pair.1[long_k], seq_pair.1[long_l]);
3179 let mismatch_pair = (seq_pair.0[long_i - 1], seq_pair.0[long_j + 1]);
3180 let mismatch_pair2 = (seq_pair.1[long_k - 1], seq_pair.1[long_l + 1]);
3181 let prob_coeff = sum_close - global_sum;
3182 let mut sum = NEG_INFINITY;
3183 let mut forward_term = sum;
3184 let mut forward_term_match = sum;
3185 let mut backward_term = sum;
3186 let pos_pair2 = (i - T::one(), k - T::one());
3187 if let Some(&x) = alignfold_sums.forward_sums_external2.get(&pos_pair2) {
3188 logsumexp(&mut forward_term, x);
3189 }
3190 if trains_alignfold_scores {
3191 if let Some(&x) = alignfold_sums.forward_sums_external.get(&pos_pair2) {
3192 let begins_sum = pos_pair2 == leftmost_pos_pair;
3193 let x = x
3194 + if begins_sum {
3195 alignfold_scores.init_match_score
3196 } else {
3197 alignfold_scores.match2match_score
3198 };
3199 logsumexp(&mut forward_term_match, x);
3200 }
3201 }
3202 let pos_pair2 = (j + T::one(), l + T::one());
3203 if let Some(&x) = alignfold_sums.backward_sums_external2.get(&pos_pair2) {
3204 logsumexp(&mut backward_term, x);
3205 }
3206 let coeff = alignfold_sums.sums_accessible_external[&pos_quad] - sum_close;
3207 if trains_alignfold_scores {
3208 let begins_sum = (i - T::one(), k - T::one()) == leftmost_pos_pair;
3209 let x = prob_coeff + coeff + forward_term_match + backward_term;
3210 if begins_sum {
3211 logsumexp(&mut alignfold_counts_expected.init_match_score, x);
3212 } else {
3213 logsumexp(&mut alignfold_counts_expected.match2match_score, x);
3214 }
3215 }
3216 let sum_external = forward_term + backward_term;
3217 if sum_external > NEG_INFINITY {
3218 sum = coeff + sum_external;
3219 let x = prob_coeff + sum;
3220 if trains_alignfold_scores {
3221 logsumexp(
3223 &mut alignfold_counts_expected.external_score_basepair,
3224 (2. as Prob).ln() + x,
3225 );
3226 logsumexp(
3228 &mut alignfold_counts_expected.helix_close_scores[basepair.1][basepair.0],
3229 x,
3230 );
3231 logsumexp(
3232 &mut alignfold_counts_expected.helix_close_scores[basepair2.1][basepair2.0],
3233 x,
3234 );
3235 if j < seq_len_pair.0 - T::from_usize(2).unwrap() {
3237 logsumexp(
3238 &mut alignfold_counts_expected.dangling_scores_left[basepair.1][basepair.0]
3239 [mismatch_pair.1],
3240 x,
3241 );
3242 }
3243 if i > T::one() {
3244 logsumexp(
3245 &mut alignfold_counts_expected.dangling_scores_right[basepair.1][basepair.0]
3246 [mismatch_pair.0],
3247 x,
3248 );
3249 }
3250 if l < seq_len_pair.1 - T::from_usize(2).unwrap() {
3251 logsumexp(
3252 &mut alignfold_counts_expected.dangling_scores_left[basepair2.1][basepair2.0]
3253 [mismatch_pair2.1],
3254 x,
3255 );
3256 }
3257 if k > T::one() {
3258 logsumexp(
3259 &mut alignfold_counts_expected.dangling_scores_right[basepair2.1][basepair2.0]
3260 [mismatch_pair2.0],
3261 x,
3262 );
3263 }
3264 }
3265 }
3266 for substr_len3 in range_inclusive(
3267 substr_len + T::from_usize(2).unwrap(),
3268 (substr_len + T::from_usize(MAX_LOOP_LEN + 2).unwrap()).min(max_basepair_span_pair.0),
3269 ) {
3270 for substr_len4 in range_inclusive(
3271 substr_len2 + T::from_usize(2).unwrap(),
3272 (substr_len2 + T::from_usize(MAX_LOOP_LEN + 2).unwrap())
3273 .min(max_basepair_span_pair.1),
3274 ) {
3275 if let Some(pos_pairs2) = pos_quads_hashed_lens.get(&(substr_len3, substr_len4)) {
3276 for &(m, o) in pos_pairs2 {
3277 let (n, p) = (m + substr_len3 - T::one(), o + substr_len4 - T::one());
3278 if !(m < i && j < n && o < k && l < p) {
3279 continue;
3280 }
3281 let (long_m, long_n, long_o, long_p) = (
3282 m.to_usize().unwrap(),
3283 n.to_usize().unwrap(),
3284 o.to_usize().unwrap(),
3285 p.to_usize().unwrap(),
3286 );
3287 let loop_len_pair = (long_i - long_m - 1, long_n - long_j - 1);
3288 let loop_len_pair2 = (long_k - long_o - 1, long_p - long_l - 1);
3289 if loop_len_pair.0 + loop_len_pair.1 > MAX_LOOP_LEN {
3290 continue;
3291 }
3292 if loop_len_pair2.0 + loop_len_pair2.1 > MAX_LOOP_LEN {
3293 continue;
3294 }
3295 let basepair3 = (seq_pair.0[long_m], seq_pair.0[long_n]);
3296 let basepair4 = (seq_pair.1[long_o], seq_pair.1[long_p]);
3297 let found_stack = loop_len_pair.0 == 0 && loop_len_pair.1 == 0;
3298 let found_bulge =
3299 (loop_len_pair.0 == 0 || loop_len_pair.1 == 0) && !found_stack;
3300 let mismatch_pair3 = (seq_pair.0[long_m + 1], seq_pair.0[long_n - 1]);
3301 let pos_quad2 = (m, n, o, p);
3302 if let Some(&outside_sum) = alignfold_outside_sums.get(&pos_quad2) {
3303 let found_stack2 = loop_len_pair2.0 == 0 && loop_len_pair2.1 == 0;
3304 let found_bulge2 =
3305 (loop_len_pair2.0 == 0 || loop_len_pair2.1 == 0) && !found_stack2;
3306 let mismatch_pair4 = (seq_pair.1[long_o + 1], seq_pair.1[long_p - 1]);
3307 let forward_sums = &alignfold_sums.forward_sums_hashed_poss[&(m, o)];
3308 let forward_sums2 = &alignfold_sums.forward_sums_hashed_poss2[&(m, o)];
3309 let backward_sums = &alignfold_sums.backward_sums_hashed_poss2[&(n, p)];
3310 let mut forward_term = NEG_INFINITY;
3311 let mut forward_term_match = forward_term;
3312 let mut backward_term = forward_term;
3313 let pos_pair2 = (i - T::one(), k - T::one());
3314 if let Some(x) = forward_sums2.get(&pos_pair2) {
3315 logsumexp(&mut forward_term, x.sum_seqalign);
3316 }
3317 if trains_alignfold_scores {
3318 if let Some(x) = forward_sums.get(&pos_pair2) {
3319 let term = x.sum_seqalign + alignfold_scores.match2match_score;
3320 logsumexp(&mut forward_term_match, term);
3321 }
3322 }
3323 let pos_pair2 = (j + T::one(), l + T::one());
3324 if let Some(x) = backward_sums.get(&pos_pair2) {
3325 logsumexp(&mut backward_term, x.sum_seqalign);
3326 }
3327 let pairmatch_score = alignfold_scores.match_scores[basepair3.0][basepair4.0]
3328 + alignfold_scores.match_scores[basepair3.1][basepair4.1];
3329 let twoloop_score = fold_scores_pair.0.twoloop_scores[&(m, n, i, j)];
3330 let twoloop_score2 = fold_scores_pair.1.twoloop_scores[&(o, p, k, l)];
3331 let coeff = pairmatch_score + twoloop_score + twoloop_score2 + outside_sum;
3332 if trains_alignfold_scores {
3333 let x = prob_coeff + coeff + forward_term_match + backward_term;
3334 logsumexp(&mut alignfold_counts_expected.match2match_score, x);
3335 }
3336 let sum_2loop = forward_term + backward_term;
3337 if sum_2loop > NEG_INFINITY {
3338 let sum_2loop = coeff + sum_2loop;
3339 logsumexp(&mut sum, sum_2loop);
3340 let pairmatch_prob_2loop = prob_coeff + sum_2loop;
3341 if produces_struct_profs {
3342 let loop_len_pair = (long_i - long_m - 1, long_n - long_j - 1);
3343 let found_bulge = (loop_len_pair.0 == 0) ^ (loop_len_pair.1 == 0);
3344 let found_interior = loop_len_pair.0 > 0 && loop_len_pair.1 > 0;
3345 for q in long_m + 1..long_i {
3346 if found_bulge {
3347 logsumexp(
3348 &mut alignfold_probs.context_profs_pair.0[(q, CONTEXT_INDEX_BULGE)],
3349 pairmatch_prob_2loop,
3350 );
3351 } else if found_interior {
3352 logsumexp(
3353 &mut alignfold_probs.context_profs_pair.0
3354 [(q, CONTEXT_INDEX_INTERIOR)],
3355 pairmatch_prob_2loop,
3356 );
3357 }
3358 }
3359 for q in long_j + 1..long_n {
3360 if found_bulge {
3361 logsumexp(
3362 &mut alignfold_probs.context_profs_pair.0[(q, CONTEXT_INDEX_BULGE)],
3363 pairmatch_prob_2loop,
3364 );
3365 } else if found_interior {
3366 logsumexp(
3367 &mut alignfold_probs.context_profs_pair.0
3368 [(q, CONTEXT_INDEX_INTERIOR)],
3369 pairmatch_prob_2loop,
3370 );
3371 }
3372 }
3373 let loop_len_pair = (long_k - long_o - 1, long_p - long_l - 1);
3374 let found_bulge = (loop_len_pair.0 == 0) ^ (loop_len_pair.1 == 0);
3375 let found_interior = loop_len_pair.0 > 0 && loop_len_pair.1 > 0;
3376 for q in long_o + 1..long_k {
3377 if found_bulge {
3378 logsumexp(
3379 &mut alignfold_probs.context_profs_pair.1[(q, CONTEXT_INDEX_BULGE)],
3380 pairmatch_prob_2loop,
3381 );
3382 } else if found_interior {
3383 logsumexp(
3384 &mut alignfold_probs.context_profs_pair.1
3385 [(q, CONTEXT_INDEX_INTERIOR)],
3386 pairmatch_prob_2loop,
3387 );
3388 }
3389 }
3390 for q in long_l + 1..long_p {
3391 if found_bulge {
3392 logsumexp(
3393 &mut alignfold_probs.context_profs_pair.1[(q, CONTEXT_INDEX_BULGE)],
3394 pairmatch_prob_2loop,
3395 );
3396 } else if found_interior {
3397 logsumexp(
3398 &mut alignfold_probs.context_profs_pair.1
3399 [(q, CONTEXT_INDEX_INTERIOR)],
3400 pairmatch_prob_2loop,
3401 );
3402 }
3403 }
3404 }
3405 if trains_alignfold_scores {
3406 if found_stack {
3407 let dict_min_stack = get_dict_min_stack(&basepair3, &basepair);
3409 logsumexp(
3410 &mut alignfold_counts_expected.stack_scores[dict_min_stack.0 .0]
3411 [dict_min_stack.0 .1][dict_min_stack.1 .0][dict_min_stack.1 .1],
3412 pairmatch_prob_2loop,
3413 );
3414 } else {
3415 if found_bulge {
3416 let bulge_len = if loop_len_pair.0 == 0 {
3418 loop_len_pair.1
3419 } else {
3420 loop_len_pair.0
3421 };
3422 logsumexp(
3423 &mut alignfold_counts_expected.bulge_scores_len[bulge_len - 1],
3424 pairmatch_prob_2loop,
3425 );
3426 if bulge_len == 1 {
3428 let mismatch = if loop_len_pair.0 == 0 {
3429 mismatch_pair3.1
3430 } else {
3431 mismatch_pair3.0
3432 };
3433 logsumexp(
3434 &mut alignfold_counts_expected.bulge_scores_0x1[mismatch],
3435 pairmatch_prob_2loop,
3436 );
3437 }
3438 } else {
3439 logsumexp(
3441 &mut alignfold_counts_expected.interior_scores_len
3442 [loop_len_pair.0 + loop_len_pair.1 - 2],
3443 pairmatch_prob_2loop,
3444 );
3445 let diff = get_diff(loop_len_pair.0, loop_len_pair.1);
3446 if diff == 0 {
3447 logsumexp(
3448 &mut alignfold_counts_expected.interior_scores_symmetric
3449 [loop_len_pair.0 - 1],
3450 pairmatch_prob_2loop,
3451 );
3452 } else {
3453 logsumexp(
3454 &mut alignfold_counts_expected.interior_scores_asymmetric
3455 [diff - 1],
3456 pairmatch_prob_2loop,
3457 );
3458 }
3459 if loop_len_pair.0 == 1 && loop_len_pair.1 == 1 {
3461 let dict_min_mismatch_pair3 = get_dict_min_pair(&mismatch_pair3);
3462 logsumexp(
3463 &mut alignfold_counts_expected.interior_scores_1x1
3464 [dict_min_mismatch_pair3.0][dict_min_mismatch_pair3.1],
3465 pairmatch_prob_2loop,
3466 );
3467 }
3468 if loop_len_pair.0 <= MAX_INTERIOR_EXPLICIT
3470 && loop_len_pair.1 <= MAX_INTERIOR_EXPLICIT
3471 {
3472 let dict_min_len_pair = get_dict_min_pair(&loop_len_pair);
3473 logsumexp(
3474 &mut alignfold_counts_expected.interior_scores_explicit
3475 [dict_min_len_pair.0 - 1][dict_min_len_pair.1 - 1],
3476 pairmatch_prob_2loop,
3477 );
3478 }
3479 }
3480 logsumexp(
3482 &mut alignfold_counts_expected.helix_close_scores[basepair.1]
3483 [basepair.0],
3484 pairmatch_prob_2loop,
3485 );
3486 logsumexp(
3487 &mut alignfold_counts_expected.helix_close_scores[basepair3.0]
3488 [basepair3.1],
3489 pairmatch_prob_2loop,
3490 );
3491 logsumexp(
3493 &mut alignfold_counts_expected.terminal_mismatch_scores[basepair3.0]
3494 [basepair3.1][mismatch_pair3.0][mismatch_pair3.1],
3495 pairmatch_prob_2loop,
3496 );
3497 logsumexp(
3498 &mut alignfold_counts_expected.terminal_mismatch_scores[basepair.1]
3499 [basepair.0][mismatch_pair.1][mismatch_pair.0],
3500 pairmatch_prob_2loop,
3501 );
3502 }
3503 if found_stack2 {
3504 let dict_min_stack2 = get_dict_min_stack(&basepair4, &basepair2);
3506 logsumexp(
3507 &mut alignfold_counts_expected.stack_scores[dict_min_stack2.0 .0]
3508 [dict_min_stack2.0 .1][dict_min_stack2.1 .0][dict_min_stack2.1 .1],
3509 pairmatch_prob_2loop,
3510 );
3511 } else {
3512 if found_bulge2 {
3513 let bulge_len2 = if loop_len_pair2.0 == 0 {
3515 loop_len_pair2.1
3516 } else {
3517 loop_len_pair2.0
3518 };
3519 logsumexp(
3520 &mut alignfold_counts_expected.bulge_scores_len[bulge_len2 - 1],
3521 pairmatch_prob_2loop,
3522 );
3523 if bulge_len2 == 1 {
3525 let mismatch2 = if loop_len_pair2.0 == 0 {
3526 mismatch_pair4.1
3527 } else {
3528 mismatch_pair4.0
3529 };
3530 logsumexp(
3531 &mut alignfold_counts_expected.bulge_scores_0x1[mismatch2],
3532 pairmatch_prob_2loop,
3533 );
3534 }
3535 } else {
3536 logsumexp(
3538 &mut alignfold_counts_expected.interior_scores_len
3539 [loop_len_pair2.0 + loop_len_pair2.1 - 2],
3540 pairmatch_prob_2loop,
3541 );
3542 let diff2 = get_diff(loop_len_pair2.0, loop_len_pair2.1);
3543 if diff2 == 0 {
3544 logsumexp(
3545 &mut alignfold_counts_expected.interior_scores_symmetric
3546 [loop_len_pair2.0 - 1],
3547 pairmatch_prob_2loop,
3548 );
3549 } else {
3550 logsumexp(
3551 &mut alignfold_counts_expected.interior_scores_asymmetric
3552 [diff2 - 1],
3553 pairmatch_prob_2loop,
3554 );
3555 }
3556 if loop_len_pair2.0 == 1 && loop_len_pair2.1 == 1 {
3558 let dict_min_mismatch_pair4 = get_dict_min_pair(&mismatch_pair4);
3559 logsumexp(
3560 &mut alignfold_counts_expected.interior_scores_1x1
3561 [dict_min_mismatch_pair4.0][dict_min_mismatch_pair4.1],
3562 pairmatch_prob_2loop,
3563 );
3564 }
3565 if loop_len_pair2.0 <= MAX_INTERIOR_EXPLICIT
3567 && loop_len_pair2.1 <= MAX_INTERIOR_EXPLICIT
3568 {
3569 let dict_min_len_pair2 = get_dict_min_pair(&loop_len_pair2);
3570 logsumexp(
3571 &mut alignfold_counts_expected.interior_scores_explicit
3572 [dict_min_len_pair2.0 - 1][dict_min_len_pair2.1 - 1],
3573 pairmatch_prob_2loop,
3574 );
3575 }
3576 }
3577 logsumexp(
3579 &mut alignfold_counts_expected.helix_close_scores[basepair2.1]
3580 [basepair2.0],
3581 pairmatch_prob_2loop,
3582 );
3583 logsumexp(
3584 &mut alignfold_counts_expected.helix_close_scores[basepair4.0]
3585 [basepair4.1],
3586 pairmatch_prob_2loop,
3587 );
3588 logsumexp(
3590 &mut alignfold_counts_expected.terminal_mismatch_scores[basepair4.0]
3591 [basepair4.1][mismatch_pair4.0][mismatch_pair4.1],
3592 pairmatch_prob_2loop,
3593 );
3594 logsumexp(
3595 &mut alignfold_counts_expected.terminal_mismatch_scores[basepair2.1]
3596 [basepair2.0][mismatch_pair2.1][mismatch_pair2.0],
3597 pairmatch_prob_2loop,
3598 );
3599 }
3600 }
3601 }
3602 }
3603 }
3604 }
3605 }
3606 }
3607 let sum_ratio = alignfold_sums.sums_accessible_multibranch[&pos_quad] - sum_close;
3608 for (pos_pair, forward_sums) in &alignfold_sums.forward_sums_hashed_poss {
3609 let &(u, v) = pos_pair;
3610 if !(u < i && v < k) {
3611 continue;
3612 }
3613 let forward_sums2 = &alignfold_sums.forward_sums_hashed_poss2[pos_pair];
3614 let pos_quad2 = (u, j, v, l);
3615 let mut forward_term = NEG_INFINITY;
3616 let mut forward_term_match = forward_term;
3617 let mut forward_term2 = forward_term;
3618 let mut forward_term_match2 = forward_term;
3619 let pos_pair2 = (i - T::one(), k - T::one());
3620 if let Some(x) = forward_sums2.get(&pos_pair2) {
3621 logsumexp(&mut forward_term, x.sum_1ormore_pairmatches);
3622 logsumexp(&mut forward_term2, x.sum_seqalign_multibranch);
3623 }
3624 if trains_alignfold_scores {
3625 if let Some(x) = forward_sums.get(&pos_pair2) {
3626 let y = x.sum_1ormore_pairmatches + alignfold_scores.match2match_score;
3627 logsumexp(&mut forward_term_match, y);
3628 let y = x.sum_seqalign_multibranch + alignfold_scores.match2match_score;
3629 logsumexp(&mut forward_term_match2, y);
3630 }
3631 }
3632 let mut sum_multibranch = NEG_INFINITY;
3633 if let Some(x) = prob_coeffs_multibranch.get(&pos_quad2) {
3634 let x = x + sum_ratio;
3635 let y = x + forward_term;
3636 logsumexp(&mut sum_multibranch, y);
3637 if trains_alignfold_scores {
3638 let y = prob_coeff + x + forward_term_match;
3639 logsumexp(&mut alignfold_counts_expected.match2match_score, y);
3640 }
3641 }
3642 if let Some(x) = prob_coeffs_multibranch2.get(&pos_quad2) {
3643 let x = x + sum_ratio;
3644 let y = x + forward_term2;
3645 logsumexp(&mut sum_multibranch, y);
3646 if trains_alignfold_scores {
3647 let y = prob_coeff + x + forward_term_match2;
3648 logsumexp(&mut alignfold_counts_expected.match2match_score, y);
3649 }
3650 }
3651 if sum_multibranch > NEG_INFINITY {
3652 logsumexp(&mut sum, sum_multibranch);
3653 let pairmatch_prob_multibranch = prob_coeff + sum_multibranch;
3654 if trains_alignfold_scores {
3655 logsumexp(
3657 &mut alignfold_counts_expected.dangling_scores_left[basepair.1][basepair.0]
3658 [mismatch_pair.1],
3659 pairmatch_prob_multibranch,
3660 );
3661 logsumexp(
3662 &mut alignfold_counts_expected.dangling_scores_right[basepair.1][basepair.0]
3663 [mismatch_pair.0],
3664 pairmatch_prob_multibranch,
3665 );
3666 logsumexp(
3667 &mut alignfold_counts_expected.dangling_scores_left[basepair2.1][basepair2.0]
3668 [mismatch_pair2.1],
3669 pairmatch_prob_multibranch,
3670 );
3671 logsumexp(
3672 &mut alignfold_counts_expected.dangling_scores_right[basepair2.1][basepair2.0]
3673 [mismatch_pair2.0],
3674 pairmatch_prob_multibranch,
3675 );
3676 logsumexp(
3678 &mut alignfold_counts_expected.helix_close_scores[basepair.1][basepair.0],
3679 pairmatch_prob_multibranch,
3680 );
3681 logsumexp(
3682 &mut alignfold_counts_expected.helix_close_scores[basepair2.1][basepair2.0],
3683 pairmatch_prob_multibranch,
3684 );
3685 logsumexp(
3687 &mut alignfold_counts_expected.multibranch_score_base,
3688 (2. as Prob).ln() + pairmatch_prob_multibranch,
3689 );
3690 logsumexp(
3692 &mut alignfold_counts_expected.multibranch_score_basepair,
3693 (2. as Prob).ln() + pairmatch_prob_multibranch,
3694 );
3695 logsumexp(
3697 &mut alignfold_counts_expected.multibranch_score_basepair,
3698 (2. as Prob).ln() + pairmatch_prob_multibranch,
3699 );
3700 }
3701 }
3702 }
3703 if sum > NEG_INFINITY {
3704 alignfold_outside_sums.insert(pos_quad, sum);
3705 let pairmatch_prob = prob_coeff + sum;
3706 if produces_match_probs {
3707 alignfold_probs
3708 .pairmatch_probs
3709 .insert(pos_quad, pairmatch_prob);
3710 match alignfold_probs.match_probs.get_mut(&(i, k)) {
3711 Some(x) => {
3712 logsumexp(x, pairmatch_prob);
3713 }
3714 None => {
3715 alignfold_probs.match_probs.insert((i, k), pairmatch_prob);
3716 }
3717 }
3718 match alignfold_probs.match_probs.get_mut(&(j, l)) {
3719 Some(x) => {
3720 logsumexp(x, pairmatch_prob);
3721 }
3722 None => {
3723 alignfold_probs.match_probs.insert((j, l), pairmatch_prob);
3724 }
3725 }
3726 }
3727 if trains_alignfold_scores {
3728 let dict_min_basepair = get_dict_min_pair(&basepair);
3730 let dict_min_basepair2 = get_dict_min_pair(&basepair2);
3731 logsumexp(
3732 &mut alignfold_counts_expected.basepair_scores[dict_min_basepair.0]
3733 [dict_min_basepair.1],
3734 pairmatch_prob,
3735 );
3736 logsumexp(
3737 &mut alignfold_counts_expected.basepair_scores[dict_min_basepair2.0]
3738 [dict_min_basepair2.1],
3739 pairmatch_prob,
3740 );
3741 let dict_min_match = get_dict_min_pair(&(basepair.0, basepair2.0));
3743 logsumexp(
3744 &mut alignfold_counts_expected.match_scores[dict_min_match.0][dict_min_match.1],
3745 pairmatch_prob,
3746 );
3747 let dict_min_match = get_dict_min_pair(&(basepair.1, basepair2.1));
3748 logsumexp(
3749 &mut alignfold_counts_expected.match_scores[dict_min_match.0][dict_min_match.1],
3750 pairmatch_prob,
3751 );
3752 }
3753 match alignfold_probs.basepair_probs_pair.0.get_mut(&(i, j)) {
3754 Some(x) => {
3755 logsumexp(x, pairmatch_prob);
3756 }
3757 None => {
3758 alignfold_probs
3759 .basepair_probs_pair
3760 .0
3761 .insert((i, j), pairmatch_prob);
3762 }
3763 }
3764 match alignfold_probs.basepair_probs_pair.1.get_mut(&(k, l)) {
3765 Some(x) => {
3766 logsumexp(x, pairmatch_prob);
3767 }
3768 None => {
3769 alignfold_probs
3770 .basepair_probs_pair
3771 .1
3772 .insert((k, l), pairmatch_prob);
3773 }
3774 }
3775 if produces_struct_profs {
3776 logsumexp(
3777 &mut alignfold_probs.context_profs_pair.0[(long_i, CONTEXT_INDEX_BASEPAIR)],
3778 pairmatch_prob,
3779 );
3780 logsumexp(
3781 &mut alignfold_probs.context_profs_pair.0[(long_j, CONTEXT_INDEX_BASEPAIR)],
3782 pairmatch_prob,
3783 );
3784 logsumexp(
3785 &mut alignfold_probs.context_profs_pair.1[(long_k, CONTEXT_INDEX_BASEPAIR)],
3786 pairmatch_prob,
3787 );
3788 logsumexp(
3789 &mut alignfold_probs.context_profs_pair.1[(long_l, CONTEXT_INDEX_BASEPAIR)],
3790 pairmatch_prob,
3791 );
3792 }
3793 let pairmatch_score = alignfold_scores.match_scores[basepair.0][basepair2.0]
3794 + alignfold_scores.match_scores[basepair.1][basepair2.1];
3795 let multibranch_close_score = fold_scores_pair.0.multibranch_close_scores[&(i, j)];
3796 let multibranch_close_score2 = fold_scores_pair.1.multibranch_close_scores[&(k, l)];
3797 if trains_alignfold_scores {
3798 let mismatch_pair = (seq_pair.0[long_i + 1], seq_pair.0[long_j - 1]);
3799 let mismatch_pair2 = (seq_pair.1[long_k + 1], seq_pair.1[long_l - 1]);
3800 let forward_sums = &alignfold_sums.forward_sums_hashed_poss[&(i, k)];
3801 let forward_sums2 = &alignfold_sums.forward_sums_hashed_poss2[&(i, k)];
3802 if substr_len.to_usize().unwrap() - 2 <= MAX_LOOP_LEN
3803 && substr_len2.to_usize().unwrap() - 2 <= MAX_LOOP_LEN
3804 {
3805 let mut sum_seqalign = NEG_INFINITY;
3806 let mut sum_seqalign_match = sum_seqalign;
3807 let pos_pair2 = (j - T::one(), l - T::one());
3808 if let Some(x) = forward_sums2.get(&pos_pair2) {
3809 logsumexp(&mut sum_seqalign, x.sum_seqalign);
3810 }
3811 if let Some(x) = forward_sums.get(&pos_pair2) {
3812 let x = x.sum_seqalign + alignfold_scores.match2match_score;
3813 logsumexp(&mut sum_seqalign_match, x);
3814 }
3815 let hairpin_score = fold_scores_pair.0.hairpin_scores[&(i, j)];
3816 let hairpin_score2 = fold_scores_pair.1.hairpin_scores[&(k, l)];
3817 let prob = sum - global_sum
3818 + sum_seqalign_match
3819 + hairpin_score
3820 + hairpin_score2
3821 + pairmatch_score;
3822 logsumexp(&mut alignfold_counts_expected.match2match_score, prob);
3823 let pairmatch_prob_hairpin = sum - global_sum
3824 + sum_seqalign
3825 + hairpin_score
3826 + hairpin_score2
3827 + pairmatch_score;
3828 logsumexp(
3829 &mut alignfold_counts_expected.hairpin_scores_len[long_j - long_i - 1],
3830 pairmatch_prob_hairpin,
3831 );
3832 logsumexp(
3833 &mut alignfold_counts_expected.hairpin_scores_len[long_l - long_k - 1],
3834 pairmatch_prob_hairpin,
3835 );
3836 logsumexp(
3837 &mut alignfold_counts_expected.terminal_mismatch_scores[basepair.0][basepair.1]
3838 [mismatch_pair.0][mismatch_pair.1],
3839 pairmatch_prob_hairpin,
3840 );
3841 logsumexp(
3842 &mut alignfold_counts_expected.terminal_mismatch_scores[basepair2.0]
3843 [basepair2.1][mismatch_pair2.0][mismatch_pair2.1],
3844 pairmatch_prob_hairpin,
3845 );
3846 logsumexp(
3848 &mut alignfold_counts_expected.helix_close_scores[basepair.0][basepair.1],
3849 pairmatch_prob_hairpin,
3850 );
3851 logsumexp(
3852 &mut alignfold_counts_expected.helix_close_scores[basepair2.0][basepair2.1],
3853 pairmatch_prob_hairpin,
3854 );
3855 }
3856 let mut sum_multibranch = NEG_INFINITY;
3857 let mut sum_multibranch_match = sum_multibranch;
3858 if let Some(x) = forward_sums2.get(&pos_pair2) {
3859 logsumexp(&mut sum_multibranch, x.sum_multibranch);
3860 }
3861 if let Some(x) = forward_sums.get(&pos_pair2) {
3862 let x = x.sum_multibranch + alignfold_scores.match2match_score;
3863 logsumexp(&mut sum_multibranch_match, x);
3864 }
3865 let prob = sum - global_sum
3866 + sum_multibranch_match
3867 + multibranch_close_score
3868 + multibranch_close_score2
3869 + pairmatch_score;
3870 logsumexp(&mut alignfold_counts_expected.match2match_score, prob);
3871 let pairmatch_prob_multibranch = sum - global_sum
3872 + sum_multibranch
3873 + multibranch_close_score
3874 + multibranch_close_score2
3875 + pairmatch_score;
3876 logsumexp(
3878 &mut alignfold_counts_expected.dangling_scores_left[basepair.0][basepair.1]
3879 [mismatch_pair.0],
3880 pairmatch_prob_multibranch,
3881 );
3882 logsumexp(
3883 &mut alignfold_counts_expected.dangling_scores_right[basepair.0][basepair.1]
3884 [mismatch_pair.1],
3885 pairmatch_prob_multibranch,
3886 );
3887 logsumexp(
3888 &mut alignfold_counts_expected.dangling_scores_left[basepair2.0][basepair2.1]
3889 [mismatch_pair2.0],
3890 pairmatch_prob_multibranch,
3891 );
3892 logsumexp(
3893 &mut alignfold_counts_expected.dangling_scores_right[basepair2.0][basepair2.1]
3894 [mismatch_pair2.1],
3895 pairmatch_prob_multibranch,
3896 );
3897 logsumexp(
3899 &mut alignfold_counts_expected.helix_close_scores[basepair.0][basepair.1],
3900 pairmatch_prob_multibranch,
3901 );
3902 logsumexp(
3903 &mut alignfold_counts_expected.helix_close_scores[basepair2.0][basepair2.1],
3904 pairmatch_prob_multibranch,
3905 );
3906 }
3907 let coeff =
3908 sum + pairmatch_score + multibranch_close_score + multibranch_close_score2;
3909 let backward_sums = &alignfold_sums.backward_sums_hashed_poss2[&(j, l)];
3910 for pos_pair in match_probs.keys() {
3911 let &(u, v) = pos_pair;
3912 if !(i < u && u < j && k < v && v < l) {
3913 continue;
3914 }
3915 let mut backward_term = NEG_INFINITY;
3916 let mut backward_term2 = backward_term;
3917 let pos_pair2 = (u + T::one(), v + T::one());
3918 if let Some(x) = backward_sums.get(&pos_pair2) {
3919 logsumexp(&mut backward_term, x.sum_0ormore_pairmatches);
3920 logsumexp(&mut backward_term2, x.sum_1ormore_pairmatches);
3921 }
3922 let pos_quad2 = (i, u, k, v);
3923 let x = coeff + backward_term;
3924 match prob_coeffs_multibranch.get_mut(&pos_quad2) {
3925 Some(y) => {
3926 logsumexp(y, x);
3927 }
3928 None => {
3929 prob_coeffs_multibranch.insert(pos_quad2, x);
3930 }
3931 }
3932 let x = coeff + backward_term2;
3933 match prob_coeffs_multibranch2.get_mut(&pos_quad2) {
3934 Some(y) => {
3935 logsumexp(y, x);
3936 }
3937 None => {
3938 prob_coeffs_multibranch2.insert(pos_quad2, x);
3939 }
3940 }
3941 }
3942 }
3943 }
3944 }
3945 }
3946 }
3947 }
3948 for x in alignfold_probs.basepair_probs_pair.0.values_mut() {
3949 *x = expf(*x);
3950 }
3951 for x in alignfold_probs.basepair_probs_pair.1.values_mut() {
3952 *x = expf(*x);
3953 }
3954 let needs_twoloop_sums = produces_match_probs || trains_alignfold_scores;
3955 let needs_indel_info = produces_struct_profs || trains_alignfold_scores;
3956 if produces_struct_profs || needs_twoloop_sums {
3957 let mut unpair_probs_range_external =
3958 (SparseProbMat::<T>::default(), SparseProbMat::<T>::default());
3959 let mut unpair_probs_range_hairpin =
3960 (SparseProbMat::<T>::default(), SparseProbMat::<T>::default());
3961 let mut unpair_probs_range = (SparseProbMat::<T>::default(), SparseProbMat::<T>::default());
3962 for u in range(T::zero(), seq_len_pair.0 - T::one()) {
3963 let long_u = u.to_usize().unwrap();
3964 let base = seq_pair.0[long_u];
3965 for v in range(T::zero(), seq_len_pair.1 - T::one()) {
3966 let pos_pair = (u, v);
3967 let long_v = v.to_usize().unwrap();
3968 let base2 = seq_pair.1[long_v];
3969 let pos_pair2 = (u + T::one(), v + T::one());
3970 let long_pos_pair2 = (
3971 pos_pair2.0.to_usize().unwrap(),
3972 pos_pair2.1.to_usize().unwrap(),
3973 );
3974 let dict_min_match = get_dict_min_pair(&(base, base2));
3975 if match_probs.contains_key(&pos_pair) {
3976 let pos_pair_loopmatch = (u - T::one(), v - T::one());
3977 let mut loopmatch_prob_external = NEG_INFINITY;
3978 let loopmatch_score = alignfold_scores.match_scores[base][base2]
3979 + 2. * alignfold_scores.external_score_unpair;
3980 let mut backward_term = NEG_INFINITY;
3981 if let Some(&x) = alignfold_sums.backward_sums_external2.get(&pos_pair2) {
3982 logsumexp(&mut backward_term, x);
3983 }
3984 if let Some(&x) = alignfold_sums
3985 .forward_sums_external2
3986 .get(&pos_pair_loopmatch)
3987 {
3988 let term = loopmatch_score + x + backward_term - global_sum;
3989 logsumexp(&mut loopmatch_prob_external, term);
3990 }
3991 if produces_struct_profs {
3992 logsumexp(
3993 &mut alignfold_probs.context_profs_pair.0[(long_u, CONTEXT_INDEX_EXTERNAL)],
3994 loopmatch_prob_external,
3995 );
3996 logsumexp(
3997 &mut alignfold_probs.context_profs_pair.1[(long_v, CONTEXT_INDEX_EXTERNAL)],
3998 loopmatch_prob_external,
3999 );
4000 }
4001 if produces_match_probs {
4002 match alignfold_probs.loopmatch_probs.get_mut(&pos_pair) {
4003 Some(x) => {
4004 logsumexp(x, loopmatch_prob_external);
4005 }
4006 None => {
4007 alignfold_probs
4008 .loopmatch_probs
4009 .insert(pos_pair, loopmatch_prob_external);
4010 }
4011 }
4012 }
4013 if trains_alignfold_scores {
4014 logsumexp(
4015 &mut alignfold_counts_expected.match_scores[dict_min_match.0][dict_min_match.1],
4016 loopmatch_prob_external,
4017 );
4018 logsumexp(
4019 &mut alignfold_counts_expected.external_score_unpair,
4020 (2. as Prob).ln() + loopmatch_prob_external,
4021 );
4022 if let Some(&x) = alignfold_sums
4023 .forward_sums_external
4024 .get(&pos_pair_loopmatch)
4025 {
4026 let begins_sum = pos_pair_loopmatch == leftmost_pos_pair;
4027 let x = x
4028 + if begins_sum {
4029 alignfold_scores.init_match_score
4030 } else {
4031 alignfold_scores.match2match_score
4032 };
4033 let x = loopmatch_score + x + backward_term - global_sum;
4034 if begins_sum {
4035 logsumexp(&mut alignfold_counts_expected.init_match_score, x);
4036 } else {
4037 logsumexp(&mut alignfold_counts_expected.match2match_score, x);
4038 }
4039 }
4040 }
4041 }
4042 if needs_indel_info {
4043 if let Some(&sum) = alignfold_sums.forward_sums_external.get(&pos_pair) {
4044 let begins_sum = pos_pair == leftmost_pos_pair;
4045 let forward_term = sum
4046 + if begins_sum {
4047 alignfold_scores.init_insert_score
4048 } else {
4049 alignfold_scores.match2insert_score
4050 };
4051 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
4052 for &x in x {
4053 if x <= pos_pair2.1 {
4054 continue;
4055 }
4056 let pos_pair3 = (pos_pair2.0, x);
4057 if let Some(&y) = alignfold_sums.backward_sums_external.get(&pos_pair3) {
4058 let long_x = x.to_usize().unwrap();
4059 let ends_sum = pos_pair3 == rightmost_pos_pair;
4060 let z = range_insert_scores.insert_scores_external2[long_pos_pair2.1][long_x - 1]
4061 + if ends_sum {
4062 0.
4063 } else {
4064 alignfold_scores.match2insert_score
4065 };
4066 let z = forward_term + y + z - global_sum;
4067 if trains_alignfold_scores {
4068 if begins_sum {
4069 logsumexp(&mut alignfold_counts_expected.init_insert_score, z);
4070 } else {
4071 logsumexp(&mut alignfold_counts_expected.match2insert_score, z);
4072 }
4073 if !ends_sum {
4074 logsumexp(&mut alignfold_counts_expected.match2insert_score, z);
4075 }
4076 logsumexp(
4077 &mut alignfold_counts_expected.external_score_unpair,
4078 ((long_x - long_pos_pair2.1) as Prob).ln() + z,
4079 );
4080 if pos_pair2.1 < x - T::one() {
4081 logsumexp(
4082 &mut alignfold_counts_expected.insert_extend_score,
4083 ((long_x - long_pos_pair2.1 - 1) as Prob).ln() + z,
4084 );
4085 }
4086 }
4087 let pos_pair4 = (pos_pair2.1, x - T::one());
4088 if produces_struct_profs {
4089 match unpair_probs_range_external.1.get_mut(&pos_pair4) {
4090 Some(x) => {
4091 logsumexp(x, z);
4092 }
4093 None => {
4094 unpair_probs_range_external.1.insert(pos_pair4, z);
4095 }
4096 }
4097 }
4098 if trains_alignfold_scores {
4099 match unpair_probs_range.1.get_mut(&pos_pair4) {
4100 Some(x) => {
4101 logsumexp(x, z);
4102 }
4103 None => {
4104 unpair_probs_range.1.insert(pos_pair4, z);
4105 }
4106 }
4107 }
4108 }
4109 }
4110 }
4111 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
4112 for &x in x {
4113 if x <= pos_pair2.0 {
4114 continue;
4115 }
4116 let pos_pair3 = (x, pos_pair2.1);
4117 if let Some(&y) = alignfold_sums.backward_sums_external.get(&pos_pair3) {
4118 let long_x = x.to_usize().unwrap();
4119 let ends_sum = pos_pair3 == rightmost_pos_pair;
4120 let z = range_insert_scores.insert_scores_external[long_pos_pair2.0][long_x - 1]
4121 + if ends_sum {
4122 0.
4123 } else {
4124 alignfold_scores.match2insert_score
4125 };
4126 let z = forward_term + y + z - global_sum;
4127 if trains_alignfold_scores {
4128 if begins_sum {
4129 logsumexp(&mut alignfold_counts_expected.init_insert_score, z);
4130 } else {
4131 logsumexp(&mut alignfold_counts_expected.match2insert_score, z);
4132 }
4133 if !ends_sum {
4134 logsumexp(&mut alignfold_counts_expected.match2insert_score, z);
4135 }
4136 logsumexp(
4137 &mut alignfold_counts_expected.external_score_unpair,
4138 ((long_x - long_pos_pair2.0) as Prob).ln() + z,
4139 );
4140 if pos_pair2.0 < x - T::one() {
4141 logsumexp(
4142 &mut alignfold_counts_expected.insert_extend_score,
4143 ((long_x - long_pos_pair2.0 - 1) as Prob).ln() + z,
4144 );
4145 }
4146 }
4147 let pos_pair4 = (pos_pair2.0, x - T::one());
4148 if produces_struct_profs {
4149 match unpair_probs_range_external.0.get_mut(&pos_pair4) {
4150 Some(x) => {
4151 logsumexp(x, z);
4152 }
4153 None => {
4154 unpair_probs_range_external.0.insert(pos_pair4, z);
4155 }
4156 }
4157 }
4158 if trains_alignfold_scores {
4159 match unpair_probs_range.0.get_mut(&pos_pair4) {
4160 Some(x) => {
4161 logsumexp(x, z);
4162 }
4163 None => {
4164 unpair_probs_range.0.insert(pos_pair4, z);
4165 }
4166 }
4167 }
4168 }
4169 }
4170 }
4171 }
4172 }
4173 }
4174 }
4175 for (pos_quad, &outside_sum) in &alignfold_outside_sums {
4176 let (i, j, k, l) = *pos_quad;
4177 let (long_i, long_j, long_k, long_l) = (
4178 i.to_usize().unwrap(),
4179 j.to_usize().unwrap(),
4180 k.to_usize().unwrap(),
4181 l.to_usize().unwrap(),
4182 );
4183 let basepair = (seq_pair.0[long_i], seq_pair.0[long_j]);
4184 let basepair2 = (seq_pair.1[long_k], seq_pair.1[long_l]);
4185 let hairpin_score = if j - i - T::one() <= T::from_usize(MAX_LOOP_LEN).unwrap() {
4186 fold_scores_pair.0.hairpin_scores[&(i, j)]
4187 } else {
4188 NEG_INFINITY
4189 };
4190 let hairpin_score2 = if l - k - T::one() <= T::from_usize(MAX_LOOP_LEN).unwrap() {
4191 fold_scores_pair.1.hairpin_scores[&(k, l)]
4192 } else {
4193 NEG_INFINITY
4194 };
4195 let multibranch_close_score = fold_scores_pair.0.multibranch_close_scores[&(i, j)];
4196 let multibranch_close_score2 = fold_scores_pair.1.multibranch_close_scores[&(k, l)];
4197 let pairmatch_score = alignfold_scores.match_scores[basepair.0][basepair2.0]
4198 + alignfold_scores.match_scores[basepair.1][basepair2.1];
4199 let prob_coeff = outside_sum - global_sum + pairmatch_score;
4200 let forward_sums = &alignfold_sums.forward_sums_hashed_poss[&(i, k)];
4201 let forward_sums2 = &alignfold_sums.forward_sums_hashed_poss2[&(i, k)];
4202 let backward_sums = &alignfold_sums.backward_sums_hashed_poss[&(j, l)];
4203 let backward_sums2 = &alignfold_sums.backward_sums_hashed_poss2[&(j, l)];
4204 let (forward_sums_2loop, forward_sums_2loop2) = if needs_twoloop_sums {
4205 let computes_forward_sums = true;
4206 get_2loop_sums((
4207 seq_pair,
4208 alignfold_scores,
4209 match_probs,
4210 pos_quad,
4211 alignfold_sums,
4212 computes_forward_sums,
4213 forward_pos_pairs,
4214 fold_scores_pair,
4215 range_insert_scores,
4216 matchable_poss,
4217 matchable_poss2,
4218 ))
4219 } else {
4220 (SparseSumMat::<T>::default(), SparseSumMat::<T>::default())
4221 };
4222 let (backward_sums_2loop, backward_sums_2loop2) = if needs_twoloop_sums {
4223 let computes_forward_sums = false;
4224 get_2loop_sums((
4225 seq_pair,
4226 alignfold_scores,
4227 match_probs,
4228 pos_quad,
4229 alignfold_sums,
4230 computes_forward_sums,
4231 backward_pos_pairs,
4232 fold_scores_pair,
4233 range_insert_scores,
4234 matchable_poss,
4235 matchable_poss2,
4236 ))
4237 } else {
4238 (SparseSumMat::<T>::default(), SparseSumMat::<T>::default())
4239 };
4240 if trains_alignfold_scores {
4241 let pos_pair2 = (j - T::one(), l - T::one());
4242 if let Some(&x) = forward_sums_2loop.get(&pos_pair2) {
4243 let x = prob_coeff + x + alignfold_scores.match2match_score;
4244 logsumexp(&mut alignfold_counts_expected.match2match_score, x);
4245 }
4246 }
4247 for u in range_inclusive(i, j) {
4248 let long_u = u.to_usize().unwrap();
4249 let base = seq_pair.0[long_u];
4250 for v in range_inclusive(k, l) {
4251 let pos_pair = (u, v);
4252 let long_v = v.to_usize().unwrap();
4253 let base2 = seq_pair.1[long_v];
4254 let pos_pair2 = (u + T::one(), v + T::one());
4255 let long_pos_pair2 = (
4256 pos_pair2.0.to_usize().unwrap(),
4257 pos_pair2.1.to_usize().unwrap(),
4258 );
4259 let dict_min_match = get_dict_min_pair(&(base, base2));
4260 let mut backward_term_seqalign = NEG_INFINITY;
4261 let mut backward_term_multibranch = backward_term_seqalign;
4262 let mut backward_term_1ormore_pairmatches = backward_term_seqalign;
4263 let mut backward_term_0ormore_pairmatches = backward_term_seqalign;
4264 let mut backward_term_2loop = backward_term_seqalign;
4265 if let Some(x) = backward_sums2.get(&pos_pair2) {
4266 logsumexp(&mut backward_term_seqalign, x.sum_seqalign);
4267 if needs_twoloop_sums {
4268 logsumexp(&mut backward_term_multibranch, x.sum_multibranch);
4269 logsumexp(
4270 &mut backward_term_1ormore_pairmatches,
4271 x.sum_1ormore_pairmatches,
4272 );
4273 logsumexp(
4274 &mut backward_term_0ormore_pairmatches,
4275 x.sum_0ormore_pairmatches,
4276 );
4277 }
4278 }
4279 if needs_twoloop_sums {
4280 if let Some(&x) = backward_sums_2loop2.get(&pos_pair2) {
4281 logsumexp(&mut backward_term_2loop, x);
4282 }
4283 }
4284 let prob_coeff_hairpin = prob_coeff + hairpin_score + hairpin_score2;
4285 let prob_coeff_multibranch =
4286 prob_coeff + multibranch_close_score + multibranch_close_score2;
4287 if match_probs.contains_key(&pos_pair) {
4288 let pos_pair_loopmatch = (u - T::one(), v - T::one());
4289 let loopmatch_score = alignfold_scores.match_scores[base][base2];
4290 let loopmatch_score_multibranch =
4291 loopmatch_score + 2. * alignfold_scores.multibranch_score_unpair;
4292 let mut loopmatch_prob_hairpin = NEG_INFINITY;
4293 let mut loopmatch_prob_multibranch = loopmatch_prob_hairpin;
4294 let mut loopmatch_prob_2loop = loopmatch_prob_hairpin;
4295 if let Some(x) = forward_sums2.get(&pos_pair_loopmatch) {
4296 let y =
4297 prob_coeff_hairpin + loopmatch_score + x.sum_seqalign + backward_term_seqalign;
4298 logsumexp(&mut loopmatch_prob_hairpin, y);
4299 if needs_twoloop_sums {
4300 let y = prob_coeff_multibranch
4301 + loopmatch_score_multibranch
4302 + x.sum_seqalign_multibranch
4303 + backward_term_multibranch;
4304 logsumexp(&mut loopmatch_prob_multibranch, y);
4305 let y = prob_coeff_multibranch
4306 + loopmatch_score_multibranch
4307 + x.sum_1st_pairmatches
4308 + backward_term_1ormore_pairmatches;
4309 logsumexp(&mut loopmatch_prob_multibranch, y);
4310 let y = prob_coeff_multibranch
4311 + loopmatch_score_multibranch
4312 + x.sum_multibranch
4313 + backward_term_0ormore_pairmatches;
4314 logsumexp(&mut loopmatch_prob_multibranch, y);
4315 let y = prob_coeff + loopmatch_score + x.sum_seqalign + backward_term_2loop;
4316 logsumexp(&mut loopmatch_prob_2loop, y);
4317 }
4318 }
4319 if produces_struct_profs {
4320 logsumexp(
4321 &mut alignfold_probs.context_profs_pair.0[(long_u, CONTEXT_INDEX_HAIRPIN)],
4322 loopmatch_prob_hairpin,
4323 );
4324 logsumexp(
4325 &mut alignfold_probs.context_profs_pair.1[(long_v, CONTEXT_INDEX_HAIRPIN)],
4326 loopmatch_prob_hairpin,
4327 );
4328 }
4329 if needs_twoloop_sums {
4330 if let Some(&x) = forward_sums_2loop2.get(&pos_pair_loopmatch) {
4331 let x = prob_coeff + loopmatch_score + x + backward_term_seqalign;
4332 logsumexp(&mut loopmatch_prob_2loop, x);
4333 }
4334 let mut prob = NEG_INFINITY;
4335 logsumexp(&mut prob, loopmatch_prob_hairpin);
4336 logsumexp(&mut prob, loopmatch_prob_multibranch);
4337 logsumexp(&mut prob, loopmatch_prob_2loop);
4338 if produces_match_probs {
4339 match alignfold_probs.loopmatch_probs.get_mut(&pos_pair) {
4340 Some(x) => {
4341 logsumexp(x, prob);
4342 }
4343 None => {
4344 alignfold_probs.loopmatch_probs.insert(pos_pair, prob);
4345 }
4346 }
4347 }
4348 if trains_alignfold_scores {
4349 logsumexp(
4350 &mut alignfold_counts_expected.match_scores[dict_min_match.0][dict_min_match.1],
4351 prob,
4352 );
4353 logsumexp(
4354 &mut alignfold_counts_expected.multibranch_score_unpair,
4355 (2. as Prob).ln() + loopmatch_prob_multibranch,
4356 );
4357 }
4358 }
4359 if let Some(x) = forward_sums.get(&pos_pair_loopmatch) {
4360 let y = prob_coeff_hairpin
4361 + loopmatch_score
4362 + x.sum_seqalign
4363 + alignfold_scores.match2match_score
4364 + backward_term_seqalign;
4365 if trains_alignfold_scores {
4366 logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4367 }
4368 if needs_twoloop_sums {
4369 let y = prob_coeff_multibranch
4370 + loopmatch_score_multibranch
4371 + x.sum_seqalign_multibranch
4372 + alignfold_scores.match2match_score
4373 + backward_term_multibranch;
4374 if trains_alignfold_scores {
4375 logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4376 }
4377 let y = prob_coeff_multibranch
4378 + loopmatch_score_multibranch
4379 + x.sum_1st_pairmatches
4380 + alignfold_scores.match2match_score
4381 + backward_term_1ormore_pairmatches;
4382 if trains_alignfold_scores {
4383 logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4384 }
4385 let y = prob_coeff_multibranch
4386 + loopmatch_score_multibranch
4387 + x.sum_multibranch
4388 + alignfold_scores.match2match_score
4389 + backward_term_0ormore_pairmatches;
4390 if trains_alignfold_scores {
4391 logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4392 }
4393 let y = prob_coeff
4394 + loopmatch_score
4395 + x.sum_seqalign
4396 + alignfold_scores.match2match_score
4397 + backward_term_2loop;
4398 if trains_alignfold_scores {
4399 logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4400 }
4401 }
4402 }
4403 if needs_indel_info {
4404 if let Some(sums) = forward_sums.get(&pos_pair) {
4405 let forward_term_seqalign = sums.sum_seqalign + alignfold_scores.match2insert_score;
4406 let forward_term_seqalign_multibranch =
4407 sums.sum_seqalign_multibranch + alignfold_scores.match2insert_score;
4408 let forward_term_1st_pairmatches =
4409 sums.sum_1st_pairmatches + alignfold_scores.match2insert_score;
4410 let forward_term_multibranch =
4411 sums.sum_multibranch + alignfold_scores.match2insert_score;
4412 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
4413 for &x in x {
4414 if x <= pos_pair2.1 {
4415 continue;
4416 }
4417 let mut insert_prob = NEG_INFINITY;
4418 let mut insert_prob_multibranch = insert_prob;
4419 let long_x = x.to_usize().unwrap();
4420 let y = range_insert_scores.insert_scores2[long_pos_pair2.1][long_x - 1]
4421 + alignfold_scores.match2insert_score;
4422 if let Some(z) = backward_sums.get(&(pos_pair2.0, x)) {
4423 let a = prob_coeff_hairpin + forward_term_seqalign + y + z.sum_seqalign;
4424 logsumexp(&mut insert_prob, a);
4425 if produces_struct_profs {
4426 let pos_pair4 = (pos_pair2.1, x - T::one());
4427 match unpair_probs_range_hairpin.1.get_mut(&pos_pair4) {
4428 Some(x) => {
4429 logsumexp(x, a);
4430 }
4431 None => {
4432 unpair_probs_range_hairpin.1.insert(pos_pair4, a);
4433 }
4434 }
4435 }
4436 if trains_alignfold_scores {
4437 let y = range_insert_scores.insert_scores_multibranch2[long_pos_pair2.1]
4438 [long_x - 1]
4439 + alignfold_scores.match2insert_score;
4440 let a = prob_coeff_multibranch
4441 + forward_term_seqalign_multibranch
4442 + y
4443 + z.sum_multibranch;
4444 logsumexp(&mut insert_prob_multibranch, a);
4445 let a = prob_coeff_multibranch
4446 + forward_term_1st_pairmatches
4447 + y
4448 + z.sum_1ormore_pairmatches;
4449 logsumexp(&mut insert_prob_multibranch, a);
4450 let a = prob_coeff_multibranch
4451 + forward_term_multibranch
4452 + y
4453 + z.sum_0ormore_pairmatches;
4454 logsumexp(&mut insert_prob_multibranch, a);
4455 logsumexp(&mut insert_prob, insert_prob_multibranch);
4456 }
4457 }
4458 if trains_alignfold_scores {
4459 if let Some(&z) = backward_sums_2loop.get(&(pos_pair2.0, x)) {
4460 let z = prob_coeff + forward_term_seqalign + y + z;
4461 logsumexp(&mut insert_prob, z);
4462 }
4463 logsumexp(
4464 &mut alignfold_counts_expected.match2insert_score,
4465 (2. as Prob).ln() + insert_prob,
4466 );
4467 let pos_pair4 = (pos_pair2.1, x - T::one());
4468 match unpair_probs_range.1.get_mut(&pos_pair4) {
4469 Some(x) => {
4470 logsumexp(x, insert_prob);
4471 }
4472 None => {
4473 unpair_probs_range.1.insert(pos_pair4, insert_prob);
4474 }
4475 }
4476 logsumexp(
4477 &mut alignfold_counts_expected.multibranch_score_unpair,
4478 ((long_x - long_pos_pair2.1) as Prob).ln() + insert_prob_multibranch,
4479 );
4480 if pos_pair2.1 < x - T::one() {
4481 logsumexp(
4482 &mut alignfold_counts_expected.insert_extend_score,
4483 ((long_x - long_pos_pair2.1 - 1) as Prob).ln() + insert_prob,
4484 );
4485 }
4486 }
4487 }
4488 }
4489 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
4490 for &x in x {
4491 if x <= pos_pair2.0 {
4492 continue;
4493 }
4494 let mut insert_prob = NEG_INFINITY;
4495 let mut insert_prob_multibranch = insert_prob;
4496 let long_x = x.to_usize().unwrap();
4497 let y = range_insert_scores.insert_scores[long_pos_pair2.0][long_x - 1]
4498 + alignfold_scores.match2insert_score;
4499 if let Some(z) = backward_sums.get(&(x, pos_pair2.1)) {
4500 let a = prob_coeff_hairpin + forward_term_seqalign + y + z.sum_seqalign;
4501 logsumexp(&mut insert_prob, a);
4502 if produces_struct_profs {
4503 let pos_pair4 = (pos_pair2.0, x - T::one());
4504 match unpair_probs_range_hairpin.0.get_mut(&pos_pair4) {
4505 Some(x) => {
4506 logsumexp(x, a);
4507 }
4508 None => {
4509 unpair_probs_range_hairpin.0.insert(pos_pair4, a);
4510 }
4511 }
4512 }
4513 if trains_alignfold_scores {
4514 let y = range_insert_scores.insert_scores_multibranch[long_pos_pair2.0]
4515 [long_x - 1]
4516 + alignfold_scores.match2insert_score;
4517 let a = prob_coeff_multibranch
4518 + forward_term_seqalign_multibranch
4519 + y
4520 + z.sum_multibranch;
4521 logsumexp(&mut insert_prob_multibranch, a);
4522 let a = prob_coeff_multibranch
4523 + forward_term_1st_pairmatches
4524 + y
4525 + z.sum_1ormore_pairmatches;
4526 logsumexp(&mut insert_prob_multibranch, a);
4527 let a = prob_coeff_multibranch
4528 + forward_term_multibranch
4529 + y
4530 + z.sum_0ormore_pairmatches;
4531 logsumexp(&mut insert_prob_multibranch, a);
4532 logsumexp(&mut insert_prob, insert_prob_multibranch);
4533 }
4534 }
4535 if let Some(&z) = backward_sums_2loop.get(&(x, pos_pair2.1)) {
4536 let z = prob_coeff + forward_term_seqalign + y + z;
4537 logsumexp(&mut insert_prob, z);
4538 }
4539 if trains_alignfold_scores {
4540 logsumexp(
4541 &mut alignfold_counts_expected.match2insert_score,
4542 (2. as Prob).ln() + insert_prob,
4543 );
4544 let pos_pair4 = (pos_pair2.0, x - T::one());
4545 match unpair_probs_range.0.get_mut(&pos_pair4) {
4546 Some(x) => {
4547 logsumexp(x, insert_prob);
4548 }
4549 None => {
4550 unpair_probs_range.0.insert(pos_pair4, insert_prob);
4551 }
4552 }
4553 logsumexp(
4554 &mut alignfold_counts_expected.multibranch_score_unpair,
4555 ((long_x - long_pos_pair2.0) as Prob).ln() + insert_prob_multibranch,
4556 );
4557 if pos_pair2.0 < x - T::one() {
4558 logsumexp(
4559 &mut alignfold_counts_expected.insert_extend_score,
4560 ((long_x - long_pos_pair2.0 - 1) as Prob).ln() + insert_prob,
4561 );
4562 }
4563 }
4564 }
4565 }
4566 }
4567 if trains_alignfold_scores {
4568 if let Some(&forward_sum) = forward_sums_2loop.get(&pos_pair) {
4569 if let Some(x) = matchable_poss.get(&pos_pair2.0) {
4570 for &x in x {
4571 if x <= pos_pair2.1 {
4572 continue;
4573 }
4574 let mut insert_prob = NEG_INFINITY;
4575 let long_x = x.to_usize().unwrap();
4576 let y = range_insert_scores.insert_scores2[long_pos_pair2.1][long_x - 1]
4577 + alignfold_scores.match2insert_score;
4578 if let Some(z) = backward_sums.get(&(pos_pair2.0, x)) {
4579 let a = prob_coeff + forward_sum + y + z.sum_seqalign;
4580 logsumexp(&mut insert_prob, a);
4581 logsumexp(
4582 &mut alignfold_counts_expected.match2insert_score,
4583 (2. as Prob).ln() + insert_prob,
4584 );
4585 let pos_pair4 = (pos_pair2.1, x - T::one());
4586 match unpair_probs_range.1.get_mut(&pos_pair4) {
4587 Some(x) => {
4588 logsumexp(x, insert_prob);
4589 }
4590 None => {
4591 unpair_probs_range.1.insert(pos_pair4, insert_prob);
4592 }
4593 }
4594 if pos_pair2.1 < x - T::one() {
4595 logsumexp(
4596 &mut alignfold_counts_expected.insert_extend_score,
4597 ((long_x - long_pos_pair2.1 - 1) as Prob).ln() + insert_prob,
4598 );
4599 }
4600 }
4601 }
4602 }
4603 if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
4604 for &x in x {
4605 if x <= pos_pair2.0 {
4606 continue;
4607 }
4608 let mut insert_prob = NEG_INFINITY;
4609 let long_x = x.to_usize().unwrap();
4610 let y = range_insert_scores.insert_scores[long_pos_pair2.0][long_x - 1]
4611 + alignfold_scores.match2insert_score;
4612 if let Some(z) = backward_sums.get(&(x, pos_pair2.1)) {
4613 let a = prob_coeff + forward_sum + y + z.sum_seqalign;
4614 logsumexp(&mut insert_prob, a);
4615 logsumexp(
4616 &mut alignfold_counts_expected.match2insert_score,
4617 (2. as Prob).ln() + insert_prob,
4618 );
4619 let pos_pair4 = (pos_pair2.0, x - T::one());
4620 match unpair_probs_range.0.get_mut(&pos_pair4) {
4621 Some(x) => {
4622 logsumexp(x, insert_prob);
4623 }
4624 None => {
4625 unpair_probs_range.0.insert(pos_pair4, insert_prob);
4626 }
4627 }
4628 if pos_pair2.0 < x - T::one() {
4629 logsumexp(
4630 &mut alignfold_counts_expected.insert_extend_score,
4631 ((long_x - long_pos_pair2.0 - 1) as Prob).ln() + insert_prob,
4632 );
4633 }
4634 }
4635 }
4636 }
4637 }
4638 }
4639 }
4640 }
4641 }
4642 }
4643 }
4644 if needs_indel_info {
4645 for (x, &y) in &unpair_probs_range.0 {
4646 for i in range_inclusive(x.0, x.1) {
4647 let long_i = i.to_usize().unwrap();
4648 let x = seq_pair.0[long_i];
4649 logsumexp(&mut alignfold_counts_expected.insert_scores[x], y);
4650 }
4651 }
4652 for (x, &y) in &unpair_probs_range.1 {
4653 for i in range_inclusive(x.0, x.1) {
4654 let long_i = i.to_usize().unwrap();
4655 let x = seq_pair.1[long_i];
4656 logsumexp(&mut alignfold_counts_expected.insert_scores[x], y);
4657 }
4658 }
4659 }
4660 if produces_struct_profs {
4661 for (x, &y) in &unpair_probs_range_external.0 {
4662 for i in range_inclusive(x.0, x.1) {
4663 let long_i = i.to_usize().unwrap();
4664 logsumexp(
4665 &mut alignfold_probs.context_profs_pair.0[(long_i, CONTEXT_INDEX_EXTERNAL)],
4666 y,
4667 );
4668 }
4669 }
4670 for (x, &y) in &unpair_probs_range_external.1 {
4671 for i in range_inclusive(x.0, x.1) {
4672 let long_i = i.to_usize().unwrap();
4673 logsumexp(
4674 &mut alignfold_probs.context_profs_pair.1[(long_i, CONTEXT_INDEX_EXTERNAL)],
4675 y,
4676 );
4677 }
4678 }
4679 for (x, &y) in &unpair_probs_range_hairpin.0 {
4680 for i in range_inclusive(x.0, x.1) {
4681 let long_i = i.to_usize().unwrap();
4682 logsumexp(
4683 &mut alignfold_probs.context_profs_pair.0[(long_i, CONTEXT_INDEX_HAIRPIN)],
4684 y,
4685 );
4686 }
4687 }
4688 for (x, &y) in &unpair_probs_range_hairpin.1 {
4689 for i in range_inclusive(x.0, x.1) {
4690 let long_i = i.to_usize().unwrap();
4691 logsumexp(
4692 &mut alignfold_probs.context_profs_pair.1[(long_i, CONTEXT_INDEX_HAIRPIN)],
4693 y,
4694 );
4695 }
4696 }
4697 alignfold_probs
4698 .context_profs_pair
4699 .0
4700 .slice_mut(s![.., ..CONTEXT_INDEX_MULTIBRANCH])
4701 .mapv_inplace(expf);
4702 let fold = 1.
4703 - alignfold_probs
4704 .context_profs_pair
4705 .0
4706 .slice_mut(s![.., ..CONTEXT_INDEX_MULTIBRANCH])
4707 .sum_axis(Axis(1));
4708 alignfold_probs
4709 .context_profs_pair
4710 .0
4711 .slice_mut(s![.., CONTEXT_INDEX_MULTIBRANCH])
4712 .assign(&fold);
4713 alignfold_probs
4714 .context_profs_pair
4715 .1
4716 .slice_mut(s![.., ..CONTEXT_INDEX_MULTIBRANCH])
4717 .mapv_inplace(expf);
4718 let fold = 1.
4719 - alignfold_probs
4720 .context_profs_pair
4721 .1
4722 .slice_mut(s![.., ..CONTEXT_INDEX_MULTIBRANCH])
4723 .sum_axis(Axis(1));
4724 alignfold_probs
4725 .context_profs_pair
4726 .1
4727 .slice_mut(s![.., CONTEXT_INDEX_MULTIBRANCH])
4728 .assign(&fold);
4729 }
4730 if produces_match_probs {
4731 for (x, y) in alignfold_probs.loopmatch_probs.iter_mut() {
4732 match alignfold_probs.match_probs.get_mut(x) {
4733 Some(x) => {
4734 logsumexp(x, *y);
4735 *x = expf(*x);
4736 }
4737 None => {
4738 alignfold_probs.match_probs.insert(*x, expf(*y));
4739 }
4740 }
4741 *y = expf(*y);
4742 }
4743 for x in alignfold_probs.pairmatch_probs.values_mut() {
4744 *x = expf(*x);
4745 }
4746 }
4747 if trains_alignfold_scores {
4748 for x in alignfold_counts_expected.hairpin_scores_len.iter_mut() {
4749 *x = expf(*x);
4750 }
4751 for x in alignfold_counts_expected.bulge_scores_len.iter_mut() {
4752 *x = expf(*x);
4753 }
4754 for x in alignfold_counts_expected.interior_scores_len.iter_mut() {
4755 *x = expf(*x);
4756 }
4757 for x in alignfold_counts_expected
4758 .interior_scores_symmetric
4759 .iter_mut()
4760 {
4761 *x = expf(*x);
4762 }
4763 for x in alignfold_counts_expected
4764 .interior_scores_asymmetric
4765 .iter_mut()
4766 {
4767 *x = expf(*x);
4768 }
4769 for x in alignfold_counts_expected.stack_scores.iter_mut() {
4770 for x in x.iter_mut() {
4771 for x in x.iter_mut() {
4772 for x in x.iter_mut() {
4773 *x = expf(*x);
4774 }
4775 }
4776 }
4777 }
4778 for x in alignfold_counts_expected
4779 .terminal_mismatch_scores
4780 .iter_mut()
4781 {
4782 for x in x.iter_mut() {
4783 for x in x.iter_mut() {
4784 for x in x.iter_mut() {
4785 *x = expf(*x);
4786 }
4787 }
4788 }
4789 }
4790 for x in alignfold_counts_expected.dangling_scores_left.iter_mut() {
4791 for x in x.iter_mut() {
4792 for x in x.iter_mut() {
4793 *x = expf(*x);
4794 }
4795 }
4796 }
4797 for x in alignfold_counts_expected.dangling_scores_right.iter_mut() {
4798 for x in x.iter_mut() {
4799 for x in x.iter_mut() {
4800 *x = expf(*x);
4801 }
4802 }
4803 }
4804 for x in alignfold_counts_expected
4805 .interior_scores_explicit
4806 .iter_mut()
4807 {
4808 for x in x.iter_mut() {
4809 *x = expf(*x);
4810 }
4811 }
4812 for x in alignfold_counts_expected.bulge_scores_0x1.iter_mut() {
4813 *x = expf(*x);
4814 }
4815 for x in alignfold_counts_expected.interior_scores_1x1.iter_mut() {
4816 for x in x.iter_mut() {
4817 *x = expf(*x);
4818 }
4819 }
4820 for x in alignfold_counts_expected.helix_close_scores.iter_mut() {
4821 for x in x.iter_mut() {
4822 *x = expf(*x);
4823 }
4824 }
4825 for x in alignfold_counts_expected.basepair_scores.iter_mut() {
4826 for x in x.iter_mut() {
4827 *x = expf(*x);
4828 }
4829 }
4830 alignfold_counts_expected.multibranch_score_base =
4831 expf(alignfold_counts_expected.multibranch_score_base);
4832 alignfold_counts_expected.multibranch_score_basepair =
4833 expf(alignfold_counts_expected.multibranch_score_basepair);
4834 alignfold_counts_expected.multibranch_score_unpair =
4835 expf(alignfold_counts_expected.multibranch_score_unpair);
4836 alignfold_counts_expected.external_score_basepair =
4837 expf(alignfold_counts_expected.external_score_basepair);
4838 alignfold_counts_expected.external_score_unpair =
4839 expf(alignfold_counts_expected.external_score_unpair);
4840 alignfold_counts_expected.match2match_score =
4841 expf(alignfold_counts_expected.match2match_score);
4842 alignfold_counts_expected.match2insert_score =
4843 expf(alignfold_counts_expected.match2insert_score);
4844 alignfold_counts_expected.insert_extend_score =
4845 expf(alignfold_counts_expected.insert_extend_score);
4846 alignfold_counts_expected.init_match_score = expf(alignfold_counts_expected.init_match_score);
4847 alignfold_counts_expected.init_insert_score =
4848 expf(alignfold_counts_expected.init_insert_score);
4849 for x in alignfold_counts_expected.insert_scores.iter_mut() {
4850 *x = expf(*x);
4851 }
4852 for x in alignfold_counts_expected.match_scores.iter_mut() {
4853 for x in x.iter_mut() {
4854 *x = expf(*x);
4855 }
4856 }
4857 }
4858 }
4859 alignfold_probs
4860}
4861
4862pub fn get_diff(x: usize, y: usize) -> usize {
4863 max(x, y) - min(x, y)
4864}
4865
4866pub fn get_hairpin_score(x: &AlignfoldScores, y: SeqSlice, z: &(usize, usize)) -> Score {
4867 let a = z.1 - z.0 - 1;
4868 x.hairpin_scores_len_cumulative[a] + get_junction_score_single(x, y, z)
4869}
4870
4871pub fn get_twoloop_score(
4872 x: &AlignfoldScores,
4873 y: SeqSlice,
4874 z: &(usize, usize),
4875 a: &(usize, usize),
4876) -> Score {
4877 let b = (y[a.0], y[a.1]);
4878 let c = if z.0 + 1 == a.0 && z.1 - 1 == a.1 {
4879 get_stack_score(x, y, z, a)
4880 } else if z.0 + 1 == a.0 || z.1 - 1 == a.1 {
4881 get_bulge_score(x, y, z, a)
4882 } else {
4883 get_interior_score(x, y, z, a)
4884 };
4885 c + x.basepair_scores[b.0][b.1]
4886}
4887
4888pub fn get_stack_score(
4889 x: &AlignfoldScores,
4890 y: SeqSlice,
4891 z: &(usize, usize),
4892 a: &(usize, usize),
4893) -> Score {
4894 let b = (y[z.0], y[z.1]);
4895 let c = (y[a.0], y[a.1]);
4896 x.stack_scores[b.0][b.1][c.0][c.1]
4897}
4898
4899pub fn get_bulge_score(
4900 x: &AlignfoldScores,
4901 y: SeqSlice,
4902 z: &(usize, usize),
4903 a: &(usize, usize),
4904) -> Score {
4905 let b = a.0 - z.0 + z.1 - a.1 - 2;
4906 let c = if b == 1 {
4907 x.bulge_scores_0x1[if a.0 - z.0 - 1 == 1 {
4908 y[z.0 + 1]
4909 } else {
4910 y[z.1 - 1]
4911 }]
4912 } else {
4913 0.
4914 };
4915 c + x.bulge_scores_len_cumulative[b - 1]
4916 + get_junction_score_single(x, y, z)
4917 + get_junction_score_single(x, y, &(a.1, a.0))
4918}
4919
4920pub fn get_interior_score(
4921 x: &AlignfoldScores,
4922 y: SeqSlice,
4923 z: &(usize, usize),
4924 a: &(usize, usize),
4925) -> Score {
4926 let b = (a.0 - z.0 - 1, z.1 - a.1 - 1);
4927 let c = b.0 + b.1;
4928 let d = if b.0 == b.1 {
4929 let d = if c == 2 {
4930 x.interior_scores_1x1[y[z.0 + 1]][y[z.1 - 1]]
4931 } else {
4932 0.
4933 };
4934 d + x.interior_scores_symmetric_cumulative[b.0 - 1]
4935 } else {
4936 x.interior_scores_asymmetric_cumulative[get_abs_diff(b.0, b.1) - 1]
4937 };
4938 let e = if b.0 <= 4 && b.1 <= 4 {
4939 x.interior_scores_explicit[b.0 - 1][b.1 - 1]
4940 } else {
4941 0.
4942 };
4943 d + e
4944 + x.interior_scores_len_cumulative[c - 2]
4945 + get_junction_score_single(x, y, z)
4946 + get_junction_score_single(x, y, &(a.1, a.0))
4947}
4948
4949pub fn get_junction_score_single(x: &AlignfoldScores, y: SeqSlice, z: &(usize, usize)) -> Score {
4950 let a = (y[z.0], y[z.1]);
4951 get_helix_close_score(x, &a) + get_terminal_mismatch_score(x, &a, &(y[z.0 + 1], y[z.1 - 1]))
4952}
4953
4954pub fn get_helix_close_score(x: &AlignfoldScores, y: &Basepair) -> Score {
4955 x.helix_close_scores[y.0][y.1]
4956}
4957
4958pub fn get_terminal_mismatch_score(x: &AlignfoldScores, y: &Basepair, z: &Basepair) -> Score {
4959 x.terminal_mismatch_scores[y.0][y.1][z.0][z.1]
4960}
4961
4962pub fn get_junction_score(x: &AlignfoldScores, y: SeqSlice, z: &(usize, usize)) -> Score {
4963 let a = (y[z.0], y[z.1]);
4964 let b = 1;
4965 let c = y.len() - 2;
4966 get_helix_close_score(x, &a)
4967 + if z.0 < c {
4968 x.dangling_scores_left[a.0][a.1][y[z.0 + 1]]
4969 } else {
4970 0.
4971 }
4972 + if z.1 > b {
4973 x.dangling_scores_right[a.0][a.1][y[z.1 - 1]]
4974 } else {
4975 0.
4976 }
4977}
4978
4979pub fn get_dict_min_stack(x: &Basepair, y: &Basepair) -> (Basepair, Basepair) {
4980 let z = (*x, *y);
4981 let a = ((y.1, y.0), (x.1, x.0));
4982 if z < a {
4983 z
4984 } else {
4985 a
4986 }
4987}
4988
4989pub fn get_dict_min_pair(x: &(usize, usize)) -> (usize, usize) {
4990 let y = (x.1, x.0);
4991 if *x < y {
4992 *x
4993 } else {
4994 y
4995 }
4996}
4997
4998pub fn get_num_unpairs_multibranch(
4999 x: &(usize, usize),
5000 y: &Vec<(usize, usize)>,
5001 z: SeqSlice,
5002) -> usize {
5003 let mut a = 0;
5004 let mut b = (0, 0);
5005 for y in y {
5006 let c = if b == (0, 0) { x.0 + 1 } else { b.1 + 1 };
5007 for &d in &z[c..y.0] {
5008 a += usize::from(d != PSEUDO_BASE);
5009 }
5010 b = *y;
5011 }
5012 for &c in &z[b.1 + 1..x.1] {
5013 a += usize::from(c != PSEUDO_BASE);
5014 }
5015 a
5016}
5017
5018pub fn get_num_unpairs_external(x: &Vec<(usize, usize)>, y: SeqSlice) -> usize {
5019 let mut z = 0;
5020 let mut a = (0, 0);
5021 for x in x {
5022 let b = if a == (0, 0) { 0 } else { a.1 + 1 };
5023 for &b in y[b..x.0].iter() {
5024 z += usize::from(b != PSEUDO_BASE);
5025 }
5026 a = *x;
5027 }
5028 for &b in y[a.1 + 1..y.len()].iter() {
5029 z += usize::from(b != PSEUDO_BASE);
5030 }
5031 z
5032}
5033
5034pub fn consprob_trained<T>(
5035 thread_pool: &mut Pool,
5036 seqs: &SeqSlices,
5037 min_basepair_prob: Prob,
5038 min_match_prob: Prob,
5039 produces_struct_profs: bool,
5040 produces_match_probs: bool,
5041 train_type: TrainType,
5042) -> (ProbMatSetsAvg<T>, MatchProbsHashedIds<T>)
5043where
5044 T: HashIndex,
5045{
5046 let trained = AlignfoldScores::load_trained_scores();
5047 let mut align_scores = AlignScores::new(0.);
5048 copy_alignfold_scores_align(&mut align_scores, &trained);
5049 let ref_align_scores = &align_scores;
5050 let mut fold_scores = FoldScoreSets::new(0.);
5051 copy_alignfold_scores_fold(&mut fold_scores, &trained);
5052 let ref_fold_scores = &fold_scores;
5053 let alignfold_scores = if matches!(train_type, TrainType::TrainedTransfer) {
5054 trained
5055 } else if matches!(train_type, TrainType::TrainedRandinit) {
5056 AlignfoldScores::load_trained_scores_randinit()
5057 } else {
5058 let mut transferred = AlignfoldScores::new(0.);
5059 transferred.transfer();
5060 transferred
5061 };
5062 let num_seqs = seqs.len();
5063 let mut basepair_prob_mats = vec![SparseProbMat::<T>::new(); num_seqs];
5064 let mut max_basepair_spans = vec![T::zero(); num_seqs];
5065 let mut fold_score_sets = vec![FoldScoresTrained::<T>::new(); num_seqs];
5066 let ref_alignfold_scores = &alignfold_scores;
5067 let uses_contra_model = true;
5068 let allows_short_hairpins = false;
5069 thread_pool.scoped(|scope| {
5070 for (x, y, z, a) in multizip((
5071 basepair_prob_mats.iter_mut(),
5072 max_basepair_spans.iter_mut(),
5073 seqs.iter(),
5074 fold_score_sets.iter_mut(),
5075 )) {
5076 let b = z.len();
5077 scope.execute(move || {
5078 let c = mccaskill_algo(
5079 &z[1..b - 1],
5080 uses_contra_model,
5081 allows_short_hairpins,
5082 ref_fold_scores,
5083 )
5084 .0;
5085 *x = filter_basepair_probs::<T>(&c, min_basepair_prob);
5086 *y = get_max_basepair_span::<T>(x);
5087 *a = FoldScoresTrained::<T>::set_curr_scores(ref_alignfold_scores, z, x);
5088 });
5089 }
5090 });
5091 let mut alignfold_probs_hashed_ids = AlignfoldProbsHashedIds::<T>::default();
5092 let mut match_probs_hashed_ids = SparseProbsHashedIds::<T>::default();
5093 for x in 0..num_seqs {
5094 for y in x + 1..num_seqs {
5095 let y = (x, y);
5096 alignfold_probs_hashed_ids.insert(y, AlignfoldProbMats::<T>::origin());
5097 match_probs_hashed_ids.insert(y, SparseProbMat::<T>::default());
5098 }
5099 }
5100 thread_pool.scoped(|x| {
5101 for (y, z) in match_probs_hashed_ids.iter_mut() {
5102 let y = (seqs[y.0], seqs[y.1]);
5103 x.execute(move || {
5104 *z = filter_match_probs(&durbin_algo(&y, ref_align_scores), min_match_prob);
5105 });
5106 }
5107 });
5108 let trains_alignfold_scores = false;
5109 thread_pool.scoped(|x| {
5110 let alignfold_scores = &alignfold_scores;
5111 for (y, z) in alignfold_probs_hashed_ids.iter_mut() {
5112 let seq_pair = (seqs[y.0], seqs[y.1]);
5113 let seq_len_pair = (
5114 T::from_usize(seq_pair.0.len()).unwrap(),
5115 T::from_usize(seq_pair.1.len()).unwrap(),
5116 );
5117 let max_basepair_span_pair = (max_basepair_spans[y.0], max_basepair_spans[y.1]);
5118 let basepair_probs_pair = (&basepair_prob_mats[y.0], &basepair_prob_mats[y.1]);
5119 let fold_scores_pair = (&fold_score_sets[y.0], &fold_score_sets[y.1]);
5120 let match_probs = &match_probs_hashed_ids[y];
5121 let (
5122 forward_pos_pairs,
5123 backward_pos_pairs,
5124 _,
5125 pos_quads_hashed_lens,
5126 matchable_poss,
5127 matchable_poss2,
5128 ) = get_sparse_poss(&basepair_probs_pair, match_probs, &seq_len_pair);
5129 x.execute(move || {
5130 *z = consprob_core::<T>((
5131 &seq_pair,
5132 alignfold_scores,
5133 &max_basepair_span_pair,
5134 match_probs,
5135 produces_struct_profs,
5136 trains_alignfold_scores,
5137 &mut AlignfoldScores::new(NEG_INFINITY),
5138 &forward_pos_pairs,
5139 &backward_pos_pairs,
5140 &pos_quads_hashed_lens,
5141 &fold_scores_pair,
5142 produces_match_probs,
5143 &matchable_poss,
5144 &matchable_poss2,
5145 ))
5146 .0;
5147 });
5148 }
5149 });
5150 let mut alignfold_prob_mats_avg = vec![AlignfoldProbMatsAvg::<T>::origin(); num_seqs];
5151 thread_pool.scoped(|x| {
5152 let y = &alignfold_probs_hashed_ids;
5153 for (z, a) in alignfold_prob_mats_avg.iter_mut().enumerate() {
5154 let b = seqs[z].len();
5155 x.execute(move || {
5156 *a = pair_probs2avg_probs::<T>(y, z, num_seqs, b, produces_struct_profs);
5157 });
5158 }
5159 });
5160 let mut match_probs_hashed_ids = MatchProbsHashedIds::<T>::default();
5161 if produces_match_probs {
5162 for x in 0..num_seqs {
5163 for y in x + 1..num_seqs {
5164 let y = (x, y);
5165 let z = &alignfold_probs_hashed_ids[&y];
5166 let mut a = MatchProbMats::<T>::new();
5167 a.loopmatch_probs = z.loopmatch_probs.clone();
5168 a.pairmatch_probs = z.pairmatch_probs.clone();
5169 a.match_probs = z.match_probs.clone();
5170 match_probs_hashed_ids.insert(y, a);
5171 }
5172 }
5173 }
5174 (alignfold_prob_mats_avg, match_probs_hashed_ids)
5175}
5176
5177pub fn constrain<T>(
5178 thread_pool: &mut Pool,
5179 train_data: &mut TrainData<T>,
5180 output_file_path: &Path,
5181 enables_randinit: bool,
5182 learning_tolerance: Score,
5183) where
5184 T: HashIndex,
5185{
5186 let mut alignfold_scores = AlignfoldScores::new(0.);
5187 if enables_randinit {
5188 alignfold_scores.rand_init();
5189 } else {
5190 alignfold_scores.transfer();
5191 }
5192 for x in train_data.iter_mut() {
5193 x.set_curr_scores(&alignfold_scores);
5194 }
5195 let mut old_alignfold_scores = alignfold_scores.clone();
5196 let mut old_cost = INFINITY;
5197 let mut old_accuracy = NEG_INFINITY;
5198 let mut costs = Probs::new();
5199 let mut accuracies = costs.clone();
5200 let mut epoch = 0;
5201 let mut regularizers = Regularizers::from(vec![1.; alignfold_scores.len()]);
5202 let num_data = train_data.len() as Score;
5203 let produces_struct_profs = false;
5204 let trains_alignfold_scores = true;
5205 let produces_match_probs = true;
5206 loop {
5207 thread_pool.scoped(|scope| {
5208 let alignfold_scores = &alignfold_scores;
5209 for train_datum in train_data.iter_mut() {
5210 train_datum.alignfold_counts_expected = AlignfoldScores::new(NEG_INFINITY);
5211 let seq_pair = (&train_datum.seq_pair.0[..], &train_datum.seq_pair.1[..]);
5212 let max_basepair_span_pair = &train_datum.max_basepair_span_pair;
5213 let alignfold_counts_expected = &mut train_datum.alignfold_counts_expected;
5214 let accuracy = &mut train_datum.accuracy;
5215 let global_sum = &mut train_datum.global_sum;
5216 let forward_pos_pairs = &train_datum.forward_pos_pairs;
5217 let backward_pos_pairs = &train_datum.backward_pos_pairs;
5218 let pos_quads_hashed_lens = &train_datum.pos_quads_hashed_lens;
5219 let matchable_poss = &train_datum.matchable_poss;
5220 let matchable_poss2 = &train_datum.matchable_poss2;
5221 let match_probs = &train_datum.match_probs;
5222 let alignfold = &train_datum.alignfold;
5223 let fold_scores_pair = (
5224 &train_datum.fold_scores_pair.0,
5225 &train_datum.fold_scores_pair.1,
5226 );
5227 scope.execute(move || {
5228 let x = consprob_core::<T>((
5229 &seq_pair,
5230 alignfold_scores,
5231 max_basepair_span_pair,
5232 match_probs,
5233 produces_struct_profs,
5234 trains_alignfold_scores,
5235 alignfold_counts_expected,
5236 forward_pos_pairs,
5237 backward_pos_pairs,
5238 pos_quads_hashed_lens,
5239 &fold_scores_pair,
5240 produces_match_probs,
5241 matchable_poss,
5242 matchable_poss2,
5243 ));
5244 *accuracy = get_accuracy_expected::<T>(&seq_pair, alignfold, &x.0.match_probs);
5245 *global_sum = x.1;
5246 });
5247 }
5248 });
5249 let accuracy = train_data.iter().map(|x| x.accuracy).sum::<Score>() / train_data.len() as Score;
5250 let accuracy_change = accuracy - old_accuracy;
5251 if accuracy_change <= learning_tolerance {
5252 alignfold_scores = old_alignfold_scores;
5253 println!(
5254 "Accuracy change {accuracy_change} is <= learning tolerance {learning_tolerance}; training finished"
5255 );
5256 break;
5257 }
5258 old_alignfold_scores = alignfold_scores.clone();
5259 alignfold_scores.update(train_data, &mut regularizers);
5260 for x in train_data.iter_mut() {
5261 x.set_curr_scores(&alignfold_scores);
5262 }
5263 let cost = alignfold_scores.get_cost(&train_data[..], ®ularizers);
5264 let avg_cost_change = (cost - old_cost) / num_data;
5265 if avg_cost_change >= 0. {
5266 println!("Average cost change {avg_cost_change} is not negative; training finished");
5267 alignfold_scores = old_alignfold_scores;
5268 break;
5269 }
5270 costs.push(cost);
5271 accuracies.push(accuracy);
5272 println!("Epoch {} finished (current cost = {}, current accuracy = {}, average cost change = {}, accuracy change = {})", epoch + 1, cost, accuracy, avg_cost_change, accuracy_change);
5273 epoch += 1;
5274 old_cost = cost;
5275 old_accuracy = accuracy;
5276 }
5277 write_alignfold_scores_trained(&alignfold_scores, enables_randinit);
5278 write_logs(&costs, &accuracies, output_file_path);
5279}
5280
5281pub fn gapped2ungapped(x: &Seq) -> Seq {
5282 x.iter().filter(|&&x| x != PSEUDO_BASE).copied().collect()
5283}
5284
5285pub fn bytes2seq_gapped(x: &[u8]) -> Seq {
5286 let mut y = Seq::new();
5287 for &x in x {
5288 let x = convert_char(x);
5289 y.push(x);
5290 }
5291 y
5292}
5293
5294pub fn convert_char(x: u8) -> Base {
5295 match x {
5296 A_LOWER | A_UPPER => A,
5297 C_LOWER | C_UPPER => C,
5298 G_LOWER | G_UPPER => G,
5299 U_LOWER | U_UPPER => U,
5300 _ => PSEUDO_BASE,
5301 }
5302}
5303
5304pub fn get_mismatch_pair(x: SeqSlice, y: &(usize, usize), z: bool) -> (usize, usize) {
5305 let mut a = if z {
5306 (x[y.1], x[y.0])
5307 } else {
5308 (PSEUDO_BASE, PSEUDO_BASE)
5309 };
5310 if z {
5311 for &b in &x[y.0 + 1..y.1] {
5312 if b != PSEUDO_BASE {
5313 a.0 = b;
5314 break;
5315 }
5316 }
5317 for i in (y.0 + 1..y.1).rev() {
5318 let x = x[i];
5319 if x != PSEUDO_BASE {
5320 a.1 = x;
5321 break;
5322 }
5323 }
5324 } else {
5325 for i in (0..y.0).rev() {
5326 let x = x[i];
5327 if x != PSEUDO_BASE {
5328 a.0 = x;
5329 break;
5330 }
5331 }
5332 let b = x.len();
5333 for &x in &x[y.1 + 1..b] {
5334 if x != PSEUDO_BASE {
5335 a.1 = x;
5336 break;
5337 }
5338 }
5339 }
5340 a
5341}
5342
5343pub fn get_hairpin_len(x: SeqSlice, y: &(usize, usize)) -> usize {
5344 let mut z = 0;
5345 for &x in &x[y.0 + 1..y.1] {
5346 if x == PSEUDO_BASE {
5347 continue;
5348 }
5349 z += 1;
5350 }
5351 z
5352}
5353
5354pub fn get_2loop_len_pair(x: SeqSlice, y: &(usize, usize), z: &(usize, usize)) -> (usize, usize) {
5355 let mut a = (0, 0);
5356 for &x in &x[y.0 + 1..z.0] {
5357 if x == PSEUDO_BASE {
5358 continue;
5359 }
5360 a.0 += 1;
5361 }
5362 for &x in &x[z.1 + 1..y.1] {
5363 if x == PSEUDO_BASE {
5364 continue;
5365 }
5366 a.1 += 1;
5367 }
5368 a
5369}
5370
5371pub fn vec2struct(source: &Scores, uses_cumulative_scores: bool) -> AlignfoldScores {
5372 let mut target = AlignfoldScores::new(0.);
5373 let mut offset = 0;
5374 let len = target.hairpin_scores_len.len();
5375 for i in 0..len {
5376 let x = source[offset + i];
5377 if uses_cumulative_scores {
5378 target.hairpin_scores_len_cumulative[i] = x;
5379 } else {
5380 target.hairpin_scores_len[i] = x;
5381 }
5382 }
5383 offset += len;
5384 let len = target.bulge_scores_len.len();
5385 for i in 0..len {
5386 let x = source[offset + i];
5387 if uses_cumulative_scores {
5388 target.bulge_scores_len_cumulative[i] = x;
5389 } else {
5390 target.bulge_scores_len[i] = x;
5391 }
5392 }
5393 offset += len;
5394 let len = target.interior_scores_len.len();
5395 for i in 0..len {
5396 let x = source[offset + i];
5397 if uses_cumulative_scores {
5398 target.interior_scores_len_cumulative[i] = x;
5399 } else {
5400 target.interior_scores_len[i] = x;
5401 }
5402 }
5403 offset += len;
5404 let len = target.interior_scores_symmetric.len();
5405 for i in 0..len {
5406 let x = source[offset + i];
5407 if uses_cumulative_scores {
5408 target.interior_scores_symmetric_cumulative[i] = x;
5409 } else {
5410 target.interior_scores_symmetric[i] = x;
5411 }
5412 }
5413 offset += len;
5414 let len = target.interior_scores_asymmetric.len();
5415 for i in 0..len {
5416 let x = source[offset + i];
5417 if uses_cumulative_scores {
5418 target.interior_scores_asymmetric_cumulative[i] = x;
5419 } else {
5420 target.interior_scores_asymmetric[i] = x;
5421 }
5422 }
5423 offset += len;
5424 let len = target.stack_scores.len();
5425 for i in 0..len {
5426 for j in 0..len {
5427 if !has_canonical_basepair(&(i, j)) {
5428 continue;
5429 }
5430 for k in 0..len {
5431 for l in 0..len {
5432 if !has_canonical_basepair(&(k, l)) {
5433 continue;
5434 }
5435 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
5436 if ((i, j), (k, l)) != dict_min_stack {
5437 continue;
5438 }
5439 target.stack_scores[i][j][k][l] = source[offset];
5440 offset += 1;
5441 }
5442 }
5443 }
5444 }
5445 let len = target.terminal_mismatch_scores.len();
5446 for i in 0..len {
5447 for j in 0..len {
5448 if !has_canonical_basepair(&(i, j)) {
5449 continue;
5450 }
5451 for k in 0..len {
5452 for l in 0..len {
5453 target.terminal_mismatch_scores[i][j][k][l] = source[offset];
5454 offset += 1;
5455 }
5456 }
5457 }
5458 }
5459 let len = target.dangling_scores_left.len();
5460 for i in 0..len {
5461 for j in 0..len {
5462 if !has_canonical_basepair(&(i, j)) {
5463 continue;
5464 }
5465 for k in 0..len {
5466 target.dangling_scores_left[i][j][k] = source[offset];
5467 offset += 1;
5468 }
5469 }
5470 }
5471 let len = target.dangling_scores_right.len();
5472 for i in 0..len {
5473 for j in 0..len {
5474 if !has_canonical_basepair(&(i, j)) {
5475 continue;
5476 }
5477 for k in 0..len {
5478 target.dangling_scores_right[i][j][k] = source[offset];
5479 offset += 1;
5480 }
5481 }
5482 }
5483 let len = target.helix_close_scores.len();
5484 for i in 0..len {
5485 for j in 0..len {
5486 if !has_canonical_basepair(&(i, j)) {
5487 continue;
5488 }
5489 target.helix_close_scores[i][j] = source[offset];
5490 offset += 1;
5491 }
5492 }
5493 let len = target.basepair_scores.len();
5494 for i in 0..len {
5495 for j in 0..len {
5496 if !has_canonical_basepair(&(i, j)) {
5497 continue;
5498 }
5499 let dict_min_basepair = get_dict_min_pair(&(i, j));
5500 if (i, j) != dict_min_basepair {
5501 continue;
5502 }
5503 target.basepair_scores[i][j] = source[offset];
5504 offset += 1;
5505 }
5506 }
5507 let len = target.interior_scores_explicit.len();
5508 for i in 0..len {
5509 for j in 0..len {
5510 let dict_min_len_pair = get_dict_min_pair(&(i, j));
5511 if (i, j) != dict_min_len_pair {
5512 continue;
5513 }
5514 target.interior_scores_explicit[i][j] = source[offset];
5515 offset += 1;
5516 }
5517 }
5518 let len = target.bulge_scores_0x1.len();
5519 for i in 0..len {
5520 target.bulge_scores_0x1[i] = source[offset + i];
5521 }
5522 offset += len;
5523 let len = target.interior_scores_1x1.len();
5524 for i in 0..len {
5525 for j in 0..len {
5526 let dict_min_basepair = get_dict_min_pair(&(i, j));
5527 if (i, j) != dict_min_basepair {
5528 continue;
5529 }
5530 target.interior_scores_1x1[i][j] = source[offset];
5531 offset += 1;
5532 }
5533 }
5534 target.multibranch_score_base = source[offset];
5535 offset += 1;
5536 target.multibranch_score_basepair = source[offset];
5537 offset += 1;
5538 target.multibranch_score_unpair = source[offset];
5539 offset += 1;
5540 target.external_score_basepair = source[offset];
5541 offset += 1;
5542 target.external_score_unpair = source[offset];
5543 offset += 1;
5544 target.match2match_score = source[offset];
5545 offset += 1;
5546 target.match2insert_score = source[offset];
5547 offset += 1;
5548 target.init_match_score = source[offset];
5549 offset += 1;
5550 target.insert_extend_score = source[offset];
5551 offset += 1;
5552 target.init_insert_score = source[offset];
5553 offset += 1;
5554 let len = target.insert_scores.len();
5555 for i in 0..len {
5556 target.insert_scores[i] = source[offset + i];
5557 }
5558 offset += len;
5559 let len = target.match_scores.len();
5560 for i in 0..len {
5561 for j in 0..len {
5562 let dict_min_match = get_dict_min_pair(&(i, j));
5563 if (i, j) != dict_min_match {
5564 continue;
5565 }
5566 target.match_scores[i][j] = source[offset];
5567 offset += 1;
5568 }
5569 }
5570 assert!(offset == target.len());
5571 target
5572}
5573
5574pub fn struct2vec(source: &AlignfoldScores, uses_cumulative_scores: bool) -> Scores {
5575 let mut target = vec![0.; source.len()];
5576 let mut offset = 0;
5577 let len = source.hairpin_scores_len.len();
5578 for i in 0..len {
5579 target[offset + i] = if uses_cumulative_scores {
5580 source.hairpin_scores_len_cumulative[i]
5581 } else {
5582 source.hairpin_scores_len[i]
5583 };
5584 }
5585 offset += len;
5586 let len = source.bulge_scores_len.len();
5587 for i in 0..len {
5588 target[offset + i] = if uses_cumulative_scores {
5589 source.bulge_scores_len_cumulative[i]
5590 } else {
5591 source.bulge_scores_len[i]
5592 };
5593 }
5594 offset += len;
5595 let len = source.interior_scores_len.len();
5596 for i in 0..len {
5597 target[offset + i] = if uses_cumulative_scores {
5598 source.interior_scores_len_cumulative[i]
5599 } else {
5600 source.interior_scores_len[i]
5601 };
5602 }
5603 offset += len;
5604 let len = source.interior_scores_symmetric.len();
5605 for i in 0..len {
5606 target[offset + i] = if uses_cumulative_scores {
5607 source.interior_scores_symmetric_cumulative[i]
5608 } else {
5609 source.interior_scores_symmetric[i]
5610 };
5611 }
5612 offset += len;
5613 let len = source.interior_scores_asymmetric.len();
5614 for i in 0..len {
5615 target[offset + i] = if uses_cumulative_scores {
5616 source.interior_scores_asymmetric_cumulative[i]
5617 } else {
5618 source.interior_scores_asymmetric[i]
5619 };
5620 }
5621 offset += len;
5622 let len = source.stack_scores.len();
5623 for i in 0..len {
5624 for j in 0..len {
5625 if !has_canonical_basepair(&(i, j)) {
5626 continue;
5627 }
5628 for k in 0..len {
5629 for l in 0..len {
5630 if !has_canonical_basepair(&(k, l)) {
5631 continue;
5632 }
5633 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
5634 if ((i, j), (k, l)) != dict_min_stack {
5635 continue;
5636 }
5637 target[offset] = source.stack_scores[i][j][k][l];
5638 offset += 1;
5639 }
5640 }
5641 }
5642 }
5643 let len = source.terminal_mismatch_scores.len();
5644 for i in 0..len {
5645 for j in 0..len {
5646 if !has_canonical_basepair(&(i, j)) {
5647 continue;
5648 }
5649 for k in 0..len {
5650 for l in 0..len {
5651 target[offset] = source.terminal_mismatch_scores[i][j][k][l];
5652 offset += 1;
5653 }
5654 }
5655 }
5656 }
5657 let len = source.dangling_scores_left.len();
5658 for i in 0..len {
5659 for j in 0..len {
5660 if !has_canonical_basepair(&(i, j)) {
5661 continue;
5662 }
5663 for k in 0..len {
5664 target[offset] = source.dangling_scores_left[i][j][k];
5665 offset += 1;
5666 }
5667 }
5668 }
5669 let len = source.dangling_scores_right.len();
5670 for i in 0..len {
5671 for j in 0..len {
5672 if !has_canonical_basepair(&(i, j)) {
5673 continue;
5674 }
5675 for k in 0..len {
5676 target[offset] = source.dangling_scores_right[i][j][k];
5677 offset += 1;
5678 }
5679 }
5680 }
5681 let len = source.helix_close_scores.len();
5682 for i in 0..len {
5683 for j in 0..len {
5684 if !has_canonical_basepair(&(i, j)) {
5685 continue;
5686 }
5687 target[offset] = source.helix_close_scores[i][j];
5688 offset += 1;
5689 }
5690 }
5691 let len = source.basepair_scores.len();
5692 for i in 0..len {
5693 for j in 0..len {
5694 if !has_canonical_basepair(&(i, j)) {
5695 continue;
5696 }
5697 let dict_min_basepair = get_dict_min_pair(&(i, j));
5698 if (i, j) != dict_min_basepair {
5699 continue;
5700 }
5701 target[offset] = source.basepair_scores[i][j];
5702 offset += 1;
5703 }
5704 }
5705 let len = source.interior_scores_explicit.len();
5706 for i in 0..len {
5707 for j in 0..len {
5708 let dict_min_len_pair = get_dict_min_pair(&(i, j));
5709 if (i, j) != dict_min_len_pair {
5710 continue;
5711 }
5712 target[offset] = source.interior_scores_explicit[i][j];
5713 offset += 1;
5714 }
5715 }
5716 let len = source.bulge_scores_0x1.len();
5717 for (i, &x) in source.bulge_scores_0x1.iter().enumerate() {
5718 target[offset + i] = x;
5719 }
5720 offset += len;
5721 let len = source.interior_scores_1x1.len();
5722 for i in 0..len {
5723 for j in 0..len {
5724 let dict_min_basepair = get_dict_min_pair(&(i, j));
5725 if (i, j) != dict_min_basepair {
5726 continue;
5727 }
5728 target[offset] = source.interior_scores_1x1[i][j];
5729 offset += 1;
5730 }
5731 }
5732 target[offset] = source.multibranch_score_base;
5733 offset += 1;
5734 target[offset] = source.multibranch_score_basepair;
5735 offset += 1;
5736 target[offset] = source.multibranch_score_unpair;
5737 offset += 1;
5738 target[offset] = source.external_score_basepair;
5739 offset += 1;
5740 target[offset] = source.external_score_unpair;
5741 offset += 1;
5742 target[offset] = source.match2match_score;
5743 offset += 1;
5744 target[offset] = source.match2insert_score;
5745 offset += 1;
5746 target[offset] = source.init_match_score;
5747 offset += 1;
5748 target[offset] = source.insert_extend_score;
5749 offset += 1;
5750 target[offset] = source.init_insert_score;
5751 offset += 1;
5752 let len = source.insert_scores.len();
5753 for (i, &x) in source.insert_scores.iter().enumerate() {
5754 target[offset + i] = x;
5755 }
5756 offset += len;
5757 let len = source.match_scores.len();
5758 for i in 0..len {
5759 for j in 0..len {
5760 let dict_min_match = get_dict_min_pair(&(i, j));
5761 if (i, j) != dict_min_match {
5762 continue;
5763 }
5764 target[offset] = source.match_scores[i][j];
5765 offset += 1;
5766 }
5767 }
5768 assert!(offset == source.len());
5769 Array::from(target)
5770}
5771
5772pub fn scores2bfgs_scores(x: &Scores) -> BfgsScores {
5773 let x: Vec<BfgsScore> = x.to_vec().iter().map(|x| *x as BfgsScore).collect();
5774 BfgsScores::from(x)
5775}
5776
5777pub fn bfgs_scores2scores(x: &BfgsScores) -> Scores {
5778 let x: Vec<Score> = x.to_vec().iter().map(|x| *x as Score).collect();
5779 Scores::from(x)
5780}
5781
5782pub fn get_regularizer(x: usize, y: Score) -> Regularizer {
5783 (x as Score / 2. + GAMMA_DISTRO_ALPHA) / (y / 2. + GAMMA_DISTRO_BETA)
5784}
5785
5786pub fn write_alignfold_scores_trained(alignfold_scores: &AlignfoldScores, enables_randinit: bool) {
5787 let mut writer = BufWriter::new(
5788 File::create(if enables_randinit {
5789 TRAINED_SCORES_FILE_RANDINIT
5790 } else {
5791 TRAINED_SCORES_FILE
5792 })
5793 .unwrap(),
5794 );
5795 let mut buf = format!("use AlignfoldScores;\nimpl AlignfoldScores {{\npub fn load_trained_scores{}() -> AlignfoldScores {{\nAlignfoldScores {{\nhairpin_scores_len: ", if enables_randinit {"_randinit"} else {""});
5796 buf.push_str(&format!(
5797 "{:?},\nbulge_scores_len: ",
5798 &alignfold_scores.hairpin_scores_len
5799 ));
5800 buf.push_str(&format!(
5801 "{:?},\ninterior_scores_len: ",
5802 &alignfold_scores.bulge_scores_len
5803 ));
5804 buf.push_str(&format!(
5805 "{:?},\ninterior_scores_symmetric: ",
5806 &alignfold_scores.interior_scores_len
5807 ));
5808 buf.push_str(&format!(
5809 "{:?},\ninterior_scores_asymmetric: ",
5810 &alignfold_scores.interior_scores_symmetric
5811 ));
5812 buf.push_str(&format!(
5813 "{:?},\nstack_scores: ",
5814 &alignfold_scores.interior_scores_asymmetric
5815 ));
5816 buf.push_str(&format!(
5817 "{:?},\nterminal_mismatch_scores: ",
5818 &alignfold_scores.stack_scores
5819 ));
5820 buf.push_str(&format!(
5821 "{:?},\ndangling_scores_left: ",
5822 &alignfold_scores.terminal_mismatch_scores
5823 ));
5824 buf.push_str(&format!(
5825 "{:?},\ndangling_scores_right: ",
5826 &alignfold_scores.dangling_scores_left
5827 ));
5828 buf.push_str(&format!(
5829 "{:?},\nhelix_close_scores: ",
5830 &alignfold_scores.dangling_scores_right
5831 ));
5832 buf.push_str(&format!(
5833 "{:?},\nbasepair_scores: ",
5834 &alignfold_scores.helix_close_scores
5835 ));
5836 buf.push_str(&format!(
5837 "{:?},\ninterior_scores_explicit: ",
5838 &alignfold_scores.basepair_scores
5839 ));
5840 buf.push_str(&format!(
5841 "{:?},\nbulge_scores_0x1: ",
5842 &alignfold_scores.interior_scores_explicit
5843 ));
5844 buf.push_str(&format!(
5845 "{:?},\ninterior_scores_1x1: ",
5846 &alignfold_scores.bulge_scores_0x1
5847 ));
5848 buf.push_str(&format!(
5849 "{:?},\nmultibranch_score_base: ",
5850 &alignfold_scores.interior_scores_1x1
5851 ));
5852 buf.push_str(&format!(
5853 "{:?},\nmultibranch_score_basepair: ",
5854 alignfold_scores.multibranch_score_base
5855 ));
5856 buf.push_str(&format!(
5857 "{:?},\nmultibranch_score_unpair: ",
5858 alignfold_scores.multibranch_score_basepair
5859 ));
5860 buf.push_str(&format!(
5861 "{:?},\nexternal_score_basepair: ",
5862 alignfold_scores.multibranch_score_unpair
5863 ));
5864 buf.push_str(&format!(
5865 "{:?},\nexternal_score_unpair: ",
5866 alignfold_scores.external_score_basepair
5867 ));
5868 buf.push_str(&format!(
5869 "{:?},\nmatch2match_score: ",
5870 alignfold_scores.external_score_unpair
5871 ));
5872 buf.push_str(&format!(
5873 "{:?},\nmatch2insert_score: ",
5874 alignfold_scores.match2match_score
5875 ));
5876 buf.push_str(&format!(
5877 "{:?},\ninsert_extend_score: ",
5878 alignfold_scores.match2insert_score
5879 ));
5880 buf.push_str(&format!(
5881 "{:?},\ninit_match_score: ",
5882 alignfold_scores.insert_extend_score
5883 ));
5884 buf.push_str(&format!(
5885 "{:?},\ninit_insert_score: ",
5886 alignfold_scores.init_match_score
5887 ));
5888 buf.push_str(&format!(
5889 "{:?},\ninsert_scores: ",
5890 alignfold_scores.init_insert_score
5891 ));
5892 buf.push_str(&format!(
5893 "{:?},\nmatch_scores: ",
5894 alignfold_scores.insert_scores
5895 ));
5896 buf.push_str(&format!(
5897 "{:?},\nhairpin_scores_len_cumulative: ",
5898 alignfold_scores.match_scores
5899 ));
5900 buf.push_str(&format!(
5901 "{:?},\nbulge_scores_len_cumulative: ",
5902 &alignfold_scores.hairpin_scores_len_cumulative
5903 ));
5904 buf.push_str(&format!(
5905 "{:?},\ninterior_scores_len_cumulative: ",
5906 &alignfold_scores.bulge_scores_len_cumulative
5907 ));
5908 buf.push_str(&format!(
5909 "{:?},\ninterior_scores_symmetric_cumulative: ",
5910 &alignfold_scores.interior_scores_len_cumulative
5911 ));
5912 buf.push_str(&format!(
5913 "{:?},\ninterior_scores_asymmetric_cumulative: ",
5914 &alignfold_scores.interior_scores_symmetric_cumulative
5915 ));
5916 buf.push_str(&format!(
5917 "{:?},",
5918 &alignfold_scores.interior_scores_asymmetric_cumulative
5919 ));
5920 buf.push_str(&String::from("\n}\n}\n}"));
5921 let _ = writer.write_all(buf.as_bytes());
5922}
5923
5924pub fn write_logs(x: &Probs, y: &Probs, z: &Path) {
5925 let mut z = BufWriter::new(File::create(z).unwrap());
5926 let mut a = String::new();
5927 for (x, y) in x.iter().zip(y.iter()) {
5928 a.push_str(&format!("{x},{y}\n"));
5929 }
5930 let _ = z.write_all(a.as_bytes());
5931}
5932
5933pub fn copy_alignfold_scores_align(x: &mut AlignScores, y: &AlignfoldScores) {
5934 x.match2match_score = y.match2match_score;
5935 x.match2insert_score = y.match2insert_score;
5936 x.insert_extend_score = y.insert_extend_score;
5937 x.init_match_score = y.init_match_score;
5938 x.init_insert_score = y.init_insert_score;
5939 let len = x.insert_scores.len();
5940 for i in 0..len {
5941 x.insert_scores[i] = y.insert_scores[i];
5942 }
5943 let len = x.match_scores.len();
5944 for i in 0..len {
5945 for j in 0..len {
5946 x.match_scores[i][j] = y.match_scores[i][j];
5947 }
5948 }
5949}
5950
5951pub fn copy_alignfold_scores_fold(
5952 fold_scores: &mut FoldScoreSets,
5953 alignfold_scores: &AlignfoldScores,
5954) {
5955 for (x, &y) in fold_scores
5956 .hairpin_scores_len
5957 .iter_mut()
5958 .zip(alignfold_scores.hairpin_scores_len.iter())
5959 {
5960 *x = y;
5961 }
5962 for (x, &y) in fold_scores
5963 .bulge_scores_len
5964 .iter_mut()
5965 .zip(alignfold_scores.bulge_scores_len.iter())
5966 {
5967 *x = y;
5968 }
5969 for (x, &y) in fold_scores
5970 .interior_scores_len
5971 .iter_mut()
5972 .zip(alignfold_scores.interior_scores_len.iter())
5973 {
5974 *x = y;
5975 }
5976 for (x, &y) in fold_scores
5977 .interior_scores_symmetric
5978 .iter_mut()
5979 .zip(alignfold_scores.interior_scores_symmetric.iter())
5980 {
5981 *x = y;
5982 }
5983 for (x, &y) in fold_scores
5984 .interior_scores_asymmetric
5985 .iter_mut()
5986 .zip(alignfold_scores.interior_scores_asymmetric.iter())
5987 {
5988 *x = y;
5989 }
5990 let len = fold_scores.stack_scores.len();
5991 for i in 0..len {
5992 for j in 0..len {
5993 if !has_canonical_basepair(&(i, j)) {
5994 continue;
5995 }
5996 for k in 0..len {
5997 for l in 0..len {
5998 if !has_canonical_basepair(&(k, l)) {
5999 continue;
6000 }
6001 fold_scores.stack_scores[i][j][k][l] = alignfold_scores.stack_scores[i][j][k][l];
6002 }
6003 }
6004 }
6005 }
6006 let len = fold_scores.terminal_mismatch_scores.len();
6007 for i in 0..len {
6008 for j in 0..len {
6009 if !has_canonical_basepair(&(i, j)) {
6010 continue;
6011 }
6012 for k in 0..len {
6013 for l in 0..len {
6014 fold_scores.terminal_mismatch_scores[i][j][k][l] =
6015 alignfold_scores.terminal_mismatch_scores[i][j][k][l];
6016 }
6017 }
6018 }
6019 }
6020 let len = fold_scores.dangling_scores_left.len();
6021 for i in 0..len {
6022 for j in 0..len {
6023 if !has_canonical_basepair(&(i, j)) {
6024 continue;
6025 }
6026 for k in 0..len {
6027 fold_scores.dangling_scores_left[i][j][k] = alignfold_scores.dangling_scores_left[i][j][k];
6028 }
6029 }
6030 }
6031 let len = fold_scores.dangling_scores_right.len();
6032 for i in 0..len {
6033 for j in 0..len {
6034 if !has_canonical_basepair(&(i, j)) {
6035 continue;
6036 }
6037 for k in 0..len {
6038 fold_scores.dangling_scores_right[i][j][k] =
6039 alignfold_scores.dangling_scores_right[i][j][k];
6040 }
6041 }
6042 }
6043 let len = fold_scores.helix_close_scores.len();
6044 for i in 0..len {
6045 for j in 0..len {
6046 if !has_canonical_basepair(&(i, j)) {
6047 continue;
6048 }
6049 fold_scores.helix_close_scores[i][j] = alignfold_scores.helix_close_scores[i][j];
6050 }
6051 }
6052 let len = fold_scores.basepair_scores.len();
6053 for i in 0..len {
6054 for j in 0..len {
6055 if !has_canonical_basepair(&(i, j)) {
6056 continue;
6057 }
6058 fold_scores.basepair_scores[i][j] = alignfold_scores.basepair_scores[i][j];
6059 }
6060 }
6061 let len = fold_scores.interior_scores_explicit.len();
6062 for i in 0..len {
6063 for j in 0..len {
6064 fold_scores.interior_scores_explicit[i][j] = alignfold_scores.interior_scores_explicit[i][j];
6065 }
6066 }
6067 for (x, &y) in fold_scores
6068 .bulge_scores_0x1
6069 .iter_mut()
6070 .zip(alignfold_scores.bulge_scores_0x1.iter())
6071 {
6072 *x = y;
6073 }
6074 let len = fold_scores.interior_scores_1x1.len();
6075 for i in 0..len {
6076 for j in 0..len {
6077 fold_scores.interior_scores_1x1[i][j] = alignfold_scores.interior_scores_1x1[i][j];
6078 }
6079 }
6080 fold_scores.multibranch_score_base = alignfold_scores.multibranch_score_base;
6081 fold_scores.multibranch_score_basepair = alignfold_scores.multibranch_score_basepair;
6082 fold_scores.multibranch_score_unpair = alignfold_scores.multibranch_score_unpair;
6083 fold_scores.external_score_basepair = alignfold_scores.external_score_basepair;
6084 fold_scores.external_score_unpair = alignfold_scores.external_score_unpair;
6085 fold_scores.accumulate();
6086}
6087
6088pub fn print_train_info(alignfold_scores: &AlignfoldScores) {
6089 let mut num_groups = 0;
6090 println!("Training the parameter groups below");
6091 println!("-----------------------------------");
6092 println!("Groups from the CONTRAfold model...");
6093 println!(
6094 "Hairpin loop length (group size {})",
6095 alignfold_scores.hairpin_scores_len.len()
6096 );
6097 num_groups += 1;
6098 println!(
6099 "Bulge loop length (group size {})",
6100 alignfold_scores.bulge_scores_len.len()
6101 );
6102 num_groups += 1;
6103 println!(
6104 "Interior loop length (group size {})",
6105 alignfold_scores.interior_scores_len.len()
6106 );
6107 num_groups += 1;
6108 println!(
6109 "Interior loop length symmetric (group size {})",
6110 alignfold_scores.interior_scores_symmetric.len()
6111 );
6112 num_groups += 1;
6113 println!(
6114 "Interior loop length asymmetric (group size {})",
6115 alignfold_scores.interior_scores_asymmetric.len()
6116 );
6117 num_groups += 1;
6118 let mut group_size = 0;
6119 let len = alignfold_scores.stack_scores.len();
6120 for i in 0..len {
6121 for j in 0..len {
6122 if !has_canonical_basepair(&(i, j)) {
6123 continue;
6124 }
6125 for k in 0..len {
6126 for l in 0..len {
6127 if !has_canonical_basepair(&(k, l)) {
6128 continue;
6129 }
6130 let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
6131 if ((i, j), (k, l)) != dict_min_stack {
6132 continue;
6133 }
6134 group_size += 1;
6135 }
6136 }
6137 }
6138 }
6139 println!("Stacking (group size {group_size})");
6140 num_groups += 1;
6141 let mut group_size = 0;
6142 let len = alignfold_scores.terminal_mismatch_scores.len();
6143 for i in 0..len {
6144 for j in 0..len {
6145 if !has_canonical_basepair(&(i, j)) {
6146 continue;
6147 }
6148 for _ in 0..len {
6149 for _ in 0..len {
6150 group_size += 1;
6151 }
6152 }
6153 }
6154 }
6155 println!("Terminal mismatch (group size {group_size})");
6156 num_groups += 1;
6157 let mut group_size = 0;
6158 let len = alignfold_scores.dangling_scores_left.len();
6159 for i in 0..len {
6160 for j in 0..len {
6161 if !has_canonical_basepair(&(i, j)) {
6162 continue;
6163 }
6164 for _ in 0..len {
6165 group_size += 1;
6166 }
6167 }
6168 }
6169 let len = alignfold_scores.dangling_scores_right.len();
6170 for i in 0..len {
6171 for j in 0..len {
6172 if !has_canonical_basepair(&(i, j)) {
6173 continue;
6174 }
6175 for _ in 0..len {
6176 group_size += 1;
6177 }
6178 }
6179 }
6180 println!("Dangling (group size {group_size})");
6181 num_groups += 1;
6182 let mut group_size = 0;
6183 let len = alignfold_scores.helix_close_scores.len();
6184 for i in 0..len {
6185 for j in 0..len {
6186 if !has_canonical_basepair(&(i, j)) {
6187 continue;
6188 }
6189 group_size += 1;
6190 }
6191 }
6192 println!("Helix end (group size {group_size})");
6193 num_groups += 1;
6194 let mut group_size = 0;
6195 let len = alignfold_scores.basepair_scores.len();
6196 for i in 0..len {
6197 for j in 0..len {
6198 if !has_canonical_basepair(&(i, j)) {
6199 continue;
6200 }
6201 let dict_min_basepair = get_dict_min_pair(&(i, j));
6202 if (i, j) != dict_min_basepair {
6203 continue;
6204 }
6205 group_size += 1;
6206 }
6207 }
6208 println!("Base-pairing (group size {group_size})");
6209 num_groups += 1;
6210 let mut group_size = 0;
6211 let len = alignfold_scores.interior_scores_explicit.len();
6212 for i in 0..len {
6213 for j in 0..len {
6214 let dict_min_len_pair = get_dict_min_pair(&(i, j));
6215 if (i, j) != dict_min_len_pair {
6216 continue;
6217 }
6218 group_size += 1;
6219 }
6220 }
6221 println!("Interior loop length explicit (group size {group_size})");
6222 num_groups += 1;
6223 println!(
6224 "Bulge loop length 0x1 (group size {})",
6225 alignfold_scores.bulge_scores_0x1.len()
6226 );
6227 num_groups += 1;
6228 let mut group_size = 0;
6229 let len = alignfold_scores.interior_scores_1x1.len();
6230 for i in 0..len {
6231 for j in 0..len {
6232 let dict_min_basepair = get_dict_min_pair(&(i, j));
6233 if (i, j) != dict_min_basepair {
6234 continue;
6235 }
6236 group_size += 1;
6237 }
6238 }
6239 println!("Interior loop length 1x1 (group size {group_size})");
6240 num_groups += 1;
6241 println!("Multi-loop length (group size {GROUP_SIZE_MULTIBRANCH})");
6242 num_groups += 1;
6243 println!("External-loop length (group size {GROUP_SIZE_EXTERNAL})");
6244 num_groups += 1;
6245 println!("-----------------------------------");
6246 println!("Groups from the CONTRAlign model...");
6247 println!("Match transition (group size {GROUP_SIZE_MATCH_TRANSITION})");
6248 num_groups += 1;
6249 println!("Insert transition (group size {GROUP_SIZE_INSERT_TRANSITION})");
6250 num_groups += 1;
6251 println!(
6252 "Insert emission (group size {})",
6253 alignfold_scores.insert_scores.len()
6254 );
6255 num_groups += 1;
6256 let mut group_size = 0;
6257 let len = alignfold_scores.match_scores.len();
6258 for i in 0..len {
6259 for j in 0..len {
6260 let dict_min_match = get_dict_min_pair(&(i, j));
6261 if (i, j) != dict_min_match {
6262 continue;
6263 }
6264 group_size += 1;
6265 }
6266 }
6267 println!("Match emission (group size {group_size})");
6268 num_groups += 1;
6269 println!("-----------------------------------");
6270 println!(
6271 "Total # scoring parameters (from {} groups) to be trained: {}",
6272 num_groups,
6273 alignfold_scores.len()
6274 );
6275}