1use cyanea_core::{CyaneaError, Result};
13
14const R: f64 = 0.001987;
18
19const DEFAULT_T: f64 = 310.15;
21
22const ML_A: f64 = 3.4;
24const ML_B: f64 = 0.4;
26const ML_C: f64 = 0.0;
28
29const INF: f64 = 1e18;
31
32const MIN_HAIRPIN: usize = 3;
34
35const HAIRPIN_INIT: [f64; 31] = [
39 0.0, 0.0, 0.0, 5.4, 5.6, 5.7, 5.4, 5.6, 5.7, 5.4, 5.6, 5.7, 5.8, 5.9, 5.9, 6.0, 6.1, 6.1, 6.2, 6.2, 6.3, 6.3, 6.3, 6.4, 6.4, 6.4, 6.5, 6.5, 6.5, 6.5, 6.6, ];
45
46const INTERNAL_INIT: [f64; 31] = [
48 0.0, 0.0, 0.0, 0.0, 1.1, 2.0, 2.0, 2.1, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 2.9, 3.0, 3.1, 3.1, 3.2, 3.2, 3.3, 3.3, 3.4, 3.4, 3.4, 3.5, 3.5, 3.5, 3.6, 3.6, ];
53
54const BULGE_INIT: [f64; 31] = [
56 0.0, 3.8, 2.8, 3.2, 3.6, 4.0, 4.4, 4.6, 4.7, 4.8, 4.9, 5.0, 5.1, 5.2, 5.3, 5.4, 5.4, 5.5, 5.5, 5.6, 5.6, 5.7, 5.7, 5.8, 5.8, 5.8, 5.9, 5.9, 5.9, 6.0, 6.0, ];
61
62fn can_pair(a: u8, b: u8) -> bool {
64 matches!(
65 (a, b),
66 (b'A', b'U')
67 | (b'U', b'A')
68 | (b'G', b'C')
69 | (b'C', b'G')
70 | (b'G', b'U')
71 | (b'U', b'G')
72 )
73}
74
75fn pair_index(a: u8, b: u8) -> Option<usize> {
78 match (a, b) {
79 (b'A', b'U') => Some(0),
80 (b'U', b'A') => Some(1),
81 (b'G', b'C') => Some(2),
82 (b'C', b'G') => Some(3),
83 (b'G', b'U') => Some(4),
84 (b'U', b'G') => Some(5),
85 _ => None,
86 }
87}
88
89const STACKING: [[f64; 6]; 6] = [
93 [-0.9, -1.1, -2.2, -2.1, -0.6, -1.4],
95 [-1.3, -0.9, -2.4, -2.1, -1.0, -0.7],
97 [-2.4, -2.1, -3.3, -2.4, -1.5, -1.5],
99 [-2.1, -2.1, -2.4, -3.4, -1.4, -2.1],
101 [-1.3, -1.0, -2.5, -1.5, -0.5, -1.3],
103 [-1.0, -0.7, -1.5, -1.5, -0.3, -0.5],
105];
106
107fn stacking_energy(i5: u8, j5: u8, i3: u8, j3: u8) -> f64 {
110 match (pair_index(i5, j5), pair_index(i3, j3)) {
111 (Some(a), Some(b)) => STACKING[a][b],
112 _ => INF,
113 }
114}
115
116fn hairpin_energy(seq: &[u8], i: usize, j: usize) -> f64 {
118 let size = j - i - 1;
119 if size < MIN_HAIRPIN {
120 return INF;
121 }
122 let init = if size <= 30 {
123 HAIRPIN_INIT[size]
124 } else {
125 HAIRPIN_INIT[30] + 1.75 * R * DEFAULT_T * ((size as f64) / 30.0).ln()
126 };
127 let mismatch = if size >= 4 {
129 terminal_mismatch(seq[i], seq[j], seq[i + 1], seq[j - 1])
130 } else {
131 0.0
132 };
133 init + mismatch
134}
135
136fn terminal_mismatch(_ci: u8, _cj: u8, ni: u8, nj: u8) -> f64 {
139 let is_purine = |b: u8| b == b'A' || b == b'G';
141 if is_purine(ni) && is_purine(nj) {
142 -0.8
143 } else if is_purine(ni) || is_purine(nj) {
144 -0.4
145 } else {
146 0.0
147 }
148}
149
150fn internal_loop_energy(seq: &[u8], i: usize, j: usize, p: usize, q: usize) -> f64 {
152 let left = p - i - 1;
153 let right = j - q - 1;
154
155 if left == 0 && right == 0 {
156 return INF; }
158
159 if left == 0 || right == 0 {
161 let size = left + right;
163 let init = if size <= 30 {
164 BULGE_INIT[size]
165 } else {
166 BULGE_INIT[30] + 1.75 * R * DEFAULT_T * ((size as f64) / 30.0).ln()
167 };
168 if size == 1 {
170 return init + stacking_energy(seq[i], seq[j], seq[p], seq[q]);
171 }
172 return init;
173 }
174
175 let size = left + right;
177 let init = if size <= 30 {
178 INTERNAL_INIT[size]
179 } else {
180 INTERNAL_INIT[30] + 1.75 * R * DEFAULT_T * ((size as f64) / 30.0).ln()
181 };
182 let asymmetry = 0.3 * ((left as f64) - (right as f64)).abs();
184 let asymmetry = asymmetry.min(3.0); init + asymmetry
187}
188
189#[derive(Debug, Clone, PartialEq, Eq)]
198pub struct RnaSecondaryStructure {
199 pub pairs: Vec<Option<usize>>,
201 pub length: usize,
203}
204
205impl RnaSecondaryStructure {
206 pub fn from_dot_bracket(s: &str) -> Result<Self> {
224 let n = s.len();
225 let mut pairs = vec![None; n];
226 let mut stack = Vec::new();
227
228 for (i, ch) in s.chars().enumerate() {
229 match ch {
230 '(' => stack.push(i),
231 ')' => {
232 let j = stack.pop().ok_or_else(|| {
233 CyaneaError::Parse("unmatched ')' in dot-bracket string".into())
234 })?;
235 pairs[j] = Some(i);
236 pairs[i] = Some(j);
237 }
238 '.' => {}
239 _ => {
240 return Err(CyaneaError::Parse(format!(
241 "invalid character '{}' in dot-bracket string",
242 ch
243 )));
244 }
245 }
246 }
247
248 if !stack.is_empty() {
249 return Err(CyaneaError::Parse("unmatched '(' in dot-bracket string".into()));
250 }
251
252 Ok(Self { pairs, length: n })
253 }
254
255 pub fn to_dot_bracket(&self) -> String {
266 let mut out = vec!['.'; self.length];
267 for (i, partner) in self.pairs.iter().enumerate() {
268 if let Some(j) = partner {
269 if i < *j {
270 out[i] = '(';
271 out[*j] = ')';
272 }
273 }
274 }
275 out.into_iter().collect()
276 }
277
278 pub fn base_pairs(&self) -> Vec<(usize, usize)> {
280 let mut bps: Vec<(usize, usize)> = self
281 .pairs
282 .iter()
283 .enumerate()
284 .filter_map(|(i, p)| p.map(|j| (i, j)))
285 .filter(|(i, j)| i < j)
286 .collect();
287 bps.sort();
288 bps
289 }
290
291 pub fn is_paired(&self, i: usize) -> bool {
293 i < self.length && self.pairs[i].is_some()
294 }
295
296 pub fn partner(&self, i: usize) -> Option<usize> {
298 if i < self.length {
299 self.pairs[i]
300 } else {
301 None
302 }
303 }
304
305 pub fn num_pairs(&self) -> usize {
307 self.pairs.iter().filter(|p| p.is_some()).count() / 2
308 }
309}
310
311#[derive(Debug, Clone)]
315pub struct NussinovResult {
316 pub structure: RnaSecondaryStructure,
318 pub max_pairs: usize,
320}
321
322pub fn nussinov(seq: &[u8], min_loop_size: usize) -> Result<NussinovResult> {
340 let seq = normalize_rna(seq)?;
341 let n = seq.len();
342
343 if n == 0 {
344 return Err(CyaneaError::InvalidInput("empty sequence".into()));
345 }
346
347 let mut m = vec![0i32; n * n];
349 let idx = |i: usize, j: usize| i * n + j;
350
351 for len in 2..=n {
353 for i in 0..=n - len {
354 let j = i + len - 1;
355 let mut best = if i + 1 <= j { m[idx(i + 1, j)] } else { 0 };
357 if j > 0 {
359 best = best.max(m[idx(i, j - 1)]);
360 }
361 if can_pair(seq[i], seq[j]) && j - i > min_loop_size {
363 let inner = if i + 1 <= j.saturating_sub(1) {
364 m[idx(i + 1, j - 1)]
365 } else {
366 0
367 };
368 best = best.max(inner + 1);
369 }
370 for k in (i + 1)..j {
372 best = best.max(m[idx(i, k)] + m[idx(k + 1, j)]);
373 }
374 m[idx(i, j)] = best;
375 }
376 }
377
378 let mut pairs = vec![None; n];
380 nussinov_traceback(&seq, &m, n, min_loop_size, 0, n - 1, &mut pairs);
381
382 let max_pairs = m[idx(0, n - 1)] as usize;
383 Ok(NussinovResult {
384 structure: RnaSecondaryStructure {
385 pairs,
386 length: n,
387 },
388 max_pairs,
389 })
390}
391
392fn nussinov_traceback(
393 seq: &[u8],
394 m: &[i32],
395 n: usize,
396 min_loop_size: usize,
397 i: usize,
398 j: usize,
399 pairs: &mut [Option<usize>],
400) {
401 if i >= j || (j - i) < 1 {
402 return;
403 }
404 let idx = |a: usize, b: usize| a * n + b;
405 let val = m[idx(i, j)];
406
407 if i + 1 <= j && m[idx(i + 1, j)] == val {
409 nussinov_traceback(seq, m, n, min_loop_size, i + 1, j, pairs);
410 return;
411 }
412
413 if can_pair(seq[i], seq[j]) && j - i > min_loop_size {
415 let inner = if i + 1 <= j.saturating_sub(1) {
416 m[idx(i + 1, j - 1)]
417 } else {
418 0
419 };
420 if inner + 1 == val {
421 pairs[i] = Some(j);
422 pairs[j] = Some(i);
423 if j > 0 && i + 1 < j {
424 nussinov_traceback(seq, m, n, min_loop_size, i + 1, j - 1, pairs);
425 }
426 return;
427 }
428 }
429
430 for k in (i + 1)..j {
432 if m[idx(i, k)] + m[idx(k + 1, j)] == val {
433 nussinov_traceback(seq, m, n, min_loop_size, i, k, pairs);
434 nussinov_traceback(seq, m, n, min_loop_size, k + 1, j, pairs);
435 return;
436 }
437 }
438
439 if j > 0 {
441 nussinov_traceback(seq, m, n, min_loop_size, i, j - 1, pairs);
442 }
443}
444
445#[derive(Debug, Clone)]
449pub struct MfeResult {
450 pub structure: RnaSecondaryStructure,
452 pub energy: f64,
454}
455
456pub fn zuker_mfe(seq: &[u8]) -> Result<MfeResult> {
474 let seq = normalize_rna(seq)?;
475 let n = seq.len();
476
477 if n == 0 {
478 return Err(CyaneaError::InvalidInput("empty sequence".into()));
479 }
480 if n < 5 {
481 return Ok(MfeResult {
483 structure: RnaSecondaryStructure {
484 pairs: vec![None; n],
485 length: n,
486 },
487 energy: 0.0,
488 });
489 }
490
491 let idx = |i: usize, j: usize| i * n + j;
492
493 let mut v = vec![INF; n * n];
495 let mut w = vec![0.0_f64; n * n];
497 let mut wm = vec![INF; n * n];
499
500 for len in 2..=n {
502 for i in 0..=n - len {
503 let j = i + len - 1;
504
505 if can_pair(seq[i], seq[j]) && j - i > MIN_HAIRPIN {
507 let mut best_v = INF;
508
509 best_v = best_v.min(hairpin_energy(&seq, i, j));
511
512 if i + 1 < j && j > 0 && can_pair(seq[i + 1], seq[j - 1]) && j - 1 - (i + 1) >= MIN_HAIRPIN {
514 let stack = stacking_energy(seq[i], seq[j], seq[i + 1], seq[j - 1]);
515 best_v = best_v.min(v[idx(i + 1, j - 1)] + stack);
516 } else if i + 1 < j && j > 0 && can_pair(seq[i + 1], seq[j - 1]) {
517 let stack = stacking_energy(seq[i], seq[j], seq[i + 1], seq[j - 1]);
519 if v[idx(i + 1, j - 1)] < INF / 2.0 {
520 best_v = best_v.min(v[idx(i + 1, j - 1)] + stack);
521 }
522 }
523
524 let max_left = (j - i - 1).min(30);
527 for p in (i + 1)..=(i + max_left).min(j.saturating_sub(1)) {
528 let max_right = (j - i - 1 - (p - i - 1)).min(30);
529 let q_min = if p + MIN_HAIRPIN + 1 > j {
530 continue;
531 } else {
532 (j - max_right).max(p + MIN_HAIRPIN + 1)
533 };
534 for q in q_min..j {
535 if !can_pair(seq[p], seq[q]) {
536 continue;
537 }
538 if p == i + 1 && q == j - 1 {
539 continue; }
541 if v[idx(p, q)] >= INF / 2.0 {
542 continue;
543 }
544 let il_e = internal_loop_energy(&seq, i, j, p, q);
545 best_v = best_v.min(v[idx(p, q)] + il_e);
546 }
547 }
548
549 if j > i + 2 && wm[idx(i + 1, j - 1)] < INF / 2.0 {
551 best_v = best_v.min(wm[idx(i + 1, j - 1)] + ML_A + ML_B);
552 }
553
554 v[idx(i, j)] = best_v;
555 }
556
557 {
559 let mut best_wm = INF;
560
561 if i + 1 <= j && wm[idx(i + 1, j)] < INF / 2.0 {
563 best_wm = best_wm.min(wm[idx(i + 1, j)] + ML_C);
564 }
565
566 if j > 0 && i <= j - 1 && wm[idx(i, j - 1)] < INF / 2.0 {
568 best_wm = best_wm.min(wm[idx(i, j - 1)] + ML_C);
569 }
570
571 if v[idx(i, j)] < INF / 2.0 {
573 best_wm = best_wm.min(v[idx(i, j)] + ML_B);
574 }
575
576 for k in (i + 1)..j {
578 if wm[idx(i, k)] < INF / 2.0 && wm[idx(k + 1, j)] < INF / 2.0 {
579 best_wm = best_wm.min(wm[idx(i, k)] + wm[idx(k + 1, j)]);
580 }
581 }
582
583 wm[idx(i, j)] = best_wm;
584 }
585
586 {
588 let mut best_w: f64 = 0.0; if i + 1 <= j {
592 best_w = best_w.min(w[idx(i + 1, j)]);
593 }
594
595 if j > 0 && i <= j - 1 {
597 best_w = best_w.min(w[idx(i, j - 1)]);
598 }
599
600 if v[idx(i, j)] < INF / 2.0 {
602 best_w = best_w.min(v[idx(i, j)]);
603 }
604
605 for k in (i + 1)..j {
607 best_w = best_w.min(w[idx(i, k)] + w[idx(k + 1, j)]);
608 }
609
610 w[idx(i, j)] = best_w;
611 }
612 }
613 }
614
615 let energy = w[idx(0, n - 1)];
616 let energy = if energy >= INF / 2.0 { 0.0 } else { energy };
617
618 let mut pairs = vec![None; n];
620 zuker_traceback_w(&seq, &v, &w, &wm, n, 0, n - 1, &mut pairs);
621
622 Ok(MfeResult {
623 structure: RnaSecondaryStructure {
624 pairs,
625 length: n,
626 },
627 energy,
628 })
629}
630
631fn zuker_traceback_w(
632 seq: &[u8],
633 v: &[f64],
634 w: &[f64],
635 wm: &[f64],
636 n: usize,
637 i: usize,
638 j: usize,
639 pairs: &mut [Option<usize>],
640) {
641 if i >= j {
642 return;
643 }
644 let idx = |a: usize, b: usize| a * n + b;
645 let val = w[idx(i, j)];
646 let eps = 1e-9;
647
648 if val.abs() < eps {
650 return;
651 }
652
653 if v[idx(i, j)] < INF / 2.0 && (v[idx(i, j)] - val).abs() < eps {
655 pairs[i] = Some(j);
656 pairs[j] = Some(i);
657 zuker_traceback_v(seq, v, w, wm, n, i, j, pairs);
658 return;
659 }
660
661 if i + 1 <= j && (w[idx(i + 1, j)] - val).abs() < eps {
663 zuker_traceback_w(seq, v, w, wm, n, i + 1, j, pairs);
664 return;
665 }
666
667 if j > 0 && i <= j - 1 && (w[idx(i, j - 1)] - val).abs() < eps {
669 zuker_traceback_w(seq, v, w, wm, n, i, j - 1, pairs);
670 return;
671 }
672
673 for k in (i + 1)..j {
675 if (w[idx(i, k)] + w[idx(k + 1, j)] - val).abs() < eps {
676 zuker_traceback_w(seq, v, w, wm, n, i, k, pairs);
677 zuker_traceback_w(seq, v, w, wm, n, k + 1, j, pairs);
678 return;
679 }
680 }
681}
682
683fn zuker_traceback_v(
684 seq: &[u8],
685 v: &[f64],
686 w: &[f64],
687 wm: &[f64],
688 n: usize,
689 i: usize,
690 j: usize,
691 pairs: &mut [Option<usize>],
692) {
693 let idx = |a: usize, b: usize| a * n + b;
694 let val = v[idx(i, j)];
695 let eps = 1e-9;
696
697 if (hairpin_energy(seq, i, j) - val).abs() < eps {
699 return;
700 }
701
702 if i + 1 < j && j > 0 && can_pair(seq[i + 1], seq[j - 1]) && v[idx(i + 1, j - 1)] < INF / 2.0 {
704 let stack = stacking_energy(seq[i], seq[j], seq[i + 1], seq[j - 1]);
705 if (v[idx(i + 1, j - 1)] + stack - val).abs() < eps {
706 pairs[i + 1] = Some(j - 1);
707 pairs[j - 1] = Some(i + 1);
708 zuker_traceback_v(seq, v, w, wm, n, i + 1, j - 1, pairs);
709 return;
710 }
711 }
712
713 let max_left = (j - i - 1).min(30);
715 for p in (i + 1)..=(i + max_left).min(j.saturating_sub(1)) {
716 let max_right = (j - i - 1 - (p - i - 1)).min(30);
717 let q_min_val = (j.saturating_sub(max_right)).max(p + MIN_HAIRPIN + 1);
718 for q in q_min_val..j {
719 if !can_pair(seq[p], seq[q]) || v[idx(p, q)] >= INF / 2.0 {
720 continue;
721 }
722 if p == i + 1 && q == j - 1 {
723 continue;
724 }
725 let il_e = internal_loop_energy(seq, i, j, p, q);
726 if (v[idx(p, q)] + il_e - val).abs() < eps {
727 pairs[p] = Some(q);
728 pairs[q] = Some(p);
729 zuker_traceback_v(seq, v, w, wm, n, p, q, pairs);
730 return;
731 }
732 }
733 }
734
735 if j > i + 2 && wm[idx(i + 1, j - 1)] < INF / 2.0 {
737 if (wm[idx(i + 1, j - 1)] + ML_A + ML_B - val).abs() < eps {
738 zuker_traceback_wm(seq, v, w, wm, n, i + 1, j - 1, pairs);
739 }
740 }
741}
742
743fn zuker_traceback_wm(
744 seq: &[u8],
745 v: &[f64],
746 w: &[f64],
747 wm: &[f64],
748 n: usize,
749 i: usize,
750 j: usize,
751 pairs: &mut [Option<usize>],
752) {
753 if i >= j {
754 return;
755 }
756 let idx = |a: usize, b: usize| a * n + b;
757 let val = wm[idx(i, j)];
758 let eps = 1e-9;
759
760 if val >= INF / 2.0 {
761 return;
762 }
763
764 if v[idx(i, j)] < INF / 2.0 && (v[idx(i, j)] + ML_B - val).abs() < eps {
766 pairs[i] = Some(j);
767 pairs[j] = Some(i);
768 zuker_traceback_v(seq, v, w, wm, n, i, j, pairs);
769 return;
770 }
771
772 if i + 1 <= j && wm[idx(i + 1, j)] < INF / 2.0 && (wm[idx(i + 1, j)] + ML_C - val).abs() < eps
774 {
775 zuker_traceback_wm(seq, v, w, wm, n, i + 1, j, pairs);
776 return;
777 }
778
779 if j > 0 && i <= j - 1 && wm[idx(i, j - 1)] < INF / 2.0
781 && (wm[idx(i, j - 1)] + ML_C - val).abs() < eps
782 {
783 zuker_traceback_wm(seq, v, w, wm, n, i, j - 1, pairs);
784 return;
785 }
786
787 for k in (i + 1)..j {
789 if wm[idx(i, k)] < INF / 2.0
790 && wm[idx(k + 1, j)] < INF / 2.0
791 && (wm[idx(i, k)] + wm[idx(k + 1, j)] - val).abs() < eps
792 {
793 zuker_traceback_wm(seq, v, w, wm, n, i, k, pairs);
794 zuker_traceback_wm(seq, v, w, wm, n, k + 1, j, pairs);
795 return;
796 }
797 }
798}
799
800#[derive(Debug, Clone)]
804pub struct PartitionResult {
805 pub pair_probabilities: Vec<f64>,
807 pub length: usize,
809 pub ensemble_energy: f64,
811}
812
813impl PartitionResult {
814 pub fn pair_probability(&self, i: usize, j: usize) -> f64 {
816 if i >= self.length || j >= self.length {
817 return 0.0;
818 }
819 self.pair_probabilities[i * self.length + j]
820 }
821
822 pub fn unpaired_probability(&self, i: usize) -> f64 {
824 if i >= self.length {
825 return 0.0;
826 }
827 let paired: f64 = (0..self.length)
828 .map(|j| self.pair_probabilities[i * self.length + j])
829 .sum();
830 (1.0 - paired).max(0.0)
831 }
832}
833
834pub fn mccaskill(seq: &[u8], temperature: f64) -> Result<PartitionResult> {
857 let seq = normalize_rna(seq)?;
858 let n = seq.len();
859
860 if n == 0 {
861 return Err(CyaneaError::InvalidInput("empty sequence".into()));
862 }
863 if temperature <= 0.0 {
864 return Err(CyaneaError::InvalidInput(
865 "temperature must be positive".into(),
866 ));
867 }
868
869 let rt = R * temperature;
870
871 if n < 5 {
872 return Ok(PartitionResult {
873 pair_probabilities: vec![0.0; n * n],
874 length: n,
875 ensemble_energy: 0.0,
876 });
877 }
878
879 let idx = |i: usize, j: usize| i * n + j;
880 let boltz = |e: f64| -> f64 {
881 if e >= INF / 2.0 {
882 0.0
883 } else {
884 (-e / rt).exp()
885 }
886 };
887
888 let mut q = vec![0.0_f64; n * n];
892 let mut qb = vec![0.0_f64; n * n];
893 let mut qm = vec![0.0_f64; n * n];
894
895 for i in 0..n {
897 q[idx(i, i)] = 1.0;
898 if i + 1 < n {
899 q[idx(i + 1, i)] = 1.0; }
901 }
902
903 for len in 2..=n {
905 for i in 0..=n - len {
906 let j = i + len - 1;
907
908 if can_pair(seq[i], seq[j]) && j - i > MIN_HAIRPIN {
910 let mut qb_val = 0.0;
911
912 qb_val += boltz(hairpin_energy(&seq, i, j));
914
915 if i + 1 < j && can_pair(seq[i + 1], seq[j - 1]) {
917 let stack = stacking_energy(seq[i], seq[j], seq[i + 1], seq[j - 1]);
918 qb_val += qb[idx(i + 1, j - 1)] * boltz(stack);
919 }
920
921 let max_left = (j - i - 1).min(30);
923 for p in (i + 1)..=(i + max_left).min(j.saturating_sub(1)) {
924 let max_right = (j - i - 1 - (p - i - 1)).min(30);
925 let q_min = (j.saturating_sub(max_right)).max(p + MIN_HAIRPIN + 1);
926 for qi in q_min..j {
927 if !can_pair(seq[p], seq[qi]) {
928 continue;
929 }
930 if p == i + 1 && qi == j - 1 {
931 continue; }
933 let il_e = internal_loop_energy(&seq, i, j, p, qi);
934 qb_val += qb[idx(p, qi)] * boltz(il_e);
935 }
936 }
937
938 if j > i + 2 {
940 qb_val += qm[idx(i + 1, j - 1)] * boltz(ML_A + ML_B);
941 }
942
943 qb[idx(i, j)] = qb_val;
944 }
945
946 {
948 let mut qm_val = 0.0;
949
950 if i + 1 <= j {
952 qm_val += qm[idx(i + 1, j)] * boltz(ML_C);
953 }
954
955 if j > 0 && i <= j - 1 {
957 qm_val += qm[idx(i, j - 1)] * boltz(ML_C);
958 }
959
960 if qb[idx(i, j)] > 0.0 {
962 qm_val += qb[idx(i, j)] * boltz(ML_B);
963 }
964
965 for k in (i + 1)..j {
967 qm_val += qm[idx(i, k)] * qm[idx(k + 1, j)];
968 }
969
970 qm[idx(i, j)] = qm_val;
971 }
972
973 {
975 let mut q_val = 1.0; for d in i..=j {
978 for e in (d + MIN_HAIRPIN + 1)..=j {
979 if !can_pair(seq[d], seq[e]) || qb[idx(d, e)] == 0.0 {
980 continue;
981 }
982 let q_left = if d > i { q[idx(i, d - 1)] } else { 1.0 };
983 let q_right = if e < j { q[idx(e + 1, j)] } else { 1.0 };
984 q_val += q_left * qb[idx(d, e)] * q_right;
985 }
986 }
987
988 q[idx(i, j)] = q_val;
989 }
990 }
991 }
992
993 let z = q[idx(0, n - 1)];
994 let ensemble_energy = if z > 0.0 { -rt * z.ln() } else { 0.0 };
995
996 let mut prob = vec![0.0_f64; n * n];
998
999 if z > 0.0 {
1000 for i in 0..n {
1001 for j in (i + MIN_HAIRPIN + 1)..n {
1002 if qb[idx(i, j)] == 0.0 {
1003 continue;
1004 }
1005 let q_left = if i > 0 { q[idx(0, i - 1)] } else { 1.0 };
1006 let q_right = if j < n - 1 { q[idx(j + 1, n - 1)] } else { 1.0 };
1007 let p_ij = q_left * qb[idx(i, j)] * q_right / z;
1008 let p_ij = p_ij.min(1.0).max(0.0);
1009 prob[idx(i, j)] = p_ij;
1010 prob[idx(j, i)] = p_ij;
1011 }
1012 }
1013 }
1014
1015 Ok(PartitionResult {
1016 pair_probabilities: prob,
1017 length: n,
1018 ensemble_energy,
1019 })
1020}
1021
1022pub fn base_pair_distance(
1043 a: &RnaSecondaryStructure,
1044 b: &RnaSecondaryStructure,
1045) -> Result<usize> {
1046 if a.length != b.length {
1047 return Err(CyaneaError::InvalidInput(format!(
1048 "structure lengths differ: {} vs {}",
1049 a.length, b.length
1050 )));
1051 }
1052
1053 let bp_a: std::collections::HashSet<(usize, usize)> = a.base_pairs().into_iter().collect();
1054 let bp_b: std::collections::HashSet<(usize, usize)> = b.base_pairs().into_iter().collect();
1055
1056 let only_a = bp_a.difference(&bp_b).count();
1057 let only_b = bp_b.difference(&bp_a).count();
1058 Ok(only_a + only_b)
1059}
1060
1061pub fn mountain_distance(
1081 a: &RnaSecondaryStructure,
1082 b: &RnaSecondaryStructure,
1083) -> Result<f64> {
1084 if a.length != b.length {
1085 return Err(CyaneaError::InvalidInput(format!(
1086 "structure lengths differ: {} vs {}",
1087 a.length, b.length
1088 )));
1089 }
1090
1091 let ma = mountain_vector(a);
1092 let mb = mountain_vector(b);
1093
1094 let dist: f64 = ma
1095 .iter()
1096 .zip(mb.iter())
1097 .map(|(x, y)| (*x as f64 - *y as f64).abs())
1098 .sum();
1099 Ok(dist)
1100}
1101
1102fn mountain_vector(s: &RnaSecondaryStructure) -> Vec<i32> {
1104 let mut m = vec![0i32; s.length];
1105 let mut depth = 0i32;
1106 for i in 0..s.length {
1107 if let Some(j) = s.pairs[i] {
1108 if i < j {
1109 depth += 1;
1110 } else {
1111 depth -= 1;
1112 }
1113 }
1114 m[i] = depth;
1115 }
1116 m
1117}
1118
1119fn normalize_rna(seq: &[u8]) -> Result<Vec<u8>> {
1123 seq.iter()
1124 .map(|&b| match b {
1125 b'A' | b'a' => Ok(b'A'),
1126 b'U' | b'u' => Ok(b'U'),
1127 b'G' | b'g' => Ok(b'G'),
1128 b'C' | b'c' => Ok(b'C'),
1129 b'T' | b't' => Ok(b'U'), _ => Err(CyaneaError::InvalidInput(format!(
1131 "invalid nucleotide '{}' in RNA sequence",
1132 b as char
1133 ))),
1134 })
1135 .collect()
1136}
1137
1138#[cfg(test)]
1141mod tests {
1142 use super::*;
1143
1144 #[test]
1147 fn dot_bracket_simple() {
1148 let s = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1149 assert_eq!(s.length, 9);
1150 assert_eq!(s.num_pairs(), 3);
1151 assert_eq!(s.pairs[0], Some(8));
1152 assert_eq!(s.pairs[8], Some(0));
1153 assert!(s.is_paired(0));
1154 assert!(!s.is_paired(3));
1155 }
1156
1157 #[test]
1158 fn dot_bracket_with_unpaired() {
1159 let s = RnaSecondaryStructure::from_dot_bracket("..((..))..").unwrap();
1160 assert_eq!(s.length, 10);
1161 assert_eq!(s.num_pairs(), 2);
1162 assert!(!s.is_paired(0));
1163 assert!(s.is_paired(2));
1164 assert_eq!(s.partner(2), Some(7));
1165 }
1166
1167 #[test]
1168 fn dot_bracket_roundtrip() {
1169 let input = "(((..((.....))...)))";
1170 let s = RnaSecondaryStructure::from_dot_bracket(input).unwrap();
1171 assert_eq!(s.to_dot_bracket(), input);
1172
1173 let s2 = RnaSecondaryStructure::from_dot_bracket(&s.to_dot_bracket()).unwrap();
1175 assert_eq!(s.pairs, s2.pairs);
1176 }
1177
1178 #[test]
1179 fn dot_bracket_base_pairs() {
1180 let s = RnaSecondaryStructure::from_dot_bracket("((.()))").unwrap();
1181 let bps = s.base_pairs();
1182 assert_eq!(bps.len(), 3);
1183 for (i, j) in &bps {
1185 assert!(i < j);
1186 }
1187 }
1188
1189 #[test]
1190 fn dot_bracket_unmatched_open() {
1191 assert!(RnaSecondaryStructure::from_dot_bracket("((...)))(").is_err());
1192 }
1193
1194 #[test]
1195 fn dot_bracket_unmatched_close() {
1196 assert!(RnaSecondaryStructure::from_dot_bracket(")((..))").is_err());
1197 }
1198
1199 #[test]
1200 fn dot_bracket_invalid_char() {
1201 assert!(RnaSecondaryStructure::from_dot_bracket("((..x..))").is_err());
1202 }
1203
1204 #[test]
1205 fn dot_bracket_empty() {
1206 let s = RnaSecondaryStructure::from_dot_bracket("").unwrap();
1207 assert_eq!(s.length, 0);
1208 assert_eq!(s.num_pairs(), 0);
1209 }
1210
1211 #[test]
1214 fn nussinov_gcaucg() {
1215 let r = nussinov(b"GCAUCG", 3).unwrap();
1216 assert!(r.max_pairs >= 1);
1218 }
1219
1220 #[test]
1221 fn nussinov_perfect_stem() {
1222 let r = nussinov(b"GGGGCCCC", 3).unwrap();
1223 assert!(r.max_pairs >= 2);
1225 }
1226
1227 #[test]
1228 fn nussinov_no_pairs() {
1229 let r = nussinov(b"AAAAAA", 3).unwrap();
1230 assert_eq!(r.max_pairs, 0);
1231 assert_eq!(r.structure.num_pairs(), 0);
1232 }
1233
1234 #[test]
1235 fn nussinov_min_loop_enforced() {
1236 let r = nussinov(b"AXXU", 3).unwrap_or_else(|_| {
1238 nussinov(b"AGCU", 3).unwrap()
1240 });
1241 assert_eq!(r.max_pairs, 0);
1244 }
1245
1246 #[test]
1247 fn nussinov_short_sequence() {
1248 let r = nussinov(b"AUG", 3).unwrap();
1249 assert_eq!(r.max_pairs, 0);
1250 }
1251
1252 #[test]
1253 fn nussinov_empty() {
1254 assert!(nussinov(b"", 3).is_err());
1255 }
1256
1257 #[test]
1258 fn nussinov_structure_valid() {
1259 let r = nussinov(b"GGGAAACCC", 3).unwrap();
1260 let bps = r.structure.base_pairs();
1262 for (idx_a, &(i1, j1)) in bps.iter().enumerate() {
1263 for &(i2, j2) in bps.iter().skip(idx_a + 1) {
1264 assert!(j1 <= i2 || i2 >= i1 && j2 <= j1,
1266 "crossing pairs: ({},{}) and ({},{})", i1, j1, i2, j2);
1267 }
1268 }
1269 }
1270
1271 #[test]
1272 fn nussinov_lowercase_and_dna() {
1273 let r = nussinov(b"gggaaaccc", 3).unwrap();
1274 assert!(r.max_pairs > 0);
1275
1276 let r2 = nussinov(b"GGGAAATCC", 3).unwrap();
1278 let r3 = nussinov(b"GGGAAAUCC", 3).unwrap();
1279 assert_eq!(r2.max_pairs, r3.max_pairs);
1280 }
1281
1282 #[test]
1285 fn zuker_simple_hairpin() {
1286 let r = zuker_mfe(b"GGGAAACCC").unwrap();
1287 assert!(r.energy < 0.0, "energy should be negative, got {}", r.energy);
1289 assert!(r.structure.num_pairs() > 0);
1290 }
1291
1292 #[test]
1293 fn zuker_gc_stronger_than_au() {
1294 let gc = zuker_mfe(b"GGGCAAAGCCC").unwrap();
1295 let au = zuker_mfe(b"AAAUAAAUUUU").unwrap();
1296 assert!(
1298 gc.energy <= au.energy,
1299 "GC energy ({}) should be <= AU energy ({})",
1300 gc.energy,
1301 au.energy
1302 );
1303 }
1304
1305 #[test]
1306 fn zuker_no_structure() {
1307 let r = zuker_mfe(b"AAAAAA").unwrap();
1308 assert_eq!(r.structure.num_pairs(), 0);
1309 assert!((r.energy - 0.0).abs() < 1e-6);
1310 }
1311
1312 #[test]
1313 fn zuker_energy_nonpositive() {
1314 for seq in &[b"GCGCGCGC" as &[u8], b"AUGCAUGC", b"GGGAAACCC", b"CCCCGGGGG"] {
1315 let r = zuker_mfe(seq).unwrap();
1316 assert!(
1317 r.energy <= 1e-9,
1318 "energy should be <= 0, got {} for {:?}",
1319 r.energy,
1320 std::str::from_utf8(seq).unwrap()
1321 );
1322 }
1323 }
1324
1325 #[test]
1326 fn zuker_valid_structure() {
1327 let r = zuker_mfe(b"GGGAAACCC").unwrap();
1328 let bps = r.structure.base_pairs();
1330 for &(i, j) in &bps {
1331 assert!(j - i > MIN_HAIRPIN, "pair ({},{}) violates min loop size", i, j);
1332 assert_eq!(r.structure.pairs[i], Some(j));
1334 assert_eq!(r.structure.pairs[j], Some(i));
1335 }
1336 }
1337
1338 #[test]
1339 fn zuker_short_sequence() {
1340 let r = zuker_mfe(b"AUGC").unwrap();
1341 assert_eq!(r.energy, 0.0);
1342 assert_eq!(r.structure.num_pairs(), 0);
1343 }
1344
1345 #[test]
1346 fn zuker_empty() {
1347 assert!(zuker_mfe(b"").is_err());
1348 }
1349
1350 #[test]
1353 fn mccaskill_strong_stem() {
1354 let r = mccaskill(b"GGGAAACCC", 310.15).unwrap();
1355 let p = r.pair_probability(0, 8);
1357 assert!(p > 0.01, "pair prob(0,8) = {} should be > 0.01", p);
1358 }
1359
1360 #[test]
1361 fn mccaskill_no_pairs() {
1362 let r = mccaskill(b"AAAAAA", 310.15).unwrap();
1363 for i in 0..r.length {
1365 for j in 0..r.length {
1366 assert!(
1367 r.pair_probability(i, j) < 0.01,
1368 "pair prob({},{}) = {} should be < 0.01",
1369 i,
1370 j,
1371 r.pair_probability(i, j)
1372 );
1373 }
1374 }
1375 }
1376
1377 #[test]
1378 fn mccaskill_probabilities_sum() {
1379 let r = mccaskill(b"GGGAAACCC", 310.15).unwrap();
1380 for i in 0..r.length {
1381 let paired: f64 = (0..r.length).map(|j| r.pair_probability(i, j)).sum();
1382 let unpaired = r.unpaired_probability(i);
1383 let total = paired + unpaired;
1384 assert!(
1385 (total - 1.0).abs() < 0.1,
1386 "probability sum at position {} = {} (paired={}, unpaired={})",
1387 i,
1388 total,
1389 paired,
1390 unpaired
1391 );
1392 }
1393 }
1394
1395 #[test]
1396 fn mccaskill_temperature_effect() {
1397 let low_t = mccaskill(b"GGGAAACCC", 300.0).unwrap();
1398 let high_t = mccaskill(b"GGGAAACCC", 370.0).unwrap();
1399 let low_p: f64 = (0..low_t.length)
1401 .flat_map(|i| (i + 1..low_t.length).map(move |j| (i, j)))
1402 .map(|(i, j)| low_t.pair_probability(i, j))
1403 .sum();
1404 let high_p: f64 = (0..high_t.length)
1405 .flat_map(|i| (i + 1..high_t.length).map(move |j| (i, j)))
1406 .map(|(i, j)| high_t.pair_probability(i, j))
1407 .sum();
1408 assert!(
1409 low_p >= high_p - 0.01,
1410 "lower T should give more pairing: {} vs {}",
1411 low_p,
1412 high_p
1413 );
1414 }
1415
1416 #[test]
1417 fn mccaskill_deterministic() {
1418 let r1 = mccaskill(b"GCGCGCGCGC", 310.15).unwrap();
1419 let r2 = mccaskill(b"GCGCGCGCGC", 310.15).unwrap();
1420 assert_eq!(r1.pair_probabilities, r2.pair_probabilities);
1421 assert_eq!(r1.ensemble_energy, r2.ensemble_energy);
1422 }
1423
1424 #[test]
1425 fn mccaskill_empty() {
1426 assert!(mccaskill(b"", 310.15).is_err());
1427 }
1428
1429 #[test]
1430 fn mccaskill_invalid_temperature() {
1431 assert!(mccaskill(b"GGGAAACCC", 0.0).is_err());
1432 assert!(mccaskill(b"GGGAAACCC", -10.0).is_err());
1433 }
1434
1435 #[test]
1436 fn mccaskill_short_sequence() {
1437 let r = mccaskill(b"AUGC", 310.15).unwrap();
1438 for i in 0..r.length {
1440 for j in 0..r.length {
1441 assert!((r.pair_probability(i, j) - 0.0).abs() < 1e-10);
1442 }
1443 }
1444 }
1445
1446 #[test]
1449 fn bp_distance_identical() {
1450 let a = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1451 let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1452 assert_eq!(base_pair_distance(&a, &b).unwrap(), 0);
1453 }
1454
1455 #[test]
1456 fn bp_distance_completely_different() {
1457 let a = RnaSecondaryStructure::from_dot_bracket("((....))").unwrap();
1458 let b = RnaSecondaryStructure::from_dot_bracket("........").unwrap();
1459 assert_eq!(base_pair_distance(&a, &b).unwrap(), 2); }
1461
1462 #[test]
1463 fn bp_distance_symmetric() {
1464 let a = RnaSecondaryStructure::from_dot_bracket("((....))").unwrap();
1465 let b = RnaSecondaryStructure::from_dot_bracket(".((..).)").unwrap();
1466 assert_eq!(
1467 base_pair_distance(&a, &b).unwrap(),
1468 base_pair_distance(&b, &a).unwrap()
1469 );
1470 }
1471
1472 #[test]
1473 fn bp_distance_different_lengths() {
1474 let a = RnaSecondaryStructure::from_dot_bracket("((..))").unwrap();
1475 let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1476 assert!(base_pair_distance(&a, &b).is_err());
1477 }
1478
1479 #[test]
1480 fn mountain_distance_identical() {
1481 let a = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1482 let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1483 assert!((mountain_distance(&a, &b).unwrap() - 0.0).abs() < 1e-10);
1484 }
1485
1486 #[test]
1487 fn mountain_distance_symmetric() {
1488 let a = RnaSecondaryStructure::from_dot_bracket("((....))").unwrap();
1489 let b = RnaSecondaryStructure::from_dot_bracket(".((..).)").unwrap();
1490 let d1 = mountain_distance(&a, &b).unwrap();
1491 let d2 = mountain_distance(&b, &a).unwrap();
1492 assert!((d1 - d2).abs() < 1e-10);
1493 }
1494
1495 #[test]
1496 fn mountain_distance_nonnegative() {
1497 let a = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1498 let b = RnaSecondaryStructure::from_dot_bracket("((...))..").unwrap();
1499 assert!(mountain_distance(&a, &b).unwrap() >= 0.0);
1500 }
1501
1502 #[test]
1503 fn mountain_distance_different_lengths() {
1504 let a = RnaSecondaryStructure::from_dot_bracket("((..))").unwrap();
1505 let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1506 assert!(mountain_distance(&a, &b).is_err());
1507 }
1508
1509 #[test]
1512 fn can_pair_valid() {
1513 assert!(can_pair(b'A', b'U'));
1514 assert!(can_pair(b'U', b'A'));
1515 assert!(can_pair(b'G', b'C'));
1516 assert!(can_pair(b'C', b'G'));
1517 assert!(can_pair(b'G', b'U'));
1518 assert!(can_pair(b'U', b'G'));
1519 }
1520
1521 #[test]
1522 fn can_pair_invalid() {
1523 assert!(!can_pair(b'A', b'A'));
1524 assert!(!can_pair(b'A', b'C'));
1525 assert!(!can_pair(b'A', b'G'));
1526 assert!(!can_pair(b'C', b'U'));
1527 }
1528
1529 #[test]
1530 fn stacking_energy_values() {
1531 let e = stacking_energy(b'C', b'G', b'C', b'G');
1533 assert!((e - (-3.4)).abs() < 1e-10, "CG/CG stack = {}", e);
1534 let e2 = stacking_energy(b'G', b'C', b'G', b'C');
1536 assert!((e2 - (-3.3)).abs() < 1e-10, "GC/GC stack = {}", e2);
1537 let e3 = stacking_energy(b'A', b'U', b'U', b'A');
1539 assert!((e3 - (-1.1)).abs() < 1e-10, "AU/UA stack = {}", e3);
1540 }
1541
1542 #[test]
1543 fn normalize_rna_dna_input() {
1544 let r = normalize_rna(b"ATGC").unwrap();
1545 assert_eq!(r, b"AUGC");
1546 }
1547
1548 #[test]
1549 fn normalize_rna_lowercase() {
1550 let r = normalize_rna(b"augc").unwrap();
1551 assert_eq!(r, b"AUGC");
1552 }
1553
1554 #[test]
1555 fn normalize_rna_invalid() {
1556 assert!(normalize_rna(b"AXGC").is_err());
1557 }
1558}