Skip to main content

cyanea_seq/
rna_structure.rs

1//! RNA secondary structure prediction.
2//!
3//! Predicts which bases in an RNA sequence pair to form stems, hairpins,
4//! internal loops, and multi-branch loops. Provides:
5//!
6//! - **Dot-bracket notation** — parse and emit `(((...)))` structures
7//! - **Nussinov algorithm** — maximize base pair count (O(n³))
8//! - **Zuker MFE** — minimum free energy with Turner nearest-neighbor parameters
9//! - **McCaskill partition function** — base pair probability matrix
10//! - **Structure comparison** — base pair distance, mountain distance
11
12use cyanea_core::{CyaneaError, Result};
13
14// ── Constants ────────────────────────────────────────────────────
15
16/// Gas constant in kcal/(mol·K).
17const R: f64 = 0.001987;
18
19/// Default temperature in Kelvin (37 °C).
20const DEFAULT_T: f64 = 310.15;
21
22/// Multi-branch loop offset (kcal/mol).
23const ML_A: f64 = 3.4;
24/// Multi-branch loop per-helix penalty (kcal/mol).
25const ML_B: f64 = 0.4;
26/// Multi-branch loop per-unpaired-base penalty (kcal/mol).
27const ML_C: f64 = 0.0;
28
29/// Large energy value representing an impossible state.
30const INF: f64 = 1e18;
31
32/// Minimum hairpin loop size (bases between closing pair).
33const MIN_HAIRPIN: usize = 3;
34
35// ── Energy tables (simplified Turner 2004) ───────────────────────
36
37/// Hairpin loop initiation energies indexed by size (3..=30), kcal/mol.
38const HAIRPIN_INIT: [f64; 31] = [
39    0.0, 0.0, 0.0, // 0, 1, 2 — unused
40    5.4, 5.6, 5.7, 5.4, 5.6, 5.7, 5.4, // 3–9
41    5.6, 5.7, 5.8, 5.9, 5.9, 6.0, 6.1, // 10–16
42    6.1, 6.2, 6.2, 6.3, 6.3, 6.3, 6.4, // 17–23
43    6.4, 6.4, 6.5, 6.5, 6.5, 6.5, 6.6, // 24–30
44];
45
46/// Internal loop initiation energies indexed by size (1..=30), kcal/mol.
47const INTERNAL_INIT: [f64; 31] = [
48    0.0, // 0 — unused
49    0.0, 0.0, 0.0, 1.1, 2.0, 2.0, 2.1, 2.3, 2.4, 2.5, // 1–10
50    2.6, 2.7, 2.8, 2.9, 2.9, 3.0, 3.1, 3.1, 3.2, 3.2, // 11–20
51    3.3, 3.3, 3.4, 3.4, 3.4, 3.5, 3.5, 3.5, 3.6, 3.6, // 21–30
52];
53
54/// Bulge loop initiation energies indexed by size (1..=30), kcal/mol.
55const BULGE_INIT: [f64; 31] = [
56    0.0, // 0 — unused
57    3.8, 2.8, 3.2, 3.6, 4.0, 4.4, 4.6, 4.7, 4.8, 4.9, // 1–10
58    5.0, 5.1, 5.2, 5.3, 5.4, 5.4, 5.5, 5.5, 5.6, 5.6, // 11–20
59    5.7, 5.7, 5.8, 5.8, 5.8, 5.9, 5.9, 5.9, 6.0, 6.0, // 21–30
60];
61
62/// Return true if bases `a` and `b` can form a pair (AU, GC, GU).
63fn can_pair(a: u8, b: u8) -> bool {
64    matches!(
65        (a, b),
66        (b'A', b'U')
67            | (b'U', b'A')
68            | (b'G', b'C')
69            | (b'C', b'G')
70            | (b'G', b'U')
71            | (b'U', b'G')
72    )
73}
74
75/// Encode a base pair as an index (0–5) for stacking lookup.
76/// Returns None if not a valid pair.
77fn pair_index(a: u8, b: u8) -> Option<usize> {
78    match (a, b) {
79        (b'A', b'U') => Some(0),
80        (b'U', b'A') => Some(1),
81        (b'G', b'C') => Some(2),
82        (b'C', b'G') => Some(3),
83        (b'G', b'U') => Some(4),
84        (b'U', b'G') => Some(5),
85        _ => None,
86    }
87}
88
89/// Stacking energies (kcal/mol, 37 °C) from Turner 1998/2004.
90/// Indexed by [closing pair index][enclosed pair index].
91/// Pair indices: AU=0, UA=1, GC=2, CG=3, GU=4, UG=5.
92const STACKING: [[f64; 6]; 6] = [
93    // closing AU
94    [-0.9, -1.1, -2.2, -2.1, -0.6, -1.4],
95    // closing UA
96    [-1.3, -0.9, -2.4, -2.1, -1.0, -0.7],
97    // closing GC
98    [-2.4, -2.1, -3.3, -2.4, -1.5, -1.5],
99    // closing CG
100    [-2.1, -2.1, -2.4, -3.4, -1.4, -2.1],
101    // closing GU
102    [-1.3, -1.0, -2.5, -1.5, -0.5, -1.3],
103    // closing UG
104    [-1.0, -0.7, -1.5, -1.5, -0.3, -0.5],
105];
106
107/// Look up the stacking energy for two consecutive base pairs.
108/// (i5, j5) is the outer (closing) pair, (i3, j3) is the inner (enclosed) pair.
109fn stacking_energy(i5: u8, j5: u8, i3: u8, j3: u8) -> f64 {
110    match (pair_index(i5, j5), pair_index(i3, j3)) {
111        (Some(a), Some(b)) => STACKING[a][b],
112        _ => INF,
113    }
114}
115
116/// Hairpin loop energy for closing pair (i, j).
117fn hairpin_energy(seq: &[u8], i: usize, j: usize) -> f64 {
118    let size = j - i - 1;
119    if size < MIN_HAIRPIN {
120        return INF;
121    }
122    let init = if size <= 30 {
123        HAIRPIN_INIT[size]
124    } else {
125        HAIRPIN_INIT[30] + 1.75 * R * DEFAULT_T * ((size as f64) / 30.0).ln()
126    };
127    // Terminal mismatch bonus for hairpins ≥ 4
128    let mismatch = if size >= 4 {
129        terminal_mismatch(seq[i], seq[j], seq[i + 1], seq[j - 1])
130    } else {
131        0.0
132    };
133    init + mismatch
134}
135
136/// Simplified terminal mismatch energy (kcal/mol).
137/// Returns a small bonus for purine-purine mismatches adjacent to a closing pair.
138fn terminal_mismatch(_ci: u8, _cj: u8, ni: u8, nj: u8) -> f64 {
139    // Simplified: small bonus if both adjacent bases are purines (stacking)
140    let is_purine = |b: u8| b == b'A' || b == b'G';
141    if is_purine(ni) && is_purine(nj) {
142        -0.8
143    } else if is_purine(ni) || is_purine(nj) {
144        -0.4
145    } else {
146        0.0
147    }
148}
149
150/// Internal/bulge loop energy for outer pair (i,j) and inner pair (p,q).
151fn internal_loop_energy(seq: &[u8], i: usize, j: usize, p: usize, q: usize) -> f64 {
152    let left = p - i - 1;
153    let right = j - q - 1;
154
155    if left == 0 && right == 0 {
156        return INF; // stacking, not internal
157    }
158
159    // 1×1 or 1×2 internal loops: use stacking + asymmetry
160    if left == 0 || right == 0 {
161        // Bulge loop
162        let size = left + right;
163        let init = if size <= 30 {
164            BULGE_INIT[size]
165        } else {
166            BULGE_INIT[30] + 1.75 * R * DEFAULT_T * ((size as f64) / 30.0).ln()
167        };
168        // Stacking bonus for single-nucleotide bulge
169        if size == 1 {
170            return init + stacking_energy(seq[i], seq[j], seq[p], seq[q]);
171        }
172        return init;
173    }
174
175    // Internal loop
176    let size = left + right;
177    let init = if size <= 30 {
178        INTERNAL_INIT[size]
179    } else {
180        INTERNAL_INIT[30] + 1.75 * R * DEFAULT_T * ((size as f64) / 30.0).ln()
181    };
182    // Asymmetry penalty
183    let asymmetry = 0.3 * ((left as f64) - (right as f64)).abs();
184    let asymmetry = asymmetry.min(3.0); // capped
185
186    init + asymmetry
187}
188
189// ── Dot-bracket notation & structure representation ──────────────
190
191/// An RNA secondary structure as a pair table.
192///
193/// Each position `i` in the sequence is either paired to some position `j`
194/// (`pairs[i] = Some(j)`) or unpaired (`pairs[i] = None`). Pairs are
195/// non-crossing: if `i` pairs with `j`, no pair `(p, q)` exists with
196/// `i < p < j < q`.
197#[derive(Debug, Clone, PartialEq, Eq)]
198pub struct RnaSecondaryStructure {
199    /// Pair table: `pairs[i] = Some(j)` if position `i` is paired with `j`.
200    pub pairs: Vec<Option<usize>>,
201    /// Length of the sequence.
202    pub length: usize,
203}
204
205impl RnaSecondaryStructure {
206    /// Parse a dot-bracket string into a secondary structure.
207    ///
208    /// `(` and `)` denote paired bases; `.` denotes unpaired bases.
209    ///
210    /// # Errors
211    ///
212    /// Returns an error if parentheses are unbalanced or the string contains
213    /// characters other than `(`, `)`, and `.`.
214    ///
215    /// # Example
216    ///
217    /// ```
218    /// use cyanea_seq::rna_structure::RnaSecondaryStructure;
219    ///
220    /// let s = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
221    /// assert_eq!(s.num_pairs(), 3);
222    /// ```
223    pub fn from_dot_bracket(s: &str) -> Result<Self> {
224        let n = s.len();
225        let mut pairs = vec![None; n];
226        let mut stack = Vec::new();
227
228        for (i, ch) in s.chars().enumerate() {
229            match ch {
230                '(' => stack.push(i),
231                ')' => {
232                    let j = stack.pop().ok_or_else(|| {
233                        CyaneaError::Parse("unmatched ')' in dot-bracket string".into())
234                    })?;
235                    pairs[j] = Some(i);
236                    pairs[i] = Some(j);
237                }
238                '.' => {}
239                _ => {
240                    return Err(CyaneaError::Parse(format!(
241                        "invalid character '{}' in dot-bracket string",
242                        ch
243                    )));
244                }
245            }
246        }
247
248        if !stack.is_empty() {
249            return Err(CyaneaError::Parse("unmatched '(' in dot-bracket string".into()));
250        }
251
252        Ok(Self { pairs, length: n })
253    }
254
255    /// Convert this structure to dot-bracket notation.
256    ///
257    /// # Example
258    ///
259    /// ```
260    /// use cyanea_seq::rna_structure::RnaSecondaryStructure;
261    ///
262    /// let s = RnaSecondaryStructure::from_dot_bracket("..((..))..").unwrap();
263    /// assert_eq!(s.to_dot_bracket(), "..((..))..");
264    /// ```
265    pub fn to_dot_bracket(&self) -> String {
266        let mut out = vec!['.'; self.length];
267        for (i, partner) in self.pairs.iter().enumerate() {
268            if let Some(j) = partner {
269                if i < *j {
270                    out[i] = '(';
271                    out[*j] = ')';
272                }
273            }
274        }
275        out.into_iter().collect()
276    }
277
278    /// Return sorted list of base pairs `(i, j)` where `i < j`.
279    pub fn base_pairs(&self) -> Vec<(usize, usize)> {
280        let mut bps: Vec<(usize, usize)> = self
281            .pairs
282            .iter()
283            .enumerate()
284            .filter_map(|(i, p)| p.map(|j| (i, j)))
285            .filter(|(i, j)| i < j)
286            .collect();
287        bps.sort();
288        bps
289    }
290
291    /// Check whether position `i` is paired.
292    pub fn is_paired(&self, i: usize) -> bool {
293        i < self.length && self.pairs[i].is_some()
294    }
295
296    /// Return the pairing partner of position `i`, if any.
297    pub fn partner(&self, i: usize) -> Option<usize> {
298        if i < self.length {
299            self.pairs[i]
300        } else {
301            None
302        }
303    }
304
305    /// Number of base pairs in the structure.
306    pub fn num_pairs(&self) -> usize {
307        self.pairs.iter().filter(|p| p.is_some()).count() / 2
308    }
309}
310
311// ── Nussinov algorithm ──────────────────────────────────────────
312
313/// Result of the Nussinov maximum base pair algorithm.
314#[derive(Debug, Clone)]
315pub struct NussinovResult {
316    /// The predicted secondary structure.
317    pub structure: RnaSecondaryStructure,
318    /// Maximum number of base pairs found.
319    pub max_pairs: usize,
320}
321
322/// Predict RNA secondary structure by maximizing base pair count (Nussinov algorithm).
323///
324/// Uses O(n³) dynamic programming. Valid pairs are AU, GC, and GU (wobble).
325/// No pair may close a loop smaller than `min_loop_size` bases.
326///
327/// # Errors
328///
329/// Returns an error if the sequence is empty.
330///
331/// # Example
332///
333/// ```
334/// use cyanea_seq::rna_structure::nussinov;
335///
336/// let result = nussinov(b"GGGGCCCC", 3).unwrap();
337/// assert!(result.max_pairs >= 2);
338/// ```
339pub fn nussinov(seq: &[u8], min_loop_size: usize) -> Result<NussinovResult> {
340    let seq = normalize_rna(seq)?;
341    let n = seq.len();
342
343    if n == 0 {
344        return Err(CyaneaError::InvalidInput("empty sequence".into()));
345    }
346
347    // DP table: M[i][j] = max pairs in subsequence [i..=j]
348    let mut m = vec![0i32; n * n];
349    let idx = |i: usize, j: usize| i * n + j;
350
351    // Fill bottom-up by increasing subsequence length
352    for len in 2..=n {
353        for i in 0..=n - len {
354            let j = i + len - 1;
355            // i unpaired
356            let mut best = if i + 1 <= j { m[idx(i + 1, j)] } else { 0 };
357            // j unpaired
358            if j > 0 {
359                best = best.max(m[idx(i, j - 1)]);
360            }
361            // i,j pair
362            if can_pair(seq[i], seq[j]) && j - i > min_loop_size {
363                let inner = if i + 1 <= j.saturating_sub(1) {
364                    m[idx(i + 1, j - 1)]
365                } else {
366                    0
367                };
368                best = best.max(inner + 1);
369            }
370            // bifurcation
371            for k in (i + 1)..j {
372                best = best.max(m[idx(i, k)] + m[idx(k + 1, j)]);
373            }
374            m[idx(i, j)] = best;
375        }
376    }
377
378    // Traceback
379    let mut pairs = vec![None; n];
380    nussinov_traceback(&seq, &m, n, min_loop_size, 0, n - 1, &mut pairs);
381
382    let max_pairs = m[idx(0, n - 1)] as usize;
383    Ok(NussinovResult {
384        structure: RnaSecondaryStructure {
385            pairs,
386            length: n,
387        },
388        max_pairs,
389    })
390}
391
392fn nussinov_traceback(
393    seq: &[u8],
394    m: &[i32],
395    n: usize,
396    min_loop_size: usize,
397    i: usize,
398    j: usize,
399    pairs: &mut [Option<usize>],
400) {
401    if i >= j || (j - i) < 1 {
402        return;
403    }
404    let idx = |a: usize, b: usize| a * n + b;
405    let val = m[idx(i, j)];
406
407    // i unpaired
408    if i + 1 <= j && m[idx(i + 1, j)] == val {
409        nussinov_traceback(seq, m, n, min_loop_size, i + 1, j, pairs);
410        return;
411    }
412
413    // i,j pair
414    if can_pair(seq[i], seq[j]) && j - i > min_loop_size {
415        let inner = if i + 1 <= j.saturating_sub(1) {
416            m[idx(i + 1, j - 1)]
417        } else {
418            0
419        };
420        if inner + 1 == val {
421            pairs[i] = Some(j);
422            pairs[j] = Some(i);
423            if j > 0 && i + 1 < j {
424                nussinov_traceback(seq, m, n, min_loop_size, i + 1, j - 1, pairs);
425            }
426            return;
427        }
428    }
429
430    // bifurcation
431    for k in (i + 1)..j {
432        if m[idx(i, k)] + m[idx(k + 1, j)] == val {
433            nussinov_traceback(seq, m, n, min_loop_size, i, k, pairs);
434            nussinov_traceback(seq, m, n, min_loop_size, k + 1, j, pairs);
435            return;
436        }
437    }
438
439    // j unpaired (fallback)
440    if j > 0 {
441        nussinov_traceback(seq, m, n, min_loop_size, i, j - 1, pairs);
442    }
443}
444
445// ── Zuker MFE algorithm ─────────────────────────────────────────
446
447/// Result of the Zuker minimum free energy algorithm.
448#[derive(Debug, Clone)]
449pub struct MfeResult {
450    /// The predicted MFE secondary structure.
451    pub structure: RnaSecondaryStructure,
452    /// Minimum free energy in kcal/mol (negative = stable).
453    pub energy: f64,
454}
455
456/// Predict RNA secondary structure by minimizing free energy (Zuker algorithm).
457///
458/// Uses simplified Turner 2004 nearest-neighbor thermodynamic parameters.
459/// Suitable for sequences up to ~500 nt.
460///
461/// # Errors
462///
463/// Returns an error if the sequence is empty or shorter than 5 bases.
464///
465/// # Example
466///
467/// ```
468/// use cyanea_seq::rna_structure::zuker_mfe;
469///
470/// let result = zuker_mfe(b"GGGAAACCC").unwrap();
471/// assert!(result.energy <= 0.0);
472/// ```
473pub fn zuker_mfe(seq: &[u8]) -> Result<MfeResult> {
474    let seq = normalize_rna(seq)?;
475    let n = seq.len();
476
477    if n == 0 {
478        return Err(CyaneaError::InvalidInput("empty sequence".into()));
479    }
480    if n < 5 {
481        // Too short to form any structure
482        return Ok(MfeResult {
483            structure: RnaSecondaryStructure {
484                pairs: vec![None; n],
485                length: n,
486            },
487            energy: 0.0,
488        });
489    }
490
491    let idx = |i: usize, j: usize| i * n + j;
492
493    // V[i,j] = MFE where (i,j) form a base pair
494    let mut v = vec![INF; n * n];
495    // W[i,j] = MFE of any structure on [i..=j]
496    let mut w = vec![0.0_f64; n * n];
497    // WM[i,j] = MFE of multi-branch loop region on [i..=j]
498    let mut wm = vec![INF; n * n];
499
500    // Fill bottom-up by increasing subsequence length
501    for len in 2..=n {
502        for i in 0..=n - len {
503            let j = i + len - 1;
504
505            // ── V(i,j): only valid if (i,j) can pair ──
506            if can_pair(seq[i], seq[j]) && j - i > MIN_HAIRPIN {
507                let mut best_v = INF;
508
509                // Hairpin
510                best_v = best_v.min(hairpin_energy(&seq, i, j));
511
512                // Stacking
513                if i + 1 < j && j > 0 && can_pair(seq[i + 1], seq[j - 1]) && j - 1 - (i + 1) >= MIN_HAIRPIN {
514                    let stack = stacking_energy(seq[i], seq[j], seq[i + 1], seq[j - 1]);
515                    best_v = best_v.min(v[idx(i + 1, j - 1)] + stack);
516                } else if i + 1 < j && j > 0 && can_pair(seq[i + 1], seq[j - 1]) {
517                    // Inner pair exists but too close — still allow stacking if V is valid
518                    let stack = stacking_energy(seq[i], seq[j], seq[i + 1], seq[j - 1]);
519                    if v[idx(i + 1, j - 1)] < INF / 2.0 {
520                        best_v = best_v.min(v[idx(i + 1, j - 1)] + stack);
521                    }
522                }
523
524                // Internal loop / bulge
525                // Iterate over all inner pairs (p, q) with i < p < q < j
526                let max_left = (j - i - 1).min(30);
527                for p in (i + 1)..=(i + max_left).min(j.saturating_sub(1)) {
528                    let max_right = (j - i - 1 - (p - i - 1)).min(30);
529                    let q_min = if p + MIN_HAIRPIN + 1 > j {
530                        continue;
531                    } else {
532                        (j - max_right).max(p + MIN_HAIRPIN + 1)
533                    };
534                    for q in q_min..j {
535                        if !can_pair(seq[p], seq[q]) {
536                            continue;
537                        }
538                        if p == i + 1 && q == j - 1 {
539                            continue; // stacking, handled above
540                        }
541                        if v[idx(p, q)] >= INF / 2.0 {
542                            continue;
543                        }
544                        let il_e = internal_loop_energy(&seq, i, j, p, q);
545                        best_v = best_v.min(v[idx(p, q)] + il_e);
546                    }
547                }
548
549                // Multi-branch loop
550                if j > i + 2 && wm[idx(i + 1, j - 1)] < INF / 2.0 {
551                    best_v = best_v.min(wm[idx(i + 1, j - 1)] + ML_A + ML_B);
552                }
553
554                v[idx(i, j)] = best_v;
555            }
556
557            // ── WM(i,j): multi-branch loop region ──
558            {
559                let mut best_wm = INF;
560
561                // i unpaired
562                if i + 1 <= j && wm[idx(i + 1, j)] < INF / 2.0 {
563                    best_wm = best_wm.min(wm[idx(i + 1, j)] + ML_C);
564                }
565
566                // j unpaired
567                if j > 0 && i <= j - 1 && wm[idx(i, j - 1)] < INF / 2.0 {
568                    best_wm = best_wm.min(wm[idx(i, j - 1)] + ML_C);
569                }
570
571                // Helix starting at (i,j)
572                if v[idx(i, j)] < INF / 2.0 {
573                    best_wm = best_wm.min(v[idx(i, j)] + ML_B);
574                }
575
576                // Concatenation
577                for k in (i + 1)..j {
578                    if wm[idx(i, k)] < INF / 2.0 && wm[idx(k + 1, j)] < INF / 2.0 {
579                        best_wm = best_wm.min(wm[idx(i, k)] + wm[idx(k + 1, j)]);
580                    }
581                }
582
583                wm[idx(i, j)] = best_wm;
584            }
585
586            // ── W(i,j): any structure on [i..=j] ──
587            {
588                let mut best_w: f64 = 0.0; // no structure
589
590                // i unpaired
591                if i + 1 <= j {
592                    best_w = best_w.min(w[idx(i + 1, j)]);
593                }
594
595                // j unpaired
596                if j > 0 && i <= j - 1 {
597                    best_w = best_w.min(w[idx(i, j - 1)]);
598                }
599
600                // (i,j) pair
601                if v[idx(i, j)] < INF / 2.0 {
602                    best_w = best_w.min(v[idx(i, j)]);
603                }
604
605                // bifurcation
606                for k in (i + 1)..j {
607                    best_w = best_w.min(w[idx(i, k)] + w[idx(k + 1, j)]);
608                }
609
610                w[idx(i, j)] = best_w;
611            }
612        }
613    }
614
615    let energy = w[idx(0, n - 1)];
616    let energy = if energy >= INF / 2.0 { 0.0 } else { energy };
617
618    // Traceback
619    let mut pairs = vec![None; n];
620    zuker_traceback_w(&seq, &v, &w, &wm, n, 0, n - 1, &mut pairs);
621
622    Ok(MfeResult {
623        structure: RnaSecondaryStructure {
624            pairs,
625            length: n,
626        },
627        energy,
628    })
629}
630
631fn zuker_traceback_w(
632    seq: &[u8],
633    v: &[f64],
634    w: &[f64],
635    wm: &[f64],
636    n: usize,
637    i: usize,
638    j: usize,
639    pairs: &mut [Option<usize>],
640) {
641    if i >= j {
642        return;
643    }
644    let idx = |a: usize, b: usize| a * n + b;
645    let val = w[idx(i, j)];
646    let eps = 1e-9;
647
648    // No structure
649    if val.abs() < eps {
650        return;
651    }
652
653    // (i,j) pair
654    if v[idx(i, j)] < INF / 2.0 && (v[idx(i, j)] - val).abs() < eps {
655        pairs[i] = Some(j);
656        pairs[j] = Some(i);
657        zuker_traceback_v(seq, v, w, wm, n, i, j, pairs);
658        return;
659    }
660
661    // i unpaired
662    if i + 1 <= j && (w[idx(i + 1, j)] - val).abs() < eps {
663        zuker_traceback_w(seq, v, w, wm, n, i + 1, j, pairs);
664        return;
665    }
666
667    // j unpaired
668    if j > 0 && i <= j - 1 && (w[idx(i, j - 1)] - val).abs() < eps {
669        zuker_traceback_w(seq, v, w, wm, n, i, j - 1, pairs);
670        return;
671    }
672
673    // bifurcation
674    for k in (i + 1)..j {
675        if (w[idx(i, k)] + w[idx(k + 1, j)] - val).abs() < eps {
676            zuker_traceback_w(seq, v, w, wm, n, i, k, pairs);
677            zuker_traceback_w(seq, v, w, wm, n, k + 1, j, pairs);
678            return;
679        }
680    }
681}
682
683fn zuker_traceback_v(
684    seq: &[u8],
685    v: &[f64],
686    w: &[f64],
687    wm: &[f64],
688    n: usize,
689    i: usize,
690    j: usize,
691    pairs: &mut [Option<usize>],
692) {
693    let idx = |a: usize, b: usize| a * n + b;
694    let val = v[idx(i, j)];
695    let eps = 1e-9;
696
697    // Hairpin
698    if (hairpin_energy(seq, i, j) - val).abs() < eps {
699        return;
700    }
701
702    // Stacking
703    if i + 1 < j && j > 0 && can_pair(seq[i + 1], seq[j - 1]) && v[idx(i + 1, j - 1)] < INF / 2.0 {
704        let stack = stacking_energy(seq[i], seq[j], seq[i + 1], seq[j - 1]);
705        if (v[idx(i + 1, j - 1)] + stack - val).abs() < eps {
706            pairs[i + 1] = Some(j - 1);
707            pairs[j - 1] = Some(i + 1);
708            zuker_traceback_v(seq, v, w, wm, n, i + 1, j - 1, pairs);
709            return;
710        }
711    }
712
713    // Internal loop / bulge
714    let max_left = (j - i - 1).min(30);
715    for p in (i + 1)..=(i + max_left).min(j.saturating_sub(1)) {
716        let max_right = (j - i - 1 - (p - i - 1)).min(30);
717        let q_min_val = (j.saturating_sub(max_right)).max(p + MIN_HAIRPIN + 1);
718        for q in q_min_val..j {
719            if !can_pair(seq[p], seq[q]) || v[idx(p, q)] >= INF / 2.0 {
720                continue;
721            }
722            if p == i + 1 && q == j - 1 {
723                continue;
724            }
725            let il_e = internal_loop_energy(seq, i, j, p, q);
726            if (v[idx(p, q)] + il_e - val).abs() < eps {
727                pairs[p] = Some(q);
728                pairs[q] = Some(p);
729                zuker_traceback_v(seq, v, w, wm, n, p, q, pairs);
730                return;
731            }
732        }
733    }
734
735    // Multi-branch loop
736    if j > i + 2 && wm[idx(i + 1, j - 1)] < INF / 2.0 {
737        if (wm[idx(i + 1, j - 1)] + ML_A + ML_B - val).abs() < eps {
738            zuker_traceback_wm(seq, v, w, wm, n, i + 1, j - 1, pairs);
739        }
740    }
741}
742
743fn zuker_traceback_wm(
744    seq: &[u8],
745    v: &[f64],
746    w: &[f64],
747    wm: &[f64],
748    n: usize,
749    i: usize,
750    j: usize,
751    pairs: &mut [Option<usize>],
752) {
753    if i >= j {
754        return;
755    }
756    let idx = |a: usize, b: usize| a * n + b;
757    let val = wm[idx(i, j)];
758    let eps = 1e-9;
759
760    if val >= INF / 2.0 {
761        return;
762    }
763
764    // Helix at (i,j)
765    if v[idx(i, j)] < INF / 2.0 && (v[idx(i, j)] + ML_B - val).abs() < eps {
766        pairs[i] = Some(j);
767        pairs[j] = Some(i);
768        zuker_traceback_v(seq, v, w, wm, n, i, j, pairs);
769        return;
770    }
771
772    // i unpaired
773    if i + 1 <= j && wm[idx(i + 1, j)] < INF / 2.0 && (wm[idx(i + 1, j)] + ML_C - val).abs() < eps
774    {
775        zuker_traceback_wm(seq, v, w, wm, n, i + 1, j, pairs);
776        return;
777    }
778
779    // j unpaired
780    if j > 0 && i <= j - 1 && wm[idx(i, j - 1)] < INF / 2.0
781        && (wm[idx(i, j - 1)] + ML_C - val).abs() < eps
782    {
783        zuker_traceback_wm(seq, v, w, wm, n, i, j - 1, pairs);
784        return;
785    }
786
787    // Concatenation
788    for k in (i + 1)..j {
789        if wm[idx(i, k)] < INF / 2.0
790            && wm[idx(k + 1, j)] < INF / 2.0
791            && (wm[idx(i, k)] + wm[idx(k + 1, j)] - val).abs() < eps
792        {
793            zuker_traceback_wm(seq, v, w, wm, n, i, k, pairs);
794            zuker_traceback_wm(seq, v, w, wm, n, k + 1, j, pairs);
795            return;
796        }
797    }
798}
799
800// ── McCaskill partition function ────────────────────────────────
801
802/// Result of the McCaskill partition function algorithm.
803#[derive(Debug, Clone)]
804pub struct PartitionResult {
805    /// Base pair probability matrix (n×n, flat row-major).
806    pub pair_probabilities: Vec<f64>,
807    /// Sequence length.
808    pub length: usize,
809    /// Ensemble free energy: −RT·ln(Z) in kcal/mol.
810    pub ensemble_energy: f64,
811}
812
813impl PartitionResult {
814    /// Get the probability that positions `i` and `j` are paired.
815    pub fn pair_probability(&self, i: usize, j: usize) -> f64 {
816        if i >= self.length || j >= self.length {
817            return 0.0;
818        }
819        self.pair_probabilities[i * self.length + j]
820    }
821
822    /// Get the probability that position `i` is unpaired.
823    pub fn unpaired_probability(&self, i: usize) -> f64 {
824        if i >= self.length {
825            return 0.0;
826        }
827        let paired: f64 = (0..self.length)
828            .map(|j| self.pair_probabilities[i * self.length + j])
829            .sum();
830        (1.0 - paired).max(0.0)
831    }
832}
833
834/// Compute base pair probabilities via the McCaskill inside-outside algorithm.
835///
836/// Uses the same simplified Turner energy model as [`zuker_mfe`].
837/// Computations are performed in log-space for numerical stability.
838///
839/// # Arguments
840///
841/// * `seq` — RNA sequence (A, U, G, C)
842/// * `temperature` — temperature in Kelvin (e.g., 310.15 for 37 °C)
843///
844/// # Errors
845///
846/// Returns an error if the sequence is empty or temperature is not positive.
847///
848/// # Example
849///
850/// ```
851/// use cyanea_seq::rna_structure::mccaskill;
852///
853/// let result = mccaskill(b"GGGAAACCC", 310.15).unwrap();
854/// assert!(result.pair_probability(0, 8) > 0.0);
855/// ```
856pub fn mccaskill(seq: &[u8], temperature: f64) -> Result<PartitionResult> {
857    let seq = normalize_rna(seq)?;
858    let n = seq.len();
859
860    if n == 0 {
861        return Err(CyaneaError::InvalidInput("empty sequence".into()));
862    }
863    if temperature <= 0.0 {
864        return Err(CyaneaError::InvalidInput(
865            "temperature must be positive".into(),
866        ));
867    }
868
869    let rt = R * temperature;
870
871    if n < 5 {
872        return Ok(PartitionResult {
873            pair_probabilities: vec![0.0; n * n],
874            length: n,
875            ensemble_energy: 0.0,
876        });
877    }
878
879    let idx = |i: usize, j: usize| i * n + j;
880    let boltz = |e: f64| -> f64 {
881        if e >= INF / 2.0 {
882            0.0
883        } else {
884            (-e / rt).exp()
885        }
886    };
887
888    // Inside partition functions (stored as Boltzmann weights, not logs)
889    // Q[i,j] = partition function for subsequence [i..=j]
890    // Qb[i,j] = partition function for subsequence [i..=j] where (i,j) form a pair
891    let mut q = vec![0.0_f64; n * n];
892    let mut qb = vec![0.0_f64; n * n];
893    let mut qm = vec![0.0_f64; n * n];
894
895    // Base cases: Q[i,i] = 1, Q[i,j] = 1 for j < i
896    for i in 0..n {
897        q[idx(i, i)] = 1.0;
898        if i + 1 < n {
899            q[idx(i + 1, i)] = 1.0; // empty
900        }
901    }
902
903    // Fill inside tables bottom-up
904    for len in 2..=n {
905        for i in 0..=n - len {
906            let j = i + len - 1;
907
908            // Qb(i,j): (i,j) must be a valid pair
909            if can_pair(seq[i], seq[j]) && j - i > MIN_HAIRPIN {
910                let mut qb_val = 0.0;
911
912                // Hairpin
913                qb_val += boltz(hairpin_energy(&seq, i, j));
914
915                // Stacking
916                if i + 1 < j && can_pair(seq[i + 1], seq[j - 1]) {
917                    let stack = stacking_energy(seq[i], seq[j], seq[i + 1], seq[j - 1]);
918                    qb_val += qb[idx(i + 1, j - 1)] * boltz(stack);
919                }
920
921                // Internal loops / bulges
922                let max_left = (j - i - 1).min(30);
923                for p in (i + 1)..=(i + max_left).min(j.saturating_sub(1)) {
924                    let max_right = (j - i - 1 - (p - i - 1)).min(30);
925                    let q_min = (j.saturating_sub(max_right)).max(p + MIN_HAIRPIN + 1);
926                    for qi in q_min..j {
927                        if !can_pair(seq[p], seq[qi]) {
928                            continue;
929                        }
930                        if p == i + 1 && qi == j - 1 {
931                            continue; // stacking
932                        }
933                        let il_e = internal_loop_energy(&seq, i, j, p, qi);
934                        qb_val += qb[idx(p, qi)] * boltz(il_e);
935                    }
936                }
937
938                // Multi-branch loop
939                if j > i + 2 {
940                    qb_val += qm[idx(i + 1, j - 1)] * boltz(ML_A + ML_B);
941                }
942
943                qb[idx(i, j)] = qb_val;
944            }
945
946            // QM(i,j): multi-branch region
947            {
948                let mut qm_val = 0.0;
949
950                // i unpaired
951                if i + 1 <= j {
952                    qm_val += qm[idx(i + 1, j)] * boltz(ML_C);
953                }
954
955                // j unpaired
956                if j > 0 && i <= j - 1 {
957                    qm_val += qm[idx(i, j - 1)] * boltz(ML_C);
958                }
959
960                // Helix starting here
961                if qb[idx(i, j)] > 0.0 {
962                    qm_val += qb[idx(i, j)] * boltz(ML_B);
963                }
964
965                // Concatenation
966                for k in (i + 1)..j {
967                    qm_val += qm[idx(i, k)] * qm[idx(k + 1, j)];
968                }
969
970                qm[idx(i, j)] = qm_val;
971            }
972
973            // Q(i,j): full partition
974            {
975                let mut q_val = 1.0; // empty structure
976
977                for d in i..=j {
978                    for e in (d + MIN_HAIRPIN + 1)..=j {
979                        if !can_pair(seq[d], seq[e]) || qb[idx(d, e)] == 0.0 {
980                            continue;
981                        }
982                        let q_left = if d > i { q[idx(i, d - 1)] } else { 1.0 };
983                        let q_right = if e < j { q[idx(e + 1, j)] } else { 1.0 };
984                        q_val += q_left * qb[idx(d, e)] * q_right;
985                    }
986                }
987
988                q[idx(i, j)] = q_val;
989            }
990        }
991    }
992
993    let z = q[idx(0, n - 1)];
994    let ensemble_energy = if z > 0.0 { -rt * z.ln() } else { 0.0 };
995
996    // Outside algorithm for pair probabilities
997    let mut prob = vec![0.0_f64; n * n];
998
999    if z > 0.0 {
1000        for i in 0..n {
1001            for j in (i + MIN_HAIRPIN + 1)..n {
1002                if qb[idx(i, j)] == 0.0 {
1003                    continue;
1004                }
1005                let q_left = if i > 0 { q[idx(0, i - 1)] } else { 1.0 };
1006                let q_right = if j < n - 1 { q[idx(j + 1, n - 1)] } else { 1.0 };
1007                let p_ij = q_left * qb[idx(i, j)] * q_right / z;
1008                let p_ij = p_ij.min(1.0).max(0.0);
1009                prob[idx(i, j)] = p_ij;
1010                prob[idx(j, i)] = p_ij;
1011            }
1012        }
1013    }
1014
1015    Ok(PartitionResult {
1016        pair_probabilities: prob,
1017        length: n,
1018        ensemble_energy,
1019    })
1020}
1021
1022// ── Structure comparison ────────────────────────────────────────
1023
1024/// Compute the base pair distance between two structures.
1025///
1026/// The base pair distance is the size of the symmetric difference of the
1027/// two base pair sets: the number of pairs in one structure but not the other.
1028///
1029/// # Errors
1030///
1031/// Returns an error if the structures have different lengths.
1032///
1033/// # Example
1034///
1035/// ```
1036/// use cyanea_seq::rna_structure::{RnaSecondaryStructure, base_pair_distance};
1037///
1038/// let a = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1039/// let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1040/// assert_eq!(base_pair_distance(&a, &b).unwrap(), 0);
1041/// ```
1042pub fn base_pair_distance(
1043    a: &RnaSecondaryStructure,
1044    b: &RnaSecondaryStructure,
1045) -> Result<usize> {
1046    if a.length != b.length {
1047        return Err(CyaneaError::InvalidInput(format!(
1048            "structure lengths differ: {} vs {}",
1049            a.length, b.length
1050        )));
1051    }
1052
1053    let bp_a: std::collections::HashSet<(usize, usize)> = a.base_pairs().into_iter().collect();
1054    let bp_b: std::collections::HashSet<(usize, usize)> = b.base_pairs().into_iter().collect();
1055
1056    let only_a = bp_a.difference(&bp_b).count();
1057    let only_b = bp_b.difference(&bp_a).count();
1058    Ok(only_a + only_b)
1059}
1060
1061/// Compute the mountain distance between two structures.
1062///
1063/// The mountain representation of a structure assigns to each position
1064/// the number of base pairs enclosing it. The mountain distance is the
1065/// L1 norm of the difference of the two mountain vectors.
1066///
1067/// # Errors
1068///
1069/// Returns an error if the structures have different lengths.
1070///
1071/// # Example
1072///
1073/// ```
1074/// use cyanea_seq::rna_structure::{RnaSecondaryStructure, mountain_distance};
1075///
1076/// let a = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1077/// let b = RnaSecondaryStructure::from_dot_bracket("((...))..").unwrap();
1078/// assert!(mountain_distance(&a, &b).unwrap() > 0.0);
1079/// ```
1080pub fn mountain_distance(
1081    a: &RnaSecondaryStructure,
1082    b: &RnaSecondaryStructure,
1083) -> Result<f64> {
1084    if a.length != b.length {
1085        return Err(CyaneaError::InvalidInput(format!(
1086            "structure lengths differ: {} vs {}",
1087            a.length, b.length
1088        )));
1089    }
1090
1091    let ma = mountain_vector(a);
1092    let mb = mountain_vector(b);
1093
1094    let dist: f64 = ma
1095        .iter()
1096        .zip(mb.iter())
1097        .map(|(x, y)| (*x as f64 - *y as f64).abs())
1098        .sum();
1099    Ok(dist)
1100}
1101
1102/// Build the mountain vector: m[i] = number of base pairs enclosing position i.
1103fn mountain_vector(s: &RnaSecondaryStructure) -> Vec<i32> {
1104    let mut m = vec![0i32; s.length];
1105    let mut depth = 0i32;
1106    for i in 0..s.length {
1107        if let Some(j) = s.pairs[i] {
1108            if i < j {
1109                depth += 1;
1110            } else {
1111                depth -= 1;
1112            }
1113        }
1114        m[i] = depth;
1115    }
1116    m
1117}
1118
1119// ── Helpers ─────────────────────────────────────────────────────
1120
1121/// Normalize sequence to uppercase RNA (T → U).
1122fn normalize_rna(seq: &[u8]) -> Result<Vec<u8>> {
1123    seq.iter()
1124        .map(|&b| match b {
1125            b'A' | b'a' => Ok(b'A'),
1126            b'U' | b'u' => Ok(b'U'),
1127            b'G' | b'g' => Ok(b'G'),
1128            b'C' | b'c' => Ok(b'C'),
1129            b'T' | b't' => Ok(b'U'), // DNA T → RNA U
1130            _ => Err(CyaneaError::InvalidInput(format!(
1131                "invalid nucleotide '{}' in RNA sequence",
1132                b as char
1133            ))),
1134        })
1135        .collect()
1136}
1137
1138// ── Tests ───────────────────────────────────────────────────────
1139
1140#[cfg(test)]
1141mod tests {
1142    use super::*;
1143
1144    // ── Dot-bracket ──
1145
1146    #[test]
1147    fn dot_bracket_simple() {
1148        let s = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1149        assert_eq!(s.length, 9);
1150        assert_eq!(s.num_pairs(), 3);
1151        assert_eq!(s.pairs[0], Some(8));
1152        assert_eq!(s.pairs[8], Some(0));
1153        assert!(s.is_paired(0));
1154        assert!(!s.is_paired(3));
1155    }
1156
1157    #[test]
1158    fn dot_bracket_with_unpaired() {
1159        let s = RnaSecondaryStructure::from_dot_bracket("..((..))..").unwrap();
1160        assert_eq!(s.length, 10);
1161        assert_eq!(s.num_pairs(), 2);
1162        assert!(!s.is_paired(0));
1163        assert!(s.is_paired(2));
1164        assert_eq!(s.partner(2), Some(7));
1165    }
1166
1167    #[test]
1168    fn dot_bracket_roundtrip() {
1169        let input = "(((..((.....))...)))";
1170        let s = RnaSecondaryStructure::from_dot_bracket(input).unwrap();
1171        assert_eq!(s.to_dot_bracket(), input);
1172
1173        // Parse again
1174        let s2 = RnaSecondaryStructure::from_dot_bracket(&s.to_dot_bracket()).unwrap();
1175        assert_eq!(s.pairs, s2.pairs);
1176    }
1177
1178    #[test]
1179    fn dot_bracket_base_pairs() {
1180        let s = RnaSecondaryStructure::from_dot_bracket("((.()))").unwrap();
1181        let bps = s.base_pairs();
1182        assert_eq!(bps.len(), 3);
1183        // All pairs should have i < j
1184        for (i, j) in &bps {
1185            assert!(i < j);
1186        }
1187    }
1188
1189    #[test]
1190    fn dot_bracket_unmatched_open() {
1191        assert!(RnaSecondaryStructure::from_dot_bracket("((...)))(").is_err());
1192    }
1193
1194    #[test]
1195    fn dot_bracket_unmatched_close() {
1196        assert!(RnaSecondaryStructure::from_dot_bracket(")((..))").is_err());
1197    }
1198
1199    #[test]
1200    fn dot_bracket_invalid_char() {
1201        assert!(RnaSecondaryStructure::from_dot_bracket("((..x..))").is_err());
1202    }
1203
1204    #[test]
1205    fn dot_bracket_empty() {
1206        let s = RnaSecondaryStructure::from_dot_bracket("").unwrap();
1207        assert_eq!(s.length, 0);
1208        assert_eq!(s.num_pairs(), 0);
1209    }
1210
1211    // ── Nussinov ──
1212
1213    #[test]
1214    fn nussinov_gcaucg() {
1215        let r = nussinov(b"GCAUCG", 3).unwrap();
1216        // Only G-C pair possible with loop≥3: G(0)-C(5) has loop=4
1217        assert!(r.max_pairs >= 1);
1218    }
1219
1220    #[test]
1221    fn nussinov_perfect_stem() {
1222        let r = nussinov(b"GGGGCCCC", 3).unwrap();
1223        // Can form at most 2 pairs with min_loop_size=3: G(0)-C(7) and G(1)-C(6)
1224        assert!(r.max_pairs >= 2);
1225    }
1226
1227    #[test]
1228    fn nussinov_no_pairs() {
1229        let r = nussinov(b"AAAAAA", 3).unwrap();
1230        assert_eq!(r.max_pairs, 0);
1231        assert_eq!(r.structure.num_pairs(), 0);
1232    }
1233
1234    #[test]
1235    fn nussinov_min_loop_enforced() {
1236        // AU with only 2 bases between: can't pair with min_loop=3
1237        let r = nussinov(b"AXXU", 3).unwrap_or_else(|_| {
1238            // X is invalid, use valid bases
1239            nussinov(b"AGCU", 3).unwrap()
1240        });
1241        // A(0) and U(3) have only 2 bases between them (positions 1,2) — can't pair
1242        // Actually j-i = 3 = min_loop_size, need > min_loop_size
1243        assert_eq!(r.max_pairs, 0);
1244    }
1245
1246    #[test]
1247    fn nussinov_short_sequence() {
1248        let r = nussinov(b"AUG", 3).unwrap();
1249        assert_eq!(r.max_pairs, 0);
1250    }
1251
1252    #[test]
1253    fn nussinov_empty() {
1254        assert!(nussinov(b"", 3).is_err());
1255    }
1256
1257    #[test]
1258    fn nussinov_structure_valid() {
1259        let r = nussinov(b"GGGAAACCC", 3).unwrap();
1260        // Verify no crossing pairs
1261        let bps = r.structure.base_pairs();
1262        for (idx_a, &(i1, j1)) in bps.iter().enumerate() {
1263            for &(i2, j2) in bps.iter().skip(idx_a + 1) {
1264                // Non-crossing: either nested or disjoint
1265                assert!(j1 <= i2 || i2 >= i1 && j2 <= j1,
1266                    "crossing pairs: ({},{}) and ({},{})", i1, j1, i2, j2);
1267            }
1268        }
1269    }
1270
1271    #[test]
1272    fn nussinov_lowercase_and_dna() {
1273        let r = nussinov(b"gggaaaccc", 3).unwrap();
1274        assert!(r.max_pairs > 0);
1275
1276        // T is treated as U
1277        let r2 = nussinov(b"GGGAAATCC", 3).unwrap();
1278        let r3 = nussinov(b"GGGAAAUCC", 3).unwrap();
1279        assert_eq!(r2.max_pairs, r3.max_pairs);
1280    }
1281
1282    // ── Zuker MFE ──
1283
1284    #[test]
1285    fn zuker_simple_hairpin() {
1286        let r = zuker_mfe(b"GGGAAACCC").unwrap();
1287        // Should form a stem-loop with negative energy
1288        assert!(r.energy < 0.0, "energy should be negative, got {}", r.energy);
1289        assert!(r.structure.num_pairs() > 0);
1290    }
1291
1292    #[test]
1293    fn zuker_gc_stronger_than_au() {
1294        let gc = zuker_mfe(b"GGGCAAAGCCC").unwrap();
1295        let au = zuker_mfe(b"AAAUAAAUUUU").unwrap();
1296        // GC-rich should have more negative (stronger) energy
1297        assert!(
1298            gc.energy <= au.energy,
1299            "GC energy ({}) should be <= AU energy ({})",
1300            gc.energy,
1301            au.energy
1302        );
1303    }
1304
1305    #[test]
1306    fn zuker_no_structure() {
1307        let r = zuker_mfe(b"AAAAAA").unwrap();
1308        assert_eq!(r.structure.num_pairs(), 0);
1309        assert!((r.energy - 0.0).abs() < 1e-6);
1310    }
1311
1312    #[test]
1313    fn zuker_energy_nonpositive() {
1314        for seq in &[b"GCGCGCGC" as &[u8], b"AUGCAUGC", b"GGGAAACCC", b"CCCCGGGGG"] {
1315            let r = zuker_mfe(seq).unwrap();
1316            assert!(
1317                r.energy <= 1e-9,
1318                "energy should be <= 0, got {} for {:?}",
1319                r.energy,
1320                std::str::from_utf8(seq).unwrap()
1321            );
1322        }
1323    }
1324
1325    #[test]
1326    fn zuker_valid_structure() {
1327        let r = zuker_mfe(b"GGGAAACCC").unwrap();
1328        // All pairs respect min_loop_size
1329        let bps = r.structure.base_pairs();
1330        for &(i, j) in &bps {
1331            assert!(j - i > MIN_HAIRPIN, "pair ({},{}) violates min loop size", i, j);
1332            // Both must pair with each other
1333            assert_eq!(r.structure.pairs[i], Some(j));
1334            assert_eq!(r.structure.pairs[j], Some(i));
1335        }
1336    }
1337
1338    #[test]
1339    fn zuker_short_sequence() {
1340        let r = zuker_mfe(b"AUGC").unwrap();
1341        assert_eq!(r.energy, 0.0);
1342        assert_eq!(r.structure.num_pairs(), 0);
1343    }
1344
1345    #[test]
1346    fn zuker_empty() {
1347        assert!(zuker_mfe(b"").is_err());
1348    }
1349
1350    // ── McCaskill partition function ──
1351
1352    #[test]
1353    fn mccaskill_strong_stem() {
1354        let r = mccaskill(b"GGGAAACCC", 310.15).unwrap();
1355        // The terminal G-C pair should have significant probability
1356        let p = r.pair_probability(0, 8);
1357        assert!(p > 0.01, "pair prob(0,8) = {} should be > 0.01", p);
1358    }
1359
1360    #[test]
1361    fn mccaskill_no_pairs() {
1362        let r = mccaskill(b"AAAAAA", 310.15).unwrap();
1363        // All pair probabilities should be near 0
1364        for i in 0..r.length {
1365            for j in 0..r.length {
1366                assert!(
1367                    r.pair_probability(i, j) < 0.01,
1368                    "pair prob({},{}) = {} should be < 0.01",
1369                    i,
1370                    j,
1371                    r.pair_probability(i, j)
1372                );
1373            }
1374        }
1375    }
1376
1377    #[test]
1378    fn mccaskill_probabilities_sum() {
1379        let r = mccaskill(b"GGGAAACCC", 310.15).unwrap();
1380        for i in 0..r.length {
1381            let paired: f64 = (0..r.length).map(|j| r.pair_probability(i, j)).sum();
1382            let unpaired = r.unpaired_probability(i);
1383            let total = paired + unpaired;
1384            assert!(
1385                (total - 1.0).abs() < 0.1,
1386                "probability sum at position {} = {} (paired={}, unpaired={})",
1387                i,
1388                total,
1389                paired,
1390                unpaired
1391            );
1392        }
1393    }
1394
1395    #[test]
1396    fn mccaskill_temperature_effect() {
1397        let low_t = mccaskill(b"GGGAAACCC", 300.0).unwrap();
1398        let high_t = mccaskill(b"GGGAAACCC", 370.0).unwrap();
1399        // Higher temperature → less structured → lower pair probabilities
1400        let low_p: f64 = (0..low_t.length)
1401            .flat_map(|i| (i + 1..low_t.length).map(move |j| (i, j)))
1402            .map(|(i, j)| low_t.pair_probability(i, j))
1403            .sum();
1404        let high_p: f64 = (0..high_t.length)
1405            .flat_map(|i| (i + 1..high_t.length).map(move |j| (i, j)))
1406            .map(|(i, j)| high_t.pair_probability(i, j))
1407            .sum();
1408        assert!(
1409            low_p >= high_p - 0.01,
1410            "lower T should give more pairing: {} vs {}",
1411            low_p,
1412            high_p
1413        );
1414    }
1415
1416    #[test]
1417    fn mccaskill_deterministic() {
1418        let r1 = mccaskill(b"GCGCGCGCGC", 310.15).unwrap();
1419        let r2 = mccaskill(b"GCGCGCGCGC", 310.15).unwrap();
1420        assert_eq!(r1.pair_probabilities, r2.pair_probabilities);
1421        assert_eq!(r1.ensemble_energy, r2.ensemble_energy);
1422    }
1423
1424    #[test]
1425    fn mccaskill_empty() {
1426        assert!(mccaskill(b"", 310.15).is_err());
1427    }
1428
1429    #[test]
1430    fn mccaskill_invalid_temperature() {
1431        assert!(mccaskill(b"GGGAAACCC", 0.0).is_err());
1432        assert!(mccaskill(b"GGGAAACCC", -10.0).is_err());
1433    }
1434
1435    #[test]
1436    fn mccaskill_short_sequence() {
1437        let r = mccaskill(b"AUGC", 310.15).unwrap();
1438        // Too short for any pairs
1439        for i in 0..r.length {
1440            for j in 0..r.length {
1441                assert!((r.pair_probability(i, j) - 0.0).abs() < 1e-10);
1442            }
1443        }
1444    }
1445
1446    // ── Structure comparison ──
1447
1448    #[test]
1449    fn bp_distance_identical() {
1450        let a = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1451        let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1452        assert_eq!(base_pair_distance(&a, &b).unwrap(), 0);
1453    }
1454
1455    #[test]
1456    fn bp_distance_completely_different() {
1457        let a = RnaSecondaryStructure::from_dot_bracket("((....))").unwrap();
1458        let b = RnaSecondaryStructure::from_dot_bracket("........").unwrap();
1459        assert_eq!(base_pair_distance(&a, &b).unwrap(), 2); // 2 pairs in a, 0 in b
1460    }
1461
1462    #[test]
1463    fn bp_distance_symmetric() {
1464        let a = RnaSecondaryStructure::from_dot_bracket("((....))").unwrap();
1465        let b = RnaSecondaryStructure::from_dot_bracket(".((..).)").unwrap();
1466        assert_eq!(
1467            base_pair_distance(&a, &b).unwrap(),
1468            base_pair_distance(&b, &a).unwrap()
1469        );
1470    }
1471
1472    #[test]
1473    fn bp_distance_different_lengths() {
1474        let a = RnaSecondaryStructure::from_dot_bracket("((..))").unwrap();
1475        let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1476        assert!(base_pair_distance(&a, &b).is_err());
1477    }
1478
1479    #[test]
1480    fn mountain_distance_identical() {
1481        let a = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1482        let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1483        assert!((mountain_distance(&a, &b).unwrap() - 0.0).abs() < 1e-10);
1484    }
1485
1486    #[test]
1487    fn mountain_distance_symmetric() {
1488        let a = RnaSecondaryStructure::from_dot_bracket("((....))").unwrap();
1489        let b = RnaSecondaryStructure::from_dot_bracket(".((..).)").unwrap();
1490        let d1 = mountain_distance(&a, &b).unwrap();
1491        let d2 = mountain_distance(&b, &a).unwrap();
1492        assert!((d1 - d2).abs() < 1e-10);
1493    }
1494
1495    #[test]
1496    fn mountain_distance_nonnegative() {
1497        let a = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1498        let b = RnaSecondaryStructure::from_dot_bracket("((...))..").unwrap();
1499        assert!(mountain_distance(&a, &b).unwrap() >= 0.0);
1500    }
1501
1502    #[test]
1503    fn mountain_distance_different_lengths() {
1504        let a = RnaSecondaryStructure::from_dot_bracket("((..))").unwrap();
1505        let b = RnaSecondaryStructure::from_dot_bracket("(((...)))").unwrap();
1506        assert!(mountain_distance(&a, &b).is_err());
1507    }
1508
1509    // ── Energy model helpers ──
1510
1511    #[test]
1512    fn can_pair_valid() {
1513        assert!(can_pair(b'A', b'U'));
1514        assert!(can_pair(b'U', b'A'));
1515        assert!(can_pair(b'G', b'C'));
1516        assert!(can_pair(b'C', b'G'));
1517        assert!(can_pair(b'G', b'U'));
1518        assert!(can_pair(b'U', b'G'));
1519    }
1520
1521    #[test]
1522    fn can_pair_invalid() {
1523        assert!(!can_pair(b'A', b'A'));
1524        assert!(!can_pair(b'A', b'C'));
1525        assert!(!can_pair(b'A', b'G'));
1526        assert!(!can_pair(b'C', b'U'));
1527    }
1528
1529    #[test]
1530    fn stacking_energy_values() {
1531        // CG closing, CG enclosed: STACKING[3][3] = -3.4
1532        let e = stacking_energy(b'C', b'G', b'C', b'G');
1533        assert!((e - (-3.4)).abs() < 1e-10, "CG/CG stack = {}", e);
1534        // GC closing, GC enclosed: STACKING[2][2] = -3.3
1535        let e2 = stacking_energy(b'G', b'C', b'G', b'C');
1536        assert!((e2 - (-3.3)).abs() < 1e-10, "GC/GC stack = {}", e2);
1537        // AU closing, UA enclosed: STACKING[0][1] = -1.1
1538        let e3 = stacking_energy(b'A', b'U', b'U', b'A');
1539        assert!((e3 - (-1.1)).abs() < 1e-10, "AU/UA stack = {}", e3);
1540    }
1541
1542    #[test]
1543    fn normalize_rna_dna_input() {
1544        let r = normalize_rna(b"ATGC").unwrap();
1545        assert_eq!(r, b"AUGC");
1546    }
1547
1548    #[test]
1549    fn normalize_rna_lowercase() {
1550        let r = normalize_rna(b"augc").unwrap();
1551        assert_eq!(r, b"AUGC");
1552    }
1553
1554    #[test]
1555    fn normalize_rna_invalid() {
1556        assert!(normalize_rna(b"AXGC").is_err());
1557    }
1558}