orphos_core/training/
mod.rs

1//! Training algorithms for gene prediction models.
2//!
3//! This module implements the unsupervised machine learning algorithms that
4//! train Orphos's statistical models from genome sequences.
5//!
6//! ## Overview
7//!
8//! Training extracts statistical patterns from genes predicted in an initial pass:
9//!
10//! 1. **Initial gene finding**: Find high-confidence genes using basic models
11//! 2. **Codon usage**: Calculate dicodon frequencies in predicted genes
12//! 3. **Start codon preference**: Learn ATG/GTG/TTG usage patterns
13//! 4. **RBS detection**: Identify ribosome binding site motifs (Shine-Dalgarno)
14//! 5. **Upstream composition**: Analyze nucleotide patterns near start codons
15//! 6. **GC bias**: Detect reading frame preferences based on GC content
16//!
17//! ## Training Modes
18//!
19//! - **Shine-Dalgarno (SD)**: For organisms with canonical RBS motifs
20//! - **Non-SD**: For organisms without RBS or with alternative start recognition
21//!
22//! The mode is auto-detected based on the strength of SD signals in the training data.
23//!
24//! ## Modules
25//!
26//! - [`sd_training`]: Shine-Dalgarno motif training
27//! - [`non_sd_training`]: Alternative start recognition training
28//! - [`common`]: Shared training utilities
29//!
30//! ## Examples
31//!
32//! Training is normally performed automatically by the `OrphosAnalyzer`, but
33//! can be done manually for advanced use cases:
34//!
35//! ```rust,no_run
36//! use orphos_core::engine::UntrainedOrphos;
37//! use orphos_core::config::OrphosConfig;
38//! use orphos_core::sequence::encoded::EncodedSequence;
39//!
40//! let mut orphos = UntrainedOrphos::new();
41//! let sequence = b"ATGAAACGCATTAGCACCACCATT...";
42//! let encoded = EncodedSequence::without_masking(sequence);
43//!
44//! // Train on the genome
45//! let trained = orphos.train_single_genome(&encoded)?;
46//!
47//! // Training data is now stored in the TrainedOrphos instance
48//! # Ok::<(), orphos_core::types::OrphosError>(())
49//! ```
50
51pub mod common;
52pub mod non_sd_training;
53pub mod sd_training;
54
55use bio::bio_types::strand::Strand;
56use rayon::prelude::*;
57
58use crate::{
59    constants::{
60        GC_HIGH_AT_FREQ, GC_HIGH_GC_FREQ, GC_LOW_AT_FREQ, GC_LOW_GC_FREQ, GENE_RATIO_THRESHOLD,
61        HIGH_GC_FREQ, INITIAL_SCORE_THRESHOLD, LOW_GC_FREQ, MAX_GC_CONTENT, MAX_MOTIF_INDEX,
62        MAX_TRAINING_ITERATIONS_NONSD, MAX_TRAINING_ITERATIONS_SD, MIN_GC_CONTENT,
63        MIN_MOTIF_LENGTH, NUM_BASES, NUM_CODON_TYPES, NUM_MOTIF_SIZES, NUM_RBS_WEIGHTS,
64        RBS_WEIGHT_STRONG_THRESHOLD, RBS_WEIGHT_THRESHOLD_HIGH, RBS_WEIGHT_THRESHOLD_LOW,
65        THRESHOLD_DIVISOR, UPSTREAM_END_POS, UPSTREAM_MOTIF_COVERAGE_THRESHOLD, UPSTREAM_POSITIONS,
66        UPSTREAM_SKIP_END, UPSTREAM_START_POS, WEIGHT_CLAMP_MAX, WEIGHT_CLAMP_MIN,
67    },
68    sequence::calculate_kmer_index,
69    types::{CodonType, MotifWeights, Node, OrphosError, Training},
70};
71
72/// Checks if training should use Shine-Dalgarno motifs.
73///
74/// Analyzes the RBS weights learned during training to determine whether
75/// the organism uses canonical Shine-Dalgarno ribosome binding sites.
76///
77/// # Arguments
78///
79/// * `training` - Training data with RBS weights
80///
81/// # Returns
82///
83/// `true` if SD motifs should be used, `false` for non-SD start recognition.
84///
85/// # Algorithm
86///
87/// The decision is based on:
88/// - Strength of canonical SD motifs (AGGAGG variants)
89/// - Absence of strong SD signals indicates non-SD organism
90/// - Threshold-based classification using empirically determined cutoffs
91#[must_use]
92pub fn should_use_sd(training: &Training) -> bool {
93    if training.rbs_weights[0] >= 0.0 {
94        return false;
95    }
96    if training.rbs_weights[16] < RBS_WEIGHT_THRESHOLD_HIGH
97        && training.rbs_weights[13] < RBS_WEIGHT_THRESHOLD_HIGH
98        && training.rbs_weights[15] < RBS_WEIGHT_THRESHOLD_HIGH
99        && (training.rbs_weights[0] >= RBS_WEIGHT_THRESHOLD_LOW
100            || (training.rbs_weights[22] < RBS_WEIGHT_STRONG_THRESHOLD
101                && training.rbs_weights[24] < RBS_WEIGHT_STRONG_THRESHOLD
102                && training.rbs_weights[27] < RBS_WEIGHT_STRONG_THRESHOLD))
103    {
104        return false;
105    }
106    true
107}
108
109fn count_upstream_composition(
110    sequence: &[u8],
111    sequence_length: usize,
112    strand: Strand,
113    position: usize,
114    training: &mut Training,
115) {
116    let start_position = match strand {
117        Strand::Forward => position,
118        Strand::Reverse => sequence_length - 1 - position,
119        Strand::Unknown => unreachable!(),
120    };
121
122    let mut count = 0;
123    for upstream_index in UPSTREAM_START_POS..UPSTREAM_END_POS {
124        if upstream_index > 2 && upstream_index < UPSTREAM_SKIP_END {
125            continue;
126        }
127
128        if start_position >= upstream_index {
129            let base_index = calculate_kmer_index(1, sequence, start_position - upstream_index);
130            training.upstream_composition[count][base_index] += 1.0;
131        }
132        count += 1;
133    }
134}
135
136/// Convert upstream composition counts to log likelihood scores
137fn convert_upstream_composition_to_log_scores(training: &mut Training) {
138    training
139        .upstream_composition
140        .par_iter_mut()
141        .take(UPSTREAM_POSITIONS)
142        .for_each(|position_data| {
143            let sum: f64 = position_data.iter().sum();
144
145            if sum > 0.0 {
146                for (j, value) in position_data.iter_mut().enumerate().take(NUM_BASES) {
147                    *value /= sum;
148
149                    let bg_freq = if training.gc_content > MIN_GC_CONTENT
150                        && training.gc_content < MAX_GC_CONTENT
151                    {
152                        if j == 0 || j == 3 {
153                            (1.0 - training.gc_content) / 2.0
154                        } else {
155                            training.gc_content / 2.0
156                        }
157                    } else if training.gc_content <= MIN_GC_CONTENT {
158                        if j == 0 || j == 3 {
159                            LOW_GC_FREQ
160                        } else {
161                            HIGH_GC_FREQ
162                        }
163                    } else if j == 0 || j == 3 {
164                        HIGH_GC_FREQ
165                    } else {
166                        LOW_GC_FREQ
167                    };
168
169                    *value = (*value / bg_freq)
170                        .ln()
171                        .clamp(WEIGHT_CLAMP_MIN, WEIGHT_CLAMP_MAX);
172                }
173            } else {
174                *position_data = [0.0; NUM_BASES];
175            }
176        });
177}
178
179/// Update motif counts during training
180fn update_motif_counts(
181    motif_counts: &mut MotifWeights,
182    zero_motif_count: &mut f64,
183    sequence: &[u8],
184    reverse_sequence: &[u8],
185    sequence_length: usize,
186    node: &Node,
187    stage: usize,
188) {
189    if node.position.codon_type == CodonType::Stop || node.position.is_edge {
190        return;
191    }
192
193    let is_zero = node.motif_info.best_motif.length == 0;
194    if is_zero {
195        *zero_motif_count += 1.0;
196        return;
197    }
198
199    let (working_sequence, start_position) = match node.position.strand {
200        Strand::Forward => (sequence, node.position.index),
201        Strand::Reverse => (reverse_sequence, sequence_length - 1 - node.position.index),
202        Strand::Unknown => unreachable!(),
203    };
204
205    match stage {
206        0 => {
207            for (motif_length_index, motif_counts_length) in
208                motif_counts.iter_mut().enumerate().take(NUM_MOTIF_SIZES)
209            {
210                let motif_length = motif_length_index + MIN_MOTIF_LENGTH;
211                let j_start = start_position as isize - 18 - motif_length_index as isize;
212                let j_end = start_position as isize - 6 - motif_length_index as isize;
213                for j in j_start..=j_end {
214                    if j < 0 {
215                        continue;
216                    }
217                    let motif_index =
218                        calculate_kmer_index(motif_length, working_sequence, j as usize);
219                    for spacer_data in motif_counts_length.iter_mut().take(NUM_MOTIF_SIZES) {
220                        spacer_data[motif_index] += 1.0;
221                    }
222                }
223            }
224        }
225        1 => {
226            let motif = &node.motif_info.best_motif;
227            motif_counts[motif.length - MIN_MOTIF_LENGTH][motif.space_index][motif.index] += 1.0;
228
229            for (submotif_length_index, motif_counts_length) in motif_counts
230                .iter_mut()
231                .enumerate()
232                .take(motif.length - MIN_MOTIF_LENGTH)
233            {
234                let submotif_length = submotif_length_index + MIN_MOTIF_LENGTH;
235                let j_start = start_position as isize - (motif.spacer + motif.length) as isize;
236                let j_end = start_position as isize - (motif.spacer + submotif_length) as isize;
237                for j in j_start..=j_end {
238                    if j < 0 {
239                        continue;
240                    }
241                    let spacer_index =
242                        get_spacer_index(j as usize, start_position, submotif_length_index);
243                    let motif_index =
244                        calculate_kmer_index(submotif_length, working_sequence, j as usize);
245                    motif_counts_length[spacer_index][motif_index] += 1.0;
246                }
247            }
248        }
249        2 => {
250            let motif = &node.motif_info.best_motif;
251            motif_counts[motif.length - MIN_MOTIF_LENGTH][motif.space_index][motif.index] += 1.0;
252        }
253        _ => {}
254    }
255}
256
257#[allow(clippy::needless_range_loop)]
258fn build_coverage_map(
259    real_motifs: &[[[f64; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES],
260    good_motifs: &mut [[[i32; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES],
261    number_of_genes: f64,
262    _stage: usize,
263) {
264    let threshold = UPSTREAM_MOTIF_COVERAGE_THRESHOLD;
265
266    // Initialize all as not good
267    *good_motifs = [[[0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES];
268
269    // 3-base motifs: mark as good if above threshold
270    for spacer_index in 0..NUM_MOTIF_SIZES {
271        for motif_index in 0..64 {
272            if real_motifs[0][spacer_index][motif_index] / number_of_genes >= threshold {
273                for alternative_spacer_index in 0..NUM_MOTIF_SIZES {
274                    good_motifs[0][alternative_spacer_index][motif_index] = 1;
275                }
276            }
277        }
278    }
279
280    // 4-base motifs: must contain two valid 3-base motifs
281    for spacer_index in 0..NUM_MOTIF_SIZES {
282        for motif_index in 0..256 {
283            let decomposition_0 = (motif_index & 252) >> 2;
284            let decomposition_1 = motif_index & 63;
285            if good_motifs[0][spacer_index][decomposition_0] != 0
286                && good_motifs[0][spacer_index][decomposition_1] != 0
287            {
288                good_motifs[1][spacer_index][motif_index] = 1;
289            }
290        }
291    }
292
293    // 5-base motifs: interior mismatch allowed if entire 5-base motif
294    // represents 3 valid 3-base motifs (if mismatch converted)
295    for spacer_index in 0..NUM_MOTIF_SIZES {
296        for motif_index in 0..1024 {
297            let decomp0 = (motif_index & 1008) >> 4; // top 3 bases
298            let decomp1 = (motif_index & 252) >> 2; // middle 3 bases
299            let decomp2 = motif_index & 63; // bottom 3 bases
300            if good_motifs[0][spacer_index][decomp0] == 0
301                || good_motifs[0][spacer_index][decomp1] == 0
302                || good_motifs[0][spacer_index][decomp2] == 0
303            {
304                continue;
305            }
306            good_motifs[2][spacer_index][motif_index] = 1;
307
308            let mut tmp = motif_index;
309            for k in (0..=16).step_by(16) {
310                tmp ^= k;
311                for l in (0..=32).step_by(32) {
312                    tmp ^= l;
313                    if good_motifs[2][spacer_index][tmp] == 0 {
314                        good_motifs[2][spacer_index][tmp] = 2; // good with mismatch
315                    }
316                }
317            }
318        }
319    }
320
321    // 6-base motifs: must contain two valid 5-base motifs
322    for spacer_index in 0..NUM_MOTIF_SIZES {
323        for motif_index in 0..4096 {
324            let decomp0 = (motif_index & 4092) >> 2; // top 5 bases
325            let decomp1 = motif_index & 1023; // bottom 5 bases
326            if good_motifs[2][spacer_index][decomp0] == 0
327                || good_motifs[2][spacer_index][decomp1] == 0
328            {
329                continue;
330            }
331            if good_motifs[2][spacer_index][decomp0] == 1
332                && good_motifs[2][spacer_index][decomp1] == 1
333            {
334                good_motifs[3][spacer_index][motif_index] = 1;
335            } else {
336                good_motifs[3][spacer_index][motif_index] = 2; // good with mismatch
337            }
338        }
339    }
340}
341
342/// Update motif weights based on real vs background counts
343fn update_motif_weights(
344    motif_real: &[[[f64; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES],
345    motif_background: &[[[f64; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES],
346    zero_motif_real: f64,
347    zero_motif_background: f64,
348    motif_good: &[[[i32; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES],
349    stage: usize,
350    training: &mut Training,
351) {
352    // Stage is unused in the weight update after aligning with the C logic
353    // (zbg accumulation no longer depends on stage).
354    let _ = stage;
355    // 1) sum = sum(mreal) + zreal
356    // 2) For bad motifs: zreal += mreal; zbg += mreal; set mreal=0 and mbg=0
357    // 3) Normalize mreal by sum and compute mot_wt = log(mreal/bg) or -4 if bg==0
358    // 4) zreal /= sum; no_mot = log(zreal/zbg) or -4 if zbg==0; clamp in [-4,4]
359
360    let mut sum_real = zero_motif_real;
361    for motif in motif_real.iter() {
362        for motif_row in motif.iter() {
363            for &value in motif_row.iter().take(MAX_MOTIF_INDEX) {
364                sum_real += value;
365            }
366        }
367    }
368
369    if sum_real == 0.0 {
370        // If no real counts, zero all motif weights and no_mot as in C
371        training
372            .motif_weights
373            .fill([[0.0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]);
374        training.no_motif_weight = 0.0;
375        return;
376    }
377
378    // Make a mutable copy of background so we can zero out bad motifs (matches C)
379    let mut bg = *motif_background;
380
381    // counts of bad motifs to zbg), then zero their slots in the background matrix.
382    let mut zreal = zero_motif_real;
383    let mut zbg = zero_motif_background; // already normalized earlier
384
385    let mut _good_real_sum = 0.0f64;
386    let mut _bad_real_sum = 0.0f64;
387
388    // We'll accumulate a temporary array for normalized real frequencies of good motifs
389    let mut real_freqs = [[[0.0f64; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES];
390
391    for l in 0..NUM_MOTIF_SIZES {
392        for s in 0..NUM_MOTIF_SIZES {
393            for idx in 0..MAX_MOTIF_INDEX {
394                let r = motif_real[l][s][idx];
395                if motif_good[l][s][idx] == 0 {
396                    // Bad motif: fold raw real counts into zreal and zbg
397                    // (C code: zreal += mreal; zbg += mreal)
398                    zreal += r;
399                    zbg += r;
400                    _bad_real_sum += r;
401                    // Zero out background slot (so bg==0 triggers -4.0)
402                    bg[l][s][idx] = 0.0;
403                    // real_freqs remains 0 here
404                } else {
405                    // Good motif: keep and normalize by sum later
406                    _good_real_sum += r;
407                    real_freqs[l][s][idx] = r / sum_real;
408                }
409            }
410        }
411    }
412
413    for l in 0..NUM_MOTIF_SIZES {
414        for s in 0..NUM_MOTIF_SIZES {
415            for idx in 0..MAX_MOTIF_INDEX {
416                let rf = real_freqs[l][s][idx];
417                let b = bg[l][s][idx];
418                let w = if b != 0.0 {
419                    (rf / b).ln()
420                } else {
421                    WEIGHT_CLAMP_MIN
422                };
423                training.motif_weights[l][s][idx] = w.clamp(WEIGHT_CLAMP_MIN, WEIGHT_CLAMP_MAX);
424            }
425        }
426    }
427
428    let zreal_freq = zreal / sum_real;
429    let mut no_mot = if zbg != 0.0 {
430        (zreal_freq / zbg).ln()
431    } else {
432        WEIGHT_CLAMP_MIN
433    };
434    no_mot = no_mot.clamp(WEIGHT_CLAMP_MIN, WEIGHT_CLAMP_MAX);
435    training.no_motif_weight = no_mot;
436
437    // removed TAP motif_weights_debug
438}
439
440/// Get spacer index based on position
441const fn get_spacer_index(
442    position_index: usize,
443    start_position: usize,
444    motif_length_index: usize,
445) -> usize {
446    if position_index + 16 + motif_length_index <= start_position {
447        3 // 13-15bp spacer
448    } else if position_index + 14 + motif_length_index <= start_position {
449        2 // 11-12bp spacer
450    } else if position_index + 7 + motif_length_index >= start_position {
451        1 // 3-4bp spacer
452    } else {
453        0 // 5-10bp spacer
454    }
455}
456
457pub fn load_training_file(_file: &str) -> Training {
458    Training::default()
459}
460
461pub const fn write_training_file(_file: &str, _training: &Training) -> Result<(), OrphosError> {
462    Ok(())
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468    use crate::types::{Motif, NodeMotifInfo, NodePosition, NodeScores, NodeState};
469    use bio::bio_types::strand::Strand;
470
471    fn create_test_training() -> Training {
472        Training {
473            gc_content: 0.5,
474            translation_table: 11,
475            uses_shine_dalgarno: true,
476            start_type_weights: [0.0; 3],
477            rbs_weights: Box::new([0.0; 28]),
478            upstream_composition: Box::new([[0.0; 4]; 32]),
479            motif_weights: Box::new([[[0.0; 4096]; 4]; 4]),
480            no_motif_weight: 0.0,
481            start_weight_factor: 4.35,
482            gc_bias_factors: [1.0; 3],
483            gene_dicodon_table: Box::new([0.0; 4096]),
484            total_dicodons: 0,
485        }
486    }
487
488    fn create_test_node_with_motif(
489        index: usize,
490        strand: Strand,
491        codon_type: CodonType,
492        motif_length: usize,
493        motif_index: usize,
494    ) -> Node {
495        Node {
496            position: NodePosition {
497                index,
498                strand,
499                codon_type,
500                stop_value: (index + 100) as isize,
501                is_edge: false,
502            },
503            scores: NodeScores {
504                gc_content: 0.5,
505                coding_score: 5.0,
506                start_score: 2.0,
507                ribosome_binding_score: 1.0,
508                type_score: 1.5,
509                upstream_score: 0.5,
510                total_score: 10.0,
511                gc_frame_scores: [1.0, 2.0, 3.0],
512            },
513            state: NodeState::default(),
514            motif_info: NodeMotifInfo {
515                ribosome_binding_sites: [0, 0],
516                best_motif: Motif {
517                    index: motif_index,
518                    length: motif_length,
519                    space_index: 0,
520                    spacer: 8,
521                    score: 2.0,
522                },
523            },
524        }
525    }
526
527    #[test]
528    fn test_should_use_sd_positive_rbs_weight() {
529        let mut training = create_test_training();
530        training.rbs_weights[0] = 1.0;
531
532        assert!(!should_use_sd(&training));
533    }
534
535    #[test]
536    fn test_should_use_sd_low_key_motifs() {
537        let mut training = create_test_training();
538        training.rbs_weights[0] = -1.0;
539        training.rbs_weights[16] = 0.5;
540        training.rbs_weights[13] = 0.5;
541        training.rbs_weights[15] = 0.5;
542
543        assert!(!should_use_sd(&training));
544    }
545
546    #[test]
547    fn test_should_use_sd_high_key_motifs() {
548        let mut training = create_test_training();
549        training.rbs_weights[0] = -1.0;
550        training.rbs_weights[16] = 2.0;
551        training.rbs_weights[13] = 2.0;
552        training.rbs_weights[15] = 2.0;
553
554        assert!(should_use_sd(&training));
555    }
556
557    #[test]
558    fn test_count_upstream_composition_forward() {
559        let sequence = b"ATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCG";
560        let sequence_length = sequence.len();
561        let mut training = create_test_training();
562        let position = 20;
563
564        count_upstream_composition(
565            sequence,
566            sequence_length,
567            Strand::Forward,
568            position,
569            &mut training,
570        );
571
572        // Check that some composition was counted
573        let total_counts: f64 = training
574            .upstream_composition
575            .iter()
576            .flat_map(|row| row.iter())
577            .sum();
578
579        assert!(total_counts > 0.0);
580    }
581
582    #[test]
583    fn test_count_upstream_composition_reverse() {
584        let sequence = b"ATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCG";
585        let sequence_length = sequence.len();
586        let mut training = create_test_training();
587        let position = 20;
588
589        count_upstream_composition(
590            sequence,
591            sequence_length,
592            Strand::Reverse,
593            position,
594            &mut training,
595        );
596
597        // Check that some composition was counted
598        let total_counts: f64 = training
599            .upstream_composition
600            .iter()
601            .flat_map(|row| row.iter())
602            .sum();
603
604        assert!(total_counts > 0.0);
605    }
606
607    #[test]
608    fn test_count_upstream_composition_edge_position() {
609        let sequence = b"ATGCGATGC";
610        let sequence_length = sequence.len();
611        let mut training = create_test_training();
612        let position = 2; // Near start of sequence
613
614        count_upstream_composition(
615            sequence,
616            sequence_length,
617            Strand::Forward,
618            position,
619            &mut training,
620        );
621    }
622
623    #[test]
624    fn test_convert_upstream_composition_to_log_scores() {
625        let mut training = create_test_training();
626
627        training.upstream_composition[0] = [10.0, 5.0, 3.0, 2.0];
628        training.upstream_composition[1] = [1.0, 1.0, 1.0, 1.0];
629
630        convert_upstream_composition_to_log_scores(&mut training);
631
632        assert_ne!(training.upstream_composition[0], [10.0, 5.0, 3.0, 2.0]);
633
634        for position in &*training.upstream_composition {
635            for &value in position {
636                assert!(value >= WEIGHT_CLAMP_MIN);
637                assert!(value <= WEIGHT_CLAMP_MAX);
638            }
639        }
640    }
641
642    #[test]
643    fn test_convert_upstream_composition_zero_sum() {
644        let mut training = create_test_training();
645
646        training.upstream_composition[0] = [0.0, 0.0, 0.0, 0.0];
647
648        convert_upstream_composition_to_log_scores(&mut training);
649
650        assert_eq!(training.upstream_composition[0], [0.0, 0.0, 0.0, 0.0]);
651    }
652
653    #[test]
654    fn test_update_motif_counts_stage_0() {
655        let sequence = b"ATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCG";
656        let reverse_sequence = b"CGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATG";
657        let sequence_length = sequence.len();
658        let mut motif_counts =
659            Box::new([[[0.0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES]);
660        let mut zero_motif_count = 0.0;
661
662        let node = create_test_node_with_motif(20, Strand::Forward, CodonType::Atg, 4, 10);
663
664        update_motif_counts(
665            &mut motif_counts,
666            &mut zero_motif_count,
667            sequence,
668            reverse_sequence,
669            sequence_length,
670            &node,
671            0,
672        );
673
674        let total_counts: f64 = motif_counts
675            .iter()
676            .flat_map(|l| l.iter())
677            .flat_map(|s| s.iter())
678            .sum();
679
680        assert!(total_counts > 0.0);
681    }
682
683    #[test]
684    fn test_update_motif_counts_zero_motif() {
685        let sequence = b"ATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCG";
686        let reverse_sequence = b"CGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATG";
687        let sequence_length = sequence.len();
688        let mut motif_counts =
689            Box::new([[[0.0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES]);
690        let mut zero_motif_count = 0.0;
691
692        let mut node = create_test_node_with_motif(20, Strand::Forward, CodonType::Atg, 0, 0);
693        node.motif_info.best_motif.length = 0;
694
695        update_motif_counts(
696            &mut motif_counts,
697            &mut zero_motif_count,
698            sequence,
699            reverse_sequence,
700            sequence_length,
701            &node,
702            0,
703        );
704
705        // Should increment zero motif count
706        assert_eq!(zero_motif_count, 1.0);
707    }
708
709    #[test]
710    fn test_update_motif_counts_stop_codon() {
711        let sequence = b"ATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCG";
712        let reverse_sequence = b"CGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATGCGATG";
713        let sequence_length = sequence.len();
714        let mut motif_counts =
715            Box::new([[[0.0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES]);
716        let mut zero_motif_count = 0.0;
717
718        let node = create_test_node_with_motif(20, Strand::Forward, CodonType::Stop, 4, 10);
719
720        update_motif_counts(
721            &mut motif_counts,
722            &mut zero_motif_count,
723            sequence,
724            reverse_sequence,
725            sequence_length,
726            &node,
727            0,
728        );
729
730        // Should not count anything for stop codons
731        let total_counts: f64 = motif_counts
732            .iter()
733            .flat_map(|l| l.iter())
734            .flat_map(|s| s.iter())
735            .sum();
736
737        assert_eq!(total_counts, 0.0);
738        assert_eq!(zero_motif_count, 0.0);
739    }
740
741    #[test]
742    fn test_build_coverage_map() {
743        let real_motifs = Box::new([[[0.0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES]);
744        let mut good_motifs = Box::new([[[0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES]);
745        let number_of_genes = 100.0;
746
747        build_coverage_map(&real_motifs, &mut good_motifs, number_of_genes, 0);
748
749        // Should initialize all motifs as not good (0)
750        let has_good_motifs = good_motifs
751            .iter()
752            .flat_map(|l| l.iter())
753            .flat_map(|s| s.iter())
754            .any(|&val| val != 0);
755
756        assert!(!has_good_motifs);
757    }
758
759    #[test]
760    fn test_get_spacer_index() {
761        let start_position = 100;
762        let motif_length_index = 1;
763
764        // Test different spacer distances
765        assert_eq!(get_spacer_index(83, start_position, motif_length_index), 3); // 13-15bp
766        assert_eq!(get_spacer_index(85, start_position, motif_length_index), 2); // 11-12bp
767        assert_eq!(get_spacer_index(95, start_position, motif_length_index), 1); // 3-4bp
768        assert_eq!(get_spacer_index(90, start_position, motif_length_index), 0); // 5-10bp
769    }
770
771    #[test]
772    fn test_load_training_file() {
773        let training = load_training_file("dummy_file.txt");
774
775        // Should return default training
776        assert_eq!(training.translation_table, 11);
777        assert_eq!(training.gc_content, 0.5);
778    }
779
780    #[test]
781    fn test_write_training_file() {
782        let training = create_test_training();
783
784        // Should not panic (placeholder implementation)
785        let result = write_training_file("dummy_output.txt", &training);
786        assert!(result.is_ok());
787    }
788
789    #[test]
790    fn test_update_motif_weights_basic() {
791        let motif_real = Box::new([[[0.0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES]);
792        let motif_background =
793            Box::new([[[0.0; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES]);
794        let zero_motif_real = 1.0;
795        let zero_motif_background = 0.5;
796        let motif_good = Box::new([[[1; MAX_MOTIF_INDEX]; NUM_MOTIF_SIZES]; NUM_MOTIF_SIZES]);
797        let mut training = create_test_training();
798
799        update_motif_weights(
800            &motif_real,
801            &motif_background,
802            zero_motif_real,
803            zero_motif_background,
804            &motif_good,
805            0,
806            &mut training,
807        );
808
809        // Should complete without panicking
810        // Function completion indicates success
811    }
812}