orphos_core/node/
creation.rs

1use bio::bio_types::strand::Strand;
2
3use crate::{
4    constants::{
5        CODON_LENGTH, MINIMUM_EDGE_GENE_LENGTH, MINIMUM_GENE_LENGTH, READING_FRAMES, STT_NOD,
6    },
7    node::validation::{
8        check_start_codon, is_edge_gene, is_reverse_edge_gene, is_valid_gene, is_valid_reverse_gene,
9    },
10    sequence::{encoded::EncodedSequence, is_stop},
11    types::{CodonType, Mask, Node, NodePosition, OrphosError, Training},
12};
13
14/// Type alias for arrays indexed by reading frame
15type ReadingFrameArray<T> = [T; READING_FRAMES];
16
17/// Context for processing DNA strands containing all tracking state
18struct StrandProcessingContext {
19    last_stop_positions: ReadingFrameArray<usize>,
20    has_start_codon: ReadingFrameArray<bool>,
21    minimum_distances: ReadingFrameArray<usize>,
22    sequence_length: usize,
23    closed: bool,
24}
25
26impl StrandProcessingContext {
27    /// Create a new processing context with initialized arrays
28    fn new(sequence_length: usize, sequence_length_mod: usize, closed: bool) -> Self {
29        let mut context = Self {
30            last_stop_positions: ReadingFrameArray::default(),
31            has_start_codon: ReadingFrameArray::default(),
32            minimum_distances: ReadingFrameArray::default(),
33            sequence_length,
34            closed,
35        };
36
37        context.initialize_arrays(sequence_length_mod);
38        context
39    }
40
41    /// Initialize the tracking arrays for strand processing
42    fn initialize_arrays(&mut self, sequence_length_mod: usize) {
43        // Initialize arrays with different indexing patterns to match original C code
44        for i in 0..READING_FRAMES {
45            // Different indexing patterns to match C code exactly
46            self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] =
47                self.sequence_length + i; // Uses (i+slmod)%3 for last
48            self.has_start_codon[i % READING_FRAMES] = false; // Uses i%3 for saw_start  
49            self.minimum_distances[i % READING_FRAMES] = MINIMUM_EDGE_GENE_LENGTH; // Uses i%3 for min_dist
50
51            if !self.closed && self.sequence_length > 0 {
52                while self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] + 2
53                    > self.sequence_length - 1
54                {
55                    self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] = self
56                        .last_stop_positions[(i + sequence_length_mod) % READING_FRAMES]
57                        .saturating_sub(3);
58                }
59            }
60        }
61    }
62}
63
64/// Add nodes for start and stop codons in both directions
65///
66/// This function identifies potential start and stop codons in the sequence
67/// and creates corresponding nodes for gene prediction analysis.
68///
69/// # Arguments
70/// * `encoded_sequence` - The complete encoded sequence data
71/// * `nodes` - Vector to store the created nodes
72/// * `closed` - Whether to treat sequence as circular/closed
73/// * `training` - Training data for scoring parameters
74///
75/// # Returns
76/// Number of nodes created, or error if processing fails
77pub fn add_nodes(
78    encoded_sequence: &EncodedSequence,
79    nodes: &mut Vec<Node>,
80    closed: bool,
81    training: &Training,
82) -> Result<usize, OrphosError> {
83    let sequence_length = encoded_sequence.sequence_length;
84    // Clear the nodes vector
85    nodes.clear();
86    nodes.reserve(STT_NOD);
87
88    let sequence_length_mod = sequence_length % READING_FRAMES;
89
90    let mut context = StrandProcessingContext::new(sequence_length, sequence_length_mod, closed);
91    process_forward_strand(
92        &encoded_sequence.forward_sequence,
93        nodes,
94        &mut context,
95        &encoded_sequence.masks,
96        training,
97    );
98    handle_remaining_starts(
99        &encoded_sequence.forward_sequence,
100        nodes,
101        &context,
102        Strand::Forward,
103        training,
104    );
105
106    let mut context = StrandProcessingContext::new(sequence_length, sequence_length_mod, closed);
107    process_reverse_strand(
108        &encoded_sequence.reverse_complement_sequence,
109        nodes,
110        &mut context,
111        &encoded_sequence.masks,
112        training,
113    );
114    handle_remaining_starts(
115        &encoded_sequence.reverse_complement_sequence,
116        nodes,
117        &context,
118        Strand::Reverse,
119        training,
120    );
121
122    Ok(nodes.len())
123}
124
125/// Process the forward strand to identify potential genes
126///
127/// Scans the forward strand from 3' to 5' end, looking for start and stop codons
128/// and creating nodes for valid gene boundaries that meet length requirements.
129fn process_forward_strand(
130    encoded_sequence: &[u8],
131    nodes: &mut Vec<Node>,
132    context: &mut StrandProcessingContext,
133    masks: &[Mask],
134    training: &Training,
135) {
136    let scanning_start_position = context.sequence_length.saturating_sub(CODON_LENGTH);
137
138    for position_index in (0..=scanning_start_position).rev() {
139        let reading_frame_index = position_index % READING_FRAMES;
140
141        if is_stop(encoded_sequence, position_index, training) {
142            if context.has_start_codon[reading_frame_index] {
143                let node = create_stop_node(
144                    context.last_stop_positions[reading_frame_index],
145                    position_index as isize,
146                    Strand::Forward,
147                    encoded_sequence,
148                    training,
149                );
150                nodes.push(node);
151            }
152
153            context.minimum_distances[reading_frame_index] = MINIMUM_GENE_LENGTH;
154            context.last_stop_positions[reading_frame_index] = position_index;
155            context.has_start_codon[reading_frame_index] = false;
156            continue;
157        }
158
159        if context.last_stop_positions[reading_frame_index] >= context.sequence_length {
160            continue;
161        }
162
163        // Check for start codons with unified logic
164        if let Some(codon_type) = check_start_codon(encoded_sequence, position_index, training) {
165            if is_valid_gene(
166                position_index,
167                context.last_stop_positions[reading_frame_index],
168                context.minimum_distances[reading_frame_index],
169                masks,
170            ) {
171                let node = create_start_node(
172                    position_index,
173                    codon_type,
174                    context.last_stop_positions[reading_frame_index] as isize,
175                    Strand::Forward,
176                );
177                context.has_start_codon[reading_frame_index] = true;
178                nodes.push(node);
179            }
180        } else if is_edge_gene(
181            position_index,
182            context.last_stop_positions[reading_frame_index],
183            context.closed,
184            masks,
185        ) {
186            let node = create_edge_node(
187                position_index,
188                context.last_stop_positions[reading_frame_index] as isize,
189                Strand::Forward,
190            );
191            context.has_start_codon[reading_frame_index] = true;
192            nodes.push(node);
193        }
194    }
195}
196
197/// Process reverse strand to find genes
198fn process_reverse_strand(
199    reverse_complement_encoded_sequence: &[u8],
200    nodes: &mut Vec<Node>,
201    context: &mut StrandProcessingContext,
202    masks: &[Mask],
203    training: &Training,
204) {
205    let scanning_start_position = context.sequence_length.saturating_sub(CODON_LENGTH);
206
207    for position_index in (0..=scanning_start_position).rev() {
208        let reading_frame_index = position_index % READING_FRAMES;
209
210        if is_stop(
211            reverse_complement_encoded_sequence,
212            position_index,
213            training,
214        ) {
215            if context.has_start_codon[reading_frame_index] {
216                let node = create_reverse_stop_node(
217                    context.last_stop_positions[reading_frame_index],
218                    position_index as isize,
219                    context.sequence_length,
220                    reverse_complement_encoded_sequence,
221                    training,
222                );
223                nodes.push(node);
224            }
225
226            context.minimum_distances[reading_frame_index] = MINIMUM_GENE_LENGTH;
227            context.last_stop_positions[reading_frame_index] = position_index;
228            context.has_start_codon[reading_frame_index] = false;
229            continue;
230        }
231
232        if context.last_stop_positions[reading_frame_index] >= context.sequence_length {
233            continue;
234        }
235
236        // Check for start codons on reverse strand
237        if let Some(codon_type) = check_start_codon(
238            reverse_complement_encoded_sequence,
239            position_index,
240            training,
241        ) {
242            if is_valid_reverse_gene(
243                position_index,
244                context.last_stop_positions[reading_frame_index],
245                context.minimum_distances[reading_frame_index],
246                context.sequence_length,
247                masks,
248            ) {
249                let node = create_reverse_start_node(
250                    position_index,
251                    codon_type,
252                    context.last_stop_positions[reading_frame_index] as isize,
253                    context.sequence_length,
254                );
255
256                context.has_start_codon[reading_frame_index] = true;
257                nodes.push(node);
258            }
259        } else if is_reverse_edge_gene(
260            position_index,
261            context.last_stop_positions[reading_frame_index],
262            context.sequence_length,
263            context.closed,
264            masks,
265        ) {
266            let node = create_reverse_edge_node(
267                position_index,
268                context.last_stop_positions[reading_frame_index] as isize,
269                context.sequence_length,
270            );
271
272            context.has_start_codon[reading_frame_index] = true;
273            nodes.push(node);
274        }
275    }
276}
277
278/// Handle remaining starts at the end of strand processing
279fn handle_remaining_starts(
280    encoded_sequence: &[u8],
281    nodes: &mut Vec<Node>,
282    context: &StrandProcessingContext,
283    strand: Strand,
284    training: &Training,
285) {
286    for i in 0..READING_FRAMES {
287        if context.has_start_codon[i % READING_FRAMES] {
288            let (position_index, stop_value, is_edge) = match strand {
289                Strand::Forward => {
290                    let is_edge = !is_stop(
291                        encoded_sequence,
292                        context.last_stop_positions[i % READING_FRAMES],
293                        training,
294                    );
295                    // C code encodes last few forward stops with negative stop_val (i-6)
296                    // to allow earliest start connections across sequence edge.
297                    let stop_val = (i as isize) - 6;
298                    (
299                        context.last_stop_positions[i % READING_FRAMES],
300                        stop_val,
301                        is_edge,
302                    )
303                }
304                Strand::Reverse => {
305                    let is_edge = !is_stop(
306                        encoded_sequence,
307                        context.last_stop_positions[i % READING_FRAMES],
308                        training,
309                    );
310                    let position_index = context.sequence_length
311                        - context.last_stop_positions[i % READING_FRAMES]
312                        - 1;
313                    let stop_val = (context.sequence_length + 5 - i) as isize;
314                    (position_index, stop_val, is_edge)
315                }
316                Strand::Unknown => unreachable!("Unknown strand should not be processed"),
317            };
318
319            nodes.push(Node {
320                position: NodePosition {
321                    index: position_index,
322                    codon_type: CodonType::Stop,
323                    strand,
324                    is_edge,
325                    stop_value,
326                },
327                ..Node::default()
328            });
329        }
330    }
331}
332
333/// Create a stop node for forward strand
334fn create_stop_node(
335    index: usize,
336    stop_value: isize,
337    strand: Strand,
338    encoded_sequence: &[u8],
339    training: &Training,
340) -> Node {
341    let is_edge = !is_stop(encoded_sequence, index, training);
342
343    Node {
344        position: NodePosition {
345            index,
346            codon_type: CodonType::Stop,
347            strand,
348            is_edge,
349            stop_value,
350        },
351        ..Node::default()
352    }
353}
354
355/// Create a stop node for reverse strand
356fn create_reverse_stop_node(
357    index: usize,
358    stop_value: isize,
359    sequence_length: usize,
360    reverse_complement_encoded_sequence: &[u8],
361    training: &Training,
362) -> Node {
363    let is_edge = !is_stop(reverse_complement_encoded_sequence, index, training);
364
365    Node {
366        position: NodePosition {
367            index: sequence_length - index - 1,
368            codon_type: CodonType::Stop,
369            strand: Strand::Reverse,
370            is_edge,
371            stop_value: sequence_length as isize - stop_value - 1,
372        },
373        ..Node::default()
374    }
375}
376
377/// Create a start node
378fn create_start_node(
379    position_index: usize,
380    codon_type: CodonType,
381    stop_value: isize,
382    strand: Strand,
383) -> Node {
384    Node {
385        position: NodePosition {
386            index: position_index,
387            codon_type,
388            strand,
389            is_edge: false,
390            stop_value,
391        },
392        ..Node::default()
393    }
394}
395
396/// Create a reverse start node
397fn create_reverse_start_node(
398    position: usize,
399    codon_type: CodonType,
400    stop_value: isize,
401    sequence_length: usize,
402) -> Node {
403    Node {
404        position: NodePosition {
405            index: sequence_length - position - 1,
406            codon_type,
407            strand: Strand::Reverse,
408            is_edge: false,
409            stop_value: sequence_length as isize - stop_value - 1,
410        },
411        ..Node::default()
412    }
413}
414
415/// Create an edge node
416fn create_edge_node(index: usize, stop_value: isize, strand: Strand) -> Node {
417    Node {
418        position: NodePosition {
419            index,
420            codon_type: CodonType::Atg,
421            strand,
422            is_edge: true,
423            stop_value,
424        },
425        ..Node::default()
426    }
427}
428
429/// Create a reverse edge node
430fn create_reverse_edge_node(position: usize, stop_value: isize, sequence_length: usize) -> Node {
431    Node {
432        position: NodePosition {
433            index: sequence_length - position - 1,
434            codon_type: CodonType::Atg,
435            strand: Strand::Reverse,
436            is_edge: true,
437            stop_value: sequence_length as isize - stop_value - 1,
438        },
439        ..Node::default()
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use crate::sequence::encode_sequence;
447
448    fn get_encoded_sequence(input: &[u8]) -> Vec<u8> {
449        let sequence_length = input.len();
450        let mut seq = vec![0u8; (sequence_length * 2).div_ceil(8)];
451        let mut unknown_sequence = vec![0u8; sequence_length.div_ceil(8)];
452        let mut masks = Vec::new();
453        let _ = encode_sequence(input, &mut seq, &mut unknown_sequence, &mut masks, false).unwrap();
454        seq
455    }
456
457    fn create_test_encoded_sequence(input: &[u8]) -> EncodedSequence {
458        EncodedSequence::without_masking(input)
459    }
460
461    fn create_test_training() -> Training {
462        Training {
463            gc_content: 0.5,
464            translation_table: 11,
465            uses_shine_dalgarno: true,
466            start_type_weights: [1.0, 2.0, 3.0],
467            rbs_weights: Box::new([1.0; 28]),
468            upstream_composition: Box::new([[0.25; 4]; 32]),
469            motif_weights: Box::new([[[1.0; 4096]; 4]; 4]),
470            no_motif_weight: 0.5,
471            start_weight_factor: 4.35,
472            gc_bias_factors: [1.0; 3],
473            gene_dicodon_table: Box::new([0.0; 4096]),
474            total_dicodons: 0,
475        }
476    }
477
478    #[test]
479    fn test_add_nodes_closed_sequence() {
480        let sequence = b"ATGAAATAAGTGAAATAG";
481        let encoded_sequence = create_test_encoded_sequence(sequence);
482        let mut nodes = Vec::new();
483        let training = create_test_training();
484
485        let result = add_nodes(&encoded_sequence, &mut nodes, true, &training);
486
487        assert!(result.is_ok());
488    }
489
490    #[test]
491    fn test_add_nodes_with_masks() {
492        let sequence = b"ATGAAATAAGTGAAATAG";
493        let encoded_sequence = create_test_encoded_sequence(sequence);
494        let mut nodes = Vec::new();
495        let training = create_test_training();
496
497        let result = add_nodes(&encoded_sequence, &mut nodes, false, &training);
498
499        assert!(result.is_ok());
500    }
501
502    #[test]
503    fn test_initialize_strand_arrays() {
504        let sequence_length = 100;
505        let sequence_length_mod = 1;
506        let closed = false;
507
508        let context = StrandProcessingContext::new(sequence_length, sequence_length_mod, closed);
509
510        // Check that arrays were properly initialized
511        assert_eq!(context.has_start_codon, [false; READING_FRAMES]);
512        for &dist in &context.minimum_distances {
513            assert_eq!(dist, MINIMUM_EDGE_GENE_LENGTH);
514        }
515        // Let's just verify they are reasonable values
516        for &val in &context.last_stop_positions {
517            assert!(val <= sequence_length + READING_FRAMES);
518        }
519    }
520
521    #[test]
522    fn test_initialize_strand_arrays_closed() {
523        let sequence_length = 10;
524        let sequence_length_mod = 0;
525        let closed = true;
526
527        let context = StrandProcessingContext::new(sequence_length, sequence_length_mod, closed);
528
529        // Check that arrays were properly initialized
530        assert_eq!(context.has_start_codon, [false; READING_FRAMES]);
531        for &dist in &context.minimum_distances {
532            assert_eq!(dist, MINIMUM_EDGE_GENE_LENGTH);
533        }
534    }
535
536    #[test]
537    fn test_create_start_node() {
538        let node = create_start_node(100, CodonType::Atg, 200, Strand::Forward);
539
540        assert_eq!(node.position.index, 100);
541        assert_eq!(node.position.codon_type, CodonType::Atg);
542        assert_eq!(node.position.strand, Strand::Forward);
543        assert_eq!(node.position.stop_value, 200);
544        assert!(!node.position.is_edge);
545    }
546
547    #[test]
548    fn test_create_reverse_start_node() {
549        let sequence_length = 1000;
550        let node = create_reverse_start_node(100, CodonType::Gtg, 200, sequence_length);
551
552        assert_eq!(node.position.index, sequence_length - 100 - 1);
553        assert_eq!(node.position.codon_type, CodonType::Gtg);
554        assert_eq!(node.position.strand, Strand::Reverse);
555        assert_eq!(
556            node.position.stop_value,
557            (sequence_length - 200 - 1) as isize
558        );
559        assert!(!node.position.is_edge);
560    }
561
562    #[test]
563    fn test_create_edge_node() {
564        let node = create_edge_node(50, 150, Strand::Forward);
565
566        assert_eq!(node.position.index, 50);
567        assert_eq!(node.position.codon_type, CodonType::Atg);
568        assert_eq!(node.position.strand, Strand::Forward);
569        assert_eq!(node.position.stop_value, 150);
570        assert!(node.position.is_edge);
571    }
572
573    #[test]
574    fn test_create_reverse_edge_node() {
575        let sequence_length = 1000;
576        let node = create_reverse_edge_node(50, 150, sequence_length);
577
578        assert_eq!(node.position.index, sequence_length - 50 - 1);
579        assert_eq!(node.position.codon_type, CodonType::Atg);
580        assert_eq!(node.position.strand, Strand::Reverse);
581        assert_eq!(
582            node.position.stop_value,
583            (sequence_length - 150 - 1) as isize
584        );
585        assert!(node.position.is_edge);
586    }
587
588    #[test]
589    fn test_create_stop_node() {
590        let sequence = b"TAAGGG";
591        let encoded_seq = get_encoded_sequence(sequence);
592        let training = create_test_training();
593
594        let node = create_stop_node(0, 3, Strand::Forward, &encoded_seq, &training);
595
596        assert_eq!(node.position.index, 0);
597        assert_eq!(node.position.codon_type, CodonType::Stop);
598        assert_eq!(node.position.strand, Strand::Forward);
599        assert_eq!(node.position.stop_value, 3);
600    }
601
602    #[test]
603    fn test_create_reverse_stop_node() {
604        let sequence = b"TAAGGG";
605        let encoded_seq = get_encoded_sequence(sequence);
606        let training = create_test_training();
607        let sequence_length = sequence.len();
608
609        let node = create_reverse_stop_node(0, 3, sequence_length, &encoded_seq, &training);
610
611        assert_eq!(node.position.index, sequence_length - 1);
612        assert_eq!(node.position.codon_type, CodonType::Stop);
613        assert_eq!(node.position.strand, Strand::Reverse);
614        assert_eq!(node.position.stop_value, (sequence_length - 3 - 1) as isize);
615    }
616
617    #[test]
618    fn test_handle_remaining_starts_forward() {
619        let sequence = b"ATGAAA";
620        let encoded_seq = get_encoded_sequence(sequence);
621        let mut nodes = Vec::new();
622        let training = create_test_training();
623
624        // Create a context with some start codons detected
625        let mut context = StrandProcessingContext::new(sequence.len(), 0, false);
626        context.last_stop_positions = [0, READING_FRAMES, 6];
627        context.has_start_codon = [true, false, true];
628
629        handle_remaining_starts(
630            &encoded_seq,
631            &mut nodes,
632            &context,
633            Strand::Forward,
634            &training,
635        );
636
637        assert_eq!(nodes.len(), 2);
638        assert_eq!(nodes[0].position.strand, Strand::Forward);
639        assert_eq!(nodes[1].position.strand, Strand::Forward);
640    }
641
642    #[test]
643    fn test_handle_remaining_starts_reverse() {
644        let sequence = b"ATGAAA";
645        let encoded_seq = get_encoded_sequence(sequence);
646        let mut nodes = Vec::new();
647        let training = create_test_training();
648
649        // Create a context with some start codons detected
650        let mut context = StrandProcessingContext::new(sequence.len(), 0, false);
651        context.last_stop_positions = [0, READING_FRAMES, 6];
652        context.has_start_codon = [false, true, false];
653
654        handle_remaining_starts(
655            &encoded_seq,
656            &mut nodes,
657            &context,
658            Strand::Reverse,
659            &training,
660        );
661
662        assert_eq!(nodes.len(), 1);
663        assert_eq!(nodes[0].position.strand, Strand::Reverse);
664    }
665
666    #[test]
667    fn test_handle_remaining_starts_no_starts() {
668        let sequence = b"ATGAAA";
669        let encoded_seq = get_encoded_sequence(sequence);
670        let mut nodes = Vec::new();
671        let training = create_test_training();
672
673        // Create a context with no start codons detected
674        let mut context = StrandProcessingContext::new(sequence.len(), 0, false);
675        context.last_stop_positions = [0, READING_FRAMES, 6];
676        context.has_start_codon = [false, false, false];
677
678        handle_remaining_starts(
679            &encoded_seq,
680            &mut nodes,
681            &context,
682            Strand::Forward,
683            &training,
684        );
685
686        assert!(nodes.is_empty());
687    }
688
689    #[test]
690    fn test_add_nodes_short_sequence() {
691        let sequence = b"ATG";
692        let encoded_sequence = create_test_encoded_sequence(sequence);
693        let mut nodes = Vec::new();
694        let training = create_test_training();
695
696        let result = add_nodes(&encoded_sequence, &mut nodes, false, &training);
697
698        assert!(result.is_ok());
699    }
700}