consprob_trained/
lib.rs

1extern crate bio;
2extern crate consprob;
3extern crate my_bfgs as bfgs;
4extern crate ndarray_rand;
5extern crate rand;
6pub mod trained_alignfold_scores;
7pub mod trained_alignfold_scores_randinit;
8
9pub use bfgs::bfgs;
10pub use bio::io::fasta::*;
11pub use bio::utils::*;
12pub use consprob::*;
13pub use ndarray_rand::rand_distr::{Distribution, Normal};
14pub use ndarray_rand::RandomExt;
15pub use rand::thread_rng;
16pub use std::f32::INFINITY;
17pub use std::fs::{read_dir, DirEntry};
18pub use std::io::stdout;
19
20pub type SparsePosMat<T> = HashSet<PosPair<T>>;
21pub type FoldScoresPairTrained<T> = (FoldScoresTrained<T>, FoldScoresTrained<T>);
22pub type RefFoldScoresPair<'a, T> = (&'a FoldScoresTrained<T>, &'a FoldScoresTrained<T>);
23pub type Regularizers = Array1<Regularizer>;
24pub type Regularizer = Score;
25pub type BfgsScores = Array1<BfgsScore>;
26pub type BfgsScore = f64;
27pub type Scores = Array1<Score>;
28pub type TrainData<T> = Vec<TrainDatum<T>>;
29pub type RealSeqPair = (Seq, Seq);
30pub type LoopStruct = HashMap<(usize, usize), Vec<(usize, usize)>>;
31
32#[derive(Clone)]
33pub struct FoldScoresTrained<T> {
34  pub hairpin_scores: SparseScoreMat<T>,
35  pub twoloop_scores: ScoreMat4d<T>,
36  pub multibranch_close_scores: SparseScoreMat<T>,
37  pub multibranch_accessible_scores: SparseScoreMat<T>,
38  pub external_accessible_scores: SparseScoreMat<T>,
39}
40
41#[derive(Clone)]
42pub struct TrainDatum<T> {
43  pub seq_pair: RealSeqPair,
44  pub seq_pair_gapped: RealSeqPair,
45  pub alignfold_counts_observed: AlignfoldScores,
46  pub alignfold_counts_expected: AlignfoldScores,
47  pub basepair_probs_pair: SparseProbMatPair<T>,
48  pub max_basepair_span_pair: (T, T),
49  pub global_sum: Prob,
50  pub forward_pos_pairs: PosPairMatSet<T>,
51  pub backward_pos_pairs: PosPairMatSet<T>,
52  pub pos_quads_hashed_lens: PosQuadsHashedLens<T>,
53  pub matchable_poss: SparsePosSets<T>,
54  pub matchable_poss2: SparsePosSets<T>,
55  pub fold_scores_pair: FoldScoresPairTrained<T>,
56  pub match_probs: SparseProbMat<T>,
57  pub alignfold: PairAlignfold<T>,
58  pub accuracy: Score,
59}
60
61#[derive(Clone)]
62pub struct PairAlignfold<T> {
63  pub matched_pos_pairs: SparsePosMat<T>,
64  pub inserted_poss: SparsePoss<T>,
65  pub deleted_poss: SparsePoss<T>,
66}
67
68#[derive(Clone, Debug)]
69pub struct AlignfoldScores {
70  // The CONTRAfold model.
71  pub hairpin_scores_len: HairpinScoresLen,
72  pub bulge_scores_len: BulgeScoresLen,
73  pub interior_scores_len: InteriorScoresLen,
74  pub interior_scores_symmetric: InteriorScoresSymmetric,
75  pub interior_scores_asymmetric: InteriorScoresAsymmetric,
76  pub stack_scores: StackScores,
77  pub terminal_mismatch_scores: TerminalMismatchScores,
78  pub dangling_scores_left: DanglingScores,
79  pub dangling_scores_right: DanglingScores,
80  pub helix_close_scores: HelixCloseScores,
81  pub basepair_scores: BasepairScores,
82  pub interior_scores_explicit: InteriorScoresExplicit,
83  pub bulge_scores_0x1: BulgeScores0x1,
84  pub interior_scores_1x1: InteriorScores1x1Contra,
85  pub multibranch_score_base: Score,
86  pub multibranch_score_basepair: Score,
87  pub multibranch_score_unpair: Score,
88  pub external_score_basepair: Score,
89  pub external_score_unpair: Score,
90  // The CONTRAlign model.
91  pub match2match_score: Score,
92  pub match2insert_score: Score,
93  pub insert_extend_score: Score,
94  pub init_match_score: Score,
95  pub init_insert_score: Score,
96  pub insert_scores: InsertScores,
97  pub match_scores: MatchScores,
98  // The cumulative parameters of the CONTRAfold model.
99  pub hairpin_scores_len_cumulative: HairpinScoresLen,
100  pub bulge_scores_len_cumulative: BulgeScoresLen,
101  pub interior_scores_len_cumulative: InteriorScoresLen,
102  pub interior_scores_symmetric_cumulative: InteriorScoresSymmetric,
103  pub interior_scores_asymmetric_cumulative: InteriorScoresAsymmetric,
104}
105pub type ScoreMat = Vec<Vec<Score>>;
106pub struct RangeInsertScores {
107  pub insert_scores: ScoreMat,
108  pub insert_scores_external: ScoreMat,
109  pub insert_scores_multibranch: ScoreMat,
110  pub insert_scores2: ScoreMat,
111  pub insert_scores_external2: ScoreMat,
112  pub insert_scores_multibranch2: ScoreMat,
113}
114
115pub type InputsAlignfoldProbsGetter<'a, T> = (
116  &'a SeqPair<'a>,
117  &'a AlignfoldScores,
118  &'a PosPair<T>,
119  &'a SparseProbMat<T>,
120  &'a AlignfoldSums<T>,
121  bool,
122  Sum,
123  bool,
124  &'a mut AlignfoldScores,
125  &'a PosQuadsHashedLens<T>,
126  &'a RefFoldScoresPair<'a, T>,
127  bool,
128  &'a PosPairMatSet<T>,
129  &'a PosPairMatSet<T>,
130  &'a RangeInsertScores,
131  &'a SparsePosSets<T>,
132  &'a SparsePosSets<T>,
133);
134
135pub type InputsConsprobCore<'a, T> = (
136  &'a SeqPair<'a>,
137  &'a AlignfoldScores,
138  &'a PosPair<T>,
139  &'a SparseProbMat<T>,
140  bool,
141  bool,
142  &'a mut AlignfoldScores,
143  &'a PosPairMatSet<T>,
144  &'a PosPairMatSet<T>,
145  &'a PosQuadsHashedLens<T>,
146  &'a RefFoldScoresPair<'a, T>,
147  bool,
148  &'a SparsePosSets<T>,
149  &'a SparsePosSets<T>,
150);
151
152pub type Inputs2loopSumsGetter<'a, T> = (
153  &'a SeqPair<'a>,
154  &'a AlignfoldScores,
155  &'a SparseProbMat<T>,
156  &'a PosQuad<T>,
157  &'a AlignfoldSums<T>,
158  bool,
159  &'a PosPairMatSet<T>,
160  &'a RefFoldScoresPair<'a, T>,
161  &'a RangeInsertScores,
162  &'a SparsePosSets<T>,
163  &'a SparsePosSets<T>,
164);
165
166pub type InputsLoopSumsGetter<'a, T> = (
167  &'a SeqPair<'a>,
168  &'a AlignfoldScores,
169  &'a SparseProbMat<T>,
170  &'a PosQuad<T>,
171  &'a mut AlignfoldSums<T>,
172  bool,
173  &'a PosPairMatSet<T>,
174  &'a RangeInsertScores,
175  &'a SparsePosSets<T>,
176  &'a SparsePosSets<T>,
177);
178
179pub type InputsInsideSumsGetter<'a, T> = (
180  &'a SeqPair<'a>,
181  &'a AlignfoldScores,
182  &'a PosPair<T>,
183  &'a SparseProbMat<T>,
184  bool,
185  &'a PosPairMatSet<T>,
186  &'a PosPairMatSet<T>,
187  &'a PosQuadsHashedLens<T>,
188  &'a RefFoldScoresPair<'a, T>,
189  &'a RangeInsertScores,
190  &'a SparsePosSets<T>,
191  &'a SparsePosSets<T>,
192);
193
194impl<T: HashIndex> Default for FoldScoresTrained<T> {
195  fn default() -> Self {
196    Self::new()
197  }
198}
199
200impl<T: HashIndex> FoldScoresTrained<T> {
201  pub fn new() -> FoldScoresTrained<T> {
202    FoldScoresTrained {
203      hairpin_scores: SparseScoreMat::<T>::default(),
204      twoloop_scores: ScoreMat4d::<T>::default(),
205      multibranch_close_scores: SparseScoreMat::<T>::default(),
206      multibranch_accessible_scores: SparseScoreMat::<T>::default(),
207      external_accessible_scores: SparseScoreMat::<T>::default(),
208    }
209  }
210
211  pub fn set_curr_scores(
212    alignfold_scores: &AlignfoldScores,
213    seq: SeqSlice,
214    basepair_probs: &SparseProbMat<T>,
215  ) -> FoldScoresTrained<T> {
216    let mut fold_scores = FoldScoresTrained::<T>::new();
217    for pos_pair in basepair_probs.keys() {
218      let long_pos_pair = (
219        pos_pair.0.to_usize().unwrap(),
220        pos_pair.1.to_usize().unwrap(),
221      );
222      if long_pos_pair.1 - long_pos_pair.0 - 1 <= MAX_LOOP_LEN {
223        let x = get_hairpin_score(alignfold_scores, seq, &long_pos_pair);
224        fold_scores.hairpin_scores.insert(*pos_pair, x);
225      }
226      let multibranch_close_score = alignfold_scores.multibranch_score_base
227        + alignfold_scores.multibranch_score_basepair
228        + get_junction_score(alignfold_scores, seq, &long_pos_pair);
229      fold_scores
230        .multibranch_close_scores
231        .insert(*pos_pair, multibranch_close_score);
232      let basepair = (seq[long_pos_pair.0], seq[long_pos_pair.1]);
233      let junction_score =
234        get_junction_score(alignfold_scores, seq, &(long_pos_pair.1, long_pos_pair.0))
235          + alignfold_scores.basepair_scores[basepair.0][basepair.1];
236      let multibranch_accessible_score =
237        junction_score + alignfold_scores.multibranch_score_basepair;
238      fold_scores
239        .multibranch_accessible_scores
240        .insert(*pos_pair, multibranch_accessible_score);
241      let external_accessible_score = junction_score + alignfold_scores.external_score_basepair;
242      fold_scores
243        .external_accessible_scores
244        .insert(*pos_pair, external_accessible_score);
245      for x in basepair_probs.keys() {
246        if !(x.0 < pos_pair.0 && pos_pair.1 < x.1) {
247          continue;
248        }
249        let y = (x.0.to_usize().unwrap(), x.1.to_usize().unwrap());
250        if long_pos_pair.0 - y.0 - 1 + y.1 - long_pos_pair.1 - 1 > MAX_LOOP_LEN {
251          continue;
252        }
253        let y = get_twoloop_score(alignfold_scores, seq, &y, &long_pos_pair);
254        fold_scores
255          .twoloop_scores
256          .insert((x.0, x.1, pos_pair.0, pos_pair.1), y);
257      }
258    }
259    fold_scores
260  }
261}
262
263impl AlignfoldScores {
264  pub fn new(init_val: Score) -> AlignfoldScores {
265    let init_vals = [init_val; NUM_BASES];
266    let mat_2d = [init_vals; NUM_BASES];
267    let mat_3d = [mat_2d; NUM_BASES];
268    let mat_4d = [mat_3d; NUM_BASES];
269    AlignfoldScores {
270      // The CONTRAfold model.
271      hairpin_scores_len: [init_val; MAX_LOOP_LEN + 1],
272      bulge_scores_len: [init_val; MAX_LOOP_LEN],
273      interior_scores_len: [init_val; MAX_LOOP_LEN - 1],
274      interior_scores_symmetric: [init_val; MAX_INTERIOR_SYMMETRIC],
275      interior_scores_asymmetric: [init_val; MAX_INTERIOR_ASYMMETRIC],
276      stack_scores: mat_4d,
277      terminal_mismatch_scores: mat_4d,
278      dangling_scores_left: mat_3d,
279      dangling_scores_right: mat_3d,
280      helix_close_scores: mat_2d,
281      basepair_scores: mat_2d,
282      interior_scores_explicit: [[init_val; MAX_INTERIOR_EXPLICIT]; MAX_INTERIOR_EXPLICIT],
283      bulge_scores_0x1: init_vals,
284      interior_scores_1x1: mat_2d,
285      multibranch_score_base: init_val,
286      multibranch_score_basepair: init_val,
287      multibranch_score_unpair: init_val,
288      external_score_basepair: init_val,
289      external_score_unpair: init_val,
290      // The CONTRAlign model.
291      match2match_score: init_val,
292      match2insert_score: init_val,
293      init_match_score: init_val,
294      insert_extend_score: init_val,
295      init_insert_score: init_val,
296      insert_scores: init_vals,
297      match_scores: mat_2d,
298      // The cumulative parameters of the CONTRAfold model.
299      hairpin_scores_len_cumulative: [init_val; MAX_LOOP_LEN + 1],
300      bulge_scores_len_cumulative: [init_val; MAX_LOOP_LEN],
301      interior_scores_len_cumulative: [init_val; MAX_LOOP_LEN - 1],
302      interior_scores_symmetric_cumulative: [init_val; MAX_INTERIOR_SYMMETRIC],
303      interior_scores_asymmetric_cumulative: [init_val; MAX_INTERIOR_ASYMMETRIC],
304    }
305  }
306
307  pub fn len(&self) -> usize {
308    let mut sum = 0;
309    sum += self.hairpin_scores_len.len();
310    sum += self.bulge_scores_len.len();
311    sum += self.interior_scores_len.len();
312    sum += self.interior_scores_symmetric.len();
313    sum += self.interior_scores_asymmetric.len();
314    let len = self.stack_scores.len();
315    for i in 0..len {
316      for j in 0..len {
317        if !has_canonical_basepair(&(i, j)) {
318          continue;
319        }
320        for k in 0..len {
321          for l in 0..len {
322            if !has_canonical_basepair(&(k, l)) {
323              continue;
324            }
325            let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
326            if ((i, j), (k, l)) != dict_min_stack {
327              continue;
328            }
329            sum += 1;
330          }
331        }
332      }
333    }
334    let len = self.terminal_mismatch_scores.len();
335    for i in 0..len {
336      for j in 0..len {
337        if !has_canonical_basepair(&(i, j)) {
338          continue;
339        }
340        for _ in 0..len {
341          for _ in 0..len {
342            sum += 1;
343          }
344        }
345      }
346    }
347    let len = self.dangling_scores_left.len();
348    for i in 0..len {
349      for j in 0..len {
350        if !has_canonical_basepair(&(i, j)) {
351          continue;
352        }
353        for _ in 0..len {
354          sum += 1;
355        }
356      }
357    }
358    let len = self.dangling_scores_right.len();
359    for i in 0..len {
360      for j in 0..len {
361        if !has_canonical_basepair(&(i, j)) {
362          continue;
363        }
364        for _ in 0..len {
365          sum += 1;
366        }
367      }
368    }
369    let len = self.helix_close_scores.len();
370    for i in 0..len {
371      for j in 0..len {
372        if !has_canonical_basepair(&(i, j)) {
373          continue;
374        }
375        sum += 1;
376      }
377    }
378    let len = self.basepair_scores.len();
379    for i in 0..len {
380      for j in 0..len {
381        if !has_canonical_basepair(&(i, j)) {
382          continue;
383        }
384        let dict_min_basepair = get_dict_min_pair(&(i, j));
385        if (i, j) != dict_min_basepair {
386          continue;
387        }
388        sum += 1;
389      }
390    }
391    let len = self.interior_scores_explicit.len();
392    for i in 0..len {
393      for j in 0..len {
394        let dict_min_len_pair = get_dict_min_pair(&(i, j));
395        if (i, j) != dict_min_len_pair {
396          continue;
397        }
398        sum += 1;
399      }
400    }
401    sum += self.bulge_scores_0x1.len();
402    let len = self.interior_scores_1x1.len();
403    for i in 0..len {
404      for j in 0..len {
405        let dict_min_basepair = get_dict_min_pair(&(i, j));
406        if (i, j) != dict_min_basepair {
407          continue;
408        }
409        sum += 1;
410      }
411    }
412    sum += GROUP_SIZE_MULTIBRANCH;
413    sum += GROUP_SIZE_EXTERNAL;
414    sum += GROUP_SIZE_MATCH_TRANSITION;
415    sum += GROUP_SIZE_INSERT_TRANSITION;
416    sum += self.insert_scores.len();
417    let len = self.match_scores.len();
418    for i in 0..len {
419      for j in 0..len {
420        let dict_min_match = get_dict_min_pair(&(i, j));
421        if (i, j) != dict_min_match {
422          continue;
423        }
424        sum += 1;
425      }
426    }
427    sum
428  }
429
430  pub fn is_empty(&self) -> bool {
431    self.len() == 0
432  }
433
434  pub fn update_regularizers(&self, regularizers: &mut Regularizers) {
435    let mut regularizers2 = vec![0.; regularizers.len()];
436    let mut offset = 0;
437    let len = self.hairpin_scores_len.len();
438    let group_size = len;
439    let mut squared_sum = 0.;
440    for i in 0..len {
441      let x = self.hairpin_scores_len[i];
442      squared_sum += x * x;
443    }
444    let regularizer = get_regularizer(group_size, squared_sum);
445    for i in 0..len {
446      regularizers2[offset + i] = regularizer;
447    }
448    offset += group_size;
449    let len = self.bulge_scores_len.len();
450    let group_size = len;
451    let mut squared_sum = 0.;
452    for i in 0..len {
453      let x = self.bulge_scores_len[i];
454      squared_sum += x * x;
455    }
456    let regularizer = get_regularizer(group_size, squared_sum);
457    for i in 0..len {
458      regularizers2[offset + i] = regularizer;
459    }
460    offset += group_size;
461    let len = self.interior_scores_len.len();
462    let group_size = len;
463    let mut squared_sum = 0.;
464    for i in 0..len {
465      let x = self.interior_scores_len[i];
466      squared_sum += x * x;
467    }
468    let regularizer = get_regularizer(group_size, squared_sum);
469    for i in 0..len {
470      regularizers2[offset + i] = regularizer;
471    }
472    offset += group_size;
473    let len = self.interior_scores_symmetric.len();
474    let group_size = len;
475    let mut squared_sum = 0.;
476    for i in 0..len {
477      let x = self.interior_scores_symmetric[i];
478      squared_sum += x * x;
479    }
480    let regularizer = get_regularizer(group_size, squared_sum);
481    for i in 0..len {
482      regularizers2[offset + i] = regularizer;
483    }
484    offset += group_size;
485    let len = self.interior_scores_asymmetric.len();
486    let group_size = len;
487    let mut squared_sum = 0.;
488    for i in 0..len {
489      let x = self.interior_scores_asymmetric[i];
490      squared_sum += x * x;
491    }
492    let regularizer = get_regularizer(group_size, squared_sum);
493    for i in 0..len {
494      regularizers2[offset + i] = regularizer;
495    }
496    offset += group_size;
497    let len = self.stack_scores.len();
498    let mut effective_group_size = 0;
499    let mut squared_sum = 0.;
500    for i in 0..len {
501      for j in 0..len {
502        if !has_canonical_basepair(&(i, j)) {
503          continue;
504        }
505        for k in 0..len {
506          for l in 0..len {
507            if !has_canonical_basepair(&(k, l)) {
508              continue;
509            }
510            let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
511            if ((i, j), (k, l)) != dict_min_stack {
512              continue;
513            }
514            let x = self.stack_scores[i][j][k][l];
515            squared_sum += x * x;
516            effective_group_size += 1;
517          }
518        }
519      }
520    }
521    let regularizer = get_regularizer(effective_group_size, squared_sum);
522    for i in 0..len {
523      for j in 0..len {
524        if !has_canonical_basepair(&(i, j)) {
525          continue;
526        }
527        for k in 0..len {
528          for l in 0..len {
529            if !has_canonical_basepair(&(k, l)) {
530              continue;
531            }
532            let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
533            if ((i, j), (k, l)) != dict_min_stack {
534              continue;
535            }
536            regularizers2[offset] = regularizer;
537            offset += 1;
538          }
539        }
540      }
541    }
542    let len = self.terminal_mismatch_scores.len();
543    let mut effective_group_size = 0;
544    let mut squared_sum = 0.;
545    for i in 0..len {
546      for j in 0..len {
547        if !has_canonical_basepair(&(i, j)) {
548          continue;
549        }
550        for k in 0..len {
551          for l in 0..len {
552            let x = self.terminal_mismatch_scores[i][j][k][l];
553            squared_sum += x * x;
554            effective_group_size += 1;
555          }
556        }
557      }
558    }
559    let regularizer = get_regularizer(effective_group_size, squared_sum);
560    for i in 0..len {
561      for j in 0..len {
562        if !has_canonical_basepair(&(i, j)) {
563          continue;
564        }
565        for _ in 0..len {
566          for _ in 0..len {
567            regularizers2[offset] = regularizer;
568            offset += 1;
569          }
570        }
571      }
572    }
573    let len = self.dangling_scores_left.len();
574    let mut effective_group_size = 0;
575    let mut squared_sum = 0.;
576    for i in 0..len {
577      for j in 0..len {
578        if !has_canonical_basepair(&(i, j)) {
579          continue;
580        }
581        for k in 0..len {
582          let x = self.dangling_scores_left[i][j][k];
583          squared_sum += x * x;
584          effective_group_size += 1;
585        }
586      }
587    }
588    for i in 0..len {
589      for j in 0..len {
590        if !has_canonical_basepair(&(i, j)) {
591          continue;
592        }
593        for k in 0..len {
594          let x = self.dangling_scores_right[i][j][k];
595          squared_sum += x * x;
596          effective_group_size += 1;
597        }
598      }
599    }
600    let regularizer = get_regularizer(effective_group_size, squared_sum);
601    for i in 0..len {
602      for j in 0..len {
603        if !has_canonical_basepair(&(i, j)) {
604          continue;
605        }
606        for _ in 0..len {
607          regularizers2[offset] = regularizer;
608          offset += 1;
609        }
610      }
611    }
612    for i in 0..len {
613      for j in 0..len {
614        if !has_canonical_basepair(&(i, j)) {
615          continue;
616        }
617        for _ in 0..len {
618          regularizers2[offset] = regularizer;
619          offset += 1;
620        }
621      }
622    }
623    let len = self.helix_close_scores.len();
624    let mut effective_group_size = 0;
625    let mut squared_sum = 0.;
626    for i in 0..len {
627      for j in 0..len {
628        if !has_canonical_basepair(&(i, j)) {
629          continue;
630        }
631        let x = self.helix_close_scores[i][j];
632        squared_sum += x * x;
633        effective_group_size += 1;
634      }
635    }
636    let regularizer = get_regularizer(effective_group_size, squared_sum);
637    for i in 0..len {
638      for j in 0..len {
639        if !has_canonical_basepair(&(i, j)) {
640          continue;
641        }
642        regularizers2[offset] = regularizer;
643        offset += 1;
644      }
645    }
646    let len = self.basepair_scores.len();
647    let mut effective_group_size = 0;
648    let mut squared_sum = 0.;
649    for i in 0..len {
650      for j in 0..len {
651        if !has_canonical_basepair(&(i, j)) {
652          continue;
653        }
654        let dict_min_basepair = get_dict_min_pair(&(i, j));
655        if (i, j) != dict_min_basepair {
656          continue;
657        }
658        let x = self.basepair_scores[i][j];
659        squared_sum += x * x;
660        effective_group_size += 1;
661      }
662    }
663    let regularizer = get_regularizer(effective_group_size, squared_sum);
664    for i in 0..len {
665      for j in 0..len {
666        if !has_canonical_basepair(&(i, j)) {
667          continue;
668        }
669        let dict_min_basepair = get_dict_min_pair(&(i, j));
670        if (i, j) != dict_min_basepair {
671          continue;
672        }
673        regularizers2[offset] = regularizer;
674        offset += 1;
675      }
676    }
677    let len = self.interior_scores_explicit.len();
678    let mut effective_group_size = 0;
679    let mut squared_sum = 0.;
680    for i in 0..len {
681      for j in 0..len {
682        let dict_min_len_pair = get_dict_min_pair(&(i, j));
683        if (i, j) != dict_min_len_pair {
684          continue;
685        }
686        let x = self.interior_scores_explicit[i][j];
687        squared_sum += x * x;
688        effective_group_size += 1;
689      }
690    }
691    let regularizer = get_regularizer(effective_group_size, squared_sum);
692    for i in 0..len {
693      for j in 0..len {
694        let dict_min_len_pair = get_dict_min_pair(&(i, j));
695        if (i, j) != dict_min_len_pair {
696          continue;
697        }
698        regularizers2[offset] = regularizer;
699        offset += 1;
700      }
701    }
702    let len = self.bulge_scores_0x1.len();
703    let group_size = len;
704    let mut squared_sum = 0.;
705    for i in 0..len {
706      let x = self.bulge_scores_0x1[i];
707      squared_sum += x * x;
708    }
709    let regularizer = get_regularizer(group_size, squared_sum);
710    for i in 0..len {
711      regularizers2[offset + i] = regularizer;
712    }
713    offset += group_size;
714    let len = self.interior_scores_1x1.len();
715    let mut effective_group_size = 0;
716    let mut squared_sum = 0.;
717    for i in 0..len {
718      for j in 0..len {
719        let dict_min_basepair = get_dict_min_pair(&(i, j));
720        if (i, j) != dict_min_basepair {
721          continue;
722        }
723        let x = self.interior_scores_1x1[i][j];
724        squared_sum += x * x;
725        effective_group_size += 1;
726      }
727    }
728    let regularizer = get_regularizer(effective_group_size, squared_sum);
729    for i in 0..len {
730      for j in 0..len {
731        let dict_min_basepair = get_dict_min_pair(&(i, j));
732        if (i, j) != dict_min_basepair {
733          continue;
734        }
735        regularizers2[offset] = regularizer;
736        offset += 1;
737      }
738    }
739    let mut squared_sum = 0.;
740    squared_sum += self.multibranch_score_base * self.multibranch_score_base;
741    squared_sum += self.multibranch_score_basepair * self.multibranch_score_basepair;
742    squared_sum += self.multibranch_score_unpair * self.multibranch_score_unpair;
743    let regularizer = get_regularizer(GROUP_SIZE_MULTIBRANCH, squared_sum);
744    regularizers2[offset] = regularizer;
745    offset += 1;
746    regularizers2[offset] = regularizer;
747    offset += 1;
748    regularizers2[offset] = regularizer;
749    offset += 1;
750    let mut squared_sum = 0.;
751    squared_sum += self.external_score_basepair * self.external_score_basepair;
752    squared_sum += self.external_score_unpair * self.external_score_unpair;
753    let regularizer = get_regularizer(GROUP_SIZE_EXTERNAL, squared_sum);
754    regularizers2[offset] = regularizer;
755    offset += 1;
756    regularizers2[offset] = regularizer;
757    offset += 1;
758    let mut squared_sum = 0.;
759    squared_sum += self.match2match_score * self.match2match_score;
760    squared_sum += self.match2insert_score * self.match2insert_score;
761    squared_sum += self.init_match_score * self.init_match_score;
762    let regularizer = get_regularizer(GROUP_SIZE_MATCH_TRANSITION, squared_sum);
763    regularizers2[offset] = regularizer;
764    offset += 1;
765    regularizers2[offset] = regularizer;
766    offset += 1;
767    regularizers2[offset] = regularizer;
768    offset += 1;
769    let mut squared_sum = 0.;
770    squared_sum += self.insert_extend_score * self.insert_extend_score;
771    squared_sum += self.init_insert_score * self.init_insert_score;
772    let regularizer = get_regularizer(GROUP_SIZE_INSERT_TRANSITION, squared_sum);
773    regularizers2[offset] = regularizer;
774    offset += 1;
775    regularizers2[offset] = regularizer;
776    offset += 1;
777    let len = self.insert_scores.len();
778    let group_size = len;
779    let mut squared_sum = 0.;
780    for i in 0..len {
781      let x = self.insert_scores[i];
782      squared_sum += x * x;
783    }
784    let regularizer = get_regularizer(group_size, squared_sum);
785    for i in 0..len {
786      regularizers2[offset + i] = regularizer;
787    }
788    offset += group_size;
789    let len = self.match_scores.len();
790    let mut effective_group_size = 0;
791    let mut squared_sum = 0.;
792    for i in 0..len {
793      for j in 0..len {
794        let dict_min_match = get_dict_min_pair(&(i, j));
795        if (i, j) != dict_min_match {
796          continue;
797        }
798        let x = self.match_scores[i][j];
799        squared_sum += x * x;
800        effective_group_size += 1;
801      }
802    }
803    let regularizer = get_regularizer(effective_group_size, squared_sum);
804    for i in 0..len {
805      for j in 0..len {
806        let dict_min_match = get_dict_min_pair(&(i, j));
807        if (i, j) != dict_min_match {
808          continue;
809        }
810        regularizers2[offset] = regularizer;
811        offset += 1;
812      }
813    }
814    assert!(offset == self.len());
815    *regularizers = Array1::from(regularizers2);
816  }
817
818  pub fn update<T: HashIndex>(
819    &mut self,
820    train_data: &[TrainDatum<T>],
821    regularizers: &mut Regularizers,
822  ) {
823    let f = |_: &BfgsScores| self.get_cost(train_data, regularizers) as BfgsScore;
824    let g = |_: &BfgsScores| scores2bfgs_scores(&self.get_grad(train_data, regularizers));
825    let uses_cumulative_scores = false;
826    match bfgs(
827      scores2bfgs_scores(&struct2vec(self, uses_cumulative_scores)),
828      f,
829      g,
830    ) {
831      Ok(solution) => {
832        *self = vec2struct(&bfgs_scores2scores(&solution), uses_cumulative_scores);
833      }
834      Err(_) => {
835        println!("BFGS failed");
836      }
837    };
838    self.update_regularizers(regularizers);
839    self.mirror();
840    self.accumulate();
841  }
842
843  pub fn mirror(&mut self) {
844    for i in 0..NUM_BASES {
845      for j in 0..NUM_BASES {
846        if !has_canonical_basepair(&(i, j)) {
847          continue;
848        }
849        for k in 0..NUM_BASES {
850          for l in 0..NUM_BASES {
851            if !has_canonical_basepair(&(k, l)) {
852              continue;
853            }
854            let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
855            if ((i, j), (k, l)) == dict_min_stack {
856              continue;
857            }
858            self.stack_scores[i][j][k][l] = self.stack_scores[dict_min_stack.0 .0]
859              [dict_min_stack.0 .1][dict_min_stack.1 .0][dict_min_stack.1 .1];
860          }
861        }
862      }
863    }
864    for i in 0..NUM_BASES {
865      for j in 0..NUM_BASES {
866        if !has_canonical_basepair(&(i, j)) {
867          continue;
868        }
869        let dict_min_basepair = get_dict_min_pair(&(i, j));
870        if (i, j) == dict_min_basepair {
871          continue;
872        }
873        self.basepair_scores[i][j] = self.basepair_scores[dict_min_basepair.0][dict_min_basepair.1];
874      }
875    }
876    let len = self.interior_scores_explicit.len();
877    for i in 0..len {
878      for j in 0..len {
879        let dict_min_len_pair = get_dict_min_pair(&(i, j));
880        if (i, j) == dict_min_len_pair {
881          continue;
882        }
883        self.interior_scores_explicit[i][j] =
884          self.interior_scores_explicit[dict_min_len_pair.0][dict_min_len_pair.1];
885      }
886    }
887    for i in 0..NUM_BASES {
888      for j in 0..NUM_BASES {
889        let dict_min_basepair = get_dict_min_pair(&(i, j));
890        if (i, j) == dict_min_basepair {
891          continue;
892        }
893        self.interior_scores_1x1[i][j] =
894          self.interior_scores_1x1[dict_min_basepair.0][dict_min_basepair.1];
895      }
896    }
897    for i in 0..NUM_BASES {
898      for j in 0..NUM_BASES {
899        let dict_min_match = get_dict_min_pair(&(i, j));
900        if (i, j) == dict_min_match {
901          continue;
902        }
903        self.match_scores[i][j] = self.match_scores[dict_min_match.0][dict_min_match.1];
904      }
905    }
906  }
907
908  pub fn accumulate(&mut self) {
909    let mut sum = 0.;
910    for i in 0..self.hairpin_scores_len_cumulative.len() {
911      sum += self.hairpin_scores_len[i];
912      self.hairpin_scores_len_cumulative[i] = sum;
913    }
914    let mut sum = 0.;
915    for i in 0..self.bulge_scores_len_cumulative.len() {
916      sum += self.bulge_scores_len[i];
917      self.bulge_scores_len_cumulative[i] = sum;
918    }
919    let mut sum = 0.;
920    for i in 0..self.interior_scores_len_cumulative.len() {
921      sum += self.interior_scores_len[i];
922      self.interior_scores_len_cumulative[i] = sum;
923    }
924    let mut sum = 0.;
925    for i in 0..self.interior_scores_symmetric_cumulative.len() {
926      sum += self.interior_scores_symmetric[i];
927      self.interior_scores_symmetric_cumulative[i] = sum;
928    }
929    let mut sum = 0.;
930    for i in 0..self.interior_scores_asymmetric_cumulative.len() {
931      sum += self.interior_scores_asymmetric[i];
932      self.interior_scores_asymmetric_cumulative[i] = sum;
933    }
934  }
935
936  pub fn get_grad<T: HashIndex>(
937    &self,
938    train_data: &[TrainDatum<T>],
939    regularizers: &Regularizers,
940  ) -> Scores {
941    let uses_cumulative_scores = false;
942    let alignfold_scores = struct2vec(self, uses_cumulative_scores);
943    let mut grad = AlignfoldScores::new(0.);
944    for train_datum in train_data {
945      let obs = &train_datum.alignfold_counts_observed;
946      let expect = &train_datum.alignfold_counts_expected;
947      let mut sum = 0.;
948      let len = obs.hairpin_scores_len.len();
949      for i in (0..len).rev() {
950        let x = obs.hairpin_scores_len[i];
951        let y = expect.hairpin_scores_len[i];
952        sum -= x - y;
953        grad.hairpin_scores_len[i] += sum;
954      }
955      let len = obs.bulge_scores_len.len();
956      let mut sum = 0.;
957      for i in (0..len).rev() {
958        let x = obs.bulge_scores_len[i];
959        let y = expect.bulge_scores_len[i];
960        sum -= x - y;
961        grad.bulge_scores_len[i] += sum;
962      }
963      let len = obs.interior_scores_len.len();
964      let mut sum = 0.;
965      for i in (0..len).rev() {
966        let x = obs.interior_scores_len[i];
967        let y = expect.interior_scores_len[i];
968        sum -= x - y;
969        grad.interior_scores_len[i] += sum;
970      }
971      let len = obs.interior_scores_symmetric.len();
972      let mut sum = 0.;
973      for i in (0..len).rev() {
974        let x = obs.interior_scores_symmetric[i];
975        let y = expect.interior_scores_symmetric[i];
976        sum -= x - y;
977        grad.interior_scores_symmetric[i] += sum;
978      }
979      let len = obs.interior_scores_asymmetric.len();
980      let mut sum = 0.;
981      for i in (0..len).rev() {
982        let x = obs.interior_scores_asymmetric[i];
983        let y = expect.interior_scores_asymmetric[i];
984        sum -= x - y;
985        grad.interior_scores_asymmetric[i] += sum;
986      }
987      for i in 0..NUM_BASES {
988        for j in 0..NUM_BASES {
989          if !has_canonical_basepair(&(i, j)) {
990            continue;
991          }
992          for k in 0..NUM_BASES {
993            for l in 0..NUM_BASES {
994              if !has_canonical_basepair(&(k, l)) {
995                continue;
996              }
997              let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
998              if ((i, j), (k, l)) != dict_min_stack {
999                continue;
1000              }
1001              let x = obs.stack_scores[i][j][k][l];
1002              let y = expect.stack_scores[i][j][k][l];
1003              grad.stack_scores[i][j][k][l] -= x - y;
1004            }
1005          }
1006        }
1007      }
1008      for i in 0..NUM_BASES {
1009        for j in 0..NUM_BASES {
1010          if !has_canonical_basepair(&(i, j)) {
1011            continue;
1012          }
1013          for k in 0..NUM_BASES {
1014            for l in 0..NUM_BASES {
1015              let x = obs.terminal_mismatch_scores[i][j][k][l];
1016              let y = expect.terminal_mismatch_scores[i][j][k][l];
1017              grad.terminal_mismatch_scores[i][j][k][l] -= x - y;
1018            }
1019          }
1020        }
1021      }
1022      for i in 0..NUM_BASES {
1023        for j in 0..NUM_BASES {
1024          if !has_canonical_basepair(&(i, j)) {
1025            continue;
1026          }
1027          for k in 0..NUM_BASES {
1028            let x = obs.dangling_scores_left[i][j][k];
1029            let y = expect.dangling_scores_left[i][j][k];
1030            grad.dangling_scores_left[i][j][k] -= x - y;
1031          }
1032        }
1033      }
1034      for i in 0..NUM_BASES {
1035        for j in 0..NUM_BASES {
1036          if !has_canonical_basepair(&(i, j)) {
1037            continue;
1038          }
1039          for k in 0..NUM_BASES {
1040            let x = obs.dangling_scores_right[i][j][k];
1041            let y = expect.dangling_scores_right[i][j][k];
1042            grad.dangling_scores_right[i][j][k] -= x - y;
1043          }
1044        }
1045      }
1046      for i in 0..NUM_BASES {
1047        for j in 0..NUM_BASES {
1048          if !has_canonical_basepair(&(i, j)) {
1049            continue;
1050          }
1051          let x = obs.helix_close_scores[i][j];
1052          let y = expect.helix_close_scores[i][j];
1053          grad.helix_close_scores[i][j] -= x - y;
1054        }
1055      }
1056      for i in 0..NUM_BASES {
1057        for j in 0..NUM_BASES {
1058          if !has_canonical_basepair(&(i, j)) {
1059            continue;
1060          }
1061          let dict_min_basepair = get_dict_min_pair(&(i, j));
1062          if (i, j) != dict_min_basepair {
1063            continue;
1064          }
1065          let x = obs.basepair_scores[i][j];
1066          let y = expect.basepair_scores[i][j];
1067          grad.basepair_scores[i][j] -= x - y;
1068        }
1069      }
1070      let len = obs.interior_scores_explicit.len();
1071      for i in 0..len {
1072        for j in 0..len {
1073          let dict_min_len_pair = get_dict_min_pair(&(i, j));
1074          if (i, j) != dict_min_len_pair {
1075            continue;
1076          }
1077          let x = obs.interior_scores_explicit[i][j];
1078          let y = expect.interior_scores_explicit[i][j];
1079          grad.interior_scores_explicit[i][j] -= x - y;
1080        }
1081      }
1082      for i in 0..NUM_BASES {
1083        let x = obs.bulge_scores_0x1[i];
1084        let y = expect.bulge_scores_0x1[i];
1085        grad.bulge_scores_0x1[i] -= x - y;
1086      }
1087      for i in 0..NUM_BASES {
1088        for j in 0..NUM_BASES {
1089          let dict_min_basepair = get_dict_min_pair(&(i, j));
1090          if (i, j) != dict_min_basepair {
1091            continue;
1092          }
1093          let x = obs.interior_scores_1x1[i][j];
1094          let y = expect.interior_scores_1x1[i][j];
1095          grad.interior_scores_1x1[i][j] -= x - y;
1096        }
1097      }
1098      let obs_score = obs.multibranch_score_base;
1099      let expect_score = expect.multibranch_score_base;
1100      grad.multibranch_score_base -= obs_score - expect_score;
1101      let obs_score = obs.multibranch_score_basepair;
1102      let expect_score = expect.multibranch_score_basepair;
1103      grad.multibranch_score_basepair -= obs_score - expect_score;
1104      let obs_score = obs.multibranch_score_unpair;
1105      let expect_score = expect.multibranch_score_unpair;
1106      grad.multibranch_score_unpair -= obs_score - expect_score;
1107      let obs_score = obs.external_score_basepair;
1108      let expect_score = expect.external_score_basepair;
1109      grad.external_score_basepair -= obs_score - expect_score;
1110      let obs_score = obs.external_score_unpair;
1111      let expect_score = expect.external_score_unpair;
1112      grad.external_score_unpair -= obs_score - expect_score;
1113      let obs_score = obs.match2match_score;
1114      let expect_score = expect.match2match_score;
1115      grad.match2match_score -= obs_score - expect_score;
1116      let obs_score = obs.match2insert_score;
1117      let expect_score = expect.match2insert_score;
1118      grad.match2insert_score -= obs_score - expect_score;
1119      let obs_score = obs.insert_extend_score;
1120      let expect_score = expect.insert_extend_score;
1121      grad.insert_extend_score -= obs_score - expect_score;
1122      let obs_score = obs.init_match_score;
1123      let expect_score = expect.init_match_score;
1124      grad.init_match_score -= obs_score - expect_score;
1125      let obs_score = obs.init_insert_score;
1126      let expect_score = expect.init_insert_score;
1127      grad.init_insert_score -= obs_score - expect_score;
1128      for i in 0..NUM_BASES {
1129        let x = obs.insert_scores[i];
1130        let y = expect.insert_scores[i];
1131        grad.insert_scores[i] -= x - y;
1132      }
1133      for i in 0..NUM_BASES {
1134        for j in 0..NUM_BASES {
1135          let dict_min_match = get_dict_min_pair(&(i, j));
1136          if (i, j) != dict_min_match {
1137            continue;
1138          }
1139          let x = obs.match_scores[i][j];
1140          let y = expect.match_scores[i][j];
1141          grad.match_scores[i][j] -= x - y;
1142        }
1143      }
1144    }
1145    struct2vec(&grad, uses_cumulative_scores) + regularizers * &alignfold_scores
1146  }
1147
1148  pub fn get_cost<T: HashIndex>(
1149    &self,
1150    train_data: &[TrainDatum<T>],
1151    regularizers: &Regularizers,
1152  ) -> Score {
1153    let uses_cumulative_scores = true;
1154    let mut log_likelihood = 0.;
1155    let alignfold_scores_cumulative = struct2vec(self, uses_cumulative_scores);
1156    let uses_cumulative_scores = false;
1157    for train_datum in train_data {
1158      let obs = &train_datum.alignfold_counts_observed;
1159      log_likelihood += alignfold_scores_cumulative.dot(&struct2vec(obs, uses_cumulative_scores));
1160      log_likelihood -= train_datum.global_sum;
1161    }
1162    let alignfold_scores = struct2vec(self, uses_cumulative_scores);
1163    let product = regularizers * &alignfold_scores;
1164    -log_likelihood + product.dot(&alignfold_scores) / 2.
1165  }
1166
1167  pub fn rand_init(&mut self) {
1168    let len = self.len();
1169    let std_deviation = 1. / (len as Score).sqrt();
1170    let normal = Normal::new(0., std_deviation).unwrap();
1171    let mut thread_rng = thread_rng();
1172    for x in self.hairpin_scores_len.iter_mut() {
1173      *x = normal.sample(&mut thread_rng);
1174    }
1175    for x in self.bulge_scores_len.iter_mut() {
1176      *x = normal.sample(&mut thread_rng);
1177    }
1178    for x in self.interior_scores_len.iter_mut() {
1179      *x = normal.sample(&mut thread_rng);
1180    }
1181    for x in self.interior_scores_symmetric.iter_mut() {
1182      *x = normal.sample(&mut thread_rng);
1183    }
1184    for x in self.interior_scores_asymmetric.iter_mut() {
1185      *x = normal.sample(&mut thread_rng);
1186    }
1187    let len = self.stack_scores.len();
1188    for i in 0..len {
1189      for j in 0..len {
1190        if !has_canonical_basepair(&(i, j)) {
1191          continue;
1192        }
1193        for k in 0..len {
1194          for l in 0..len {
1195            if !has_canonical_basepair(&(k, l)) {
1196              continue;
1197            }
1198            let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
1199            if ((i, j), (k, l)) != dict_min_stack {
1200              continue;
1201            }
1202            let x = normal.sample(&mut thread_rng);
1203            self.stack_scores[i][j][k][l] = x;
1204          }
1205        }
1206      }
1207    }
1208    let len = self.terminal_mismatch_scores.len();
1209    for i in 0..len {
1210      for j in 0..len {
1211        if !has_canonical_basepair(&(i, j)) {
1212          continue;
1213        }
1214        for k in 0..len {
1215          for l in 0..len {
1216            let x = normal.sample(&mut thread_rng);
1217            self.terminal_mismatch_scores[i][j][k][l] = x;
1218          }
1219        }
1220      }
1221    }
1222    let len = self.dangling_scores_left.len();
1223    for i in 0..len {
1224      for j in 0..len {
1225        if !has_canonical_basepair(&(i, j)) {
1226          continue;
1227        }
1228        for k in 0..len {
1229          let x = normal.sample(&mut thread_rng);
1230          self.dangling_scores_left[i][j][k] = x;
1231        }
1232      }
1233    }
1234    let len = self.dangling_scores_right.len();
1235    for i in 0..len {
1236      for j in 0..len {
1237        if !has_canonical_basepair(&(i, j)) {
1238          continue;
1239        }
1240        for k in 0..len {
1241          let x = normal.sample(&mut thread_rng);
1242          self.dangling_scores_right[i][j][k] = x;
1243        }
1244      }
1245    }
1246    let len = self.helix_close_scores.len();
1247    for i in 0..len {
1248      for j in 0..len {
1249        if !has_canonical_basepair(&(i, j)) {
1250          continue;
1251        }
1252        let x = normal.sample(&mut thread_rng);
1253        self.helix_close_scores[i][j] = x;
1254      }
1255    }
1256    let len = self.basepair_scores.len();
1257    for i in 0..len {
1258      for j in 0..len {
1259        if !has_canonical_basepair(&(i, j)) {
1260          continue;
1261        }
1262        let dict_min_basepair = get_dict_min_pair(&(i, j));
1263        if (i, j) != dict_min_basepair {
1264          continue;
1265        }
1266        let x = normal.sample(&mut thread_rng);
1267        self.basepair_scores[i][j] = x;
1268      }
1269    }
1270    let len = self.interior_scores_explicit.len();
1271    for i in 0..len {
1272      for j in 0..len {
1273        let dict_min_len_pair = get_dict_min_pair(&(i, j));
1274        if (i, j) != dict_min_len_pair {
1275          continue;
1276        }
1277        let x = normal.sample(&mut thread_rng);
1278        self.interior_scores_explicit[i][j] = x;
1279      }
1280    }
1281    for x in &mut self.bulge_scores_0x1 {
1282      *x = normal.sample(&mut thread_rng);
1283    }
1284    let len = self.interior_scores_1x1.len();
1285    for i in 0..len {
1286      for j in 0..len {
1287        let dict_min_basepair = get_dict_min_pair(&(i, j));
1288        if (i, j) != dict_min_basepair {
1289          continue;
1290        }
1291        let x = normal.sample(&mut thread_rng);
1292        self.interior_scores_1x1[i][j] = x;
1293      }
1294    }
1295    self.multibranch_score_base = normal.sample(&mut thread_rng);
1296    self.multibranch_score_basepair = normal.sample(&mut thread_rng);
1297    self.multibranch_score_unpair = normal.sample(&mut thread_rng);
1298    self.external_score_basepair = normal.sample(&mut thread_rng);
1299    self.external_score_unpair = normal.sample(&mut thread_rng);
1300    self.match2match_score = normal.sample(&mut thread_rng);
1301    self.match2insert_score = normal.sample(&mut thread_rng);
1302    self.insert_extend_score = normal.sample(&mut thread_rng);
1303    self.init_match_score = normal.sample(&mut thread_rng);
1304    self.init_insert_score = normal.sample(&mut thread_rng);
1305    let len = self.insert_scores.len();
1306    for i in 0..len {
1307      let x = normal.sample(&mut thread_rng);
1308      self.insert_scores[i] = x;
1309    }
1310    let len = self.match_scores.len();
1311    for i in 0..len {
1312      for j in 0..len {
1313        let dict_min_match = get_dict_min_pair(&(i, j));
1314        if (i, j) != dict_min_match {
1315          continue;
1316        }
1317        let x = normal.sample(&mut thread_rng);
1318        self.match_scores[i][j] = x;
1319      }
1320    }
1321    self.mirror();
1322    self.accumulate();
1323  }
1324
1325  pub fn transfer(&mut self) {
1326    for (x, &y) in self
1327      .hairpin_scores_len
1328      .iter_mut()
1329      .zip(HAIRPIN_SCORES_LEN_ATLEAST.iter())
1330    {
1331      *x = y;
1332    }
1333    for (x, &y) in self
1334      .bulge_scores_len
1335      .iter_mut()
1336      .zip(BULGE_SCORES_LEN_ATLEAST.iter())
1337    {
1338      *x = y;
1339    }
1340    for (x, &y) in self
1341      .interior_scores_len
1342      .iter_mut()
1343      .zip(INTERIOR_SCORES_LEN_ATLEAST.iter())
1344    {
1345      *x = y;
1346    }
1347    for (x, &y) in self
1348      .interior_scores_symmetric
1349      .iter_mut()
1350      .zip(INTERIOR_SCORES_SYMMETRIC_ATLEAST.iter())
1351    {
1352      *x = y;
1353    }
1354    for (x, &y) in self
1355      .interior_scores_asymmetric
1356      .iter_mut()
1357      .zip(INTERIOR_SCORES_ASYMMETRIC_ATLEAST.iter())
1358    {
1359      *x = y;
1360    }
1361    for (i, x) in STACK_SCORES_CONTRA.iter().enumerate() {
1362      for (j, x) in x.iter().enumerate() {
1363        if !has_canonical_basepair(&(i, j)) {
1364          continue;
1365        }
1366        for (k, x) in x.iter().enumerate() {
1367          for (l, &x) in x.iter().enumerate() {
1368            if !has_canonical_basepair(&(k, l)) {
1369              continue;
1370            }
1371            self.stack_scores[i][j][k][l] = x;
1372          }
1373        }
1374      }
1375    }
1376    for (i, x) in TERMINAL_MISMATCH_SCORES_CONTRA.iter().enumerate() {
1377      for (j, x) in x.iter().enumerate() {
1378        if !has_canonical_basepair(&(i, j)) {
1379          continue;
1380        }
1381        for (k, x) in x.iter().enumerate() {
1382          for (l, &x) in x.iter().enumerate() {
1383            self.terminal_mismatch_scores[i][j][k][l] = x;
1384          }
1385        }
1386      }
1387    }
1388    for (i, x) in DANGLING_SCORES_LEFT.iter().enumerate() {
1389      for (j, x) in x.iter().enumerate() {
1390        if !has_canonical_basepair(&(i, j)) {
1391          continue;
1392        }
1393        for (k, &x) in x.iter().enumerate() {
1394          self.dangling_scores_left[i][j][k] = x;
1395        }
1396      }
1397    }
1398    for (i, x) in DANGLING_SCORES_RIGHT.iter().enumerate() {
1399      for (j, x) in x.iter().enumerate() {
1400        if !has_canonical_basepair(&(i, j)) {
1401          continue;
1402        }
1403        for (k, &x) in x.iter().enumerate() {
1404          self.dangling_scores_right[i][j][k] = x;
1405        }
1406      }
1407    }
1408    for (i, x) in HELIX_CLOSE_SCORES.iter().enumerate() {
1409      for (j, &x) in x.iter().enumerate() {
1410        if !has_canonical_basepair(&(i, j)) {
1411          continue;
1412        }
1413        self.helix_close_scores[i][j] = x;
1414      }
1415    }
1416    for (i, x) in BASEPAIR_SCORES.iter().enumerate() {
1417      for (j, &x) in x.iter().enumerate() {
1418        if !has_canonical_basepair(&(i, j)) {
1419          continue;
1420        }
1421        self.basepair_scores[i][j] = x;
1422      }
1423    }
1424    for (i, x) in INTERIOR_SCORES_EXPLICIT.iter().enumerate() {
1425      for (j, &x) in x.iter().enumerate() {
1426        self.interior_scores_explicit[i][j] = x;
1427      }
1428    }
1429    for (x, &y) in self
1430      .bulge_scores_0x1
1431      .iter_mut()
1432      .zip(BULGE_SCORES_0X1.iter())
1433    {
1434      *x = y;
1435    }
1436    for (i, x) in INTERIOR_SCORES_1X1_CONTRA.iter().enumerate() {
1437      for (j, &x) in x.iter().enumerate() {
1438        self.interior_scores_1x1[i][j] = x;
1439      }
1440    }
1441    self.multibranch_score_base = MULTIBRANCH_SCORE_BASE;
1442    self.multibranch_score_basepair = MULTIBRANCH_SCORE_BASEPAIR;
1443    self.multibranch_score_unpair = MULTIBRANCH_SCORE_UNPAIR;
1444    self.external_score_basepair = EXTERNAL_SCORE_BASEPAIR;
1445    self.external_score_unpair = EXTERNAL_SCORE_UNPAIR;
1446    self.match2match_score = MATCH2MATCH_SCORE;
1447    self.match2insert_score = MATCH2INSERT_SCORE;
1448    self.insert_extend_score = INSERT_EXTEND_SCORE;
1449    self.init_match_score = INIT_MATCH_SCORE;
1450    self.init_insert_score = INIT_INSERT_SCORE;
1451    for (i, &x) in INSERT_SCORES.iter().enumerate() {
1452      self.insert_scores[i] = x;
1453    }
1454    for (i, x) in MATCH_SCORES.iter().enumerate() {
1455      for (j, &x) in x.iter().enumerate() {
1456        self.match_scores[i][j] = x;
1457      }
1458    }
1459    self.accumulate();
1460  }
1461}
1462
1463impl<T: HashIndex> TrainDatum<T> {
1464  pub fn origin() -> TrainDatum<T> {
1465    TrainDatum {
1466      seq_pair: (Seq::new(), Seq::new()),
1467      seq_pair_gapped: (Seq::new(), Seq::new()),
1468      alignfold_counts_observed: AlignfoldScores::new(0.),
1469      alignfold_counts_expected: AlignfoldScores::new(NEG_INFINITY),
1470      basepair_probs_pair: (SparseProbMat::<T>::default(), SparseProbMat::<T>::default()),
1471      max_basepair_span_pair: (T::zero(), T::zero()),
1472      global_sum: NEG_INFINITY,
1473      forward_pos_pairs: PosPairMatSet::<T>::default(),
1474      backward_pos_pairs: PosPairMatSet::<T>::default(),
1475      pos_quads_hashed_lens: PosQuadsHashedLens::<T>::default(),
1476      matchable_poss: SparsePosSets::<T>::default(),
1477      matchable_poss2: SparsePosSets::<T>::default(),
1478      fold_scores_pair: (FoldScoresTrained::<T>::new(), FoldScoresTrained::<T>::new()),
1479      match_probs: SparseProbMat::<T>::default(),
1480      alignfold: PairAlignfold::<T>::new(),
1481      accuracy: NEG_INFINITY,
1482    }
1483  }
1484
1485  pub fn new(
1486    input_file_path: &Path,
1487    min_basepair_prob: Prob,
1488    min_match_prob: Prob,
1489    align_scores: &AlignScores,
1490  ) -> TrainDatum<T> {
1491    let fasta_file_reader = Reader::from_file(Path::new(input_file_path)).unwrap();
1492    let fasta_records: Vec<Record> = fasta_file_reader
1493      .records()
1494      .map(|rec| rec.unwrap())
1495      .collect();
1496    let consensus_fold = fasta_records[2].seq();
1497    let seq_pair = (
1498      bytes2seq_gapped(fasta_records[0].seq()),
1499      bytes2seq_gapped(fasta_records[1].seq()),
1500    );
1501    let mut seq_pair_ungapped = (gapped2ungapped(&seq_pair.0), gapped2ungapped(&seq_pair.1));
1502    let uses_contra_model = false;
1503    let allows_short_hairpins = true;
1504    let basepair_probs_pair = (
1505      filter_basepair_probs::<T>(
1506        &mccaskill_algo(
1507          &seq_pair_ungapped.0[..],
1508          uses_contra_model,
1509          allows_short_hairpins,
1510          &FoldScoreSets::new(0.),
1511        )
1512        .0,
1513        min_basepair_prob,
1514      ),
1515      filter_basepair_probs::<T>(
1516        &mccaskill_algo(
1517          &seq_pair_ungapped.1[..],
1518          uses_contra_model,
1519          allows_short_hairpins,
1520          &FoldScoreSets::new(0.),
1521        )
1522        .0,
1523        min_basepair_prob,
1524      ),
1525    );
1526    seq_pair_ungapped.0.insert(0, PSEUDO_BASE);
1527    seq_pair_ungapped.0.push(PSEUDO_BASE);
1528    seq_pair_ungapped.1.insert(0, PSEUDO_BASE);
1529    seq_pair_ungapped.1.push(PSEUDO_BASE);
1530    let seq_len_pair = (
1531      T::from_usize(seq_pair_ungapped.0.len()).unwrap(),
1532      T::from_usize(seq_pair_ungapped.1.len()).unwrap(),
1533    );
1534    let match_probs = filter_match_probs(
1535      &durbin_algo(
1536        &(&seq_pair_ungapped.0[..], &seq_pair_ungapped.1[..]),
1537        align_scores,
1538      ),
1539      min_match_prob,
1540    );
1541    let (
1542      forward_pos_pairs,
1543      backward_pos_pairs,
1544      _,
1545      pos_quads_hashed_lens,
1546      matchable_poss,
1547      matchable_poss2,
1548    ) = get_sparse_poss(
1549      &(&basepair_probs_pair.0, &basepair_probs_pair.1),
1550      &match_probs,
1551      &seq_len_pair,
1552    );
1553    let max_basepair_span_pair = (
1554      get_max_basepair_span::<T>(&basepair_probs_pair.0),
1555      get_max_basepair_span::<T>(&basepair_probs_pair.1),
1556    );
1557    let mut train_datum = TrainDatum {
1558      seq_pair: seq_pair_ungapped,
1559      seq_pair_gapped: seq_pair,
1560      alignfold_counts_observed: AlignfoldScores::new(0.),
1561      alignfold_counts_expected: AlignfoldScores::new(NEG_INFINITY),
1562      basepair_probs_pair,
1563      max_basepair_span_pair,
1564      global_sum: NEG_INFINITY,
1565      forward_pos_pairs,
1566      backward_pos_pairs,
1567      pos_quads_hashed_lens,
1568      matchable_poss,
1569      matchable_poss2,
1570      fold_scores_pair: (FoldScoresTrained::<T>::new(), FoldScoresTrained::<T>::new()),
1571      match_probs,
1572      alignfold: PairAlignfold::<T>::new(),
1573      accuracy: NEG_INFINITY,
1574    };
1575    train_datum.obs2counts(consensus_fold);
1576    train_datum
1577  }
1578
1579  pub fn obs2counts(&mut self, consensus_fold: TextSlice) {
1580    let align_len = consensus_fold.len();
1581    let mut inserted = false;
1582    let mut inserted2 = inserted;
1583    let seq_pair = &self.seq_pair_gapped;
1584    let mut pos_pair = (T::one(), T::one());
1585    for (i, &notation) in consensus_fold.iter().enumerate() {
1586      let basepair = (seq_pair.0[i], seq_pair.1[i]);
1587      if notation != UNPAIR {
1588        let dict_min_match = get_dict_min_pair(&basepair);
1589        self.alignfold_counts_observed.match_scores[dict_min_match.0][dict_min_match.1] += 1.;
1590        if i == 0 {
1591          self.alignfold_counts_observed.init_match_score += 1.;
1592        } else if inserted || inserted2 {
1593          self.alignfold_counts_observed.match2insert_score += 1.;
1594        } else {
1595          self.alignfold_counts_observed.match2match_score += 1.;
1596        }
1597        inserted = false;
1598        inserted2 = inserted;
1599        if basepair.1 == PSEUDO_BASE {
1600          self.alignfold.inserted_poss.insert(pos_pair.0);
1601          pos_pair.0 = pos_pair.0 + T::one();
1602        } else if basepair.0 == PSEUDO_BASE {
1603          self.alignfold.deleted_poss.insert(pos_pair.1);
1604          pos_pair.1 = pos_pair.1 + T::one();
1605        } else {
1606          self.alignfold.matched_pos_pairs.insert(pos_pair);
1607          pos_pair.0 = pos_pair.0 + T::one();
1608          pos_pair.1 = pos_pair.1 + T::one();
1609        }
1610        continue;
1611      }
1612      if basepair.1 == PSEUDO_BASE {
1613        if i == 0 {
1614          self.alignfold_counts_observed.init_insert_score += 1.;
1615        } else if inserted {
1616          self.alignfold_counts_observed.insert_extend_score += 1.;
1617        } else if inserted2 {
1618          inserted = true;
1619          inserted2 = false;
1620        } else {
1621          self.alignfold_counts_observed.match2insert_score += 1.;
1622          inserted = true;
1623        }
1624        self.alignfold_counts_observed.insert_scores[basepair.0] += 1.;
1625        self.alignfold.inserted_poss.insert(pos_pair.0);
1626        pos_pair.0 = pos_pair.0 + T::one();
1627      } else if basepair.0 == PSEUDO_BASE {
1628        if i == 0 {
1629          self.alignfold_counts_observed.init_insert_score += 1.;
1630          inserted2 = true;
1631        } else if inserted2 {
1632          self.alignfold_counts_observed.insert_extend_score += 1.;
1633        } else if inserted {
1634          inserted2 = true;
1635          inserted = false;
1636        } else {
1637          self.alignfold_counts_observed.match2insert_score += 1.;
1638          inserted2 = true;
1639        }
1640        self.alignfold_counts_observed.insert_scores[basepair.1] += 1.;
1641        self.alignfold.deleted_poss.insert(pos_pair.1);
1642        pos_pair.1 = pos_pair.1 + T::one();
1643      } else {
1644        let dict_min_match = get_dict_min_pair(&basepair);
1645        self.alignfold_counts_observed.match_scores[dict_min_match.0][dict_min_match.1] += 1.;
1646        if i == 0 {
1647          self.alignfold_counts_observed.init_match_score += 1.;
1648        } else if inserted || inserted2 {
1649          self.alignfold_counts_observed.match2insert_score += 1.;
1650        } else {
1651          self.alignfold_counts_observed.match2match_score += 1.;
1652        }
1653        inserted = false;
1654        inserted2 = inserted;
1655        self.alignfold.matched_pos_pairs.insert(pos_pair);
1656        pos_pair.0 = pos_pair.0 + T::one();
1657        pos_pair.1 = pos_pair.1 + T::one();
1658      }
1659    }
1660    let mut stack = Vec::new();
1661    let mut consensus_basepairs = HashSet::<(usize, usize)>::default();
1662    for (i, &x) in consensus_fold.iter().enumerate() {
1663      if x == BASEPAIR_LEFT {
1664        stack.push(i);
1665      } else if x == BASEPAIR_RIGHT {
1666        let x = stack.pop().unwrap();
1667        consensus_basepairs.insert((x, i));
1668        let y = (seq_pair.0[x], seq_pair.0[i]);
1669        if has_canonical_basepair(&y) {
1670          let y = get_dict_min_pair(&y);
1671          self.alignfold_counts_observed.basepair_scores[y.0][y.1] += 1.;
1672        }
1673        let y = (seq_pair.1[x], seq_pair.1[i]);
1674        if has_canonical_basepair(&y) {
1675          let y = get_dict_min_pair(&y);
1676          self.alignfold_counts_observed.basepair_scores[y.0][y.1] += 1.;
1677        }
1678      }
1679    }
1680    let mut loop_struct = LoopStruct::default();
1681    let mut stored_basepairs = HashSet::<(usize, usize)>::default();
1682    for substr_len in 2..align_len + 1 {
1683      for i in 0..align_len - substr_len + 1 {
1684        let mut found_basepair = false;
1685        let j = i + substr_len - 1;
1686        if consensus_basepairs.contains(&(i, j)) {
1687          found_basepair = true;
1688          consensus_basepairs.remove(&(i, j));
1689          stored_basepairs.insert((i, j));
1690        }
1691        if found_basepair {
1692          let mut loop_basepairs = Vec::new();
1693          for x in stored_basepairs.iter() {
1694            if i < x.0 && x.1 < j {
1695              loop_basepairs.push(*x);
1696            }
1697          }
1698          for x in loop_basepairs.iter() {
1699            stored_basepairs.remove(x);
1700          }
1701          loop_basepairs.sort();
1702          loop_struct.insert((i, j), loop_basepairs);
1703        }
1704      }
1705    }
1706    for (basepair_close, basepairs_loop) in loop_struct.iter() {
1707      let num_basepairs_loop = basepairs_loop.len();
1708      let basepair = (seq_pair.0[basepair_close.0], seq_pair.0[basepair_close.1]);
1709      let basepair2 = (seq_pair.1[basepair_close.0], seq_pair.1[basepair_close.1]);
1710      let closes = true;
1711      let mismatch_pair = get_mismatch_pair(&seq_pair.0[..], basepair_close, closes);
1712      let mismatch_pair2 = get_mismatch_pair(&seq_pair.1[..], basepair_close, closes);
1713      if num_basepairs_loop == 0 {
1714        let hairpin_len_pair = (
1715          get_hairpin_len(&seq_pair.0[..], basepair_close),
1716          get_hairpin_len(&seq_pair.1[..], basepair_close),
1717        );
1718        if has_canonical_basepair(&basepair) {
1719          self.alignfold_counts_observed.terminal_mismatch_scores[basepair.0][basepair.1]
1720            [mismatch_pair.0][mismatch_pair.1] += 1.;
1721          self.alignfold_counts_observed.helix_close_scores[basepair.0][basepair.1] += 1.;
1722        }
1723        if has_canonical_basepair(&basepair2) {
1724          self.alignfold_counts_observed.terminal_mismatch_scores[basepair2.0][basepair2.1]
1725            [mismatch_pair2.0][mismatch_pair2.1] += 1.;
1726          self.alignfold_counts_observed.helix_close_scores[basepair2.0][basepair2.1] += 1.;
1727        }
1728        if hairpin_len_pair.0 <= MAX_LOOP_LEN {
1729          self.alignfold_counts_observed.hairpin_scores_len[hairpin_len_pair.0] += 1.;
1730        } else {
1731          self.alignfold_counts_observed.hairpin_scores_len[MAX_LOOP_LEN] += 1.;
1732        }
1733        if hairpin_len_pair.1 <= MAX_LOOP_LEN {
1734          self.alignfold_counts_observed.hairpin_scores_len[hairpin_len_pair.1] += 1.;
1735        } else {
1736          self.alignfold_counts_observed.hairpin_scores_len[MAX_LOOP_LEN] += 1.;
1737        }
1738      } else if num_basepairs_loop == 1 {
1739        let basepair_loop = &basepairs_loop[0];
1740        let basepair3 = (seq_pair.0[basepair_loop.0], seq_pair.0[basepair_loop.1]);
1741        let basepair4 = (seq_pair.1[basepair_loop.0], seq_pair.1[basepair_loop.1]);
1742        let twoloop_len_pair = get_2loop_len_pair(&seq_pair.0[..], basepair_close, basepair_loop);
1743        let sum = twoloop_len_pair.0 + twoloop_len_pair.1;
1744        let twoloop_len_pair2 = get_2loop_len_pair(&seq_pair.1[..], basepair_close, basepair_loop);
1745        let sum2 = twoloop_len_pair2.0 + twoloop_len_pair2.1;
1746        let closes = false;
1747        let mismatch_pair3 = get_mismatch_pair(&seq_pair.0[..], basepair_loop, closes);
1748        let mismatch_pair4 = get_mismatch_pair(&seq_pair.1[..], basepair_loop, closes);
1749        if sum == 0 {
1750          if has_canonical_basepair(&basepair) && has_canonical_basepair(&basepair3) {
1751            let dict_min_stack = get_dict_min_stack(&basepair, &basepair3);
1752            self.alignfold_counts_observed.stack_scores[dict_min_stack.0 .0]
1753              [dict_min_stack.0 .1][dict_min_stack.1 .0][dict_min_stack.1 .1] += 1.;
1754          }
1755        } else {
1756          if twoloop_len_pair.0 == 0 || twoloop_len_pair.1 == 0 {
1757            if sum <= MAX_LOOP_LEN {
1758              self.alignfold_counts_observed.bulge_scores_len[sum - 1] += 1.;
1759              if sum == 1 {
1760                let mismatch = if twoloop_len_pair.0 == 0 {
1761                  mismatch_pair.1
1762                } else {
1763                  mismatch_pair.0
1764                };
1765                self.alignfold_counts_observed.bulge_scores_0x1[mismatch] += 1.;
1766              }
1767            } else {
1768              self.alignfold_counts_observed.bulge_scores_len[MAX_LOOP_LEN - 1] += 1.;
1769            }
1770          } else {
1771            let diff = get_diff(twoloop_len_pair.0, twoloop_len_pair.1);
1772            if sum <= MAX_LOOP_LEN {
1773              self.alignfold_counts_observed.interior_scores_len[sum - 2] += 1.;
1774              if diff == 0 {
1775                self.alignfold_counts_observed.interior_scores_symmetric[twoloop_len_pair.0 - 1] +=
1776                  1.;
1777              } else {
1778                self.alignfold_counts_observed.interior_scores_asymmetric[diff - 1] += 1.;
1779              }
1780              if twoloop_len_pair.0 == 1 && twoloop_len_pair.1 == 1 {
1781                let dict_min_mismatch_pair = get_dict_min_pair(&mismatch_pair);
1782                self.alignfold_counts_observed.interior_scores_1x1[dict_min_mismatch_pair.0]
1783                  [dict_min_mismatch_pair.1] += 1.;
1784              }
1785              if twoloop_len_pair.0 <= MAX_INTERIOR_EXPLICIT
1786                && twoloop_len_pair.1 <= MAX_INTERIOR_EXPLICIT
1787              {
1788                let dict_min_len_pair = get_dict_min_pair(&twoloop_len_pair);
1789                self.alignfold_counts_observed.interior_scores_explicit[dict_min_len_pair.0 - 1]
1790                  [dict_min_len_pair.1 - 1] += 1.;
1791              }
1792            } else {
1793              self.alignfold_counts_observed.interior_scores_len[MAX_LOOP_LEN - 2] += 1.;
1794              if diff == 0 {
1795                if twoloop_len_pair.0 <= MAX_INTERIOR_SYMMETRIC {
1796                  self.alignfold_counts_observed.interior_scores_symmetric
1797                    [twoloop_len_pair.0 - 1] += 1.;
1798                } else {
1799                  self.alignfold_counts_observed.interior_scores_symmetric
1800                    [MAX_INTERIOR_SYMMETRIC - 1] += 1.;
1801                }
1802              } else if diff <= MAX_INTERIOR_ASYMMETRIC {
1803                self.alignfold_counts_observed.interior_scores_asymmetric[diff - 1] += 1.;
1804              } else {
1805                self.alignfold_counts_observed.interior_scores_asymmetric
1806                  [MAX_INTERIOR_ASYMMETRIC - 1] += 1.;
1807              }
1808            }
1809          }
1810          if has_canonical_basepair(&basepair) {
1811            self.alignfold_counts_observed.terminal_mismatch_scores[basepair.0][basepair.1]
1812              [mismatch_pair.0][mismatch_pair.1] += 1.;
1813            self.alignfold_counts_observed.helix_close_scores[basepair.0][basepair.1] += 1.;
1814          }
1815          if has_canonical_basepair(&basepair3) {
1816            self.alignfold_counts_observed.terminal_mismatch_scores[basepair3.1][basepair3.0]
1817              [mismatch_pair3.1][mismatch_pair3.0] += 1.;
1818            self.alignfold_counts_observed.helix_close_scores[basepair3.1][basepair3.0] += 1.;
1819          }
1820        }
1821        if sum2 == 0 {
1822          if has_canonical_basepair(&basepair2) && has_canonical_basepair(&basepair4) {
1823            let dict_min_stack2 = get_dict_min_stack(&basepair2, &basepair4);
1824            self.alignfold_counts_observed.stack_scores[dict_min_stack2.0 .0]
1825              [dict_min_stack2.0 .1][dict_min_stack2.1 .0][dict_min_stack2.1 .1] += 1.;
1826          }
1827        } else {
1828          if twoloop_len_pair2.0 == 0 || twoloop_len_pair2.1 == 0 {
1829            if sum2 <= MAX_LOOP_LEN {
1830              self.alignfold_counts_observed.bulge_scores_len[sum2 - 1] += 1.;
1831              if sum2 == 1 {
1832                let mismatch2 = if twoloop_len_pair2.0 == 0 {
1833                  mismatch_pair2.1
1834                } else {
1835                  mismatch_pair2.0
1836                };
1837                self.alignfold_counts_observed.bulge_scores_0x1[mismatch2] += 1.;
1838              }
1839            } else {
1840              self.alignfold_counts_observed.bulge_scores_len[MAX_LOOP_LEN - 1] += 1.;
1841            }
1842          } else {
1843            let diff2 = get_diff(twoloop_len_pair2.0, twoloop_len_pair2.1);
1844            if sum2 <= MAX_LOOP_LEN {
1845              self.alignfold_counts_observed.interior_scores_len[sum2 - 2] += 1.;
1846              if diff2 == 0 {
1847                self.alignfold_counts_observed.interior_scores_symmetric
1848                  [twoloop_len_pair2.0 - 1] += 1.;
1849              } else {
1850                self.alignfold_counts_observed.interior_scores_asymmetric[diff2 - 1] += 1.;
1851              }
1852              if twoloop_len_pair2.0 == 1 && twoloop_len_pair2.1 == 1 {
1853                let dict_min_mismatch_pair2 = get_dict_min_pair(&mismatch_pair2);
1854                self.alignfold_counts_observed.interior_scores_1x1[dict_min_mismatch_pair2.0]
1855                  [dict_min_mismatch_pair2.1] += 1.;
1856              }
1857              if twoloop_len_pair2.0 <= MAX_INTERIOR_EXPLICIT
1858                && twoloop_len_pair2.1 <= MAX_INTERIOR_EXPLICIT
1859              {
1860                let dict_min_len_pair2 = get_dict_min_pair(&twoloop_len_pair2);
1861                self.alignfold_counts_observed.interior_scores_explicit
1862                  [dict_min_len_pair2.0 - 1][dict_min_len_pair2.1 - 1] += 1.;
1863              }
1864            } else {
1865              self.alignfold_counts_observed.interior_scores_len[MAX_LOOP_LEN - 2] += 1.;
1866              if diff2 == 0 {
1867                if twoloop_len_pair2.0 <= MAX_INTERIOR_SYMMETRIC {
1868                  self.alignfold_counts_observed.interior_scores_symmetric
1869                    [twoloop_len_pair2.0 - 1] += 1.;
1870                } else {
1871                  self.alignfold_counts_observed.interior_scores_symmetric
1872                    [MAX_INTERIOR_SYMMETRIC - 1] += 1.;
1873                }
1874              } else if diff2 <= MAX_INTERIOR_ASYMMETRIC {
1875                self.alignfold_counts_observed.interior_scores_asymmetric[diff2 - 1] += 1.;
1876              } else {
1877                self.alignfold_counts_observed.interior_scores_asymmetric
1878                  [MAX_INTERIOR_ASYMMETRIC - 1] += 1.;
1879              }
1880            }
1881          }
1882          if has_canonical_basepair(&basepair2) {
1883            self.alignfold_counts_observed.terminal_mismatch_scores[basepair2.0][basepair2.1]
1884              [mismatch_pair2.0][mismatch_pair2.1] += 1.;
1885            self.alignfold_counts_observed.helix_close_scores[basepair2.0][basepair2.1] += 1.;
1886          }
1887          if has_canonical_basepair(&basepair4) {
1888            self.alignfold_counts_observed.terminal_mismatch_scores[basepair4.1][basepair4.0]
1889              [mismatch_pair4.1][mismatch_pair4.0] += 1.;
1890            self.alignfold_counts_observed.helix_close_scores[basepair4.1][basepair4.0] += 1.;
1891          }
1892        }
1893      } else {
1894        if has_canonical_basepair(&basepair) {
1895          self.alignfold_counts_observed.dangling_scores_left[basepair.0][basepair.1]
1896            [mismatch_pair.0] += 1.;
1897          self.alignfold_counts_observed.dangling_scores_right[basepair.0][basepair.1]
1898            [mismatch_pair.1] += 1.;
1899          self.alignfold_counts_observed.helix_close_scores[basepair.0][basepair.1] += 1.;
1900        }
1901        if has_canonical_basepair(&basepair2) {
1902          self.alignfold_counts_observed.dangling_scores_left[basepair2.0][basepair2.1]
1903            [mismatch_pair2.0] += 1.;
1904          self.alignfold_counts_observed.dangling_scores_right[basepair2.0][basepair2.1]
1905            [mismatch_pair2.1] += 1.;
1906          self.alignfold_counts_observed.helix_close_scores[basepair2.0][basepair2.1] += 1.;
1907        }
1908        self.alignfold_counts_observed.multibranch_score_base += 2.;
1909        self.alignfold_counts_observed.multibranch_score_basepair += 2.;
1910        self.alignfold_counts_observed.multibranch_score_basepair +=
1911          2. * num_basepairs_loop as Prob;
1912        let num_unpairs_multibranch =
1913          get_num_unpairs_multibranch(basepair_close, basepairs_loop, &seq_pair.0[..]);
1914        self.alignfold_counts_observed.multibranch_score_unpair += num_unpairs_multibranch as Prob;
1915        let num_unpairs_multibranch2 =
1916          get_num_unpairs_multibranch(basepair_close, basepairs_loop, &seq_pair.1[..]);
1917        self.alignfold_counts_observed.multibranch_score_unpair += num_unpairs_multibranch2 as Prob;
1918        for basepair_loop in basepairs_loop.iter() {
1919          let basepair3 = (seq_pair.0[basepair_loop.0], seq_pair.0[basepair_loop.1]);
1920          let closes = false;
1921          let mismatch_pair3 = get_mismatch_pair(&seq_pair.0[..], basepair_loop, closes);
1922          let basepair4 = (seq_pair.1[basepair_loop.0], seq_pair.1[basepair_loop.1]);
1923          let mismatch_pair4 = get_mismatch_pair(&seq_pair.1[..], basepair_loop, closes);
1924          if has_canonical_basepair(&basepair3) {
1925            self.alignfold_counts_observed.dangling_scores_left[basepair3.1][basepair3.0]
1926              [mismatch_pair3.1] += 1.;
1927            self.alignfold_counts_observed.dangling_scores_right[basepair3.1][basepair3.0]
1928              [mismatch_pair3.0] += 1.;
1929            self.alignfold_counts_observed.helix_close_scores[basepair3.1][basepair3.0] += 1.;
1930          }
1931          if has_canonical_basepair(&basepair4) {
1932            self.alignfold_counts_observed.dangling_scores_left[basepair4.1][basepair4.0]
1933              [mismatch_pair4.1] += 1.;
1934            self.alignfold_counts_observed.dangling_scores_right[basepair4.1][basepair4.0]
1935              [mismatch_pair4.0] += 1.;
1936            self.alignfold_counts_observed.helix_close_scores[basepair4.1][basepair4.0] += 1.;
1937          }
1938        }
1939      }
1940    }
1941    self.alignfold_counts_observed.external_score_basepair += 2. * stored_basepairs.len() as Prob;
1942    let mut stored_basepairs_sorted = stored_basepairs
1943      .iter()
1944      .copied()
1945      .collect::<Vec<(usize, usize)>>();
1946    stored_basepairs_sorted.sort();
1947    let num_unpairs_external = get_num_unpairs_external(&stored_basepairs_sorted, &seq_pair.0[..]);
1948    self.alignfold_counts_observed.external_score_unpair += num_unpairs_external as Prob;
1949    let num_unpairs_external2 = get_num_unpairs_external(&stored_basepairs_sorted, &seq_pair.1[..]);
1950    self.alignfold_counts_observed.external_score_unpair += num_unpairs_external2 as Prob;
1951    for basepair_loop in stored_basepairs.iter() {
1952      let basepair = (seq_pair.0[basepair_loop.0], seq_pair.0[basepair_loop.1]);
1953      let basepair2 = (seq_pair.1[basepair_loop.0], seq_pair.1[basepair_loop.1]);
1954      let closes = false;
1955      let mismatch_pair = get_mismatch_pair(&seq_pair.0[..], basepair_loop, closes);
1956      let mismatch_pair2 = get_mismatch_pair(&seq_pair.1[..], basepair_loop, closes);
1957      if has_canonical_basepair(&basepair) {
1958        if mismatch_pair.1 != PSEUDO_BASE {
1959          self.alignfold_counts_observed.dangling_scores_left[basepair.1][basepair.0]
1960            [mismatch_pair.1] += 1.;
1961        }
1962        if mismatch_pair.0 != PSEUDO_BASE {
1963          self.alignfold_counts_observed.dangling_scores_right[basepair.1][basepair.0]
1964            [mismatch_pair.0] += 1.;
1965        }
1966        self.alignfold_counts_observed.helix_close_scores[basepair.1][basepair.0] += 1.;
1967      }
1968      if has_canonical_basepair(&basepair2) {
1969        if mismatch_pair2.1 != PSEUDO_BASE {
1970          self.alignfold_counts_observed.dangling_scores_left[basepair2.1][basepair2.0]
1971            [mismatch_pair2.1] += 1.;
1972        }
1973        if mismatch_pair2.0 != PSEUDO_BASE {
1974          self.alignfold_counts_observed.dangling_scores_right[basepair2.1][basepair2.0]
1975            [mismatch_pair2.0] += 1.;
1976        }
1977        self.alignfold_counts_observed.helix_close_scores[basepair2.1][basepair2.0] += 1.;
1978      }
1979    }
1980  }
1981
1982  pub fn set_curr_scores(&mut self, alignfold_scores: &AlignfoldScores) {
1983    self.fold_scores_pair.0 = FoldScoresTrained::<T>::set_curr_scores(
1984      alignfold_scores,
1985      &self.seq_pair.0,
1986      &self.basepair_probs_pair.0,
1987    );
1988    self.fold_scores_pair.1 = FoldScoresTrained::<T>::set_curr_scores(
1989      alignfold_scores,
1990      &self.seq_pair.1,
1991      &self.basepair_probs_pair.1,
1992    );
1993  }
1994}
1995
1996impl RangeInsertScores {
1997  pub fn origin() -> RangeInsertScores {
1998    let x = Vec::new();
1999    RangeInsertScores {
2000      insert_scores: x.clone(),
2001      insert_scores_external: x.clone(),
2002      insert_scores_multibranch: x.clone(),
2003      insert_scores2: x.clone(),
2004      insert_scores_external2: x.clone(),
2005      insert_scores_multibranch2: x,
2006    }
2007  }
2008
2009  pub fn new(seq_pair: &SeqPair, alignfold_scores: &AlignfoldScores) -> RangeInsertScores {
2010    let seq_len_pair = (seq_pair.0.len(), seq_pair.1.len());
2011    let mut range_insert_scores = RangeInsertScores::origin();
2012    let neg_infs = vec![
2013      vec![NEG_INFINITY; seq_len_pair.0.to_usize().unwrap()];
2014      seq_len_pair.0.to_usize().unwrap()
2015    ];
2016    range_insert_scores.insert_scores = neg_infs.clone();
2017    range_insert_scores.insert_scores_external = neg_infs.clone();
2018    range_insert_scores.insert_scores_multibranch = neg_infs;
2019    let neg_infs = vec![
2020      vec![NEG_INFINITY; seq_len_pair.1.to_usize().unwrap()];
2021      seq_len_pair.1.to_usize().unwrap()
2022    ];
2023    range_insert_scores.insert_scores2 = neg_infs.clone();
2024    range_insert_scores.insert_scores_external2 = neg_infs.clone();
2025    range_insert_scores.insert_scores_multibranch2 = neg_infs;
2026    for i in 1..seq_len_pair.1 - 1 {
2027      let base = seq_pair.1[i];
2028      let mut sum = alignfold_scores.insert_scores[base];
2029      let mut sum_external = sum + alignfold_scores.external_score_unpair;
2030      let mut sum_multibranch = sum + alignfold_scores.multibranch_score_unpair;
2031      range_insert_scores.insert_scores2[i][i] = sum;
2032      range_insert_scores.insert_scores_external2[i][i] = sum_external;
2033      range_insert_scores.insert_scores_multibranch2[i][i] = sum_multibranch;
2034      for j in i + 1..seq_len_pair.1 - 1 {
2035        let x = seq_pair.1[j];
2036        let x = alignfold_scores.insert_scores[x] + alignfold_scores.insert_extend_score;
2037        sum += x;
2038        sum_external += x + alignfold_scores.external_score_unpair;
2039        sum_multibranch += x + alignfold_scores.multibranch_score_unpair;
2040        range_insert_scores.insert_scores2[i][j] = sum;
2041        range_insert_scores.insert_scores_external2[i][j] = sum_external;
2042        range_insert_scores.insert_scores_multibranch2[i][j] = sum_multibranch;
2043      }
2044    }
2045    for i in 1..seq_len_pair.0 - 1 {
2046      let base = seq_pair.0[i];
2047      let term = alignfold_scores.insert_scores[base];
2048      let mut sum = term;
2049      let mut sum_external = sum + alignfold_scores.external_score_unpair;
2050      let mut sum_multibranch = sum + alignfold_scores.multibranch_score_unpair;
2051      range_insert_scores.insert_scores[i][i] = sum;
2052      range_insert_scores.insert_scores_external[i][i] = sum_external;
2053      range_insert_scores.insert_scores_multibranch[i][i] = sum_multibranch;
2054      for j in i + 1..seq_len_pair.0 - 1 {
2055        let x = seq_pair.0[j];
2056        let x = alignfold_scores.insert_scores[x] + alignfold_scores.insert_extend_score;
2057        sum += x;
2058        sum_external += x + alignfold_scores.external_score_unpair;
2059        sum_multibranch += x + alignfold_scores.multibranch_score_unpair;
2060        range_insert_scores.insert_scores[i][j] = sum;
2061        range_insert_scores.insert_scores_external[i][j] = sum_external;
2062        range_insert_scores.insert_scores_multibranch[i][j] = sum_multibranch;
2063      }
2064    }
2065    range_insert_scores
2066  }
2067}
2068
2069impl<T: HashIndex> Default for PairAlignfold<T> {
2070  fn default() -> Self {
2071    Self::new()
2072  }
2073}
2074
2075impl<T: HashIndex> PairAlignfold<T> {
2076  pub fn new() -> PairAlignfold<T> {
2077    PairAlignfold {
2078      matched_pos_pairs: SparsePosMat::<T>::default(),
2079      inserted_poss: SparsePoss::<T>::default(),
2080      deleted_poss: SparsePoss::<T>::default(),
2081    }
2082  }
2083}
2084
2085pub const DEFAULT_BASEPAIR_PROB_TRAIN: Prob = DEFAULT_MIN_BASEPAIR_PROB;
2086pub const DEFAULT_MATCH_PROB_TRAIN: Prob = DEFAULT_MIN_MATCH_PROB;
2087pub const NUM_BASEPAIRS: usize = 6;
2088pub const GROUP_SIZE_MULTIBRANCH: usize = 3;
2089pub const GROUP_SIZE_EXTERNAL: usize = 2;
2090pub const GROUP_SIZE_MATCH_TRANSITION: usize = 3;
2091pub const GROUP_SIZE_INSERT_TRANSITION: usize = 2;
2092pub const GAMMA_DISTRO_ALPHA: Score = 0.;
2093pub const GAMMA_DISTRO_BETA: Score = 1.;
2094pub const DEFAULT_LEARNING_TOLERANCE: Score = 0.000_1;
2095pub const TRAINED_SCORES_FILE: &str = "../src/trained_alignfold_scores.rs";
2096pub const TRAINED_SCORES_FILE_RANDINIT: &str = "../src/trained_alignfold_scores_randinit.rs";
2097#[derive(Clone, Copy)]
2098pub enum TrainType {
2099  TrainedTransfer,
2100  TrainedRandinit,
2101  TransferredOnly,
2102}
2103pub const DEFAULT_TRAIN_TYPE: &str = "trained_transfer";
2104
2105pub fn get_accuracy_expected<T>(
2106  seq_pair: &SeqPair,
2107  alignfold: &PairAlignfold<T>,
2108  match_probs: &SparseProbMat<T>,
2109) -> Score
2110where
2111  T: HashIndex,
2112{
2113  let seq_len_pair = (seq_pair.0.len(), seq_pair.1.len());
2114  let mut insert_probs_pair = (vec![1.; seq_len_pair.0], vec![1.; seq_len_pair.1]);
2115  for (x, &y) in match_probs {
2116    let x = (x.0.to_usize().unwrap(), x.1.to_usize().unwrap());
2117    insert_probs_pair.0[x.0] -= y;
2118    insert_probs_pair.1[x.1] -= y;
2119  }
2120  let total = alignfold.matched_pos_pairs.len()
2121    + alignfold.inserted_poss.len()
2122    + alignfold.deleted_poss.len();
2123  let mut total_expected = alignfold
2124    .matched_pos_pairs
2125    .iter()
2126    .map(|x| match match_probs.get(x) {
2127      Some(&x) => x,
2128      None => 0.,
2129    })
2130    .sum::<Score>();
2131  total_expected += alignfold
2132    .inserted_poss
2133    .iter()
2134    .map(|&x| insert_probs_pair.0[x.to_usize().unwrap()])
2135    .sum::<Score>();
2136  total_expected += alignfold
2137    .deleted_poss
2138    .iter()
2139    .map(|&x| insert_probs_pair.1[x.to_usize().unwrap()])
2140    .sum::<Score>();
2141  total_expected / total as Score
2142}
2143
2144pub fn consprob_core<T>(inputs: InputsConsprobCore<T>) -> (AlignfoldProbMats<T>, Sum)
2145where
2146  T: HashIndex,
2147{
2148  let (
2149    seq_pair,
2150    alignfold_scores,
2151    max_basepair_span_pair,
2152    match_probs,
2153    produces_struct_profs,
2154    trains_alignfold_scores,
2155    alignfold_counts_expected,
2156    forward_pos_pairs,
2157    backward_pos_pairs,
2158    pos_quads_hashed_lens,
2159    fold_scores_pair,
2160    produces_match_probs,
2161    matchable_poss,
2162    matchable_poss2,
2163  ) = inputs;
2164  let range_insert_scores = RangeInsertScores::new(seq_pair, alignfold_scores);
2165  let (alignfold_sums, global_sum) = get_alignfold_sums::<T>((
2166    seq_pair,
2167    alignfold_scores,
2168    max_basepair_span_pair,
2169    match_probs,
2170    trains_alignfold_scores,
2171    forward_pos_pairs,
2172    backward_pos_pairs,
2173    pos_quads_hashed_lens,
2174    fold_scores_pair,
2175    &range_insert_scores,
2176    matchable_poss,
2177    matchable_poss2,
2178  ));
2179  (
2180    get_alignfold_probs::<T>((
2181      seq_pair,
2182      alignfold_scores,
2183      max_basepair_span_pair,
2184      match_probs,
2185      &alignfold_sums,
2186      produces_struct_profs,
2187      global_sum,
2188      trains_alignfold_scores,
2189      alignfold_counts_expected,
2190      pos_quads_hashed_lens,
2191      fold_scores_pair,
2192      produces_match_probs,
2193      forward_pos_pairs,
2194      backward_pos_pairs,
2195      &range_insert_scores,
2196      matchable_poss,
2197      matchable_poss2,
2198    )),
2199    global_sum,
2200  )
2201}
2202
2203pub fn get_alignfold_sums<T>(inputs: InputsInsideSumsGetter<T>) -> (AlignfoldSums<T>, Sum)
2204where
2205  T: HashIndex,
2206{
2207  let (
2208    seq_pair,
2209    alignfold_scores,
2210    max_basepair_span_pair,
2211    match_probs,
2212    trains_alignfold_scores,
2213    forward_pos_pairs,
2214    backward_pos_pairs,
2215    pos_quads_hashed_lens,
2216    fold_scores_pair,
2217    range_insert_scores,
2218    matchable_poss,
2219    matchable_poss2,
2220  ) = inputs;
2221  let seq_len_pair = (
2222    T::from_usize(seq_pair.0.len()).unwrap(),
2223    T::from_usize(seq_pair.1.len()).unwrap(),
2224  );
2225  let mut alignfold_sums = AlignfoldSums::<T>::new();
2226  for substr_len in range_inclusive(
2227    T::from_usize(if trains_alignfold_scores {
2228      2
2229    } else {
2230      MIN_SPAN_HAIRPIN_CLOSE
2231    })
2232    .unwrap(),
2233    max_basepair_span_pair.0,
2234  ) {
2235    for substr_len2 in range_inclusive(
2236      T::from_usize(if trains_alignfold_scores {
2237        2
2238      } else {
2239        MIN_SPAN_HAIRPIN_CLOSE
2240      })
2241      .unwrap(),
2242      max_basepair_span_pair.1,
2243    ) {
2244      if let Some(pos_pairs) = pos_quads_hashed_lens.get(&(substr_len, substr_len2)) {
2245        for &(i, k) in pos_pairs {
2246          let (j, l) = (i + substr_len - T::one(), k + substr_len2 - T::one());
2247          let (long_i, long_j, long_k, long_l) = (
2248            i.to_usize().unwrap(),
2249            j.to_usize().unwrap(),
2250            k.to_usize().unwrap(),
2251            l.to_usize().unwrap(),
2252          );
2253          let basepair = (seq_pair.0[long_i], seq_pair.0[long_j]);
2254          let basepair2 = (seq_pair.1[long_k], seq_pair.1[long_l]);
2255          let pos_quad = (i, j, k, l);
2256          let computes_forward_sums = true;
2257          let (sum_seqalign, sum_multibranch) = get_loop_sums::<T>((
2258            seq_pair,
2259            alignfold_scores,
2260            match_probs,
2261            &pos_quad,
2262            &mut alignfold_sums,
2263            computes_forward_sums,
2264            forward_pos_pairs,
2265            range_insert_scores,
2266            matchable_poss,
2267            matchable_poss2,
2268          ));
2269          let computes_forward_sums = false;
2270          let _ = get_loop_sums::<T>((
2271            seq_pair,
2272            alignfold_scores,
2273            match_probs,
2274            &pos_quad,
2275            &mut alignfold_sums,
2276            computes_forward_sums,
2277            backward_pos_pairs,
2278            range_insert_scores,
2279            matchable_poss,
2280            matchable_poss2,
2281          ));
2282          let mut sum = NEG_INFINITY;
2283          let pairmatch_score = alignfold_scores.match_scores[basepair.0][basepair2.0]
2284            + alignfold_scores.match_scores[basepair.1][basepair2.1];
2285          if substr_len.to_usize().unwrap() - 2 <= MAX_LOOP_LEN
2286            && substr_len2.to_usize().unwrap() - 2 <= MAX_LOOP_LEN
2287          {
2288            let hairpin_score = fold_scores_pair.0.hairpin_scores[&(i, j)];
2289            let hairpin_score2 = fold_scores_pair.1.hairpin_scores[&(k, l)];
2290            let score = hairpin_score + hairpin_score2 + sum_seqalign;
2291            logsumexp(&mut sum, score);
2292          }
2293          let forward_sums = &alignfold_sums.forward_sums_hashed_poss2[&(i, k)];
2294          let backward_sums = &alignfold_sums.backward_sums_hashed_poss2[&(j, l)];
2295          let min = T::from_usize(if trains_alignfold_scores {
2296            2
2297          } else {
2298            MIN_HAIRPIN_LEN
2299          })
2300          .unwrap();
2301          let min_len_pair = (
2302            if substr_len <= min + T::from_usize(MAX_LOOP_LEN + 2).unwrap() {
2303              min
2304            } else {
2305              substr_len - T::from_usize(MAX_LOOP_LEN + 2).unwrap()
2306            },
2307            if substr_len2 <= min + T::from_usize(MAX_LOOP_LEN + 2).unwrap() {
2308              min
2309            } else {
2310              substr_len2 - T::from_usize(MAX_LOOP_LEN + 2).unwrap()
2311            },
2312          );
2313          for substr_len3 in range(min_len_pair.0, substr_len - T::one()) {
2314            for substr_len4 in range(min_len_pair.1, substr_len2 - T::one()) {
2315              if let Some(pos_pairs2) = pos_quads_hashed_lens.get(&(substr_len3, substr_len4)) {
2316                for &(m, o) in pos_pairs2 {
2317                  let (n, p) = (m + substr_len3 - T::one(), o + substr_len4 - T::one());
2318                  if !(i < m && n < j && k < o && p < l) {
2319                    continue;
2320                  }
2321                  if m - i - T::one() + j - n - T::one() > T::from_usize(MAX_LOOP_LEN).unwrap() {
2322                    continue;
2323                  }
2324                  if o - k - T::one() + l - p - T::one() > T::from_usize(MAX_LOOP_LEN).unwrap() {
2325                    continue;
2326                  }
2327                  let pos_quad2 = (m, n, o, p);
2328                  if let Some(&x) = alignfold_sums.sums_close.get(&pos_quad2) {
2329                    let mut forward_term = NEG_INFINITY;
2330                    let mut backward_term = forward_term;
2331                    let pos_pair2 = (m - T::one(), o - T::one());
2332                    if let Some(x) = forward_sums.get(&pos_pair2) {
2333                      logsumexp(&mut forward_term, x.sum_seqalign);
2334                    }
2335                    let pos_pair2 = (n + T::one(), p + T::one());
2336                    if let Some(x) = backward_sums.get(&pos_pair2) {
2337                      logsumexp(&mut backward_term, x.sum_seqalign);
2338                    }
2339                    let twoloop_score = fold_scores_pair.0.twoloop_scores[&(i, j, m, n)];
2340                    let twoloop_score2 = fold_scores_pair.1.twoloop_scores[&(k, l, o, p)];
2341                    let x = twoloop_score + twoloop_score2 + x + forward_term + backward_term;
2342                    logsumexp(&mut sum, x);
2343                  }
2344                }
2345              }
2346            }
2347          }
2348          let multibranch_close_score = fold_scores_pair.0.multibranch_close_scores[&(i, j)];
2349          let multibranch_close_score2 = fold_scores_pair.1.multibranch_close_scores[&(k, l)];
2350          let score = multibranch_close_score + multibranch_close_score2 + sum_multibranch;
2351          logsumexp(&mut sum, score);
2352          if sum > NEG_INFINITY {
2353            let sum = sum + pairmatch_score;
2354            alignfold_sums.sums_close.insert(pos_quad, sum);
2355            let external_accessible_score = fold_scores_pair.0.external_accessible_scores[&(i, j)];
2356            let external_accessible_score2 = fold_scores_pair.1.external_accessible_scores[&(k, l)];
2357            alignfold_sums.sums_accessible_external.insert(
2358              pos_quad,
2359              sum + external_accessible_score + external_accessible_score2,
2360            );
2361            let multibranch_accessible_score =
2362              fold_scores_pair.0.multibranch_accessible_scores[&(i, j)];
2363            let multibranch_accessible_score2 =
2364              fold_scores_pair.1.multibranch_accessible_scores[&(k, l)];
2365            alignfold_sums.sums_accessible_multibranch.insert(
2366              pos_quad,
2367              sum + multibranch_accessible_score + multibranch_accessible_score2,
2368            );
2369          }
2370        }
2371      }
2372    }
2373  }
2374  let leftmost_pos_pair = (T::zero(), T::zero());
2375  let rightmost_pos_pair = (seq_len_pair.0 - T::one(), seq_len_pair.1 - T::one());
2376  alignfold_sums
2377    .forward_sums_external
2378    .insert(leftmost_pos_pair, 0.);
2379  alignfold_sums
2380    .backward_sums_external
2381    .insert(rightmost_pos_pair, 0.);
2382  for i in range(T::zero(), seq_len_pair.0 - T::one()) {
2383    let long_i = i.to_usize().unwrap();
2384    let base = seq_pair.0[long_i];
2385    for j in range(T::zero(), seq_len_pair.1 - T::one()) {
2386      let pos_pair = (i, j);
2387      if pos_pair == (T::zero(), T::zero()) {
2388        continue;
2389      }
2390      let long_j = j.to_usize().unwrap();
2391      let mut sum = NEG_INFINITY;
2392      if let Some(x) = forward_pos_pairs.get(&pos_pair) {
2393        for &(k, l) in x {
2394          let pos_pair2 = (k - T::one(), l - T::one());
2395          let pos_quad = (k, i, l, j);
2396          if let Some(&x) = alignfold_sums.sums_accessible_external.get(&pos_quad) {
2397            if let Some(&y) = alignfold_sums.forward_sums_external2.get(&pos_pair2) {
2398              let y = x + y;
2399              logsumexp(&mut sum, y);
2400            }
2401          }
2402        }
2403      }
2404      let base2 = seq_pair.1[long_j];
2405      if i > T::zero() && j > T::zero() && match_probs.contains_key(&pos_pair) {
2406        let mut sum2 = NEG_INFINITY;
2407        let loopmatch_score =
2408          alignfold_scores.match_scores[base][base2] + 2. * alignfold_scores.external_score_unpair;
2409        let pos_pair2 = (i - T::one(), j - T::one());
2410        let long_pos_pair2 = (
2411          pos_pair2.0.to_usize().unwrap(),
2412          pos_pair2.1.to_usize().unwrap(),
2413        );
2414        let begins_sum = pos_pair2 == leftmost_pos_pair;
2415        if let Some(&x) = alignfold_sums.forward_sums_external.get(&pos_pair2) {
2416          let x = x
2417            + if begins_sum {
2418              alignfold_scores.init_match_score
2419            } else {
2420              alignfold_scores.match2match_score
2421            };
2422          logsumexp(&mut sum2, x);
2423        }
2424        if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2425          for &x in x {
2426            if x >= pos_pair2.1 {
2427              continue;
2428            }
2429            let pos_pair3 = (pos_pair2.0, x);
2430            if let Some(&y) = alignfold_sums.forward_sums_external.get(&pos_pair3) {
2431              let long_x = x.to_usize().unwrap();
2432              let begins_sum = pos_pair3 == leftmost_pos_pair;
2433              let z = range_insert_scores.insert_scores_external2[long_x + 1][long_pos_pair2.1]
2434                + if begins_sum {
2435                  alignfold_scores.init_insert_score
2436                } else {
2437                  alignfold_scores.match2insert_score
2438                };
2439              let z = y + z + alignfold_scores.match2insert_score;
2440              logsumexp(&mut sum2, z);
2441            }
2442          }
2443        }
2444        if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2445          for &x in x {
2446            if x >= pos_pair2.0 {
2447              continue;
2448            }
2449            let pos_pair3 = (x, pos_pair2.1);
2450            if let Some(&y) = alignfold_sums.forward_sums_external.get(&pos_pair3) {
2451              let long_x = x.to_usize().unwrap();
2452              let begins_sum = pos_pair3 == leftmost_pos_pair;
2453              let z = range_insert_scores.insert_scores_external[long_x + 1][long_pos_pair2.0]
2454                + if begins_sum {
2455                  alignfold_scores.init_insert_score
2456                } else {
2457                  alignfold_scores.match2insert_score
2458                };
2459              let z = y + z + alignfold_scores.match2insert_score;
2460              logsumexp(&mut sum2, z);
2461            }
2462          }
2463        }
2464        if sum2 > NEG_INFINITY {
2465          alignfold_sums
2466            .forward_sums_external2
2467            .insert(pos_pair2, sum2);
2468        }
2469        let term = sum2 + loopmatch_score;
2470        logsumexp(&mut sum, term);
2471        if sum > NEG_INFINITY {
2472          alignfold_sums.forward_sums_external.insert(pos_pair, sum);
2473        }
2474      }
2475    }
2476  }
2477  let mut global_sum = NEG_INFINITY;
2478  let pos_pair2 = (
2479    rightmost_pos_pair.0 - T::one(),
2480    rightmost_pos_pair.1 - T::one(),
2481  );
2482  let long_pos_pair2 = (
2483    pos_pair2.0.to_usize().unwrap(),
2484    pos_pair2.1.to_usize().unwrap(),
2485  );
2486  if let Some(&x) = alignfold_sums.forward_sums_external.get(&pos_pair2) {
2487    logsumexp(&mut global_sum, x);
2488  }
2489  if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2490    for &x in x {
2491      if x >= pos_pair2.1 {
2492        continue;
2493      }
2494      if let Some(&y) = alignfold_sums.forward_sums_external.get(&(pos_pair2.0, x)) {
2495        let long_x = x.to_usize().unwrap();
2496        let z = range_insert_scores.insert_scores_external2[long_x + 1][long_pos_pair2.1]
2497          + alignfold_scores.match2insert_score;
2498        let z = y + z;
2499        logsumexp(&mut global_sum, z);
2500      }
2501    }
2502  }
2503  if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2504    for &x in x {
2505      if x >= pos_pair2.0 {
2506        continue;
2507      }
2508      if let Some(&y) = alignfold_sums.forward_sums_external.get(&(x, pos_pair2.1)) {
2509        let long_x = x.to_usize().unwrap();
2510        let z = range_insert_scores.insert_scores_external[long_x + 1][long_pos_pair2.0]
2511          + alignfold_scores.match2insert_score;
2512        let z = y + z;
2513        logsumexp(&mut global_sum, z);
2514      }
2515    }
2516  }
2517  for i in range(T::one(), seq_len_pair.0).rev() {
2518    let long_i = i.to_usize().unwrap();
2519    let base = seq_pair.0[long_i];
2520    for j in range(T::one(), seq_len_pair.1).rev() {
2521      let pos_pair = (i, j);
2522      if pos_pair == (seq_len_pair.0 - T::one(), seq_len_pair.1 - T::one()) {
2523        continue;
2524      }
2525      let long_j = j.to_usize().unwrap();
2526      let mut sum = NEG_INFINITY;
2527      if let Some(x) = backward_pos_pairs.get(&pos_pair) {
2528        for &(k, l) in x {
2529          let pos_pair2 = (k + T::one(), l + T::one());
2530          let pos_quad = (i, k, j, l);
2531          if let Some(&x) = alignfold_sums.sums_accessible_external.get(&pos_quad) {
2532            if let Some(&y) = alignfold_sums.backward_sums_external2.get(&pos_pair2) {
2533              let y = x + y;
2534              logsumexp(&mut sum, y);
2535            }
2536          }
2537        }
2538      }
2539      let base2 = seq_pair.1[long_j];
2540      if i < seq_len_pair.0 - T::one() && j < seq_len_pair.1 - T::one() {
2541        let pos_pair2 = (i + T::one(), j + T::one());
2542        let long_pos_pair2 = (
2543          pos_pair2.0.to_usize().unwrap(),
2544          pos_pair2.1.to_usize().unwrap(),
2545        );
2546        let ends_sum = pos_pair2 == rightmost_pos_pair;
2547        if match_probs.contains_key(&pos_pair) {
2548          let mut sum2 = NEG_INFINITY;
2549          let loopmatch_score = alignfold_scores.match_scores[base][base2]
2550            + 2. * alignfold_scores.external_score_unpair;
2551          if let Some(&x) = alignfold_sums.backward_sums_external.get(&pos_pair2) {
2552            let x = x
2553              + if ends_sum {
2554                0.
2555              } else {
2556                alignfold_scores.match2match_score
2557              };
2558            logsumexp(&mut sum2, x);
2559          }
2560          if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2561            for &x in x {
2562              if x <= pos_pair2.1 {
2563                continue;
2564              }
2565              let pos_pair3 = (pos_pair2.0, x);
2566              if let Some(&y) = alignfold_sums.backward_sums_external.get(&pos_pair3) {
2567                let long_x = x.to_usize().unwrap();
2568                let ends_sum = pos_pair3 == rightmost_pos_pair;
2569                let z = range_insert_scores.insert_scores_external2[long_pos_pair2.1][long_x - 1]
2570                  + if ends_sum {
2571                    0.
2572                  } else {
2573                    alignfold_scores.match2insert_score
2574                  };
2575                let z = y + z + alignfold_scores.match2insert_score;
2576                logsumexp(&mut sum2, z);
2577              }
2578            }
2579          }
2580          if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2581            for &x in x {
2582              if x <= pos_pair2.0 {
2583                continue;
2584              }
2585              let pos_pair3 = (x, pos_pair2.1);
2586              if let Some(&y) = alignfold_sums.backward_sums_external.get(&pos_pair3) {
2587                let long_x = x.to_usize().unwrap();
2588                let ends_sum = pos_pair3 == rightmost_pos_pair;
2589                let z = range_insert_scores.insert_scores_external[long_pos_pair2.0][long_x - 1]
2590                  + if ends_sum {
2591                    0.
2592                  } else {
2593                    alignfold_scores.match2insert_score
2594                  };
2595                let z = y + z + alignfold_scores.match2insert_score;
2596                logsumexp(&mut sum2, z);
2597              }
2598            }
2599          }
2600          if sum2 > NEG_INFINITY {
2601            alignfold_sums
2602              .backward_sums_external2
2603              .insert(pos_pair2, sum2);
2604          }
2605          let term = sum2 + loopmatch_score;
2606          logsumexp(&mut sum, term);
2607          if sum > NEG_INFINITY {
2608            alignfold_sums.backward_sums_external.insert(pos_pair, sum);
2609          }
2610        }
2611      }
2612    }
2613  }
2614  (alignfold_sums, global_sum)
2615}
2616
2617pub fn get_loop_sums<T>(inputs: InputsLoopSumsGetter<T>) -> (Sum, Sum)
2618where
2619  T: HashIndex,
2620{
2621  let (
2622    seq_pair,
2623    alignfold_scores,
2624    match_probs,
2625    pos_quad,
2626    alignfold_sums,
2627    computes_forward_sums,
2628    pos_pairs,
2629    range_insert_scores,
2630    matchable_poss,
2631    matchable_poss2,
2632  ) = inputs;
2633  let &(i, j, k, l) = pos_quad;
2634  let leftmost_pos_pair = if computes_forward_sums {
2635    (i, k)
2636  } else {
2637    (i + T::one(), k + T::one())
2638  };
2639  let rightmost_pos_pair = if computes_forward_sums {
2640    (j - T::one(), l - T::one())
2641  } else {
2642    (j, l)
2643  };
2644  let sums_hashed_poss = if computes_forward_sums {
2645    &mut alignfold_sums.forward_sums_hashed_poss
2646  } else {
2647    &mut alignfold_sums.backward_sums_hashed_poss
2648  };
2649  let sums_hashed_poss2 = if computes_forward_sums {
2650    &mut alignfold_sums.forward_sums_hashed_poss2
2651  } else {
2652    &mut alignfold_sums.backward_sums_hashed_poss2
2653  };
2654  if !sums_hashed_poss.contains_key(&if computes_forward_sums {
2655    leftmost_pos_pair
2656  } else {
2657    rightmost_pos_pair
2658  }) {
2659    sums_hashed_poss.insert(
2660      if computes_forward_sums {
2661        leftmost_pos_pair
2662      } else {
2663        rightmost_pos_pair
2664      },
2665      LoopSumsMat::<T>::new(),
2666    );
2667  }
2668  if !sums_hashed_poss2.contains_key(&if computes_forward_sums {
2669    leftmost_pos_pair
2670  } else {
2671    rightmost_pos_pair
2672  }) {
2673    sums_hashed_poss2.insert(
2674      if computes_forward_sums {
2675        leftmost_pos_pair
2676      } else {
2677        rightmost_pos_pair
2678      },
2679      LoopSumsMat::<T>::new(),
2680    );
2681  }
2682  let sums_mat = &mut sums_hashed_poss
2683    .get_mut(&if computes_forward_sums {
2684      leftmost_pos_pair
2685    } else {
2686      rightmost_pos_pair
2687    })
2688    .unwrap();
2689  let sums_mat2 = &mut sums_hashed_poss2
2690    .get_mut(&if computes_forward_sums {
2691      leftmost_pos_pair
2692    } else {
2693      rightmost_pos_pair
2694    })
2695    .unwrap();
2696  let iter: Poss<T> = if computes_forward_sums {
2697    range(i, j).collect()
2698  } else {
2699    range_inclusive(i + T::one(), j).rev().collect()
2700  };
2701  let iter2: Poss<T> = if computes_forward_sums {
2702    range(k, l).collect()
2703  } else {
2704    range_inclusive(k + T::one(), l).rev().collect()
2705  };
2706  for &u in iter.iter() {
2707    let long_u = u.to_usize().unwrap();
2708    let base = seq_pair.0[long_u];
2709    for &v in iter2.iter() {
2710      let pos_pair = (u, v);
2711      if sums_mat.contains_key(&pos_pair) {
2712        continue;
2713      }
2714      let mut sums = LoopSums::new();
2715      if (computes_forward_sums && u == i && v == k) || (!computes_forward_sums && u == j && v == l)
2716      {
2717        sums.sum_seqalign = 0.;
2718        sums.sum_seqalign_multibranch = 0.;
2719        sums.sum_0ormore_pairmatches = 0.;
2720        sums_mat.insert(pos_pair, sums);
2721        continue;
2722      }
2723      let long_v = v.to_usize().unwrap();
2724      let mut sum_multibranch = NEG_INFINITY;
2725      let mut sum_1st_pairmatches = sum_multibranch;
2726      let mut sum = sum_multibranch;
2727      if let Some(pos_pairs) = pos_pairs.get(&pos_pair) {
2728        for &(m, n) in pos_pairs {
2729          if computes_forward_sums {
2730            if !(i < m && k < n) {
2731              continue;
2732            }
2733          } else if !(m < j && n < l) {
2734            continue;
2735          }
2736          let pos_pair2 = if computes_forward_sums {
2737            (m - T::one(), n - T::one())
2738          } else {
2739            (m + T::one(), n + T::one())
2740          };
2741          let pos_quad2 = if computes_forward_sums {
2742            (m, u, n, v)
2743          } else {
2744            (u, m, v, n)
2745          };
2746          if let Some(&x) = alignfold_sums.sums_accessible_multibranch.get(&pos_quad2) {
2747            if let Some(y) = sums_mat2.get(&pos_pair2) {
2748              let z = x + y.sum_1ormore_pairmatches;
2749              logsumexp(&mut sum_multibranch, z);
2750              let z = x + y.sum_seqalign_multibranch;
2751              logsumexp(&mut sum_1st_pairmatches, z);
2752            }
2753          }
2754        }
2755      }
2756      let pos_pair2 = if computes_forward_sums {
2757        (u - T::one(), v - T::one())
2758      } else {
2759        (u + T::one(), v + T::one())
2760      };
2761      let long_pos_pair2 = (
2762        pos_pair2.0.to_usize().unwrap(),
2763        pos_pair2.1.to_usize().unwrap(),
2764      );
2765      let base2 = seq_pair.1[long_v];
2766      if match_probs.contains_key(&pos_pair) {
2767        let mut sums2 = LoopSums::new();
2768        let mut sum_seqalign2 = NEG_INFINITY;
2769        let mut sum_seqalign_multibranch2 = sum_seqalign2;
2770        let mut sum_multibranch2 = sum_seqalign2;
2771        let mut sum_1st_pairmatches2 = sum_seqalign2;
2772        let mut sum2 = sum_seqalign2;
2773        let loopmatch_score = alignfold_scores.match_scores[base][base2];
2774        let loopmatch_score_multibranch =
2775          loopmatch_score + 2. * alignfold_scores.multibranch_score_unpair;
2776        if let Some(x) = sums_mat.get(&pos_pair2) {
2777          let y = x.sum_multibranch + alignfold_scores.match2match_score;
2778          logsumexp(&mut sum_multibranch2, y);
2779          let y = x.sum_1st_pairmatches + alignfold_scores.match2match_score;
2780          logsumexp(&mut sum_1st_pairmatches2, y);
2781          let y = x.sum_seqalign_multibranch + alignfold_scores.match2match_score;
2782          logsumexp(&mut sum_seqalign_multibranch2, y);
2783          let y = x.sum_seqalign + alignfold_scores.match2match_score;
2784          logsumexp(&mut sum_seqalign2, y);
2785        }
2786        if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2787          for &x in x {
2788            if computes_forward_sums && x >= pos_pair2.1
2789              || (!computes_forward_sums && x <= pos_pair2.1)
2790            {
2791              continue;
2792            }
2793            if let Some(y) = sums_mat.get(&(pos_pair2.0, x)) {
2794              let long_x = x.to_usize().unwrap();
2795              let z = if computes_forward_sums {
2796                range_insert_scores.insert_scores2[long_x + 1][long_pos_pair2.1]
2797              } else {
2798                range_insert_scores.insert_scores2[long_pos_pair2.1][long_x - 1]
2799              } + alignfold_scores.match2insert_score;
2800              let a = if computes_forward_sums {
2801                range_insert_scores.insert_scores_multibranch2[long_x + 1][long_pos_pair2.1]
2802              } else {
2803                range_insert_scores.insert_scores_multibranch2[long_pos_pair2.1][long_x - 1]
2804              } + alignfold_scores.match2insert_score;
2805              let x = y.sum_multibranch + alignfold_scores.match2insert_score + a;
2806              logsumexp(&mut sum_multibranch2, x);
2807              let x = y.sum_1st_pairmatches + alignfold_scores.match2insert_score + a;
2808              logsumexp(&mut sum_1st_pairmatches2, x);
2809              let x = y.sum_seqalign + alignfold_scores.match2insert_score + z;
2810              logsumexp(&mut sum_seqalign2, x);
2811              let x = y.sum_seqalign_multibranch + alignfold_scores.match2insert_score + a;
2812              logsumexp(&mut sum_seqalign_multibranch2, x);
2813            }
2814          }
2815        }
2816        if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2817          for &x in x {
2818            if computes_forward_sums && x >= pos_pair2.0
2819              || (!computes_forward_sums && x <= pos_pair2.0)
2820            {
2821              continue;
2822            }
2823            if let Some(y) = sums_mat.get(&(x, pos_pair2.1)) {
2824              let long_x = x.to_usize().unwrap();
2825              let z = if computes_forward_sums {
2826                range_insert_scores.insert_scores[long_x + 1][long_pos_pair2.0]
2827              } else {
2828                range_insert_scores.insert_scores[long_pos_pair2.0][long_x - 1]
2829              } + alignfold_scores.match2insert_score;
2830              let a = if computes_forward_sums {
2831                range_insert_scores.insert_scores_multibranch[long_x + 1][long_pos_pair2.0]
2832              } else {
2833                range_insert_scores.insert_scores_multibranch[long_pos_pair2.0][long_x - 1]
2834              } + alignfold_scores.match2insert_score;
2835              let x = y.sum_multibranch + alignfold_scores.match2insert_score + a;
2836              logsumexp(&mut sum_multibranch2, x);
2837              let x = y.sum_1st_pairmatches + alignfold_scores.match2insert_score + a;
2838              logsumexp(&mut sum_1st_pairmatches2, x);
2839              let x = y.sum_seqalign + alignfold_scores.match2insert_score + z;
2840              logsumexp(&mut sum_seqalign2, x);
2841              let x = y.sum_seqalign_multibranch + alignfold_scores.match2insert_score + a;
2842              logsumexp(&mut sum_seqalign_multibranch2, x);
2843            }
2844          }
2845        }
2846        sums2.sum_multibranch = sum_multibranch2;
2847        logsumexp(&mut sum2, sum_multibranch2);
2848        sums2.sum_1st_pairmatches = sum_1st_pairmatches2;
2849        logsumexp(&mut sum2, sum_1st_pairmatches2);
2850        sums2.sum_1ormore_pairmatches = sum2;
2851        sums2.sum_seqalign = sum_seqalign2;
2852        sums2.sum_seqalign_multibranch = sum_seqalign_multibranch2;
2853        logsumexp(&mut sum2, sum_seqalign_multibranch2);
2854        sums2.sum_0ormore_pairmatches = sum2;
2855        if has_valid_sums(&sums2) {
2856          sums_mat2.insert(pos_pair2, sums2);
2857        }
2858        let term = sum_multibranch2 + loopmatch_score_multibranch;
2859        logsumexp(&mut sum_multibranch, term);
2860        sums.sum_multibranch = sum_multibranch;
2861        logsumexp(&mut sum, sum_multibranch);
2862        let term = sum_1st_pairmatches2 + loopmatch_score_multibranch;
2863        logsumexp(&mut sum_1st_pairmatches, term);
2864        sums.sum_1st_pairmatches = sum_1st_pairmatches;
2865        logsumexp(&mut sum, sum_1st_pairmatches);
2866        sums.sum_1ormore_pairmatches = sum;
2867        let sum_seqalign_multibranch = sum_seqalign_multibranch2 + loopmatch_score_multibranch;
2868        sums.sum_seqalign_multibranch = sum_seqalign_multibranch;
2869        logsumexp(&mut sum, sum_seqalign_multibranch);
2870        sums.sum_0ormore_pairmatches = sum;
2871        let sum_seqalign = sum_seqalign2 + loopmatch_score;
2872        sums.sum_seqalign = sum_seqalign;
2873        if has_valid_sums(&sums) {
2874          sums_mat.insert(pos_pair, sums);
2875        }
2876      }
2877    }
2878  }
2879  let mut final_sum_seqalign = NEG_INFINITY;
2880  let mut final_sum_multibranch = final_sum_seqalign;
2881  if computes_forward_sums {
2882    let pos_pair2 = rightmost_pos_pair;
2883    let long_pos_pair2 = (
2884      pos_pair2.0.to_usize().unwrap(),
2885      pos_pair2.1.to_usize().unwrap(),
2886    );
2887    if let Some(x) = sums_mat.get(&pos_pair2) {
2888      let y = x.sum_multibranch + alignfold_scores.match2match_score;
2889      logsumexp(&mut final_sum_multibranch, y);
2890      let y = x.sum_seqalign + alignfold_scores.match2match_score;
2891      logsumexp(&mut final_sum_seqalign, y);
2892    }
2893    if let Some(x) = matchable_poss.get(&pos_pair2.0) {
2894      for &x in x {
2895        if x >= pos_pair2.1 {
2896          continue;
2897        }
2898        if let Some(y) = sums_mat.get(&(pos_pair2.0, x)) {
2899          let long_x = x.to_usize().unwrap();
2900          let z = range_insert_scores.insert_scores2[long_x + 1][long_pos_pair2.1]
2901            + alignfold_scores.match2insert_score;
2902          let a = range_insert_scores.insert_scores_multibranch2[long_x + 1][long_pos_pair2.1]
2903            + alignfold_scores.match2insert_score;
2904          let x = y.sum_multibranch + alignfold_scores.match2insert_score + a;
2905          logsumexp(&mut final_sum_multibranch, x);
2906          let x = y.sum_seqalign + alignfold_scores.match2insert_score + z;
2907          logsumexp(&mut final_sum_seqalign, x);
2908        }
2909      }
2910    }
2911    if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
2912      for &x in x {
2913        if x >= pos_pair2.0 {
2914          continue;
2915        }
2916        if let Some(y) = sums_mat.get(&(x, pos_pair2.1)) {
2917          let long_x = x.to_usize().unwrap();
2918          let z = range_insert_scores.insert_scores[long_x + 1][long_pos_pair2.0]
2919            + alignfold_scores.match2insert_score;
2920          let a = range_insert_scores.insert_scores_multibranch[long_x + 1][long_pos_pair2.0]
2921            + alignfold_scores.match2insert_score;
2922          let x = y.sum_multibranch + alignfold_scores.match2insert_score + a;
2923          logsumexp(&mut final_sum_multibranch, x);
2924          let x = y.sum_seqalign + alignfold_scores.match2insert_score + z;
2925          logsumexp(&mut final_sum_seqalign, x);
2926        }
2927      }
2928    }
2929    let mut sums2 = LoopSums::new();
2930    sums2.sum_multibranch = final_sum_multibranch;
2931    sums2.sum_seqalign = final_sum_seqalign;
2932    sums_mat2.insert(pos_pair2, sums2);
2933  }
2934  (final_sum_seqalign, final_sum_multibranch)
2935}
2936
2937pub fn get_2loop_sums<T>(inputs: Inputs2loopSumsGetter<T>) -> (SparseSumMat<T>, SparseSumMat<T>)
2938where
2939  T: HashIndex,
2940{
2941  let (
2942    seq_pair,
2943    alignfold_scores,
2944    match_probs,
2945    pos_quad,
2946    alignfold_sums,
2947    computes_forward_sums,
2948    pos_pairs,
2949    fold_scores_pair,
2950    range_insert_scores,
2951    matchable_poss,
2952    matchable_poss2,
2953  ) = inputs;
2954  let &(i, j, k, l) = pos_quad;
2955  let leftmost_pos_pair = if computes_forward_sums {
2956    (i, k)
2957  } else {
2958    (i + T::one(), k + T::one())
2959  };
2960  let rightmost_pos_pair = if computes_forward_sums {
2961    (j - T::one(), l - T::one())
2962  } else {
2963    (j, l)
2964  };
2965  let sums_hashed_poss = if computes_forward_sums {
2966    &alignfold_sums.forward_sums_hashed_poss2
2967  } else {
2968    &alignfold_sums.backward_sums_hashed_poss2
2969  };
2970  let sums = &sums_hashed_poss[&if computes_forward_sums {
2971    leftmost_pos_pair
2972  } else {
2973    rightmost_pos_pair
2974  }];
2975  let iter: Poss<T> = if computes_forward_sums {
2976    range(i, j).collect()
2977  } else {
2978    range_inclusive(i + T::one(), j).rev().collect()
2979  };
2980  let iter2: Poss<T> = if computes_forward_sums {
2981    range(k, l).collect()
2982  } else {
2983    range_inclusive(k + T::one(), l).rev().collect()
2984  };
2985  let mut sum_mat = SparseSumMat::<T>::default();
2986  let mut sum_mat2 = sum_mat.clone();
2987  for &u in iter.iter() {
2988    let long_u = u.to_usize().unwrap();
2989    let base = seq_pair.0[long_u];
2990    for &v in iter2.iter() {
2991      let pos_pair = (u, v);
2992      if (computes_forward_sums && u == i && v == k) || (!computes_forward_sums && u == j && v == l)
2993      {
2994        continue;
2995      }
2996      let long_v = v.to_usize().unwrap();
2997      let mut sum = NEG_INFINITY;
2998      if let Some(x) = pos_pairs.get(&pos_pair) {
2999        for &(m, n) in x {
3000          if computes_forward_sums {
3001            if !(i < m && k < n) {
3002              continue;
3003            }
3004          } else if !(m < j && n < l) {
3005            continue;
3006          }
3007          let pos_pair2 = if computes_forward_sums {
3008            (m - T::one(), n - T::one())
3009          } else {
3010            (m + T::one(), n + T::one())
3011          };
3012          let pos_quad2 = if computes_forward_sums {
3013            (m, u, n, v)
3014          } else {
3015            (u, m, v, n)
3016          };
3017          if pos_quad2.0 - i - T::one() + j - pos_quad2.1 - T::one()
3018            > T::from_usize(MAX_LOOP_LEN).unwrap()
3019          {
3020            continue;
3021          }
3022          if pos_quad2.2 - k - T::one() + l - pos_quad2.3 - T::one()
3023            > T::from_usize(MAX_LOOP_LEN).unwrap()
3024          {
3025            continue;
3026          }
3027          if let Some(&x) = alignfold_sums.sums_close.get(&pos_quad2) {
3028            if let Some(y) = sums.get(&pos_pair2) {
3029              let twoloop_score =
3030                fold_scores_pair.0.twoloop_scores[&(i, j, pos_quad2.0, pos_quad2.1)];
3031              let twoloop_score2 =
3032                fold_scores_pair.1.twoloop_scores[&(k, l, pos_quad2.2, pos_quad2.3)];
3033              let y = x + y.sum_seqalign + twoloop_score + twoloop_score2;
3034              logsumexp(&mut sum, y);
3035            }
3036          }
3037        }
3038      }
3039      let pos_pair2 = if computes_forward_sums {
3040        (u - T::one(), v - T::one())
3041      } else {
3042        (u + T::one(), v + T::one())
3043      };
3044      let long_pos_pair2 = (
3045        pos_pair2.0.to_usize().unwrap(),
3046        pos_pair2.1.to_usize().unwrap(),
3047      );
3048      let base2 = seq_pair.1[long_v];
3049      if match_probs.contains_key(&pos_pair) {
3050        let mut sum2 = NEG_INFINITY;
3051        let loopmatch_score = alignfold_scores.match_scores[base][base2];
3052        if let Some(&x) = sum_mat.get(&pos_pair2) {
3053          let x = x + alignfold_scores.match2match_score;
3054          logsumexp(&mut sum2, x);
3055        }
3056        if let Some(x) = matchable_poss.get(&pos_pair2.0) {
3057          for &x in x {
3058            if computes_forward_sums && x >= pos_pair2.1
3059              || (!computes_forward_sums && x <= pos_pair2.1)
3060            {
3061              continue;
3062            }
3063            if let Some(&y) = sum_mat.get(&(pos_pair2.0, x)) {
3064              let long_x = x.to_usize().unwrap();
3065              let z = if computes_forward_sums {
3066                range_insert_scores.insert_scores2[long_x + 1][long_pos_pair2.1]
3067              } else {
3068                range_insert_scores.insert_scores2[long_pos_pair2.1][long_x - 1]
3069              } + alignfold_scores.match2insert_score;
3070              let z = y + alignfold_scores.match2insert_score + z;
3071              logsumexp(&mut sum2, z);
3072            }
3073          }
3074        }
3075        if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
3076          for &x in x {
3077            if computes_forward_sums && x >= pos_pair2.0
3078              || (!computes_forward_sums && x <= pos_pair2.0)
3079            {
3080              continue;
3081            }
3082            if let Some(&y) = sum_mat.get(&(x, pos_pair2.1)) {
3083              let long_x = x.to_usize().unwrap();
3084              let z = if computes_forward_sums {
3085                range_insert_scores.insert_scores[long_x + 1][long_pos_pair2.0]
3086              } else {
3087                range_insert_scores.insert_scores[long_pos_pair2.0][long_x - 1]
3088              } + alignfold_scores.match2insert_score;
3089              let z = y + alignfold_scores.match2insert_score + z;
3090              logsumexp(&mut sum2, z);
3091            }
3092          }
3093        }
3094        if sum2 > NEG_INFINITY {
3095          sum_mat2.insert(pos_pair2, sum2);
3096        }
3097        let term = sum2 + loopmatch_score;
3098        logsumexp(&mut sum, term);
3099        if sum > NEG_INFINITY {
3100          sum_mat.insert(pos_pair, sum);
3101        }
3102      }
3103    }
3104  }
3105  (sum_mat, sum_mat2)
3106}
3107
3108pub fn get_alignfold_probs<T>(inputs: InputsAlignfoldProbsGetter<T>) -> AlignfoldProbMats<T>
3109where
3110  T: HashIndex,
3111{
3112  let (
3113    seq_pair,
3114    alignfold_scores,
3115    max_basepair_span_pair,
3116    match_probs,
3117    alignfold_sums,
3118    produces_struct_profs,
3119    global_sum,
3120    trains_alignfold_scores,
3121    alignfold_counts_expected,
3122    pos_quads_hashed_lens,
3123    fold_scores_pair,
3124    produces_match_probs,
3125    forward_pos_pairs,
3126    backward_pos_pairs,
3127    range_insert_scores,
3128    matchable_poss,
3129    matchable_poss2,
3130  ) = inputs;
3131  let seq_len_pair = (
3132    T::from_usize(seq_pair.0.len()).unwrap(),
3133    T::from_usize(seq_pair.1.len()).unwrap(),
3134  );
3135  let mut alignfold_outside_sums = SumMat4d::<T>::default();
3136  let mut alignfold_probs = AlignfoldProbMats::<T>::new(&(
3137    seq_len_pair.0.to_usize().unwrap(),
3138    seq_len_pair.1.to_usize().unwrap(),
3139  ));
3140  let leftmost_pos_pair = (T::zero(), T::zero());
3141  let rightmost_pos_pair = (seq_len_pair.0 - T::one(), seq_len_pair.1 - T::one());
3142  let mut prob_coeffs_multibranch = SumMat4d::<T>::default();
3143  let mut prob_coeffs_multibranch2 = prob_coeffs_multibranch.clone();
3144  for substr_len in range_inclusive(
3145    T::from_usize(if trains_alignfold_scores {
3146      2
3147    } else {
3148      MIN_SPAN_HAIRPIN_CLOSE
3149    })
3150    .unwrap(),
3151    max_basepair_span_pair.0,
3152  )
3153  .rev()
3154  {
3155    for substr_len2 in range_inclusive(
3156      T::from_usize(if trains_alignfold_scores {
3157        2
3158      } else {
3159        MIN_SPAN_HAIRPIN_CLOSE
3160      })
3161      .unwrap(),
3162      max_basepair_span_pair.1,
3163    )
3164    .rev()
3165    {
3166      if let Some(pos_pairs) = pos_quads_hashed_lens.get(&(substr_len, substr_len2)) {
3167        for &(i, k) in pos_pairs {
3168          let (j, l) = (i + substr_len - T::one(), k + substr_len2 - T::one());
3169          let pos_quad = (i, j, k, l);
3170          if let Some(&sum_close) = alignfold_sums.sums_close.get(&pos_quad) {
3171            let (long_i, long_j, long_k, long_l) = (
3172              i.to_usize().unwrap(),
3173              j.to_usize().unwrap(),
3174              k.to_usize().unwrap(),
3175              l.to_usize().unwrap(),
3176            );
3177            let basepair = (seq_pair.0[long_i], seq_pair.0[long_j]);
3178            let basepair2 = (seq_pair.1[long_k], seq_pair.1[long_l]);
3179            let mismatch_pair = (seq_pair.0[long_i - 1], seq_pair.0[long_j + 1]);
3180            let mismatch_pair2 = (seq_pair.1[long_k - 1], seq_pair.1[long_l + 1]);
3181            let prob_coeff = sum_close - global_sum;
3182            let mut sum = NEG_INFINITY;
3183            let mut forward_term = sum;
3184            let mut forward_term_match = sum;
3185            let mut backward_term = sum;
3186            let pos_pair2 = (i - T::one(), k - T::one());
3187            if let Some(&x) = alignfold_sums.forward_sums_external2.get(&pos_pair2) {
3188              logsumexp(&mut forward_term, x);
3189            }
3190            if trains_alignfold_scores {
3191              if let Some(&x) = alignfold_sums.forward_sums_external.get(&pos_pair2) {
3192                let begins_sum = pos_pair2 == leftmost_pos_pair;
3193                let x = x
3194                  + if begins_sum {
3195                    alignfold_scores.init_match_score
3196                  } else {
3197                    alignfold_scores.match2match_score
3198                  };
3199                logsumexp(&mut forward_term_match, x);
3200              }
3201            }
3202            let pos_pair2 = (j + T::one(), l + T::one());
3203            if let Some(&x) = alignfold_sums.backward_sums_external2.get(&pos_pair2) {
3204              logsumexp(&mut backward_term, x);
3205            }
3206            let coeff = alignfold_sums.sums_accessible_external[&pos_quad] - sum_close;
3207            if trains_alignfold_scores {
3208              let begins_sum = (i - T::one(), k - T::one()) == leftmost_pos_pair;
3209              let x = prob_coeff + coeff + forward_term_match + backward_term;
3210              if begins_sum {
3211                logsumexp(&mut alignfold_counts_expected.init_match_score, x);
3212              } else {
3213                logsumexp(&mut alignfold_counts_expected.match2match_score, x);
3214              }
3215            }
3216            let sum_external = forward_term + backward_term;
3217            if sum_external > NEG_INFINITY {
3218              sum = coeff + sum_external;
3219              let x = prob_coeff + sum;
3220              if trains_alignfold_scores {
3221                // Count external loop accessible basepairings.
3222                logsumexp(
3223                  &mut alignfold_counts_expected.external_score_basepair,
3224                  (2. as Prob).ln() + x,
3225                );
3226                // Count helix ends.
3227                logsumexp(
3228                  &mut alignfold_counts_expected.helix_close_scores[basepair.1][basepair.0],
3229                  x,
3230                );
3231                logsumexp(
3232                  &mut alignfold_counts_expected.helix_close_scores[basepair2.1][basepair2.0],
3233                  x,
3234                );
3235                // Count external loop terminal mismatches.
3236                if j < seq_len_pair.0 - T::from_usize(2).unwrap() {
3237                  logsumexp(
3238                    &mut alignfold_counts_expected.dangling_scores_left[basepair.1][basepair.0]
3239                      [mismatch_pair.1],
3240                    x,
3241                  );
3242                }
3243                if i > T::one() {
3244                  logsumexp(
3245                    &mut alignfold_counts_expected.dangling_scores_right[basepair.1][basepair.0]
3246                      [mismatch_pair.0],
3247                    x,
3248                  );
3249                }
3250                if l < seq_len_pair.1 - T::from_usize(2).unwrap() {
3251                  logsumexp(
3252                    &mut alignfold_counts_expected.dangling_scores_left[basepair2.1][basepair2.0]
3253                      [mismatch_pair2.1],
3254                    x,
3255                  );
3256                }
3257                if k > T::one() {
3258                  logsumexp(
3259                    &mut alignfold_counts_expected.dangling_scores_right[basepair2.1][basepair2.0]
3260                      [mismatch_pair2.0],
3261                    x,
3262                  );
3263                }
3264              }
3265            }
3266            for substr_len3 in range_inclusive(
3267              substr_len + T::from_usize(2).unwrap(),
3268              (substr_len + T::from_usize(MAX_LOOP_LEN + 2).unwrap()).min(max_basepair_span_pair.0),
3269            ) {
3270              for substr_len4 in range_inclusive(
3271                substr_len2 + T::from_usize(2).unwrap(),
3272                (substr_len2 + T::from_usize(MAX_LOOP_LEN + 2).unwrap())
3273                  .min(max_basepair_span_pair.1),
3274              ) {
3275                if let Some(pos_pairs2) = pos_quads_hashed_lens.get(&(substr_len3, substr_len4)) {
3276                  for &(m, o) in pos_pairs2 {
3277                    let (n, p) = (m + substr_len3 - T::one(), o + substr_len4 - T::one());
3278                    if !(m < i && j < n && o < k && l < p) {
3279                      continue;
3280                    }
3281                    let (long_m, long_n, long_o, long_p) = (
3282                      m.to_usize().unwrap(),
3283                      n.to_usize().unwrap(),
3284                      o.to_usize().unwrap(),
3285                      p.to_usize().unwrap(),
3286                    );
3287                    let loop_len_pair = (long_i - long_m - 1, long_n - long_j - 1);
3288                    let loop_len_pair2 = (long_k - long_o - 1, long_p - long_l - 1);
3289                    if loop_len_pair.0 + loop_len_pair.1 > MAX_LOOP_LEN {
3290                      continue;
3291                    }
3292                    if loop_len_pair2.0 + loop_len_pair2.1 > MAX_LOOP_LEN {
3293                      continue;
3294                    }
3295                    let basepair3 = (seq_pair.0[long_m], seq_pair.0[long_n]);
3296                    let basepair4 = (seq_pair.1[long_o], seq_pair.1[long_p]);
3297                    let found_stack = loop_len_pair.0 == 0 && loop_len_pair.1 == 0;
3298                    let found_bulge =
3299                      (loop_len_pair.0 == 0 || loop_len_pair.1 == 0) && !found_stack;
3300                    let mismatch_pair3 = (seq_pair.0[long_m + 1], seq_pair.0[long_n - 1]);
3301                    let pos_quad2 = (m, n, o, p);
3302                    if let Some(&outside_sum) = alignfold_outside_sums.get(&pos_quad2) {
3303                      let found_stack2 = loop_len_pair2.0 == 0 && loop_len_pair2.1 == 0;
3304                      let found_bulge2 =
3305                        (loop_len_pair2.0 == 0 || loop_len_pair2.1 == 0) && !found_stack2;
3306                      let mismatch_pair4 = (seq_pair.1[long_o + 1], seq_pair.1[long_p - 1]);
3307                      let forward_sums = &alignfold_sums.forward_sums_hashed_poss[&(m, o)];
3308                      let forward_sums2 = &alignfold_sums.forward_sums_hashed_poss2[&(m, o)];
3309                      let backward_sums = &alignfold_sums.backward_sums_hashed_poss2[&(n, p)];
3310                      let mut forward_term = NEG_INFINITY;
3311                      let mut forward_term_match = forward_term;
3312                      let mut backward_term = forward_term;
3313                      let pos_pair2 = (i - T::one(), k - T::one());
3314                      if let Some(x) = forward_sums2.get(&pos_pair2) {
3315                        logsumexp(&mut forward_term, x.sum_seqalign);
3316                      }
3317                      if trains_alignfold_scores {
3318                        if let Some(x) = forward_sums.get(&pos_pair2) {
3319                          let term = x.sum_seqalign + alignfold_scores.match2match_score;
3320                          logsumexp(&mut forward_term_match, term);
3321                        }
3322                      }
3323                      let pos_pair2 = (j + T::one(), l + T::one());
3324                      if let Some(x) = backward_sums.get(&pos_pair2) {
3325                        logsumexp(&mut backward_term, x.sum_seqalign);
3326                      }
3327                      let pairmatch_score = alignfold_scores.match_scores[basepair3.0][basepair4.0]
3328                        + alignfold_scores.match_scores[basepair3.1][basepair4.1];
3329                      let twoloop_score = fold_scores_pair.0.twoloop_scores[&(m, n, i, j)];
3330                      let twoloop_score2 = fold_scores_pair.1.twoloop_scores[&(o, p, k, l)];
3331                      let coeff = pairmatch_score + twoloop_score + twoloop_score2 + outside_sum;
3332                      if trains_alignfold_scores {
3333                        let x = prob_coeff + coeff + forward_term_match + backward_term;
3334                        logsumexp(&mut alignfold_counts_expected.match2match_score, x);
3335                      }
3336                      let sum_2loop = forward_term + backward_term;
3337                      if sum_2loop > NEG_INFINITY {
3338                        let sum_2loop = coeff + sum_2loop;
3339                        logsumexp(&mut sum, sum_2loop);
3340                        let pairmatch_prob_2loop = prob_coeff + sum_2loop;
3341                        if produces_struct_profs {
3342                          let loop_len_pair = (long_i - long_m - 1, long_n - long_j - 1);
3343                          let found_bulge = (loop_len_pair.0 == 0) ^ (loop_len_pair.1 == 0);
3344                          let found_interior = loop_len_pair.0 > 0 && loop_len_pair.1 > 0;
3345                          for q in long_m + 1..long_i {
3346                            if found_bulge {
3347                              logsumexp(
3348                                &mut alignfold_probs.context_profs_pair.0[(q, CONTEXT_INDEX_BULGE)],
3349                                pairmatch_prob_2loop,
3350                              );
3351                            } else if found_interior {
3352                              logsumexp(
3353                                &mut alignfold_probs.context_profs_pair.0
3354                                  [(q, CONTEXT_INDEX_INTERIOR)],
3355                                pairmatch_prob_2loop,
3356                              );
3357                            }
3358                          }
3359                          for q in long_j + 1..long_n {
3360                            if found_bulge {
3361                              logsumexp(
3362                                &mut alignfold_probs.context_profs_pair.0[(q, CONTEXT_INDEX_BULGE)],
3363                                pairmatch_prob_2loop,
3364                              );
3365                            } else if found_interior {
3366                              logsumexp(
3367                                &mut alignfold_probs.context_profs_pair.0
3368                                  [(q, CONTEXT_INDEX_INTERIOR)],
3369                                pairmatch_prob_2loop,
3370                              );
3371                            }
3372                          }
3373                          let loop_len_pair = (long_k - long_o - 1, long_p - long_l - 1);
3374                          let found_bulge = (loop_len_pair.0 == 0) ^ (loop_len_pair.1 == 0);
3375                          let found_interior = loop_len_pair.0 > 0 && loop_len_pair.1 > 0;
3376                          for q in long_o + 1..long_k {
3377                            if found_bulge {
3378                              logsumexp(
3379                                &mut alignfold_probs.context_profs_pair.1[(q, CONTEXT_INDEX_BULGE)],
3380                                pairmatch_prob_2loop,
3381                              );
3382                            } else if found_interior {
3383                              logsumexp(
3384                                &mut alignfold_probs.context_profs_pair.1
3385                                  [(q, CONTEXT_INDEX_INTERIOR)],
3386                                pairmatch_prob_2loop,
3387                              );
3388                            }
3389                          }
3390                          for q in long_l + 1..long_p {
3391                            if found_bulge {
3392                              logsumexp(
3393                                &mut alignfold_probs.context_profs_pair.1[(q, CONTEXT_INDEX_BULGE)],
3394                                pairmatch_prob_2loop,
3395                              );
3396                            } else if found_interior {
3397                              logsumexp(
3398                                &mut alignfold_probs.context_profs_pair.1
3399                                  [(q, CONTEXT_INDEX_INTERIOR)],
3400                                pairmatch_prob_2loop,
3401                              );
3402                            }
3403                          }
3404                        }
3405                        if trains_alignfold_scores {
3406                          if found_stack {
3407                            // Count a stack.
3408                            let dict_min_stack = get_dict_min_stack(&basepair3, &basepair);
3409                            logsumexp(
3410                              &mut alignfold_counts_expected.stack_scores[dict_min_stack.0 .0]
3411                                [dict_min_stack.0 .1][dict_min_stack.1 .0][dict_min_stack.1 .1],
3412                              pairmatch_prob_2loop,
3413                            );
3414                          } else {
3415                            if found_bulge {
3416                              // Count a bulge loop length.
3417                              let bulge_len = if loop_len_pair.0 == 0 {
3418                                loop_len_pair.1
3419                              } else {
3420                                loop_len_pair.0
3421                              };
3422                              logsumexp(
3423                                &mut alignfold_counts_expected.bulge_scores_len[bulge_len - 1],
3424                                pairmatch_prob_2loop,
3425                              );
3426                              // Count a 0x1 bulge loop.
3427                              if bulge_len == 1 {
3428                                let mismatch = if loop_len_pair.0 == 0 {
3429                                  mismatch_pair3.1
3430                                } else {
3431                                  mismatch_pair3.0
3432                                };
3433                                logsumexp(
3434                                  &mut alignfold_counts_expected.bulge_scores_0x1[mismatch],
3435                                  pairmatch_prob_2loop,
3436                                );
3437                              }
3438                            } else {
3439                              // Count an interior loop length.
3440                              logsumexp(
3441                                &mut alignfold_counts_expected.interior_scores_len
3442                                  [loop_len_pair.0 + loop_len_pair.1 - 2],
3443                                pairmatch_prob_2loop,
3444                              );
3445                              let diff = get_diff(loop_len_pair.0, loop_len_pair.1);
3446                              if diff == 0 {
3447                                logsumexp(
3448                                  &mut alignfold_counts_expected.interior_scores_symmetric
3449                                    [loop_len_pair.0 - 1],
3450                                  pairmatch_prob_2loop,
3451                                );
3452                              } else {
3453                                logsumexp(
3454                                  &mut alignfold_counts_expected.interior_scores_asymmetric
3455                                    [diff - 1],
3456                                  pairmatch_prob_2loop,
3457                                );
3458                              }
3459                              // Count a 1x1 interior loop.
3460                              if loop_len_pair.0 == 1 && loop_len_pair.1 == 1 {
3461                                let dict_min_mismatch_pair3 = get_dict_min_pair(&mismatch_pair3);
3462                                logsumexp(
3463                                  &mut alignfold_counts_expected.interior_scores_1x1
3464                                    [dict_min_mismatch_pair3.0][dict_min_mismatch_pair3.1],
3465                                  pairmatch_prob_2loop,
3466                                );
3467                              }
3468                              // Count an explicit interior loop length pair.
3469                              if loop_len_pair.0 <= MAX_INTERIOR_EXPLICIT
3470                                && loop_len_pair.1 <= MAX_INTERIOR_EXPLICIT
3471                              {
3472                                let dict_min_len_pair = get_dict_min_pair(&loop_len_pair);
3473                                logsumexp(
3474                                  &mut alignfold_counts_expected.interior_scores_explicit
3475                                    [dict_min_len_pair.0 - 1][dict_min_len_pair.1 - 1],
3476                                  pairmatch_prob_2loop,
3477                                );
3478                              }
3479                            }
3480                            // Count helix ends.
3481                            logsumexp(
3482                              &mut alignfold_counts_expected.helix_close_scores[basepair.1]
3483                                [basepair.0],
3484                              pairmatch_prob_2loop,
3485                            );
3486                            logsumexp(
3487                              &mut alignfold_counts_expected.helix_close_scores[basepair3.0]
3488                                [basepair3.1],
3489                              pairmatch_prob_2loop,
3490                            );
3491                            // Count 2-loop terminal mismatches.
3492                            logsumexp(
3493                              &mut alignfold_counts_expected.terminal_mismatch_scores[basepair3.0]
3494                                [basepair3.1][mismatch_pair3.0][mismatch_pair3.1],
3495                              pairmatch_prob_2loop,
3496                            );
3497                            logsumexp(
3498                              &mut alignfold_counts_expected.terminal_mismatch_scores[basepair.1]
3499                                [basepair.0][mismatch_pair.1][mismatch_pair.0],
3500                              pairmatch_prob_2loop,
3501                            );
3502                          }
3503                          if found_stack2 {
3504                            // Count a stack.
3505                            let dict_min_stack2 = get_dict_min_stack(&basepair4, &basepair2);
3506                            logsumexp(
3507                              &mut alignfold_counts_expected.stack_scores[dict_min_stack2.0 .0]
3508                                [dict_min_stack2.0 .1][dict_min_stack2.1 .0][dict_min_stack2.1 .1],
3509                              pairmatch_prob_2loop,
3510                            );
3511                          } else {
3512                            if found_bulge2 {
3513                              // Count a bulge loop length.
3514                              let bulge_len2 = if loop_len_pair2.0 == 0 {
3515                                loop_len_pair2.1
3516                              } else {
3517                                loop_len_pair2.0
3518                              };
3519                              logsumexp(
3520                                &mut alignfold_counts_expected.bulge_scores_len[bulge_len2 - 1],
3521                                pairmatch_prob_2loop,
3522                              );
3523                              // Count a 0x1 bulge loop.
3524                              if bulge_len2 == 1 {
3525                                let mismatch2 = if loop_len_pair2.0 == 0 {
3526                                  mismatch_pair4.1
3527                                } else {
3528                                  mismatch_pair4.0
3529                                };
3530                                logsumexp(
3531                                  &mut alignfold_counts_expected.bulge_scores_0x1[mismatch2],
3532                                  pairmatch_prob_2loop,
3533                                );
3534                              }
3535                            } else {
3536                              // Count an interior loop length.
3537                              logsumexp(
3538                                &mut alignfold_counts_expected.interior_scores_len
3539                                  [loop_len_pair2.0 + loop_len_pair2.1 - 2],
3540                                pairmatch_prob_2loop,
3541                              );
3542                              let diff2 = get_diff(loop_len_pair2.0, loop_len_pair2.1);
3543                              if diff2 == 0 {
3544                                logsumexp(
3545                                  &mut alignfold_counts_expected.interior_scores_symmetric
3546                                    [loop_len_pair2.0 - 1],
3547                                  pairmatch_prob_2loop,
3548                                );
3549                              } else {
3550                                logsumexp(
3551                                  &mut alignfold_counts_expected.interior_scores_asymmetric
3552                                    [diff2 - 1],
3553                                  pairmatch_prob_2loop,
3554                                );
3555                              }
3556                              // Count a 1x1 interior loop.
3557                              if loop_len_pair2.0 == 1 && loop_len_pair2.1 == 1 {
3558                                let dict_min_mismatch_pair4 = get_dict_min_pair(&mismatch_pair4);
3559                                logsumexp(
3560                                  &mut alignfold_counts_expected.interior_scores_1x1
3561                                    [dict_min_mismatch_pair4.0][dict_min_mismatch_pair4.1],
3562                                  pairmatch_prob_2loop,
3563                                );
3564                              }
3565                              // Count an explicit interior loop length pair.
3566                              if loop_len_pair2.0 <= MAX_INTERIOR_EXPLICIT
3567                                && loop_len_pair2.1 <= MAX_INTERIOR_EXPLICIT
3568                              {
3569                                let dict_min_len_pair2 = get_dict_min_pair(&loop_len_pair2);
3570                                logsumexp(
3571                                  &mut alignfold_counts_expected.interior_scores_explicit
3572                                    [dict_min_len_pair2.0 - 1][dict_min_len_pair2.1 - 1],
3573                                  pairmatch_prob_2loop,
3574                                );
3575                              }
3576                            }
3577                            // Count helix ends.
3578                            logsumexp(
3579                              &mut alignfold_counts_expected.helix_close_scores[basepair2.1]
3580                                [basepair2.0],
3581                              pairmatch_prob_2loop,
3582                            );
3583                            logsumexp(
3584                              &mut alignfold_counts_expected.helix_close_scores[basepair4.0]
3585                                [basepair4.1],
3586                              pairmatch_prob_2loop,
3587                            );
3588                            // Count 2-loop terminal mismatches.
3589                            logsumexp(
3590                              &mut alignfold_counts_expected.terminal_mismatch_scores[basepair4.0]
3591                                [basepair4.1][mismatch_pair4.0][mismatch_pair4.1],
3592                              pairmatch_prob_2loop,
3593                            );
3594                            logsumexp(
3595                              &mut alignfold_counts_expected.terminal_mismatch_scores[basepair2.1]
3596                                [basepair2.0][mismatch_pair2.1][mismatch_pair2.0],
3597                              pairmatch_prob_2loop,
3598                            );
3599                          }
3600                        }
3601                      }
3602                    }
3603                  }
3604                }
3605              }
3606            }
3607            let sum_ratio = alignfold_sums.sums_accessible_multibranch[&pos_quad] - sum_close;
3608            for (pos_pair, forward_sums) in &alignfold_sums.forward_sums_hashed_poss {
3609              let &(u, v) = pos_pair;
3610              if !(u < i && v < k) {
3611                continue;
3612              }
3613              let forward_sums2 = &alignfold_sums.forward_sums_hashed_poss2[pos_pair];
3614              let pos_quad2 = (u, j, v, l);
3615              let mut forward_term = NEG_INFINITY;
3616              let mut forward_term_match = forward_term;
3617              let mut forward_term2 = forward_term;
3618              let mut forward_term_match2 = forward_term;
3619              let pos_pair2 = (i - T::one(), k - T::one());
3620              if let Some(x) = forward_sums2.get(&pos_pair2) {
3621                logsumexp(&mut forward_term, x.sum_1ormore_pairmatches);
3622                logsumexp(&mut forward_term2, x.sum_seqalign_multibranch);
3623              }
3624              if trains_alignfold_scores {
3625                if let Some(x) = forward_sums.get(&pos_pair2) {
3626                  let y = x.sum_1ormore_pairmatches + alignfold_scores.match2match_score;
3627                  logsumexp(&mut forward_term_match, y);
3628                  let y = x.sum_seqalign_multibranch + alignfold_scores.match2match_score;
3629                  logsumexp(&mut forward_term_match2, y);
3630                }
3631              }
3632              let mut sum_multibranch = NEG_INFINITY;
3633              if let Some(x) = prob_coeffs_multibranch.get(&pos_quad2) {
3634                let x = x + sum_ratio;
3635                let y = x + forward_term;
3636                logsumexp(&mut sum_multibranch, y);
3637                if trains_alignfold_scores {
3638                  let y = prob_coeff + x + forward_term_match;
3639                  logsumexp(&mut alignfold_counts_expected.match2match_score, y);
3640                }
3641              }
3642              if let Some(x) = prob_coeffs_multibranch2.get(&pos_quad2) {
3643                let x = x + sum_ratio;
3644                let y = x + forward_term2;
3645                logsumexp(&mut sum_multibranch, y);
3646                if trains_alignfold_scores {
3647                  let y = prob_coeff + x + forward_term_match2;
3648                  logsumexp(&mut alignfold_counts_expected.match2match_score, y);
3649                }
3650              }
3651              if sum_multibranch > NEG_INFINITY {
3652                logsumexp(&mut sum, sum_multibranch);
3653                let pairmatch_prob_multibranch = prob_coeff + sum_multibranch;
3654                if trains_alignfold_scores {
3655                  // Count multi-loop terminal mismatches.
3656                  logsumexp(
3657                    &mut alignfold_counts_expected.dangling_scores_left[basepair.1][basepair.0]
3658                      [mismatch_pair.1],
3659                    pairmatch_prob_multibranch,
3660                  );
3661                  logsumexp(
3662                    &mut alignfold_counts_expected.dangling_scores_right[basepair.1][basepair.0]
3663                      [mismatch_pair.0],
3664                    pairmatch_prob_multibranch,
3665                  );
3666                  logsumexp(
3667                    &mut alignfold_counts_expected.dangling_scores_left[basepair2.1][basepair2.0]
3668                      [mismatch_pair2.1],
3669                    pairmatch_prob_multibranch,
3670                  );
3671                  logsumexp(
3672                    &mut alignfold_counts_expected.dangling_scores_right[basepair2.1][basepair2.0]
3673                      [mismatch_pair2.0],
3674                    pairmatch_prob_multibranch,
3675                  );
3676                  // Count helix ends.
3677                  logsumexp(
3678                    &mut alignfold_counts_expected.helix_close_scores[basepair.1][basepair.0],
3679                    pairmatch_prob_multibranch,
3680                  );
3681                  logsumexp(
3682                    &mut alignfold_counts_expected.helix_close_scores[basepair2.1][basepair2.0],
3683                    pairmatch_prob_multibranch,
3684                  );
3685                  // Count multi-loop closings.
3686                  logsumexp(
3687                    &mut alignfold_counts_expected.multibranch_score_base,
3688                    (2. as Prob).ln() + pairmatch_prob_multibranch,
3689                  );
3690                  // Count multi-loop closing basepairings.
3691                  logsumexp(
3692                    &mut alignfold_counts_expected.multibranch_score_basepair,
3693                    (2. as Prob).ln() + pairmatch_prob_multibranch,
3694                  );
3695                  // Count multi-loop accessible basepairings.
3696                  logsumexp(
3697                    &mut alignfold_counts_expected.multibranch_score_basepair,
3698                    (2. as Prob).ln() + pairmatch_prob_multibranch,
3699                  );
3700                }
3701              }
3702            }
3703            if sum > NEG_INFINITY {
3704              alignfold_outside_sums.insert(pos_quad, sum);
3705              let pairmatch_prob = prob_coeff + sum;
3706              if produces_match_probs {
3707                alignfold_probs
3708                  .pairmatch_probs
3709                  .insert(pos_quad, pairmatch_prob);
3710                match alignfold_probs.match_probs.get_mut(&(i, k)) {
3711                  Some(x) => {
3712                    logsumexp(x, pairmatch_prob);
3713                  }
3714                  None => {
3715                    alignfold_probs.match_probs.insert((i, k), pairmatch_prob);
3716                  }
3717                }
3718                match alignfold_probs.match_probs.get_mut(&(j, l)) {
3719                  Some(x) => {
3720                    logsumexp(x, pairmatch_prob);
3721                  }
3722                  None => {
3723                    alignfold_probs.match_probs.insert((j, l), pairmatch_prob);
3724                  }
3725                }
3726              }
3727              if trains_alignfold_scores {
3728                // Count base pairs.
3729                let dict_min_basepair = get_dict_min_pair(&basepair);
3730                let dict_min_basepair2 = get_dict_min_pair(&basepair2);
3731                logsumexp(
3732                  &mut alignfold_counts_expected.basepair_scores[dict_min_basepair.0]
3733                    [dict_min_basepair.1],
3734                  pairmatch_prob,
3735                );
3736                logsumexp(
3737                  &mut alignfold_counts_expected.basepair_scores[dict_min_basepair2.0]
3738                    [dict_min_basepair2.1],
3739                  pairmatch_prob,
3740                );
3741                // Count alignments.
3742                let dict_min_match = get_dict_min_pair(&(basepair.0, basepair2.0));
3743                logsumexp(
3744                  &mut alignfold_counts_expected.match_scores[dict_min_match.0][dict_min_match.1],
3745                  pairmatch_prob,
3746                );
3747                let dict_min_match = get_dict_min_pair(&(basepair.1, basepair2.1));
3748                logsumexp(
3749                  &mut alignfold_counts_expected.match_scores[dict_min_match.0][dict_min_match.1],
3750                  pairmatch_prob,
3751                );
3752              }
3753              match alignfold_probs.basepair_probs_pair.0.get_mut(&(i, j)) {
3754                Some(x) => {
3755                  logsumexp(x, pairmatch_prob);
3756                }
3757                None => {
3758                  alignfold_probs
3759                    .basepair_probs_pair
3760                    .0
3761                    .insert((i, j), pairmatch_prob);
3762                }
3763              }
3764              match alignfold_probs.basepair_probs_pair.1.get_mut(&(k, l)) {
3765                Some(x) => {
3766                  logsumexp(x, pairmatch_prob);
3767                }
3768                None => {
3769                  alignfold_probs
3770                    .basepair_probs_pair
3771                    .1
3772                    .insert((k, l), pairmatch_prob);
3773                }
3774              }
3775              if produces_struct_profs {
3776                logsumexp(
3777                  &mut alignfold_probs.context_profs_pair.0[(long_i, CONTEXT_INDEX_BASEPAIR)],
3778                  pairmatch_prob,
3779                );
3780                logsumexp(
3781                  &mut alignfold_probs.context_profs_pair.0[(long_j, CONTEXT_INDEX_BASEPAIR)],
3782                  pairmatch_prob,
3783                );
3784                logsumexp(
3785                  &mut alignfold_probs.context_profs_pair.1[(long_k, CONTEXT_INDEX_BASEPAIR)],
3786                  pairmatch_prob,
3787                );
3788                logsumexp(
3789                  &mut alignfold_probs.context_profs_pair.1[(long_l, CONTEXT_INDEX_BASEPAIR)],
3790                  pairmatch_prob,
3791                );
3792              }
3793              let pairmatch_score = alignfold_scores.match_scores[basepair.0][basepair2.0]
3794                + alignfold_scores.match_scores[basepair.1][basepair2.1];
3795              let multibranch_close_score = fold_scores_pair.0.multibranch_close_scores[&(i, j)];
3796              let multibranch_close_score2 = fold_scores_pair.1.multibranch_close_scores[&(k, l)];
3797              if trains_alignfold_scores {
3798                let mismatch_pair = (seq_pair.0[long_i + 1], seq_pair.0[long_j - 1]);
3799                let mismatch_pair2 = (seq_pair.1[long_k + 1], seq_pair.1[long_l - 1]);
3800                let forward_sums = &alignfold_sums.forward_sums_hashed_poss[&(i, k)];
3801                let forward_sums2 = &alignfold_sums.forward_sums_hashed_poss2[&(i, k)];
3802                if substr_len.to_usize().unwrap() - 2 <= MAX_LOOP_LEN
3803                  && substr_len2.to_usize().unwrap() - 2 <= MAX_LOOP_LEN
3804                {
3805                  let mut sum_seqalign = NEG_INFINITY;
3806                  let mut sum_seqalign_match = sum_seqalign;
3807                  let pos_pair2 = (j - T::one(), l - T::one());
3808                  if let Some(x) = forward_sums2.get(&pos_pair2) {
3809                    logsumexp(&mut sum_seqalign, x.sum_seqalign);
3810                  }
3811                  if let Some(x) = forward_sums.get(&pos_pair2) {
3812                    let x = x.sum_seqalign + alignfold_scores.match2match_score;
3813                    logsumexp(&mut sum_seqalign_match, x);
3814                  }
3815                  let hairpin_score = fold_scores_pair.0.hairpin_scores[&(i, j)];
3816                  let hairpin_score2 = fold_scores_pair.1.hairpin_scores[&(k, l)];
3817                  let prob = sum - global_sum
3818                    + sum_seqalign_match
3819                    + hairpin_score
3820                    + hairpin_score2
3821                    + pairmatch_score;
3822                  logsumexp(&mut alignfold_counts_expected.match2match_score, prob);
3823                  let pairmatch_prob_hairpin = sum - global_sum
3824                    + sum_seqalign
3825                    + hairpin_score
3826                    + hairpin_score2
3827                    + pairmatch_score;
3828                  logsumexp(
3829                    &mut alignfold_counts_expected.hairpin_scores_len[long_j - long_i - 1],
3830                    pairmatch_prob_hairpin,
3831                  );
3832                  logsumexp(
3833                    &mut alignfold_counts_expected.hairpin_scores_len[long_l - long_k - 1],
3834                    pairmatch_prob_hairpin,
3835                  );
3836                  logsumexp(
3837                    &mut alignfold_counts_expected.terminal_mismatch_scores[basepair.0][basepair.1]
3838                      [mismatch_pair.0][mismatch_pair.1],
3839                    pairmatch_prob_hairpin,
3840                  );
3841                  logsumexp(
3842                    &mut alignfold_counts_expected.terminal_mismatch_scores[basepair2.0]
3843                      [basepair2.1][mismatch_pair2.0][mismatch_pair2.1],
3844                    pairmatch_prob_hairpin,
3845                  );
3846                  // Count helix ends.
3847                  logsumexp(
3848                    &mut alignfold_counts_expected.helix_close_scores[basepair.0][basepair.1],
3849                    pairmatch_prob_hairpin,
3850                  );
3851                  logsumexp(
3852                    &mut alignfold_counts_expected.helix_close_scores[basepair2.0][basepair2.1],
3853                    pairmatch_prob_hairpin,
3854                  );
3855                }
3856                let mut sum_multibranch = NEG_INFINITY;
3857                let mut sum_multibranch_match = sum_multibranch;
3858                if let Some(x) = forward_sums2.get(&pos_pair2) {
3859                  logsumexp(&mut sum_multibranch, x.sum_multibranch);
3860                }
3861                if let Some(x) = forward_sums.get(&pos_pair2) {
3862                  let x = x.sum_multibranch + alignfold_scores.match2match_score;
3863                  logsumexp(&mut sum_multibranch_match, x);
3864                }
3865                let prob = sum - global_sum
3866                  + sum_multibranch_match
3867                  + multibranch_close_score
3868                  + multibranch_close_score2
3869                  + pairmatch_score;
3870                logsumexp(&mut alignfold_counts_expected.match2match_score, prob);
3871                let pairmatch_prob_multibranch = sum - global_sum
3872                  + sum_multibranch
3873                  + multibranch_close_score
3874                  + multibranch_close_score2
3875                  + pairmatch_score;
3876                // Count multi-loop terminal mismatches.
3877                logsumexp(
3878                  &mut alignfold_counts_expected.dangling_scores_left[basepair.0][basepair.1]
3879                    [mismatch_pair.0],
3880                  pairmatch_prob_multibranch,
3881                );
3882                logsumexp(
3883                  &mut alignfold_counts_expected.dangling_scores_right[basepair.0][basepair.1]
3884                    [mismatch_pair.1],
3885                  pairmatch_prob_multibranch,
3886                );
3887                logsumexp(
3888                  &mut alignfold_counts_expected.dangling_scores_left[basepair2.0][basepair2.1]
3889                    [mismatch_pair2.0],
3890                  pairmatch_prob_multibranch,
3891                );
3892                logsumexp(
3893                  &mut alignfold_counts_expected.dangling_scores_right[basepair2.0][basepair2.1]
3894                    [mismatch_pair2.1],
3895                  pairmatch_prob_multibranch,
3896                );
3897                // Count helix ends.
3898                logsumexp(
3899                  &mut alignfold_counts_expected.helix_close_scores[basepair.0][basepair.1],
3900                  pairmatch_prob_multibranch,
3901                );
3902                logsumexp(
3903                  &mut alignfold_counts_expected.helix_close_scores[basepair2.0][basepair2.1],
3904                  pairmatch_prob_multibranch,
3905                );
3906              }
3907              let coeff =
3908                sum + pairmatch_score + multibranch_close_score + multibranch_close_score2;
3909              let backward_sums = &alignfold_sums.backward_sums_hashed_poss2[&(j, l)];
3910              for pos_pair in match_probs.keys() {
3911                let &(u, v) = pos_pair;
3912                if !(i < u && u < j && k < v && v < l) {
3913                  continue;
3914                }
3915                let mut backward_term = NEG_INFINITY;
3916                let mut backward_term2 = backward_term;
3917                let pos_pair2 = (u + T::one(), v + T::one());
3918                if let Some(x) = backward_sums.get(&pos_pair2) {
3919                  logsumexp(&mut backward_term, x.sum_0ormore_pairmatches);
3920                  logsumexp(&mut backward_term2, x.sum_1ormore_pairmatches);
3921                }
3922                let pos_quad2 = (i, u, k, v);
3923                let x = coeff + backward_term;
3924                match prob_coeffs_multibranch.get_mut(&pos_quad2) {
3925                  Some(y) => {
3926                    logsumexp(y, x);
3927                  }
3928                  None => {
3929                    prob_coeffs_multibranch.insert(pos_quad2, x);
3930                  }
3931                }
3932                let x = coeff + backward_term2;
3933                match prob_coeffs_multibranch2.get_mut(&pos_quad2) {
3934                  Some(y) => {
3935                    logsumexp(y, x);
3936                  }
3937                  None => {
3938                    prob_coeffs_multibranch2.insert(pos_quad2, x);
3939                  }
3940                }
3941              }
3942            }
3943          }
3944        }
3945      }
3946    }
3947  }
3948  for x in alignfold_probs.basepair_probs_pair.0.values_mut() {
3949    *x = expf(*x);
3950  }
3951  for x in alignfold_probs.basepair_probs_pair.1.values_mut() {
3952    *x = expf(*x);
3953  }
3954  let needs_twoloop_sums = produces_match_probs || trains_alignfold_scores;
3955  let needs_indel_info = produces_struct_profs || trains_alignfold_scores;
3956  if produces_struct_profs || needs_twoloop_sums {
3957    let mut unpair_probs_range_external =
3958      (SparseProbMat::<T>::default(), SparseProbMat::<T>::default());
3959    let mut unpair_probs_range_hairpin =
3960      (SparseProbMat::<T>::default(), SparseProbMat::<T>::default());
3961    let mut unpair_probs_range = (SparseProbMat::<T>::default(), SparseProbMat::<T>::default());
3962    for u in range(T::zero(), seq_len_pair.0 - T::one()) {
3963      let long_u = u.to_usize().unwrap();
3964      let base = seq_pair.0[long_u];
3965      for v in range(T::zero(), seq_len_pair.1 - T::one()) {
3966        let pos_pair = (u, v);
3967        let long_v = v.to_usize().unwrap();
3968        let base2 = seq_pair.1[long_v];
3969        let pos_pair2 = (u + T::one(), v + T::one());
3970        let long_pos_pair2 = (
3971          pos_pair2.0.to_usize().unwrap(),
3972          pos_pair2.1.to_usize().unwrap(),
3973        );
3974        let dict_min_match = get_dict_min_pair(&(base, base2));
3975        if match_probs.contains_key(&pos_pair) {
3976          let pos_pair_loopmatch = (u - T::one(), v - T::one());
3977          let mut loopmatch_prob_external = NEG_INFINITY;
3978          let loopmatch_score = alignfold_scores.match_scores[base][base2]
3979            + 2. * alignfold_scores.external_score_unpair;
3980          let mut backward_term = NEG_INFINITY;
3981          if let Some(&x) = alignfold_sums.backward_sums_external2.get(&pos_pair2) {
3982            logsumexp(&mut backward_term, x);
3983          }
3984          if let Some(&x) = alignfold_sums
3985            .forward_sums_external2
3986            .get(&pos_pair_loopmatch)
3987          {
3988            let term = loopmatch_score + x + backward_term - global_sum;
3989            logsumexp(&mut loopmatch_prob_external, term);
3990          }
3991          if produces_struct_profs {
3992            logsumexp(
3993              &mut alignfold_probs.context_profs_pair.0[(long_u, CONTEXT_INDEX_EXTERNAL)],
3994              loopmatch_prob_external,
3995            );
3996            logsumexp(
3997              &mut alignfold_probs.context_profs_pair.1[(long_v, CONTEXT_INDEX_EXTERNAL)],
3998              loopmatch_prob_external,
3999            );
4000          }
4001          if produces_match_probs {
4002            match alignfold_probs.loopmatch_probs.get_mut(&pos_pair) {
4003              Some(x) => {
4004                logsumexp(x, loopmatch_prob_external);
4005              }
4006              None => {
4007                alignfold_probs
4008                  .loopmatch_probs
4009                  .insert(pos_pair, loopmatch_prob_external);
4010              }
4011            }
4012          }
4013          if trains_alignfold_scores {
4014            logsumexp(
4015              &mut alignfold_counts_expected.match_scores[dict_min_match.0][dict_min_match.1],
4016              loopmatch_prob_external,
4017            );
4018            logsumexp(
4019              &mut alignfold_counts_expected.external_score_unpair,
4020              (2. as Prob).ln() + loopmatch_prob_external,
4021            );
4022            if let Some(&x) = alignfold_sums
4023              .forward_sums_external
4024              .get(&pos_pair_loopmatch)
4025            {
4026              let begins_sum = pos_pair_loopmatch == leftmost_pos_pair;
4027              let x = x
4028                + if begins_sum {
4029                  alignfold_scores.init_match_score
4030                } else {
4031                  alignfold_scores.match2match_score
4032                };
4033              let x = loopmatch_score + x + backward_term - global_sum;
4034              if begins_sum {
4035                logsumexp(&mut alignfold_counts_expected.init_match_score, x);
4036              } else {
4037                logsumexp(&mut alignfold_counts_expected.match2match_score, x);
4038              }
4039            }
4040          }
4041        }
4042        if needs_indel_info {
4043          if let Some(&sum) = alignfold_sums.forward_sums_external.get(&pos_pair) {
4044            let begins_sum = pos_pair == leftmost_pos_pair;
4045            let forward_term = sum
4046              + if begins_sum {
4047                alignfold_scores.init_insert_score
4048              } else {
4049                alignfold_scores.match2insert_score
4050              };
4051            if let Some(x) = matchable_poss.get(&pos_pair2.0) {
4052              for &x in x {
4053                if x <= pos_pair2.1 {
4054                  continue;
4055                }
4056                let pos_pair3 = (pos_pair2.0, x);
4057                if let Some(&y) = alignfold_sums.backward_sums_external.get(&pos_pair3) {
4058                  let long_x = x.to_usize().unwrap();
4059                  let ends_sum = pos_pair3 == rightmost_pos_pair;
4060                  let z = range_insert_scores.insert_scores_external2[long_pos_pair2.1][long_x - 1]
4061                    + if ends_sum {
4062                      0.
4063                    } else {
4064                      alignfold_scores.match2insert_score
4065                    };
4066                  let z = forward_term + y + z - global_sum;
4067                  if trains_alignfold_scores {
4068                    if begins_sum {
4069                      logsumexp(&mut alignfold_counts_expected.init_insert_score, z);
4070                    } else {
4071                      logsumexp(&mut alignfold_counts_expected.match2insert_score, z);
4072                    }
4073                    if !ends_sum {
4074                      logsumexp(&mut alignfold_counts_expected.match2insert_score, z);
4075                    }
4076                    logsumexp(
4077                      &mut alignfold_counts_expected.external_score_unpair,
4078                      ((long_x - long_pos_pair2.1) as Prob).ln() + z,
4079                    );
4080                    if pos_pair2.1 < x - T::one() {
4081                      logsumexp(
4082                        &mut alignfold_counts_expected.insert_extend_score,
4083                        ((long_x - long_pos_pair2.1 - 1) as Prob).ln() + z,
4084                      );
4085                    }
4086                  }
4087                  let pos_pair4 = (pos_pair2.1, x - T::one());
4088                  if produces_struct_profs {
4089                    match unpair_probs_range_external.1.get_mut(&pos_pair4) {
4090                      Some(x) => {
4091                        logsumexp(x, z);
4092                      }
4093                      None => {
4094                        unpair_probs_range_external.1.insert(pos_pair4, z);
4095                      }
4096                    }
4097                  }
4098                  if trains_alignfold_scores {
4099                    match unpair_probs_range.1.get_mut(&pos_pair4) {
4100                      Some(x) => {
4101                        logsumexp(x, z);
4102                      }
4103                      None => {
4104                        unpair_probs_range.1.insert(pos_pair4, z);
4105                      }
4106                    }
4107                  }
4108                }
4109              }
4110            }
4111            if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
4112              for &x in x {
4113                if x <= pos_pair2.0 {
4114                  continue;
4115                }
4116                let pos_pair3 = (x, pos_pair2.1);
4117                if let Some(&y) = alignfold_sums.backward_sums_external.get(&pos_pair3) {
4118                  let long_x = x.to_usize().unwrap();
4119                  let ends_sum = pos_pair3 == rightmost_pos_pair;
4120                  let z = range_insert_scores.insert_scores_external[long_pos_pair2.0][long_x - 1]
4121                    + if ends_sum {
4122                      0.
4123                    } else {
4124                      alignfold_scores.match2insert_score
4125                    };
4126                  let z = forward_term + y + z - global_sum;
4127                  if trains_alignfold_scores {
4128                    if begins_sum {
4129                      logsumexp(&mut alignfold_counts_expected.init_insert_score, z);
4130                    } else {
4131                      logsumexp(&mut alignfold_counts_expected.match2insert_score, z);
4132                    }
4133                    if !ends_sum {
4134                      logsumexp(&mut alignfold_counts_expected.match2insert_score, z);
4135                    }
4136                    logsumexp(
4137                      &mut alignfold_counts_expected.external_score_unpair,
4138                      ((long_x - long_pos_pair2.0) as Prob).ln() + z,
4139                    );
4140                    if pos_pair2.0 < x - T::one() {
4141                      logsumexp(
4142                        &mut alignfold_counts_expected.insert_extend_score,
4143                        ((long_x - long_pos_pair2.0 - 1) as Prob).ln() + z,
4144                      );
4145                    }
4146                  }
4147                  let pos_pair4 = (pos_pair2.0, x - T::one());
4148                  if produces_struct_profs {
4149                    match unpair_probs_range_external.0.get_mut(&pos_pair4) {
4150                      Some(x) => {
4151                        logsumexp(x, z);
4152                      }
4153                      None => {
4154                        unpair_probs_range_external.0.insert(pos_pair4, z);
4155                      }
4156                    }
4157                  }
4158                  if trains_alignfold_scores {
4159                    match unpair_probs_range.0.get_mut(&pos_pair4) {
4160                      Some(x) => {
4161                        logsumexp(x, z);
4162                      }
4163                      None => {
4164                        unpair_probs_range.0.insert(pos_pair4, z);
4165                      }
4166                    }
4167                  }
4168                }
4169              }
4170            }
4171          }
4172        }
4173      }
4174    }
4175    for (pos_quad, &outside_sum) in &alignfold_outside_sums {
4176      let (i, j, k, l) = *pos_quad;
4177      let (long_i, long_j, long_k, long_l) = (
4178        i.to_usize().unwrap(),
4179        j.to_usize().unwrap(),
4180        k.to_usize().unwrap(),
4181        l.to_usize().unwrap(),
4182      );
4183      let basepair = (seq_pair.0[long_i], seq_pair.0[long_j]);
4184      let basepair2 = (seq_pair.1[long_k], seq_pair.1[long_l]);
4185      let hairpin_score = if j - i - T::one() <= T::from_usize(MAX_LOOP_LEN).unwrap() {
4186        fold_scores_pair.0.hairpin_scores[&(i, j)]
4187      } else {
4188        NEG_INFINITY
4189      };
4190      let hairpin_score2 = if l - k - T::one() <= T::from_usize(MAX_LOOP_LEN).unwrap() {
4191        fold_scores_pair.1.hairpin_scores[&(k, l)]
4192      } else {
4193        NEG_INFINITY
4194      };
4195      let multibranch_close_score = fold_scores_pair.0.multibranch_close_scores[&(i, j)];
4196      let multibranch_close_score2 = fold_scores_pair.1.multibranch_close_scores[&(k, l)];
4197      let pairmatch_score = alignfold_scores.match_scores[basepair.0][basepair2.0]
4198        + alignfold_scores.match_scores[basepair.1][basepair2.1];
4199      let prob_coeff = outside_sum - global_sum + pairmatch_score;
4200      let forward_sums = &alignfold_sums.forward_sums_hashed_poss[&(i, k)];
4201      let forward_sums2 = &alignfold_sums.forward_sums_hashed_poss2[&(i, k)];
4202      let backward_sums = &alignfold_sums.backward_sums_hashed_poss[&(j, l)];
4203      let backward_sums2 = &alignfold_sums.backward_sums_hashed_poss2[&(j, l)];
4204      let (forward_sums_2loop, forward_sums_2loop2) = if needs_twoloop_sums {
4205        let computes_forward_sums = true;
4206        get_2loop_sums((
4207          seq_pair,
4208          alignfold_scores,
4209          match_probs,
4210          pos_quad,
4211          alignfold_sums,
4212          computes_forward_sums,
4213          forward_pos_pairs,
4214          fold_scores_pair,
4215          range_insert_scores,
4216          matchable_poss,
4217          matchable_poss2,
4218        ))
4219      } else {
4220        (SparseSumMat::<T>::default(), SparseSumMat::<T>::default())
4221      };
4222      let (backward_sums_2loop, backward_sums_2loop2) = if needs_twoloop_sums {
4223        let computes_forward_sums = false;
4224        get_2loop_sums((
4225          seq_pair,
4226          alignfold_scores,
4227          match_probs,
4228          pos_quad,
4229          alignfold_sums,
4230          computes_forward_sums,
4231          backward_pos_pairs,
4232          fold_scores_pair,
4233          range_insert_scores,
4234          matchable_poss,
4235          matchable_poss2,
4236        ))
4237      } else {
4238        (SparseSumMat::<T>::default(), SparseSumMat::<T>::default())
4239      };
4240      if trains_alignfold_scores {
4241        let pos_pair2 = (j - T::one(), l - T::one());
4242        if let Some(&x) = forward_sums_2loop.get(&pos_pair2) {
4243          let x = prob_coeff + x + alignfold_scores.match2match_score;
4244          logsumexp(&mut alignfold_counts_expected.match2match_score, x);
4245        }
4246      }
4247      for u in range_inclusive(i, j) {
4248        let long_u = u.to_usize().unwrap();
4249        let base = seq_pair.0[long_u];
4250        for v in range_inclusive(k, l) {
4251          let pos_pair = (u, v);
4252          let long_v = v.to_usize().unwrap();
4253          let base2 = seq_pair.1[long_v];
4254          let pos_pair2 = (u + T::one(), v + T::one());
4255          let long_pos_pair2 = (
4256            pos_pair2.0.to_usize().unwrap(),
4257            pos_pair2.1.to_usize().unwrap(),
4258          );
4259          let dict_min_match = get_dict_min_pair(&(base, base2));
4260          let mut backward_term_seqalign = NEG_INFINITY;
4261          let mut backward_term_multibranch = backward_term_seqalign;
4262          let mut backward_term_1ormore_pairmatches = backward_term_seqalign;
4263          let mut backward_term_0ormore_pairmatches = backward_term_seqalign;
4264          let mut backward_term_2loop = backward_term_seqalign;
4265          if let Some(x) = backward_sums2.get(&pos_pair2) {
4266            logsumexp(&mut backward_term_seqalign, x.sum_seqalign);
4267            if needs_twoloop_sums {
4268              logsumexp(&mut backward_term_multibranch, x.sum_multibranch);
4269              logsumexp(
4270                &mut backward_term_1ormore_pairmatches,
4271                x.sum_1ormore_pairmatches,
4272              );
4273              logsumexp(
4274                &mut backward_term_0ormore_pairmatches,
4275                x.sum_0ormore_pairmatches,
4276              );
4277            }
4278          }
4279          if needs_twoloop_sums {
4280            if let Some(&x) = backward_sums_2loop2.get(&pos_pair2) {
4281              logsumexp(&mut backward_term_2loop, x);
4282            }
4283          }
4284          let prob_coeff_hairpin = prob_coeff + hairpin_score + hairpin_score2;
4285          let prob_coeff_multibranch =
4286            prob_coeff + multibranch_close_score + multibranch_close_score2;
4287          if match_probs.contains_key(&pos_pair) {
4288            let pos_pair_loopmatch = (u - T::one(), v - T::one());
4289            let loopmatch_score = alignfold_scores.match_scores[base][base2];
4290            let loopmatch_score_multibranch =
4291              loopmatch_score + 2. * alignfold_scores.multibranch_score_unpair;
4292            let mut loopmatch_prob_hairpin = NEG_INFINITY;
4293            let mut loopmatch_prob_multibranch = loopmatch_prob_hairpin;
4294            let mut loopmatch_prob_2loop = loopmatch_prob_hairpin;
4295            if let Some(x) = forward_sums2.get(&pos_pair_loopmatch) {
4296              let y =
4297                prob_coeff_hairpin + loopmatch_score + x.sum_seqalign + backward_term_seqalign;
4298              logsumexp(&mut loopmatch_prob_hairpin, y);
4299              if needs_twoloop_sums {
4300                let y = prob_coeff_multibranch
4301                  + loopmatch_score_multibranch
4302                  + x.sum_seqalign_multibranch
4303                  + backward_term_multibranch;
4304                logsumexp(&mut loopmatch_prob_multibranch, y);
4305                let y = prob_coeff_multibranch
4306                  + loopmatch_score_multibranch
4307                  + x.sum_1st_pairmatches
4308                  + backward_term_1ormore_pairmatches;
4309                logsumexp(&mut loopmatch_prob_multibranch, y);
4310                let y = prob_coeff_multibranch
4311                  + loopmatch_score_multibranch
4312                  + x.sum_multibranch
4313                  + backward_term_0ormore_pairmatches;
4314                logsumexp(&mut loopmatch_prob_multibranch, y);
4315                let y = prob_coeff + loopmatch_score + x.sum_seqalign + backward_term_2loop;
4316                logsumexp(&mut loopmatch_prob_2loop, y);
4317              }
4318            }
4319            if produces_struct_profs {
4320              logsumexp(
4321                &mut alignfold_probs.context_profs_pair.0[(long_u, CONTEXT_INDEX_HAIRPIN)],
4322                loopmatch_prob_hairpin,
4323              );
4324              logsumexp(
4325                &mut alignfold_probs.context_profs_pair.1[(long_v, CONTEXT_INDEX_HAIRPIN)],
4326                loopmatch_prob_hairpin,
4327              );
4328            }
4329            if needs_twoloop_sums {
4330              if let Some(&x) = forward_sums_2loop2.get(&pos_pair_loopmatch) {
4331                let x = prob_coeff + loopmatch_score + x + backward_term_seqalign;
4332                logsumexp(&mut loopmatch_prob_2loop, x);
4333              }
4334              let mut prob = NEG_INFINITY;
4335              logsumexp(&mut prob, loopmatch_prob_hairpin);
4336              logsumexp(&mut prob, loopmatch_prob_multibranch);
4337              logsumexp(&mut prob, loopmatch_prob_2loop);
4338              if produces_match_probs {
4339                match alignfold_probs.loopmatch_probs.get_mut(&pos_pair) {
4340                  Some(x) => {
4341                    logsumexp(x, prob);
4342                  }
4343                  None => {
4344                    alignfold_probs.loopmatch_probs.insert(pos_pair, prob);
4345                  }
4346                }
4347              }
4348              if trains_alignfold_scores {
4349                logsumexp(
4350                  &mut alignfold_counts_expected.match_scores[dict_min_match.0][dict_min_match.1],
4351                  prob,
4352                );
4353                logsumexp(
4354                  &mut alignfold_counts_expected.multibranch_score_unpair,
4355                  (2. as Prob).ln() + loopmatch_prob_multibranch,
4356                );
4357              }
4358            }
4359            if let Some(x) = forward_sums.get(&pos_pair_loopmatch) {
4360              let y = prob_coeff_hairpin
4361                + loopmatch_score
4362                + x.sum_seqalign
4363                + alignfold_scores.match2match_score
4364                + backward_term_seqalign;
4365              if trains_alignfold_scores {
4366                logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4367              }
4368              if needs_twoloop_sums {
4369                let y = prob_coeff_multibranch
4370                  + loopmatch_score_multibranch
4371                  + x.sum_seqalign_multibranch
4372                  + alignfold_scores.match2match_score
4373                  + backward_term_multibranch;
4374                if trains_alignfold_scores {
4375                  logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4376                }
4377                let y = prob_coeff_multibranch
4378                  + loopmatch_score_multibranch
4379                  + x.sum_1st_pairmatches
4380                  + alignfold_scores.match2match_score
4381                  + backward_term_1ormore_pairmatches;
4382                if trains_alignfold_scores {
4383                  logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4384                }
4385                let y = prob_coeff_multibranch
4386                  + loopmatch_score_multibranch
4387                  + x.sum_multibranch
4388                  + alignfold_scores.match2match_score
4389                  + backward_term_0ormore_pairmatches;
4390                if trains_alignfold_scores {
4391                  logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4392                }
4393                let y = prob_coeff
4394                  + loopmatch_score
4395                  + x.sum_seqalign
4396                  + alignfold_scores.match2match_score
4397                  + backward_term_2loop;
4398                if trains_alignfold_scores {
4399                  logsumexp(&mut alignfold_counts_expected.match2match_score, y);
4400                }
4401              }
4402            }
4403            if needs_indel_info {
4404              if let Some(sums) = forward_sums.get(&pos_pair) {
4405                let forward_term_seqalign = sums.sum_seqalign + alignfold_scores.match2insert_score;
4406                let forward_term_seqalign_multibranch =
4407                  sums.sum_seqalign_multibranch + alignfold_scores.match2insert_score;
4408                let forward_term_1st_pairmatches =
4409                  sums.sum_1st_pairmatches + alignfold_scores.match2insert_score;
4410                let forward_term_multibranch =
4411                  sums.sum_multibranch + alignfold_scores.match2insert_score;
4412                if let Some(x) = matchable_poss.get(&pos_pair2.0) {
4413                  for &x in x {
4414                    if x <= pos_pair2.1 {
4415                      continue;
4416                    }
4417                    let mut insert_prob = NEG_INFINITY;
4418                    let mut insert_prob_multibranch = insert_prob;
4419                    let long_x = x.to_usize().unwrap();
4420                    let y = range_insert_scores.insert_scores2[long_pos_pair2.1][long_x - 1]
4421                      + alignfold_scores.match2insert_score;
4422                    if let Some(z) = backward_sums.get(&(pos_pair2.0, x)) {
4423                      let a = prob_coeff_hairpin + forward_term_seqalign + y + z.sum_seqalign;
4424                      logsumexp(&mut insert_prob, a);
4425                      if produces_struct_profs {
4426                        let pos_pair4 = (pos_pair2.1, x - T::one());
4427                        match unpair_probs_range_hairpin.1.get_mut(&pos_pair4) {
4428                          Some(x) => {
4429                            logsumexp(x, a);
4430                          }
4431                          None => {
4432                            unpair_probs_range_hairpin.1.insert(pos_pair4, a);
4433                          }
4434                        }
4435                      }
4436                      if trains_alignfold_scores {
4437                        let y = range_insert_scores.insert_scores_multibranch2[long_pos_pair2.1]
4438                          [long_x - 1]
4439                          + alignfold_scores.match2insert_score;
4440                        let a = prob_coeff_multibranch
4441                          + forward_term_seqalign_multibranch
4442                          + y
4443                          + z.sum_multibranch;
4444                        logsumexp(&mut insert_prob_multibranch, a);
4445                        let a = prob_coeff_multibranch
4446                          + forward_term_1st_pairmatches
4447                          + y
4448                          + z.sum_1ormore_pairmatches;
4449                        logsumexp(&mut insert_prob_multibranch, a);
4450                        let a = prob_coeff_multibranch
4451                          + forward_term_multibranch
4452                          + y
4453                          + z.sum_0ormore_pairmatches;
4454                        logsumexp(&mut insert_prob_multibranch, a);
4455                        logsumexp(&mut insert_prob, insert_prob_multibranch);
4456                      }
4457                    }
4458                    if trains_alignfold_scores {
4459                      if let Some(&z) = backward_sums_2loop.get(&(pos_pair2.0, x)) {
4460                        let z = prob_coeff + forward_term_seqalign + y + z;
4461                        logsumexp(&mut insert_prob, z);
4462                      }
4463                      logsumexp(
4464                        &mut alignfold_counts_expected.match2insert_score,
4465                        (2. as Prob).ln() + insert_prob,
4466                      );
4467                      let pos_pair4 = (pos_pair2.1, x - T::one());
4468                      match unpair_probs_range.1.get_mut(&pos_pair4) {
4469                        Some(x) => {
4470                          logsumexp(x, insert_prob);
4471                        }
4472                        None => {
4473                          unpair_probs_range.1.insert(pos_pair4, insert_prob);
4474                        }
4475                      }
4476                      logsumexp(
4477                        &mut alignfold_counts_expected.multibranch_score_unpair,
4478                        ((long_x - long_pos_pair2.1) as Prob).ln() + insert_prob_multibranch,
4479                      );
4480                      if pos_pair2.1 < x - T::one() {
4481                        logsumexp(
4482                          &mut alignfold_counts_expected.insert_extend_score,
4483                          ((long_x - long_pos_pair2.1 - 1) as Prob).ln() + insert_prob,
4484                        );
4485                      }
4486                    }
4487                  }
4488                }
4489                if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
4490                  for &x in x {
4491                    if x <= pos_pair2.0 {
4492                      continue;
4493                    }
4494                    let mut insert_prob = NEG_INFINITY;
4495                    let mut insert_prob_multibranch = insert_prob;
4496                    let long_x = x.to_usize().unwrap();
4497                    let y = range_insert_scores.insert_scores[long_pos_pair2.0][long_x - 1]
4498                      + alignfold_scores.match2insert_score;
4499                    if let Some(z) = backward_sums.get(&(x, pos_pair2.1)) {
4500                      let a = prob_coeff_hairpin + forward_term_seqalign + y + z.sum_seqalign;
4501                      logsumexp(&mut insert_prob, a);
4502                      if produces_struct_profs {
4503                        let pos_pair4 = (pos_pair2.0, x - T::one());
4504                        match unpair_probs_range_hairpin.0.get_mut(&pos_pair4) {
4505                          Some(x) => {
4506                            logsumexp(x, a);
4507                          }
4508                          None => {
4509                            unpair_probs_range_hairpin.0.insert(pos_pair4, a);
4510                          }
4511                        }
4512                      }
4513                      if trains_alignfold_scores {
4514                        let y = range_insert_scores.insert_scores_multibranch[long_pos_pair2.0]
4515                          [long_x - 1]
4516                          + alignfold_scores.match2insert_score;
4517                        let a = prob_coeff_multibranch
4518                          + forward_term_seqalign_multibranch
4519                          + y
4520                          + z.sum_multibranch;
4521                        logsumexp(&mut insert_prob_multibranch, a);
4522                        let a = prob_coeff_multibranch
4523                          + forward_term_1st_pairmatches
4524                          + y
4525                          + z.sum_1ormore_pairmatches;
4526                        logsumexp(&mut insert_prob_multibranch, a);
4527                        let a = prob_coeff_multibranch
4528                          + forward_term_multibranch
4529                          + y
4530                          + z.sum_0ormore_pairmatches;
4531                        logsumexp(&mut insert_prob_multibranch, a);
4532                        logsumexp(&mut insert_prob, insert_prob_multibranch);
4533                      }
4534                    }
4535                    if let Some(&z) = backward_sums_2loop.get(&(x, pos_pair2.1)) {
4536                      let z = prob_coeff + forward_term_seqalign + y + z;
4537                      logsumexp(&mut insert_prob, z);
4538                    }
4539                    if trains_alignfold_scores {
4540                      logsumexp(
4541                        &mut alignfold_counts_expected.match2insert_score,
4542                        (2. as Prob).ln() + insert_prob,
4543                      );
4544                      let pos_pair4 = (pos_pair2.0, x - T::one());
4545                      match unpair_probs_range.0.get_mut(&pos_pair4) {
4546                        Some(x) => {
4547                          logsumexp(x, insert_prob);
4548                        }
4549                        None => {
4550                          unpair_probs_range.0.insert(pos_pair4, insert_prob);
4551                        }
4552                      }
4553                      logsumexp(
4554                        &mut alignfold_counts_expected.multibranch_score_unpair,
4555                        ((long_x - long_pos_pair2.0) as Prob).ln() + insert_prob_multibranch,
4556                      );
4557                      if pos_pair2.0 < x - T::one() {
4558                        logsumexp(
4559                          &mut alignfold_counts_expected.insert_extend_score,
4560                          ((long_x - long_pos_pair2.0 - 1) as Prob).ln() + insert_prob,
4561                        );
4562                      }
4563                    }
4564                  }
4565                }
4566              }
4567              if trains_alignfold_scores {
4568                if let Some(&forward_sum) = forward_sums_2loop.get(&pos_pair) {
4569                  if let Some(x) = matchable_poss.get(&pos_pair2.0) {
4570                    for &x in x {
4571                      if x <= pos_pair2.1 {
4572                        continue;
4573                      }
4574                      let mut insert_prob = NEG_INFINITY;
4575                      let long_x = x.to_usize().unwrap();
4576                      let y = range_insert_scores.insert_scores2[long_pos_pair2.1][long_x - 1]
4577                        + alignfold_scores.match2insert_score;
4578                      if let Some(z) = backward_sums.get(&(pos_pair2.0, x)) {
4579                        let a = prob_coeff + forward_sum + y + z.sum_seqalign;
4580                        logsumexp(&mut insert_prob, a);
4581                        logsumexp(
4582                          &mut alignfold_counts_expected.match2insert_score,
4583                          (2. as Prob).ln() + insert_prob,
4584                        );
4585                        let pos_pair4 = (pos_pair2.1, x - T::one());
4586                        match unpair_probs_range.1.get_mut(&pos_pair4) {
4587                          Some(x) => {
4588                            logsumexp(x, insert_prob);
4589                          }
4590                          None => {
4591                            unpair_probs_range.1.insert(pos_pair4, insert_prob);
4592                          }
4593                        }
4594                        if pos_pair2.1 < x - T::one() {
4595                          logsumexp(
4596                            &mut alignfold_counts_expected.insert_extend_score,
4597                            ((long_x - long_pos_pair2.1 - 1) as Prob).ln() + insert_prob,
4598                          );
4599                        }
4600                      }
4601                    }
4602                  }
4603                  if let Some(x) = matchable_poss2.get(&pos_pair2.1) {
4604                    for &x in x {
4605                      if x <= pos_pair2.0 {
4606                        continue;
4607                      }
4608                      let mut insert_prob = NEG_INFINITY;
4609                      let long_x = x.to_usize().unwrap();
4610                      let y = range_insert_scores.insert_scores[long_pos_pair2.0][long_x - 1]
4611                        + alignfold_scores.match2insert_score;
4612                      if let Some(z) = backward_sums.get(&(x, pos_pair2.1)) {
4613                        let a = prob_coeff + forward_sum + y + z.sum_seqalign;
4614                        logsumexp(&mut insert_prob, a);
4615                        logsumexp(
4616                          &mut alignfold_counts_expected.match2insert_score,
4617                          (2. as Prob).ln() + insert_prob,
4618                        );
4619                        let pos_pair4 = (pos_pair2.0, x - T::one());
4620                        match unpair_probs_range.0.get_mut(&pos_pair4) {
4621                          Some(x) => {
4622                            logsumexp(x, insert_prob);
4623                          }
4624                          None => {
4625                            unpair_probs_range.0.insert(pos_pair4, insert_prob);
4626                          }
4627                        }
4628                        if pos_pair2.0 < x - T::one() {
4629                          logsumexp(
4630                            &mut alignfold_counts_expected.insert_extend_score,
4631                            ((long_x - long_pos_pair2.0 - 1) as Prob).ln() + insert_prob,
4632                          );
4633                        }
4634                      }
4635                    }
4636                  }
4637                }
4638              }
4639            }
4640          }
4641        }
4642      }
4643    }
4644    if needs_indel_info {
4645      for (x, &y) in &unpair_probs_range.0 {
4646        for i in range_inclusive(x.0, x.1) {
4647          let long_i = i.to_usize().unwrap();
4648          let x = seq_pair.0[long_i];
4649          logsumexp(&mut alignfold_counts_expected.insert_scores[x], y);
4650        }
4651      }
4652      for (x, &y) in &unpair_probs_range.1 {
4653        for i in range_inclusive(x.0, x.1) {
4654          let long_i = i.to_usize().unwrap();
4655          let x = seq_pair.1[long_i];
4656          logsumexp(&mut alignfold_counts_expected.insert_scores[x], y);
4657        }
4658      }
4659    }
4660    if produces_struct_profs {
4661      for (x, &y) in &unpair_probs_range_external.0 {
4662        for i in range_inclusive(x.0, x.1) {
4663          let long_i = i.to_usize().unwrap();
4664          logsumexp(
4665            &mut alignfold_probs.context_profs_pair.0[(long_i, CONTEXT_INDEX_EXTERNAL)],
4666            y,
4667          );
4668        }
4669      }
4670      for (x, &y) in &unpair_probs_range_external.1 {
4671        for i in range_inclusive(x.0, x.1) {
4672          let long_i = i.to_usize().unwrap();
4673          logsumexp(
4674            &mut alignfold_probs.context_profs_pair.1[(long_i, CONTEXT_INDEX_EXTERNAL)],
4675            y,
4676          );
4677        }
4678      }
4679      for (x, &y) in &unpair_probs_range_hairpin.0 {
4680        for i in range_inclusive(x.0, x.1) {
4681          let long_i = i.to_usize().unwrap();
4682          logsumexp(
4683            &mut alignfold_probs.context_profs_pair.0[(long_i, CONTEXT_INDEX_HAIRPIN)],
4684            y,
4685          );
4686        }
4687      }
4688      for (x, &y) in &unpair_probs_range_hairpin.1 {
4689        for i in range_inclusive(x.0, x.1) {
4690          let long_i = i.to_usize().unwrap();
4691          logsumexp(
4692            &mut alignfold_probs.context_profs_pair.1[(long_i, CONTEXT_INDEX_HAIRPIN)],
4693            y,
4694          );
4695        }
4696      }
4697      alignfold_probs
4698        .context_profs_pair
4699        .0
4700        .slice_mut(s![.., ..CONTEXT_INDEX_MULTIBRANCH])
4701        .mapv_inplace(expf);
4702      let fold = 1.
4703        - alignfold_probs
4704          .context_profs_pair
4705          .0
4706          .slice_mut(s![.., ..CONTEXT_INDEX_MULTIBRANCH])
4707          .sum_axis(Axis(1));
4708      alignfold_probs
4709        .context_profs_pair
4710        .0
4711        .slice_mut(s![.., CONTEXT_INDEX_MULTIBRANCH])
4712        .assign(&fold);
4713      alignfold_probs
4714        .context_profs_pair
4715        .1
4716        .slice_mut(s![.., ..CONTEXT_INDEX_MULTIBRANCH])
4717        .mapv_inplace(expf);
4718      let fold = 1.
4719        - alignfold_probs
4720          .context_profs_pair
4721          .1
4722          .slice_mut(s![.., ..CONTEXT_INDEX_MULTIBRANCH])
4723          .sum_axis(Axis(1));
4724      alignfold_probs
4725        .context_profs_pair
4726        .1
4727        .slice_mut(s![.., CONTEXT_INDEX_MULTIBRANCH])
4728        .assign(&fold);
4729    }
4730    if produces_match_probs {
4731      for (x, y) in alignfold_probs.loopmatch_probs.iter_mut() {
4732        match alignfold_probs.match_probs.get_mut(x) {
4733          Some(x) => {
4734            logsumexp(x, *y);
4735            *x = expf(*x);
4736          }
4737          None => {
4738            alignfold_probs.match_probs.insert(*x, expf(*y));
4739          }
4740        }
4741        *y = expf(*y);
4742      }
4743      for x in alignfold_probs.pairmatch_probs.values_mut() {
4744        *x = expf(*x);
4745      }
4746    }
4747    if trains_alignfold_scores {
4748      for x in alignfold_counts_expected.hairpin_scores_len.iter_mut() {
4749        *x = expf(*x);
4750      }
4751      for x in alignfold_counts_expected.bulge_scores_len.iter_mut() {
4752        *x = expf(*x);
4753      }
4754      for x in alignfold_counts_expected.interior_scores_len.iter_mut() {
4755        *x = expf(*x);
4756      }
4757      for x in alignfold_counts_expected
4758        .interior_scores_symmetric
4759        .iter_mut()
4760      {
4761        *x = expf(*x);
4762      }
4763      for x in alignfold_counts_expected
4764        .interior_scores_asymmetric
4765        .iter_mut()
4766      {
4767        *x = expf(*x);
4768      }
4769      for x in alignfold_counts_expected.stack_scores.iter_mut() {
4770        for x in x.iter_mut() {
4771          for x in x.iter_mut() {
4772            for x in x.iter_mut() {
4773              *x = expf(*x);
4774            }
4775          }
4776        }
4777      }
4778      for x in alignfold_counts_expected
4779        .terminal_mismatch_scores
4780        .iter_mut()
4781      {
4782        for x in x.iter_mut() {
4783          for x in x.iter_mut() {
4784            for x in x.iter_mut() {
4785              *x = expf(*x);
4786            }
4787          }
4788        }
4789      }
4790      for x in alignfold_counts_expected.dangling_scores_left.iter_mut() {
4791        for x in x.iter_mut() {
4792          for x in x.iter_mut() {
4793            *x = expf(*x);
4794          }
4795        }
4796      }
4797      for x in alignfold_counts_expected.dangling_scores_right.iter_mut() {
4798        for x in x.iter_mut() {
4799          for x in x.iter_mut() {
4800            *x = expf(*x);
4801          }
4802        }
4803      }
4804      for x in alignfold_counts_expected
4805        .interior_scores_explicit
4806        .iter_mut()
4807      {
4808        for x in x.iter_mut() {
4809          *x = expf(*x);
4810        }
4811      }
4812      for x in alignfold_counts_expected.bulge_scores_0x1.iter_mut() {
4813        *x = expf(*x);
4814      }
4815      for x in alignfold_counts_expected.interior_scores_1x1.iter_mut() {
4816        for x in x.iter_mut() {
4817          *x = expf(*x);
4818        }
4819      }
4820      for x in alignfold_counts_expected.helix_close_scores.iter_mut() {
4821        for x in x.iter_mut() {
4822          *x = expf(*x);
4823        }
4824      }
4825      for x in alignfold_counts_expected.basepair_scores.iter_mut() {
4826        for x in x.iter_mut() {
4827          *x = expf(*x);
4828        }
4829      }
4830      alignfold_counts_expected.multibranch_score_base =
4831        expf(alignfold_counts_expected.multibranch_score_base);
4832      alignfold_counts_expected.multibranch_score_basepair =
4833        expf(alignfold_counts_expected.multibranch_score_basepair);
4834      alignfold_counts_expected.multibranch_score_unpair =
4835        expf(alignfold_counts_expected.multibranch_score_unpair);
4836      alignfold_counts_expected.external_score_basepair =
4837        expf(alignfold_counts_expected.external_score_basepair);
4838      alignfold_counts_expected.external_score_unpair =
4839        expf(alignfold_counts_expected.external_score_unpair);
4840      alignfold_counts_expected.match2match_score =
4841        expf(alignfold_counts_expected.match2match_score);
4842      alignfold_counts_expected.match2insert_score =
4843        expf(alignfold_counts_expected.match2insert_score);
4844      alignfold_counts_expected.insert_extend_score =
4845        expf(alignfold_counts_expected.insert_extend_score);
4846      alignfold_counts_expected.init_match_score = expf(alignfold_counts_expected.init_match_score);
4847      alignfold_counts_expected.init_insert_score =
4848        expf(alignfold_counts_expected.init_insert_score);
4849      for x in alignfold_counts_expected.insert_scores.iter_mut() {
4850        *x = expf(*x);
4851      }
4852      for x in alignfold_counts_expected.match_scores.iter_mut() {
4853        for x in x.iter_mut() {
4854          *x = expf(*x);
4855        }
4856      }
4857    }
4858  }
4859  alignfold_probs
4860}
4861
4862pub fn get_diff(x: usize, y: usize) -> usize {
4863  max(x, y) - min(x, y)
4864}
4865
4866pub fn get_hairpin_score(x: &AlignfoldScores, y: SeqSlice, z: &(usize, usize)) -> Score {
4867  let a = z.1 - z.0 - 1;
4868  x.hairpin_scores_len_cumulative[a] + get_junction_score_single(x, y, z)
4869}
4870
4871pub fn get_twoloop_score(
4872  x: &AlignfoldScores,
4873  y: SeqSlice,
4874  z: &(usize, usize),
4875  a: &(usize, usize),
4876) -> Score {
4877  let b = (y[a.0], y[a.1]);
4878  let c = if z.0 + 1 == a.0 && z.1 - 1 == a.1 {
4879    get_stack_score(x, y, z, a)
4880  } else if z.0 + 1 == a.0 || z.1 - 1 == a.1 {
4881    get_bulge_score(x, y, z, a)
4882  } else {
4883    get_interior_score(x, y, z, a)
4884  };
4885  c + x.basepair_scores[b.0][b.1]
4886}
4887
4888pub fn get_stack_score(
4889  x: &AlignfoldScores,
4890  y: SeqSlice,
4891  z: &(usize, usize),
4892  a: &(usize, usize),
4893) -> Score {
4894  let b = (y[z.0], y[z.1]);
4895  let c = (y[a.0], y[a.1]);
4896  x.stack_scores[b.0][b.1][c.0][c.1]
4897}
4898
4899pub fn get_bulge_score(
4900  x: &AlignfoldScores,
4901  y: SeqSlice,
4902  z: &(usize, usize),
4903  a: &(usize, usize),
4904) -> Score {
4905  let b = a.0 - z.0 + z.1 - a.1 - 2;
4906  let c = if b == 1 {
4907    x.bulge_scores_0x1[if a.0 - z.0 - 1 == 1 {
4908      y[z.0 + 1]
4909    } else {
4910      y[z.1 - 1]
4911    }]
4912  } else {
4913    0.
4914  };
4915  c + x.bulge_scores_len_cumulative[b - 1]
4916    + get_junction_score_single(x, y, z)
4917    + get_junction_score_single(x, y, &(a.1, a.0))
4918}
4919
4920pub fn get_interior_score(
4921  x: &AlignfoldScores,
4922  y: SeqSlice,
4923  z: &(usize, usize),
4924  a: &(usize, usize),
4925) -> Score {
4926  let b = (a.0 - z.0 - 1, z.1 - a.1 - 1);
4927  let c = b.0 + b.1;
4928  let d = if b.0 == b.1 {
4929    let d = if c == 2 {
4930      x.interior_scores_1x1[y[z.0 + 1]][y[z.1 - 1]]
4931    } else {
4932      0.
4933    };
4934    d + x.interior_scores_symmetric_cumulative[b.0 - 1]
4935  } else {
4936    x.interior_scores_asymmetric_cumulative[get_abs_diff(b.0, b.1) - 1]
4937  };
4938  let e = if b.0 <= 4 && b.1 <= 4 {
4939    x.interior_scores_explicit[b.0 - 1][b.1 - 1]
4940  } else {
4941    0.
4942  };
4943  d + e
4944    + x.interior_scores_len_cumulative[c - 2]
4945    + get_junction_score_single(x, y, z)
4946    + get_junction_score_single(x, y, &(a.1, a.0))
4947}
4948
4949pub fn get_junction_score_single(x: &AlignfoldScores, y: SeqSlice, z: &(usize, usize)) -> Score {
4950  let a = (y[z.0], y[z.1]);
4951  get_helix_close_score(x, &a) + get_terminal_mismatch_score(x, &a, &(y[z.0 + 1], y[z.1 - 1]))
4952}
4953
4954pub fn get_helix_close_score(x: &AlignfoldScores, y: &Basepair) -> Score {
4955  x.helix_close_scores[y.0][y.1]
4956}
4957
4958pub fn get_terminal_mismatch_score(x: &AlignfoldScores, y: &Basepair, z: &Basepair) -> Score {
4959  x.terminal_mismatch_scores[y.0][y.1][z.0][z.1]
4960}
4961
4962pub fn get_junction_score(x: &AlignfoldScores, y: SeqSlice, z: &(usize, usize)) -> Score {
4963  let a = (y[z.0], y[z.1]);
4964  let b = 1;
4965  let c = y.len() - 2;
4966  get_helix_close_score(x, &a)
4967    + if z.0 < c {
4968      x.dangling_scores_left[a.0][a.1][y[z.0 + 1]]
4969    } else {
4970      0.
4971    }
4972    + if z.1 > b {
4973      x.dangling_scores_right[a.0][a.1][y[z.1 - 1]]
4974    } else {
4975      0.
4976    }
4977}
4978
4979pub fn get_dict_min_stack(x: &Basepair, y: &Basepair) -> (Basepair, Basepair) {
4980  let z = (*x, *y);
4981  let a = ((y.1, y.0), (x.1, x.0));
4982  if z < a {
4983    z
4984  } else {
4985    a
4986  }
4987}
4988
4989pub fn get_dict_min_pair(x: &(usize, usize)) -> (usize, usize) {
4990  let y = (x.1, x.0);
4991  if *x < y {
4992    *x
4993  } else {
4994    y
4995  }
4996}
4997
4998pub fn get_num_unpairs_multibranch(
4999  x: &(usize, usize),
5000  y: &Vec<(usize, usize)>,
5001  z: SeqSlice,
5002) -> usize {
5003  let mut a = 0;
5004  let mut b = (0, 0);
5005  for y in y {
5006    let c = if b == (0, 0) { x.0 + 1 } else { b.1 + 1 };
5007    for &d in &z[c..y.0] {
5008      a += usize::from(d != PSEUDO_BASE);
5009    }
5010    b = *y;
5011  }
5012  for &c in &z[b.1 + 1..x.1] {
5013    a += usize::from(c != PSEUDO_BASE);
5014  }
5015  a
5016}
5017
5018pub fn get_num_unpairs_external(x: &Vec<(usize, usize)>, y: SeqSlice) -> usize {
5019  let mut z = 0;
5020  let mut a = (0, 0);
5021  for x in x {
5022    let b = if a == (0, 0) { 0 } else { a.1 + 1 };
5023    for &b in y[b..x.0].iter() {
5024      z += usize::from(b != PSEUDO_BASE);
5025    }
5026    a = *x;
5027  }
5028  for &b in y[a.1 + 1..y.len()].iter() {
5029    z += usize::from(b != PSEUDO_BASE);
5030  }
5031  z
5032}
5033
5034pub fn consprob_trained<T>(
5035  thread_pool: &mut Pool,
5036  seqs: &SeqSlices,
5037  min_basepair_prob: Prob,
5038  min_match_prob: Prob,
5039  produces_struct_profs: bool,
5040  produces_match_probs: bool,
5041  train_type: TrainType,
5042) -> (ProbMatSetsAvg<T>, MatchProbsHashedIds<T>)
5043where
5044  T: HashIndex,
5045{
5046  let trained = AlignfoldScores::load_trained_scores();
5047  let mut align_scores = AlignScores::new(0.);
5048  copy_alignfold_scores_align(&mut align_scores, &trained);
5049  let ref_align_scores = &align_scores;
5050  let mut fold_scores = FoldScoreSets::new(0.);
5051  copy_alignfold_scores_fold(&mut fold_scores, &trained);
5052  let ref_fold_scores = &fold_scores;
5053  let alignfold_scores = if matches!(train_type, TrainType::TrainedTransfer) {
5054    trained
5055  } else if matches!(train_type, TrainType::TrainedRandinit) {
5056    AlignfoldScores::load_trained_scores_randinit()
5057  } else {
5058    let mut transferred = AlignfoldScores::new(0.);
5059    transferred.transfer();
5060    transferred
5061  };
5062  let num_seqs = seqs.len();
5063  let mut basepair_prob_mats = vec![SparseProbMat::<T>::new(); num_seqs];
5064  let mut max_basepair_spans = vec![T::zero(); num_seqs];
5065  let mut fold_score_sets = vec![FoldScoresTrained::<T>::new(); num_seqs];
5066  let ref_alignfold_scores = &alignfold_scores;
5067  let uses_contra_model = true;
5068  let allows_short_hairpins = false;
5069  thread_pool.scoped(|scope| {
5070    for (x, y, z, a) in multizip((
5071      basepair_prob_mats.iter_mut(),
5072      max_basepair_spans.iter_mut(),
5073      seqs.iter(),
5074      fold_score_sets.iter_mut(),
5075    )) {
5076      let b = z.len();
5077      scope.execute(move || {
5078        let c = mccaskill_algo(
5079          &z[1..b - 1],
5080          uses_contra_model,
5081          allows_short_hairpins,
5082          ref_fold_scores,
5083        )
5084        .0;
5085        *x = filter_basepair_probs::<T>(&c, min_basepair_prob);
5086        *y = get_max_basepair_span::<T>(x);
5087        *a = FoldScoresTrained::<T>::set_curr_scores(ref_alignfold_scores, z, x);
5088      });
5089    }
5090  });
5091  let mut alignfold_probs_hashed_ids = AlignfoldProbsHashedIds::<T>::default();
5092  let mut match_probs_hashed_ids = SparseProbsHashedIds::<T>::default();
5093  for x in 0..num_seqs {
5094    for y in x + 1..num_seqs {
5095      let y = (x, y);
5096      alignfold_probs_hashed_ids.insert(y, AlignfoldProbMats::<T>::origin());
5097      match_probs_hashed_ids.insert(y, SparseProbMat::<T>::default());
5098    }
5099  }
5100  thread_pool.scoped(|x| {
5101    for (y, z) in match_probs_hashed_ids.iter_mut() {
5102      let y = (seqs[y.0], seqs[y.1]);
5103      x.execute(move || {
5104        *z = filter_match_probs(&durbin_algo(&y, ref_align_scores), min_match_prob);
5105      });
5106    }
5107  });
5108  let trains_alignfold_scores = false;
5109  thread_pool.scoped(|x| {
5110    let alignfold_scores = &alignfold_scores;
5111    for (y, z) in alignfold_probs_hashed_ids.iter_mut() {
5112      let seq_pair = (seqs[y.0], seqs[y.1]);
5113      let seq_len_pair = (
5114        T::from_usize(seq_pair.0.len()).unwrap(),
5115        T::from_usize(seq_pair.1.len()).unwrap(),
5116      );
5117      let max_basepair_span_pair = (max_basepair_spans[y.0], max_basepair_spans[y.1]);
5118      let basepair_probs_pair = (&basepair_prob_mats[y.0], &basepair_prob_mats[y.1]);
5119      let fold_scores_pair = (&fold_score_sets[y.0], &fold_score_sets[y.1]);
5120      let match_probs = &match_probs_hashed_ids[y];
5121      let (
5122        forward_pos_pairs,
5123        backward_pos_pairs,
5124        _,
5125        pos_quads_hashed_lens,
5126        matchable_poss,
5127        matchable_poss2,
5128      ) = get_sparse_poss(&basepair_probs_pair, match_probs, &seq_len_pair);
5129      x.execute(move || {
5130        *z = consprob_core::<T>((
5131          &seq_pair,
5132          alignfold_scores,
5133          &max_basepair_span_pair,
5134          match_probs,
5135          produces_struct_profs,
5136          trains_alignfold_scores,
5137          &mut AlignfoldScores::new(NEG_INFINITY),
5138          &forward_pos_pairs,
5139          &backward_pos_pairs,
5140          &pos_quads_hashed_lens,
5141          &fold_scores_pair,
5142          produces_match_probs,
5143          &matchable_poss,
5144          &matchable_poss2,
5145        ))
5146        .0;
5147      });
5148    }
5149  });
5150  let mut alignfold_prob_mats_avg = vec![AlignfoldProbMatsAvg::<T>::origin(); num_seqs];
5151  thread_pool.scoped(|x| {
5152    let y = &alignfold_probs_hashed_ids;
5153    for (z, a) in alignfold_prob_mats_avg.iter_mut().enumerate() {
5154      let b = seqs[z].len();
5155      x.execute(move || {
5156        *a = pair_probs2avg_probs::<T>(y, z, num_seqs, b, produces_struct_profs);
5157      });
5158    }
5159  });
5160  let mut match_probs_hashed_ids = MatchProbsHashedIds::<T>::default();
5161  if produces_match_probs {
5162    for x in 0..num_seqs {
5163      for y in x + 1..num_seqs {
5164        let y = (x, y);
5165        let z = &alignfold_probs_hashed_ids[&y];
5166        let mut a = MatchProbMats::<T>::new();
5167        a.loopmatch_probs = z.loopmatch_probs.clone();
5168        a.pairmatch_probs = z.pairmatch_probs.clone();
5169        a.match_probs = z.match_probs.clone();
5170        match_probs_hashed_ids.insert(y, a);
5171      }
5172    }
5173  }
5174  (alignfold_prob_mats_avg, match_probs_hashed_ids)
5175}
5176
5177pub fn constrain<T>(
5178  thread_pool: &mut Pool,
5179  train_data: &mut TrainData<T>,
5180  output_file_path: &Path,
5181  enables_randinit: bool,
5182  learning_tolerance: Score,
5183) where
5184  T: HashIndex,
5185{
5186  let mut alignfold_scores = AlignfoldScores::new(0.);
5187  if enables_randinit {
5188    alignfold_scores.rand_init();
5189  } else {
5190    alignfold_scores.transfer();
5191  }
5192  for x in train_data.iter_mut() {
5193    x.set_curr_scores(&alignfold_scores);
5194  }
5195  let mut old_alignfold_scores = alignfold_scores.clone();
5196  let mut old_cost = INFINITY;
5197  let mut old_accuracy = NEG_INFINITY;
5198  let mut costs = Probs::new();
5199  let mut accuracies = costs.clone();
5200  let mut epoch = 0;
5201  let mut regularizers = Regularizers::from(vec![1.; alignfold_scores.len()]);
5202  let num_data = train_data.len() as Score;
5203  let produces_struct_profs = false;
5204  let trains_alignfold_scores = true;
5205  let produces_match_probs = true;
5206  loop {
5207    thread_pool.scoped(|scope| {
5208      let alignfold_scores = &alignfold_scores;
5209      for train_datum in train_data.iter_mut() {
5210        train_datum.alignfold_counts_expected = AlignfoldScores::new(NEG_INFINITY);
5211        let seq_pair = (&train_datum.seq_pair.0[..], &train_datum.seq_pair.1[..]);
5212        let max_basepair_span_pair = &train_datum.max_basepair_span_pair;
5213        let alignfold_counts_expected = &mut train_datum.alignfold_counts_expected;
5214        let accuracy = &mut train_datum.accuracy;
5215        let global_sum = &mut train_datum.global_sum;
5216        let forward_pos_pairs = &train_datum.forward_pos_pairs;
5217        let backward_pos_pairs = &train_datum.backward_pos_pairs;
5218        let pos_quads_hashed_lens = &train_datum.pos_quads_hashed_lens;
5219        let matchable_poss = &train_datum.matchable_poss;
5220        let matchable_poss2 = &train_datum.matchable_poss2;
5221        let match_probs = &train_datum.match_probs;
5222        let alignfold = &train_datum.alignfold;
5223        let fold_scores_pair = (
5224          &train_datum.fold_scores_pair.0,
5225          &train_datum.fold_scores_pair.1,
5226        );
5227        scope.execute(move || {
5228          let x = consprob_core::<T>((
5229            &seq_pair,
5230            alignfold_scores,
5231            max_basepair_span_pair,
5232            match_probs,
5233            produces_struct_profs,
5234            trains_alignfold_scores,
5235            alignfold_counts_expected,
5236            forward_pos_pairs,
5237            backward_pos_pairs,
5238            pos_quads_hashed_lens,
5239            &fold_scores_pair,
5240            produces_match_probs,
5241            matchable_poss,
5242            matchable_poss2,
5243          ));
5244          *accuracy = get_accuracy_expected::<T>(&seq_pair, alignfold, &x.0.match_probs);
5245          *global_sum = x.1;
5246        });
5247      }
5248    });
5249    let accuracy = train_data.iter().map(|x| x.accuracy).sum::<Score>() / train_data.len() as Score;
5250    let accuracy_change = accuracy - old_accuracy;
5251    if accuracy_change <= learning_tolerance {
5252      alignfold_scores = old_alignfold_scores;
5253      println!(
5254        "Accuracy change {accuracy_change} is <= learning tolerance {learning_tolerance}; training finished"
5255      );
5256      break;
5257    }
5258    old_alignfold_scores = alignfold_scores.clone();
5259    alignfold_scores.update(train_data, &mut regularizers);
5260    for x in train_data.iter_mut() {
5261      x.set_curr_scores(&alignfold_scores);
5262    }
5263    let cost = alignfold_scores.get_cost(&train_data[..], &regularizers);
5264    let avg_cost_change = (cost - old_cost) / num_data;
5265    if avg_cost_change >= 0. {
5266      println!("Average cost change {avg_cost_change} is not negative; training finished");
5267      alignfold_scores = old_alignfold_scores;
5268      break;
5269    }
5270    costs.push(cost);
5271    accuracies.push(accuracy);
5272    println!("Epoch {} finished (current cost = {}, current accuracy = {}, average cost change = {}, accuracy change = {})", epoch + 1, cost, accuracy, avg_cost_change, accuracy_change);
5273    epoch += 1;
5274    old_cost = cost;
5275    old_accuracy = accuracy;
5276  }
5277  write_alignfold_scores_trained(&alignfold_scores, enables_randinit);
5278  write_logs(&costs, &accuracies, output_file_path);
5279}
5280
5281pub fn gapped2ungapped(x: &Seq) -> Seq {
5282  x.iter().filter(|&&x| x != PSEUDO_BASE).copied().collect()
5283}
5284
5285pub fn bytes2seq_gapped(x: &[u8]) -> Seq {
5286  let mut y = Seq::new();
5287  for &x in x {
5288    let x = convert_char(x);
5289    y.push(x);
5290  }
5291  y
5292}
5293
5294pub fn convert_char(x: u8) -> Base {
5295  match x {
5296    A_LOWER | A_UPPER => A,
5297    C_LOWER | C_UPPER => C,
5298    G_LOWER | G_UPPER => G,
5299    U_LOWER | U_UPPER => U,
5300    _ => PSEUDO_BASE,
5301  }
5302}
5303
5304pub fn get_mismatch_pair(x: SeqSlice, y: &(usize, usize), z: bool) -> (usize, usize) {
5305  let mut a = if z {
5306    (x[y.1], x[y.0])
5307  } else {
5308    (PSEUDO_BASE, PSEUDO_BASE)
5309  };
5310  if z {
5311    for &b in &x[y.0 + 1..y.1] {
5312      if b != PSEUDO_BASE {
5313        a.0 = b;
5314        break;
5315      }
5316    }
5317    for i in (y.0 + 1..y.1).rev() {
5318      let x = x[i];
5319      if x != PSEUDO_BASE {
5320        a.1 = x;
5321        break;
5322      }
5323    }
5324  } else {
5325    for i in (0..y.0).rev() {
5326      let x = x[i];
5327      if x != PSEUDO_BASE {
5328        a.0 = x;
5329        break;
5330      }
5331    }
5332    let b = x.len();
5333    for &x in &x[y.1 + 1..b] {
5334      if x != PSEUDO_BASE {
5335        a.1 = x;
5336        break;
5337      }
5338    }
5339  }
5340  a
5341}
5342
5343pub fn get_hairpin_len(x: SeqSlice, y: &(usize, usize)) -> usize {
5344  let mut z = 0;
5345  for &x in &x[y.0 + 1..y.1] {
5346    if x == PSEUDO_BASE {
5347      continue;
5348    }
5349    z += 1;
5350  }
5351  z
5352}
5353
5354pub fn get_2loop_len_pair(x: SeqSlice, y: &(usize, usize), z: &(usize, usize)) -> (usize, usize) {
5355  let mut a = (0, 0);
5356  for &x in &x[y.0 + 1..z.0] {
5357    if x == PSEUDO_BASE {
5358      continue;
5359    }
5360    a.0 += 1;
5361  }
5362  for &x in &x[z.1 + 1..y.1] {
5363    if x == PSEUDO_BASE {
5364      continue;
5365    }
5366    a.1 += 1;
5367  }
5368  a
5369}
5370
5371pub fn vec2struct(source: &Scores, uses_cumulative_scores: bool) -> AlignfoldScores {
5372  let mut target = AlignfoldScores::new(0.);
5373  let mut offset = 0;
5374  let len = target.hairpin_scores_len.len();
5375  for i in 0..len {
5376    let x = source[offset + i];
5377    if uses_cumulative_scores {
5378      target.hairpin_scores_len_cumulative[i] = x;
5379    } else {
5380      target.hairpin_scores_len[i] = x;
5381    }
5382  }
5383  offset += len;
5384  let len = target.bulge_scores_len.len();
5385  for i in 0..len {
5386    let x = source[offset + i];
5387    if uses_cumulative_scores {
5388      target.bulge_scores_len_cumulative[i] = x;
5389    } else {
5390      target.bulge_scores_len[i] = x;
5391    }
5392  }
5393  offset += len;
5394  let len = target.interior_scores_len.len();
5395  for i in 0..len {
5396    let x = source[offset + i];
5397    if uses_cumulative_scores {
5398      target.interior_scores_len_cumulative[i] = x;
5399    } else {
5400      target.interior_scores_len[i] = x;
5401    }
5402  }
5403  offset += len;
5404  let len = target.interior_scores_symmetric.len();
5405  for i in 0..len {
5406    let x = source[offset + i];
5407    if uses_cumulative_scores {
5408      target.interior_scores_symmetric_cumulative[i] = x;
5409    } else {
5410      target.interior_scores_symmetric[i] = x;
5411    }
5412  }
5413  offset += len;
5414  let len = target.interior_scores_asymmetric.len();
5415  for i in 0..len {
5416    let x = source[offset + i];
5417    if uses_cumulative_scores {
5418      target.interior_scores_asymmetric_cumulative[i] = x;
5419    } else {
5420      target.interior_scores_asymmetric[i] = x;
5421    }
5422  }
5423  offset += len;
5424  let len = target.stack_scores.len();
5425  for i in 0..len {
5426    for j in 0..len {
5427      if !has_canonical_basepair(&(i, j)) {
5428        continue;
5429      }
5430      for k in 0..len {
5431        for l in 0..len {
5432          if !has_canonical_basepair(&(k, l)) {
5433            continue;
5434          }
5435          let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
5436          if ((i, j), (k, l)) != dict_min_stack {
5437            continue;
5438          }
5439          target.stack_scores[i][j][k][l] = source[offset];
5440          offset += 1;
5441        }
5442      }
5443    }
5444  }
5445  let len = target.terminal_mismatch_scores.len();
5446  for i in 0..len {
5447    for j in 0..len {
5448      if !has_canonical_basepair(&(i, j)) {
5449        continue;
5450      }
5451      for k in 0..len {
5452        for l in 0..len {
5453          target.terminal_mismatch_scores[i][j][k][l] = source[offset];
5454          offset += 1;
5455        }
5456      }
5457    }
5458  }
5459  let len = target.dangling_scores_left.len();
5460  for i in 0..len {
5461    for j in 0..len {
5462      if !has_canonical_basepair(&(i, j)) {
5463        continue;
5464      }
5465      for k in 0..len {
5466        target.dangling_scores_left[i][j][k] = source[offset];
5467        offset += 1;
5468      }
5469    }
5470  }
5471  let len = target.dangling_scores_right.len();
5472  for i in 0..len {
5473    for j in 0..len {
5474      if !has_canonical_basepair(&(i, j)) {
5475        continue;
5476      }
5477      for k in 0..len {
5478        target.dangling_scores_right[i][j][k] = source[offset];
5479        offset += 1;
5480      }
5481    }
5482  }
5483  let len = target.helix_close_scores.len();
5484  for i in 0..len {
5485    for j in 0..len {
5486      if !has_canonical_basepair(&(i, j)) {
5487        continue;
5488      }
5489      target.helix_close_scores[i][j] = source[offset];
5490      offset += 1;
5491    }
5492  }
5493  let len = target.basepair_scores.len();
5494  for i in 0..len {
5495    for j in 0..len {
5496      if !has_canonical_basepair(&(i, j)) {
5497        continue;
5498      }
5499      let dict_min_basepair = get_dict_min_pair(&(i, j));
5500      if (i, j) != dict_min_basepair {
5501        continue;
5502      }
5503      target.basepair_scores[i][j] = source[offset];
5504      offset += 1;
5505    }
5506  }
5507  let len = target.interior_scores_explicit.len();
5508  for i in 0..len {
5509    for j in 0..len {
5510      let dict_min_len_pair = get_dict_min_pair(&(i, j));
5511      if (i, j) != dict_min_len_pair {
5512        continue;
5513      }
5514      target.interior_scores_explicit[i][j] = source[offset];
5515      offset += 1;
5516    }
5517  }
5518  let len = target.bulge_scores_0x1.len();
5519  for i in 0..len {
5520    target.bulge_scores_0x1[i] = source[offset + i];
5521  }
5522  offset += len;
5523  let len = target.interior_scores_1x1.len();
5524  for i in 0..len {
5525    for j in 0..len {
5526      let dict_min_basepair = get_dict_min_pair(&(i, j));
5527      if (i, j) != dict_min_basepair {
5528        continue;
5529      }
5530      target.interior_scores_1x1[i][j] = source[offset];
5531      offset += 1;
5532    }
5533  }
5534  target.multibranch_score_base = source[offset];
5535  offset += 1;
5536  target.multibranch_score_basepair = source[offset];
5537  offset += 1;
5538  target.multibranch_score_unpair = source[offset];
5539  offset += 1;
5540  target.external_score_basepair = source[offset];
5541  offset += 1;
5542  target.external_score_unpair = source[offset];
5543  offset += 1;
5544  target.match2match_score = source[offset];
5545  offset += 1;
5546  target.match2insert_score = source[offset];
5547  offset += 1;
5548  target.init_match_score = source[offset];
5549  offset += 1;
5550  target.insert_extend_score = source[offset];
5551  offset += 1;
5552  target.init_insert_score = source[offset];
5553  offset += 1;
5554  let len = target.insert_scores.len();
5555  for i in 0..len {
5556    target.insert_scores[i] = source[offset + i];
5557  }
5558  offset += len;
5559  let len = target.match_scores.len();
5560  for i in 0..len {
5561    for j in 0..len {
5562      let dict_min_match = get_dict_min_pair(&(i, j));
5563      if (i, j) != dict_min_match {
5564        continue;
5565      }
5566      target.match_scores[i][j] = source[offset];
5567      offset += 1;
5568    }
5569  }
5570  assert!(offset == target.len());
5571  target
5572}
5573
5574pub fn struct2vec(source: &AlignfoldScores, uses_cumulative_scores: bool) -> Scores {
5575  let mut target = vec![0.; source.len()];
5576  let mut offset = 0;
5577  let len = source.hairpin_scores_len.len();
5578  for i in 0..len {
5579    target[offset + i] = if uses_cumulative_scores {
5580      source.hairpin_scores_len_cumulative[i]
5581    } else {
5582      source.hairpin_scores_len[i]
5583    };
5584  }
5585  offset += len;
5586  let len = source.bulge_scores_len.len();
5587  for i in 0..len {
5588    target[offset + i] = if uses_cumulative_scores {
5589      source.bulge_scores_len_cumulative[i]
5590    } else {
5591      source.bulge_scores_len[i]
5592    };
5593  }
5594  offset += len;
5595  let len = source.interior_scores_len.len();
5596  for i in 0..len {
5597    target[offset + i] = if uses_cumulative_scores {
5598      source.interior_scores_len_cumulative[i]
5599    } else {
5600      source.interior_scores_len[i]
5601    };
5602  }
5603  offset += len;
5604  let len = source.interior_scores_symmetric.len();
5605  for i in 0..len {
5606    target[offset + i] = if uses_cumulative_scores {
5607      source.interior_scores_symmetric_cumulative[i]
5608    } else {
5609      source.interior_scores_symmetric[i]
5610    };
5611  }
5612  offset += len;
5613  let len = source.interior_scores_asymmetric.len();
5614  for i in 0..len {
5615    target[offset + i] = if uses_cumulative_scores {
5616      source.interior_scores_asymmetric_cumulative[i]
5617    } else {
5618      source.interior_scores_asymmetric[i]
5619    };
5620  }
5621  offset += len;
5622  let len = source.stack_scores.len();
5623  for i in 0..len {
5624    for j in 0..len {
5625      if !has_canonical_basepair(&(i, j)) {
5626        continue;
5627      }
5628      for k in 0..len {
5629        for l in 0..len {
5630          if !has_canonical_basepair(&(k, l)) {
5631            continue;
5632          }
5633          let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
5634          if ((i, j), (k, l)) != dict_min_stack {
5635            continue;
5636          }
5637          target[offset] = source.stack_scores[i][j][k][l];
5638          offset += 1;
5639        }
5640      }
5641    }
5642  }
5643  let len = source.terminal_mismatch_scores.len();
5644  for i in 0..len {
5645    for j in 0..len {
5646      if !has_canonical_basepair(&(i, j)) {
5647        continue;
5648      }
5649      for k in 0..len {
5650        for l in 0..len {
5651          target[offset] = source.terminal_mismatch_scores[i][j][k][l];
5652          offset += 1;
5653        }
5654      }
5655    }
5656  }
5657  let len = source.dangling_scores_left.len();
5658  for i in 0..len {
5659    for j in 0..len {
5660      if !has_canonical_basepair(&(i, j)) {
5661        continue;
5662      }
5663      for k in 0..len {
5664        target[offset] = source.dangling_scores_left[i][j][k];
5665        offset += 1;
5666      }
5667    }
5668  }
5669  let len = source.dangling_scores_right.len();
5670  for i in 0..len {
5671    for j in 0..len {
5672      if !has_canonical_basepair(&(i, j)) {
5673        continue;
5674      }
5675      for k in 0..len {
5676        target[offset] = source.dangling_scores_right[i][j][k];
5677        offset += 1;
5678      }
5679    }
5680  }
5681  let len = source.helix_close_scores.len();
5682  for i in 0..len {
5683    for j in 0..len {
5684      if !has_canonical_basepair(&(i, j)) {
5685        continue;
5686      }
5687      target[offset] = source.helix_close_scores[i][j];
5688      offset += 1;
5689    }
5690  }
5691  let len = source.basepair_scores.len();
5692  for i in 0..len {
5693    for j in 0..len {
5694      if !has_canonical_basepair(&(i, j)) {
5695        continue;
5696      }
5697      let dict_min_basepair = get_dict_min_pair(&(i, j));
5698      if (i, j) != dict_min_basepair {
5699        continue;
5700      }
5701      target[offset] = source.basepair_scores[i][j];
5702      offset += 1;
5703    }
5704  }
5705  let len = source.interior_scores_explicit.len();
5706  for i in 0..len {
5707    for j in 0..len {
5708      let dict_min_len_pair = get_dict_min_pair(&(i, j));
5709      if (i, j) != dict_min_len_pair {
5710        continue;
5711      }
5712      target[offset] = source.interior_scores_explicit[i][j];
5713      offset += 1;
5714    }
5715  }
5716  let len = source.bulge_scores_0x1.len();
5717  for (i, &x) in source.bulge_scores_0x1.iter().enumerate() {
5718    target[offset + i] = x;
5719  }
5720  offset += len;
5721  let len = source.interior_scores_1x1.len();
5722  for i in 0..len {
5723    for j in 0..len {
5724      let dict_min_basepair = get_dict_min_pair(&(i, j));
5725      if (i, j) != dict_min_basepair {
5726        continue;
5727      }
5728      target[offset] = source.interior_scores_1x1[i][j];
5729      offset += 1;
5730    }
5731  }
5732  target[offset] = source.multibranch_score_base;
5733  offset += 1;
5734  target[offset] = source.multibranch_score_basepair;
5735  offset += 1;
5736  target[offset] = source.multibranch_score_unpair;
5737  offset += 1;
5738  target[offset] = source.external_score_basepair;
5739  offset += 1;
5740  target[offset] = source.external_score_unpair;
5741  offset += 1;
5742  target[offset] = source.match2match_score;
5743  offset += 1;
5744  target[offset] = source.match2insert_score;
5745  offset += 1;
5746  target[offset] = source.init_match_score;
5747  offset += 1;
5748  target[offset] = source.insert_extend_score;
5749  offset += 1;
5750  target[offset] = source.init_insert_score;
5751  offset += 1;
5752  let len = source.insert_scores.len();
5753  for (i, &x) in source.insert_scores.iter().enumerate() {
5754    target[offset + i] = x;
5755  }
5756  offset += len;
5757  let len = source.match_scores.len();
5758  for i in 0..len {
5759    for j in 0..len {
5760      let dict_min_match = get_dict_min_pair(&(i, j));
5761      if (i, j) != dict_min_match {
5762        continue;
5763      }
5764      target[offset] = source.match_scores[i][j];
5765      offset += 1;
5766    }
5767  }
5768  assert!(offset == source.len());
5769  Array::from(target)
5770}
5771
5772pub fn scores2bfgs_scores(x: &Scores) -> BfgsScores {
5773  let x: Vec<BfgsScore> = x.to_vec().iter().map(|x| *x as BfgsScore).collect();
5774  BfgsScores::from(x)
5775}
5776
5777pub fn bfgs_scores2scores(x: &BfgsScores) -> Scores {
5778  let x: Vec<Score> = x.to_vec().iter().map(|x| *x as Score).collect();
5779  Scores::from(x)
5780}
5781
5782pub fn get_regularizer(x: usize, y: Score) -> Regularizer {
5783  (x as Score / 2. + GAMMA_DISTRO_ALPHA) / (y / 2. + GAMMA_DISTRO_BETA)
5784}
5785
5786pub fn write_alignfold_scores_trained(alignfold_scores: &AlignfoldScores, enables_randinit: bool) {
5787  let mut writer = BufWriter::new(
5788    File::create(if enables_randinit {
5789      TRAINED_SCORES_FILE_RANDINIT
5790    } else {
5791      TRAINED_SCORES_FILE
5792    })
5793    .unwrap(),
5794  );
5795  let mut buf = format!("use AlignfoldScores;\nimpl AlignfoldScores {{\npub fn load_trained_scores{}() -> AlignfoldScores {{\nAlignfoldScores {{\nhairpin_scores_len: ", if enables_randinit {"_randinit"} else {""});
5796  buf.push_str(&format!(
5797    "{:?},\nbulge_scores_len: ",
5798    &alignfold_scores.hairpin_scores_len
5799  ));
5800  buf.push_str(&format!(
5801    "{:?},\ninterior_scores_len: ",
5802    &alignfold_scores.bulge_scores_len
5803  ));
5804  buf.push_str(&format!(
5805    "{:?},\ninterior_scores_symmetric: ",
5806    &alignfold_scores.interior_scores_len
5807  ));
5808  buf.push_str(&format!(
5809    "{:?},\ninterior_scores_asymmetric: ",
5810    &alignfold_scores.interior_scores_symmetric
5811  ));
5812  buf.push_str(&format!(
5813    "{:?},\nstack_scores: ",
5814    &alignfold_scores.interior_scores_asymmetric
5815  ));
5816  buf.push_str(&format!(
5817    "{:?},\nterminal_mismatch_scores: ",
5818    &alignfold_scores.stack_scores
5819  ));
5820  buf.push_str(&format!(
5821    "{:?},\ndangling_scores_left: ",
5822    &alignfold_scores.terminal_mismatch_scores
5823  ));
5824  buf.push_str(&format!(
5825    "{:?},\ndangling_scores_right: ",
5826    &alignfold_scores.dangling_scores_left
5827  ));
5828  buf.push_str(&format!(
5829    "{:?},\nhelix_close_scores: ",
5830    &alignfold_scores.dangling_scores_right
5831  ));
5832  buf.push_str(&format!(
5833    "{:?},\nbasepair_scores: ",
5834    &alignfold_scores.helix_close_scores
5835  ));
5836  buf.push_str(&format!(
5837    "{:?},\ninterior_scores_explicit: ",
5838    &alignfold_scores.basepair_scores
5839  ));
5840  buf.push_str(&format!(
5841    "{:?},\nbulge_scores_0x1: ",
5842    &alignfold_scores.interior_scores_explicit
5843  ));
5844  buf.push_str(&format!(
5845    "{:?},\ninterior_scores_1x1: ",
5846    &alignfold_scores.bulge_scores_0x1
5847  ));
5848  buf.push_str(&format!(
5849    "{:?},\nmultibranch_score_base: ",
5850    &alignfold_scores.interior_scores_1x1
5851  ));
5852  buf.push_str(&format!(
5853    "{:?},\nmultibranch_score_basepair: ",
5854    alignfold_scores.multibranch_score_base
5855  ));
5856  buf.push_str(&format!(
5857    "{:?},\nmultibranch_score_unpair: ",
5858    alignfold_scores.multibranch_score_basepair
5859  ));
5860  buf.push_str(&format!(
5861    "{:?},\nexternal_score_basepair: ",
5862    alignfold_scores.multibranch_score_unpair
5863  ));
5864  buf.push_str(&format!(
5865    "{:?},\nexternal_score_unpair: ",
5866    alignfold_scores.external_score_basepair
5867  ));
5868  buf.push_str(&format!(
5869    "{:?},\nmatch2match_score: ",
5870    alignfold_scores.external_score_unpair
5871  ));
5872  buf.push_str(&format!(
5873    "{:?},\nmatch2insert_score: ",
5874    alignfold_scores.match2match_score
5875  ));
5876  buf.push_str(&format!(
5877    "{:?},\ninsert_extend_score: ",
5878    alignfold_scores.match2insert_score
5879  ));
5880  buf.push_str(&format!(
5881    "{:?},\ninit_match_score: ",
5882    alignfold_scores.insert_extend_score
5883  ));
5884  buf.push_str(&format!(
5885    "{:?},\ninit_insert_score: ",
5886    alignfold_scores.init_match_score
5887  ));
5888  buf.push_str(&format!(
5889    "{:?},\ninsert_scores: ",
5890    alignfold_scores.init_insert_score
5891  ));
5892  buf.push_str(&format!(
5893    "{:?},\nmatch_scores: ",
5894    alignfold_scores.insert_scores
5895  ));
5896  buf.push_str(&format!(
5897    "{:?},\nhairpin_scores_len_cumulative: ",
5898    alignfold_scores.match_scores
5899  ));
5900  buf.push_str(&format!(
5901    "{:?},\nbulge_scores_len_cumulative: ",
5902    &alignfold_scores.hairpin_scores_len_cumulative
5903  ));
5904  buf.push_str(&format!(
5905    "{:?},\ninterior_scores_len_cumulative: ",
5906    &alignfold_scores.bulge_scores_len_cumulative
5907  ));
5908  buf.push_str(&format!(
5909    "{:?},\ninterior_scores_symmetric_cumulative: ",
5910    &alignfold_scores.interior_scores_len_cumulative
5911  ));
5912  buf.push_str(&format!(
5913    "{:?},\ninterior_scores_asymmetric_cumulative: ",
5914    &alignfold_scores.interior_scores_symmetric_cumulative
5915  ));
5916  buf.push_str(&format!(
5917    "{:?},",
5918    &alignfold_scores.interior_scores_asymmetric_cumulative
5919  ));
5920  buf.push_str(&String::from("\n}\n}\n}"));
5921  let _ = writer.write_all(buf.as_bytes());
5922}
5923
5924pub fn write_logs(x: &Probs, y: &Probs, z: &Path) {
5925  let mut z = BufWriter::new(File::create(z).unwrap());
5926  let mut a = String::new();
5927  for (x, y) in x.iter().zip(y.iter()) {
5928    a.push_str(&format!("{x},{y}\n"));
5929  }
5930  let _ = z.write_all(a.as_bytes());
5931}
5932
5933pub fn copy_alignfold_scores_align(x: &mut AlignScores, y: &AlignfoldScores) {
5934  x.match2match_score = y.match2match_score;
5935  x.match2insert_score = y.match2insert_score;
5936  x.insert_extend_score = y.insert_extend_score;
5937  x.init_match_score = y.init_match_score;
5938  x.init_insert_score = y.init_insert_score;
5939  let len = x.insert_scores.len();
5940  for i in 0..len {
5941    x.insert_scores[i] = y.insert_scores[i];
5942  }
5943  let len = x.match_scores.len();
5944  for i in 0..len {
5945    for j in 0..len {
5946      x.match_scores[i][j] = y.match_scores[i][j];
5947    }
5948  }
5949}
5950
5951pub fn copy_alignfold_scores_fold(
5952  fold_scores: &mut FoldScoreSets,
5953  alignfold_scores: &AlignfoldScores,
5954) {
5955  for (x, &y) in fold_scores
5956    .hairpin_scores_len
5957    .iter_mut()
5958    .zip(alignfold_scores.hairpin_scores_len.iter())
5959  {
5960    *x = y;
5961  }
5962  for (x, &y) in fold_scores
5963    .bulge_scores_len
5964    .iter_mut()
5965    .zip(alignfold_scores.bulge_scores_len.iter())
5966  {
5967    *x = y;
5968  }
5969  for (x, &y) in fold_scores
5970    .interior_scores_len
5971    .iter_mut()
5972    .zip(alignfold_scores.interior_scores_len.iter())
5973  {
5974    *x = y;
5975  }
5976  for (x, &y) in fold_scores
5977    .interior_scores_symmetric
5978    .iter_mut()
5979    .zip(alignfold_scores.interior_scores_symmetric.iter())
5980  {
5981    *x = y;
5982  }
5983  for (x, &y) in fold_scores
5984    .interior_scores_asymmetric
5985    .iter_mut()
5986    .zip(alignfold_scores.interior_scores_asymmetric.iter())
5987  {
5988    *x = y;
5989  }
5990  let len = fold_scores.stack_scores.len();
5991  for i in 0..len {
5992    for j in 0..len {
5993      if !has_canonical_basepair(&(i, j)) {
5994        continue;
5995      }
5996      for k in 0..len {
5997        for l in 0..len {
5998          if !has_canonical_basepair(&(k, l)) {
5999            continue;
6000          }
6001          fold_scores.stack_scores[i][j][k][l] = alignfold_scores.stack_scores[i][j][k][l];
6002        }
6003      }
6004    }
6005  }
6006  let len = fold_scores.terminal_mismatch_scores.len();
6007  for i in 0..len {
6008    for j in 0..len {
6009      if !has_canonical_basepair(&(i, j)) {
6010        continue;
6011      }
6012      for k in 0..len {
6013        for l in 0..len {
6014          fold_scores.terminal_mismatch_scores[i][j][k][l] =
6015            alignfold_scores.terminal_mismatch_scores[i][j][k][l];
6016        }
6017      }
6018    }
6019  }
6020  let len = fold_scores.dangling_scores_left.len();
6021  for i in 0..len {
6022    for j in 0..len {
6023      if !has_canonical_basepair(&(i, j)) {
6024        continue;
6025      }
6026      for k in 0..len {
6027        fold_scores.dangling_scores_left[i][j][k] = alignfold_scores.dangling_scores_left[i][j][k];
6028      }
6029    }
6030  }
6031  let len = fold_scores.dangling_scores_right.len();
6032  for i in 0..len {
6033    for j in 0..len {
6034      if !has_canonical_basepair(&(i, j)) {
6035        continue;
6036      }
6037      for k in 0..len {
6038        fold_scores.dangling_scores_right[i][j][k] =
6039          alignfold_scores.dangling_scores_right[i][j][k];
6040      }
6041    }
6042  }
6043  let len = fold_scores.helix_close_scores.len();
6044  for i in 0..len {
6045    for j in 0..len {
6046      if !has_canonical_basepair(&(i, j)) {
6047        continue;
6048      }
6049      fold_scores.helix_close_scores[i][j] = alignfold_scores.helix_close_scores[i][j];
6050    }
6051  }
6052  let len = fold_scores.basepair_scores.len();
6053  for i in 0..len {
6054    for j in 0..len {
6055      if !has_canonical_basepair(&(i, j)) {
6056        continue;
6057      }
6058      fold_scores.basepair_scores[i][j] = alignfold_scores.basepair_scores[i][j];
6059    }
6060  }
6061  let len = fold_scores.interior_scores_explicit.len();
6062  for i in 0..len {
6063    for j in 0..len {
6064      fold_scores.interior_scores_explicit[i][j] = alignfold_scores.interior_scores_explicit[i][j];
6065    }
6066  }
6067  for (x, &y) in fold_scores
6068    .bulge_scores_0x1
6069    .iter_mut()
6070    .zip(alignfold_scores.bulge_scores_0x1.iter())
6071  {
6072    *x = y;
6073  }
6074  let len = fold_scores.interior_scores_1x1.len();
6075  for i in 0..len {
6076    for j in 0..len {
6077      fold_scores.interior_scores_1x1[i][j] = alignfold_scores.interior_scores_1x1[i][j];
6078    }
6079  }
6080  fold_scores.multibranch_score_base = alignfold_scores.multibranch_score_base;
6081  fold_scores.multibranch_score_basepair = alignfold_scores.multibranch_score_basepair;
6082  fold_scores.multibranch_score_unpair = alignfold_scores.multibranch_score_unpair;
6083  fold_scores.external_score_basepair = alignfold_scores.external_score_basepair;
6084  fold_scores.external_score_unpair = alignfold_scores.external_score_unpair;
6085  fold_scores.accumulate();
6086}
6087
6088pub fn print_train_info(alignfold_scores: &AlignfoldScores) {
6089  let mut num_groups = 0;
6090  println!("Training the parameter groups below");
6091  println!("-----------------------------------");
6092  println!("Groups from the CONTRAfold model...");
6093  println!(
6094    "Hairpin loop length (group size {})",
6095    alignfold_scores.hairpin_scores_len.len()
6096  );
6097  num_groups += 1;
6098  println!(
6099    "Bulge loop length (group size {})",
6100    alignfold_scores.bulge_scores_len.len()
6101  );
6102  num_groups += 1;
6103  println!(
6104    "Interior loop length (group size {})",
6105    alignfold_scores.interior_scores_len.len()
6106  );
6107  num_groups += 1;
6108  println!(
6109    "Interior loop length symmetric (group size {})",
6110    alignfold_scores.interior_scores_symmetric.len()
6111  );
6112  num_groups += 1;
6113  println!(
6114    "Interior loop length asymmetric (group size {})",
6115    alignfold_scores.interior_scores_asymmetric.len()
6116  );
6117  num_groups += 1;
6118  let mut group_size = 0;
6119  let len = alignfold_scores.stack_scores.len();
6120  for i in 0..len {
6121    for j in 0..len {
6122      if !has_canonical_basepair(&(i, j)) {
6123        continue;
6124      }
6125      for k in 0..len {
6126        for l in 0..len {
6127          if !has_canonical_basepair(&(k, l)) {
6128            continue;
6129          }
6130          let dict_min_stack = get_dict_min_stack(&(i, j), &(k, l));
6131          if ((i, j), (k, l)) != dict_min_stack {
6132            continue;
6133          }
6134          group_size += 1;
6135        }
6136      }
6137    }
6138  }
6139  println!("Stacking (group size {group_size})");
6140  num_groups += 1;
6141  let mut group_size = 0;
6142  let len = alignfold_scores.terminal_mismatch_scores.len();
6143  for i in 0..len {
6144    for j in 0..len {
6145      if !has_canonical_basepair(&(i, j)) {
6146        continue;
6147      }
6148      for _ in 0..len {
6149        for _ in 0..len {
6150          group_size += 1;
6151        }
6152      }
6153    }
6154  }
6155  println!("Terminal mismatch (group size {group_size})");
6156  num_groups += 1;
6157  let mut group_size = 0;
6158  let len = alignfold_scores.dangling_scores_left.len();
6159  for i in 0..len {
6160    for j in 0..len {
6161      if !has_canonical_basepair(&(i, j)) {
6162        continue;
6163      }
6164      for _ in 0..len {
6165        group_size += 1;
6166      }
6167    }
6168  }
6169  let len = alignfold_scores.dangling_scores_right.len();
6170  for i in 0..len {
6171    for j in 0..len {
6172      if !has_canonical_basepair(&(i, j)) {
6173        continue;
6174      }
6175      for _ in 0..len {
6176        group_size += 1;
6177      }
6178    }
6179  }
6180  println!("Dangling (group size {group_size})");
6181  num_groups += 1;
6182  let mut group_size = 0;
6183  let len = alignfold_scores.helix_close_scores.len();
6184  for i in 0..len {
6185    for j in 0..len {
6186      if !has_canonical_basepair(&(i, j)) {
6187        continue;
6188      }
6189      group_size += 1;
6190    }
6191  }
6192  println!("Helix end (group size {group_size})");
6193  num_groups += 1;
6194  let mut group_size = 0;
6195  let len = alignfold_scores.basepair_scores.len();
6196  for i in 0..len {
6197    for j in 0..len {
6198      if !has_canonical_basepair(&(i, j)) {
6199        continue;
6200      }
6201      let dict_min_basepair = get_dict_min_pair(&(i, j));
6202      if (i, j) != dict_min_basepair {
6203        continue;
6204      }
6205      group_size += 1;
6206    }
6207  }
6208  println!("Base-pairing (group size {group_size})");
6209  num_groups += 1;
6210  let mut group_size = 0;
6211  let len = alignfold_scores.interior_scores_explicit.len();
6212  for i in 0..len {
6213    for j in 0..len {
6214      let dict_min_len_pair = get_dict_min_pair(&(i, j));
6215      if (i, j) != dict_min_len_pair {
6216        continue;
6217      }
6218      group_size += 1;
6219    }
6220  }
6221  println!("Interior loop length explicit (group size {group_size})");
6222  num_groups += 1;
6223  println!(
6224    "Bulge loop length 0x1 (group size {})",
6225    alignfold_scores.bulge_scores_0x1.len()
6226  );
6227  num_groups += 1;
6228  let mut group_size = 0;
6229  let len = alignfold_scores.interior_scores_1x1.len();
6230  for i in 0..len {
6231    for j in 0..len {
6232      let dict_min_basepair = get_dict_min_pair(&(i, j));
6233      if (i, j) != dict_min_basepair {
6234        continue;
6235      }
6236      group_size += 1;
6237    }
6238  }
6239  println!("Interior loop length 1x1 (group size {group_size})");
6240  num_groups += 1;
6241  println!("Multi-loop length (group size {GROUP_SIZE_MULTIBRANCH})");
6242  num_groups += 1;
6243  println!("External-loop length (group size {GROUP_SIZE_EXTERNAL})");
6244  num_groups += 1;
6245  println!("-----------------------------------");
6246  println!("Groups from the CONTRAlign model...");
6247  println!("Match transition (group size {GROUP_SIZE_MATCH_TRANSITION})");
6248  num_groups += 1;
6249  println!("Insert transition (group size {GROUP_SIZE_INSERT_TRANSITION})");
6250  num_groups += 1;
6251  println!(
6252    "Insert emission (group size {})",
6253    alignfold_scores.insert_scores.len()
6254  );
6255  num_groups += 1;
6256  let mut group_size = 0;
6257  let len = alignfold_scores.match_scores.len();
6258  for i in 0..len {
6259    for j in 0..len {
6260      let dict_min_match = get_dict_min_pair(&(i, j));
6261      if (i, j) != dict_min_match {
6262        continue;
6263      }
6264      group_size += 1;
6265    }
6266  }
6267  println!("Match emission (group size {group_size})");
6268  num_groups += 1;
6269  println!("-----------------------------------");
6270  println!(
6271    "Total # scoring parameters (from {} groups) to be trained: {}",
6272    num_groups,
6273    alignfold_scores.len()
6274  );
6275}