1use bio::bio_types::strand::Strand;
2
3use crate::{
4 constants::{
5 CODON_LENGTH, MINIMUM_EDGE_GENE_LENGTH, MINIMUM_GENE_LENGTH, READING_FRAMES, STT_NOD,
6 },
7 node::validation::{
8 check_start_codon, is_edge_gene, is_reverse_edge_gene, is_valid_gene, is_valid_reverse_gene,
9 },
10 sequence::{encoded::EncodedSequence, is_stop},
11 types::{CodonType, Mask, Node, NodePosition, OrphosError, Training},
12};
13
14type ReadingFrameArray<T> = [T; READING_FRAMES];
16
17struct StrandProcessingContext {
19 last_stop_positions: ReadingFrameArray<usize>,
20 has_start_codon: ReadingFrameArray<bool>,
21 minimum_distances: ReadingFrameArray<usize>,
22 sequence_length: usize,
23 closed: bool,
24 circular: bool,
25}
26
27impl StrandProcessingContext {
28 fn new(
30 sequence_length: usize,
31 sequence_length_mod: usize,
32 closed: bool,
33 circular: bool,
34 ) -> Self {
35 let mut context = Self {
36 last_stop_positions: ReadingFrameArray::default(),
37 has_start_codon: ReadingFrameArray::default(),
38 minimum_distances: ReadingFrameArray::default(),
39 sequence_length,
40 closed,
41 circular,
42 };
43
44 context.initialize_arrays(sequence_length_mod);
45 context
46 }
47
48 fn initialize_arrays(&mut self, sequence_length_mod: usize) {
50 for i in 0..READING_FRAMES {
52 self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] =
54 self.sequence_length + i; self.has_start_codon[i % READING_FRAMES] = false; self.minimum_distances[i % READING_FRAMES] = MINIMUM_EDGE_GENE_LENGTH; if !self.closed && self.sequence_length > 0 {
59 while self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] + 2
60 > self.sequence_length - 1
61 {
62 self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] = self
63 .last_stop_positions[(i + sequence_length_mod) % READING_FRAMES]
64 .saturating_sub(3);
65 }
66 }
67 }
68 }
69
70 fn seed_circular_stops(&mut self, encoded_sequence: &[u8], training: &Training) {
71 if !self.circular || self.sequence_length < CODON_LENGTH {
72 return;
73 }
74
75 let mut first_stop_per_frame: [Option<usize>; READING_FRAMES] = [None; READING_FRAMES];
76 let scanning_start_position = self.sequence_length - CODON_LENGTH;
77
78 for position_index in 0..=scanning_start_position {
79 if is_stop(encoded_sequence, position_index, training) {
80 let frame = position_index % READING_FRAMES;
81 if first_stop_per_frame[frame].is_none() {
82 first_stop_per_frame[frame] = Some(position_index);
83 }
84 }
85 }
86
87 for (frame, first_stop) in first_stop_per_frame.iter().enumerate() {
88 if let Some(stop_position) = first_stop {
89 self.last_stop_positions[frame] = *stop_position;
90 self.minimum_distances[frame] = MINIMUM_GENE_LENGTH;
91 }
92 }
93 }
94}
95
96pub fn add_nodes(
111 encoded_sequence: &EncodedSequence,
112 nodes: &mut Vec<Node>,
113 closed: bool,
114 circular: bool,
115 training: &Training,
116) -> Result<usize, OrphosError> {
117 let sequence_length = encoded_sequence.sequence_length;
118 nodes.clear();
120 nodes.reserve(STT_NOD);
121
122 let sequence_length_mod = sequence_length % READING_FRAMES;
123
124 let mut context =
125 StrandProcessingContext::new(sequence_length, sequence_length_mod, closed, circular);
126 context.seed_circular_stops(&encoded_sequence.forward_sequence, training);
127 process_forward_strand(
128 &encoded_sequence.forward_sequence,
129 nodes,
130 &mut context,
131 &encoded_sequence.masks,
132 training,
133 );
134 handle_remaining_starts(
135 &encoded_sequence.forward_sequence,
136 nodes,
137 &context,
138 Strand::Forward,
139 training,
140 );
141
142 let mut context =
143 StrandProcessingContext::new(sequence_length, sequence_length_mod, closed, circular);
144 context.seed_circular_stops(&encoded_sequence.reverse_complement_sequence, training);
145 process_reverse_strand(
146 &encoded_sequence.reverse_complement_sequence,
147 nodes,
148 &mut context,
149 &encoded_sequence.masks,
150 training,
151 );
152 handle_remaining_starts(
153 &encoded_sequence.reverse_complement_sequence,
154 nodes,
155 &context,
156 Strand::Reverse,
157 training,
158 );
159
160 Ok(nodes.len())
161}
162
163fn process_forward_strand(
168 encoded_sequence: &[u8],
169 nodes: &mut Vec<Node>,
170 context: &mut StrandProcessingContext,
171 masks: &[Mask],
172 training: &Training,
173) {
174 let scanning_start_position = context.sequence_length.saturating_sub(CODON_LENGTH);
175
176 for position_index in (0..=scanning_start_position).rev() {
177 let reading_frame_index = position_index % READING_FRAMES;
178
179 if is_stop(encoded_sequence, position_index, training) {
180 if context.has_start_codon[reading_frame_index] {
181 let node = create_stop_node(
182 context.last_stop_positions[reading_frame_index],
183 position_index as isize,
184 Strand::Forward,
185 encoded_sequence,
186 training,
187 );
188 nodes.push(node);
189 }
190
191 context.minimum_distances[reading_frame_index] = MINIMUM_GENE_LENGTH;
192 context.last_stop_positions[reading_frame_index] = position_index;
193 context.has_start_codon[reading_frame_index] = false;
194 continue;
195 }
196
197 if context.last_stop_positions[reading_frame_index] >= context.sequence_length {
198 continue;
199 }
200
201 if let Some(codon_type) = check_start_codon(encoded_sequence, position_index, training) {
203 if is_valid_gene(
204 position_index,
205 context.last_stop_positions[reading_frame_index],
206 context.minimum_distances[reading_frame_index],
207 context.sequence_length,
208 context.circular,
209 masks,
210 ) {
211 let node = create_start_node(
212 position_index,
213 codon_type,
214 context.last_stop_positions[reading_frame_index] as isize,
215 Strand::Forward,
216 );
217 context.has_start_codon[reading_frame_index] = true;
218 nodes.push(node);
219 }
220 } else if is_edge_gene(
221 position_index,
222 context.last_stop_positions[reading_frame_index],
223 context.closed || context.circular,
224 masks,
225 ) {
226 let node = create_edge_node(
227 position_index,
228 context.last_stop_positions[reading_frame_index] as isize,
229 Strand::Forward,
230 );
231 context.has_start_codon[reading_frame_index] = true;
232 nodes.push(node);
233 }
234 }
235}
236
237fn process_reverse_strand(
239 reverse_complement_encoded_sequence: &[u8],
240 nodes: &mut Vec<Node>,
241 context: &mut StrandProcessingContext,
242 masks: &[Mask],
243 training: &Training,
244) {
245 let scanning_start_position = context.sequence_length.saturating_sub(CODON_LENGTH);
246
247 for position_index in (0..=scanning_start_position).rev() {
248 let reading_frame_index = position_index % READING_FRAMES;
249
250 if is_stop(
251 reverse_complement_encoded_sequence,
252 position_index,
253 training,
254 ) {
255 if context.has_start_codon[reading_frame_index] {
256 let node = create_reverse_stop_node(
257 context.last_stop_positions[reading_frame_index],
258 position_index as isize,
259 context.sequence_length,
260 reverse_complement_encoded_sequence,
261 training,
262 );
263 nodes.push(node);
264 }
265
266 context.minimum_distances[reading_frame_index] = MINIMUM_GENE_LENGTH;
267 context.last_stop_positions[reading_frame_index] = position_index;
268 context.has_start_codon[reading_frame_index] = false;
269 continue;
270 }
271
272 if context.last_stop_positions[reading_frame_index] >= context.sequence_length {
273 continue;
274 }
275
276 if let Some(codon_type) = check_start_codon(
278 reverse_complement_encoded_sequence,
279 position_index,
280 training,
281 ) {
282 if is_valid_reverse_gene(
283 position_index,
284 context.last_stop_positions[reading_frame_index],
285 context.minimum_distances[reading_frame_index],
286 context.sequence_length,
287 context.circular,
288 masks,
289 ) {
290 let node = create_reverse_start_node(
291 position_index,
292 codon_type,
293 context.last_stop_positions[reading_frame_index] as isize,
294 context.sequence_length,
295 );
296
297 context.has_start_codon[reading_frame_index] = true;
298 nodes.push(node);
299 }
300 } else if is_reverse_edge_gene(
301 position_index,
302 context.last_stop_positions[reading_frame_index],
303 context.sequence_length,
304 context.closed || context.circular,
305 masks,
306 ) {
307 let node = create_reverse_edge_node(
308 position_index,
309 context.last_stop_positions[reading_frame_index] as isize,
310 context.sequence_length,
311 );
312
313 context.has_start_codon[reading_frame_index] = true;
314 nodes.push(node);
315 }
316 }
317}
318
319fn handle_remaining_starts(
321 encoded_sequence: &[u8],
322 nodes: &mut Vec<Node>,
323 context: &StrandProcessingContext,
324 strand: Strand,
325 training: &Training,
326) {
327 if context.circular {
328 return;
329 }
330
331 for i in 0..READING_FRAMES {
332 if context.has_start_codon[i % READING_FRAMES] {
333 let (position_index, stop_value, is_edge) = match strand {
334 Strand::Forward => {
335 let is_edge = !is_stop(
336 encoded_sequence,
337 context.last_stop_positions[i % READING_FRAMES],
338 training,
339 );
340 let stop_val = (i as isize) - 6;
343 (
344 context.last_stop_positions[i % READING_FRAMES],
345 stop_val,
346 is_edge,
347 )
348 }
349 Strand::Reverse => {
350 let is_edge = !is_stop(
351 encoded_sequence,
352 context.last_stop_positions[i % READING_FRAMES],
353 training,
354 );
355 let position_index = context.sequence_length
356 - context.last_stop_positions[i % READING_FRAMES]
357 - 1;
358 let stop_val = (context.sequence_length + 5 - i) as isize;
359 (position_index, stop_val, is_edge)
360 }
361 Strand::Unknown => unreachable!("Unknown strand should not be processed"),
362 };
363
364 nodes.push(Node {
365 position: NodePosition {
366 index: position_index,
367 codon_type: CodonType::Stop,
368 strand,
369 is_edge,
370 stop_value,
371 },
372 ..Node::default()
373 });
374 }
375 }
376}
377
378fn create_stop_node(
380 index: usize,
381 stop_value: isize,
382 strand: Strand,
383 encoded_sequence: &[u8],
384 training: &Training,
385) -> Node {
386 let is_edge = !is_stop(encoded_sequence, index, training);
387
388 Node {
389 position: NodePosition {
390 index,
391 codon_type: CodonType::Stop,
392 strand,
393 is_edge,
394 stop_value,
395 },
396 ..Node::default()
397 }
398}
399
400fn create_reverse_stop_node(
402 index: usize,
403 stop_value: isize,
404 sequence_length: usize,
405 reverse_complement_encoded_sequence: &[u8],
406 training: &Training,
407) -> Node {
408 let is_edge = !is_stop(reverse_complement_encoded_sequence, index, training);
409
410 Node {
411 position: NodePosition {
412 index: sequence_length - index - 1,
413 codon_type: CodonType::Stop,
414 strand: Strand::Reverse,
415 is_edge,
416 stop_value: sequence_length as isize - stop_value - 1,
417 },
418 ..Node::default()
419 }
420}
421
422fn create_start_node(
424 position_index: usize,
425 codon_type: CodonType,
426 stop_value: isize,
427 strand: Strand,
428) -> Node {
429 Node {
430 position: NodePosition {
431 index: position_index,
432 codon_type,
433 strand,
434 is_edge: false,
435 stop_value,
436 },
437 ..Node::default()
438 }
439}
440
441fn create_reverse_start_node(
443 position: usize,
444 codon_type: CodonType,
445 stop_value: isize,
446 sequence_length: usize,
447) -> Node {
448 Node {
449 position: NodePosition {
450 index: sequence_length - position - 1,
451 codon_type,
452 strand: Strand::Reverse,
453 is_edge: false,
454 stop_value: sequence_length as isize - stop_value - 1,
455 },
456 ..Node::default()
457 }
458}
459
460fn create_edge_node(index: usize, stop_value: isize, strand: Strand) -> Node {
462 Node {
463 position: NodePosition {
464 index,
465 codon_type: CodonType::Atg,
466 strand,
467 is_edge: true,
468 stop_value,
469 },
470 ..Node::default()
471 }
472}
473
474fn create_reverse_edge_node(position: usize, stop_value: isize, sequence_length: usize) -> Node {
476 Node {
477 position: NodePosition {
478 index: sequence_length - position - 1,
479 codon_type: CodonType::Atg,
480 strand: Strand::Reverse,
481 is_edge: true,
482 stop_value: sequence_length as isize - stop_value - 1,
483 },
484 ..Node::default()
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use crate::sequence::encode_sequence;
492
493 fn get_encoded_sequence(input: &[u8]) -> Vec<u8> {
494 let sequence_length = input.len();
495 let mut seq = vec![0u8; (sequence_length * 2).div_ceil(8)];
496 let mut unknown_sequence = vec![0u8; sequence_length.div_ceil(8)];
497 let mut masks = Vec::new();
498 let _ = encode_sequence(input, &mut seq, &mut unknown_sequence, &mut masks, false).unwrap();
499 seq
500 }
501
502 fn create_test_encoded_sequence(input: &[u8]) -> EncodedSequence {
503 EncodedSequence::without_masking(input)
504 }
505
506 fn create_test_training() -> Training {
507 Training {
508 gc_content: 0.5,
509 translation_table: 11,
510 uses_shine_dalgarno: true,
511 start_type_weights: [1.0, 2.0, 3.0],
512 rbs_weights: Box::new([1.0; 28]),
513 upstream_composition: Box::new([[0.25; 4]; 32]),
514 motif_weights: Box::new([[[1.0; 4096]; 4]; 4]),
515 no_motif_weight: 0.5,
516 start_weight_factor: 4.35,
517 gc_bias_factors: [1.0; 3],
518 gene_dicodon_table: Box::new([0.0; 4096]),
519 total_dicodons: 0,
520 }
521 }
522
523 #[test]
524 fn test_add_nodes_closed_sequence() {
525 let sequence = b"ATGAAATAAGTGAAATAG";
526 let encoded_sequence = create_test_encoded_sequence(sequence);
527 let mut nodes = Vec::new();
528 let training = create_test_training();
529
530 let result = add_nodes(&encoded_sequence, &mut nodes, true, false, &training);
531
532 assert!(result.is_ok());
533 }
534
535 #[test]
536 fn test_add_nodes_with_masks() {
537 let sequence = b"ATGAAATAAGTGAAATAG";
538 let encoded_sequence = create_test_encoded_sequence(sequence);
539 let mut nodes = Vec::new();
540 let training = create_test_training();
541
542 let result = add_nodes(&encoded_sequence, &mut nodes, false, false, &training);
543
544 assert!(result.is_ok());
545 }
546
547 #[test]
548 fn test_initialize_strand_arrays() {
549 let sequence_length = 100;
550 let sequence_length_mod = 1;
551 let closed = false;
552 let circular = false;
553
554 let context =
555 StrandProcessingContext::new(sequence_length, sequence_length_mod, closed, circular);
556
557 assert_eq!(context.has_start_codon, [false; READING_FRAMES]);
559 for &dist in &context.minimum_distances {
560 assert_eq!(dist, MINIMUM_EDGE_GENE_LENGTH);
561 }
562 for &val in &context.last_stop_positions {
564 assert!(val <= sequence_length + READING_FRAMES);
565 }
566 }
567
568 #[test]
569 fn test_initialize_strand_arrays_closed() {
570 let sequence_length = 10;
571 let sequence_length_mod = 0;
572 let closed = true;
573 let circular = false;
574
575 let context =
576 StrandProcessingContext::new(sequence_length, sequence_length_mod, closed, circular);
577
578 assert_eq!(context.has_start_codon, [false; READING_FRAMES]);
580 for &dist in &context.minimum_distances {
581 assert_eq!(dist, MINIMUM_EDGE_GENE_LENGTH);
582 }
583 }
584
585 #[test]
586 fn test_create_start_node() {
587 let node = create_start_node(100, CodonType::Atg, 200, Strand::Forward);
588
589 assert_eq!(node.position.index, 100);
590 assert_eq!(node.position.codon_type, CodonType::Atg);
591 assert_eq!(node.position.strand, Strand::Forward);
592 assert_eq!(node.position.stop_value, 200);
593 assert!(!node.position.is_edge);
594 }
595
596 #[test]
597 fn test_create_reverse_start_node() {
598 let sequence_length = 1000;
599 let node = create_reverse_start_node(100, CodonType::Gtg, 200, sequence_length);
600
601 assert_eq!(node.position.index, sequence_length - 100 - 1);
602 assert_eq!(node.position.codon_type, CodonType::Gtg);
603 assert_eq!(node.position.strand, Strand::Reverse);
604 assert_eq!(
605 node.position.stop_value,
606 (sequence_length - 200 - 1) as isize
607 );
608 assert!(!node.position.is_edge);
609 }
610
611 #[test]
612 fn test_create_edge_node() {
613 let node = create_edge_node(50, 150, Strand::Forward);
614
615 assert_eq!(node.position.index, 50);
616 assert_eq!(node.position.codon_type, CodonType::Atg);
617 assert_eq!(node.position.strand, Strand::Forward);
618 assert_eq!(node.position.stop_value, 150);
619 assert!(node.position.is_edge);
620 }
621
622 #[test]
623 fn test_create_reverse_edge_node() {
624 let sequence_length = 1000;
625 let node = create_reverse_edge_node(50, 150, sequence_length);
626
627 assert_eq!(node.position.index, sequence_length - 50 - 1);
628 assert_eq!(node.position.codon_type, CodonType::Atg);
629 assert_eq!(node.position.strand, Strand::Reverse);
630 assert_eq!(
631 node.position.stop_value,
632 (sequence_length - 150 - 1) as isize
633 );
634 assert!(node.position.is_edge);
635 }
636
637 #[test]
638 fn test_create_stop_node() {
639 let sequence = b"TAAGGG";
640 let encoded_seq = get_encoded_sequence(sequence);
641 let training = create_test_training();
642
643 let node = create_stop_node(0, 3, Strand::Forward, &encoded_seq, &training);
644
645 assert_eq!(node.position.index, 0);
646 assert_eq!(node.position.codon_type, CodonType::Stop);
647 assert_eq!(node.position.strand, Strand::Forward);
648 assert_eq!(node.position.stop_value, 3);
649 }
650
651 #[test]
652 fn test_create_reverse_stop_node() {
653 let sequence = b"TAAGGG";
654 let encoded_seq = get_encoded_sequence(sequence);
655 let training = create_test_training();
656 let sequence_length = sequence.len();
657
658 let node = create_reverse_stop_node(0, 3, sequence_length, &encoded_seq, &training);
659
660 assert_eq!(node.position.index, sequence_length - 1);
661 assert_eq!(node.position.codon_type, CodonType::Stop);
662 assert_eq!(node.position.strand, Strand::Reverse);
663 assert_eq!(node.position.stop_value, (sequence_length - 3 - 1) as isize);
664 }
665
666 #[test]
667 fn test_handle_remaining_starts_forward() {
668 let sequence = b"ATGAAA";
669 let encoded_seq = get_encoded_sequence(sequence);
670 let mut nodes = Vec::new();
671 let training = create_test_training();
672
673 let mut context = StrandProcessingContext::new(sequence.len(), 0, false, false);
675 context.last_stop_positions = [0, READING_FRAMES, 6];
676 context.has_start_codon = [true, false, true];
677
678 handle_remaining_starts(
679 &encoded_seq,
680 &mut nodes,
681 &context,
682 Strand::Forward,
683 &training,
684 );
685
686 assert_eq!(nodes.len(), 2);
687 assert_eq!(nodes[0].position.strand, Strand::Forward);
688 assert_eq!(nodes[1].position.strand, Strand::Forward);
689 }
690
691 #[test]
692 fn test_handle_remaining_starts_reverse() {
693 let sequence = b"ATGAAA";
694 let encoded_seq = get_encoded_sequence(sequence);
695 let mut nodes = Vec::new();
696 let training = create_test_training();
697
698 let mut context = StrandProcessingContext::new(sequence.len(), 0, false, false);
700 context.last_stop_positions = [0, READING_FRAMES, 6];
701 context.has_start_codon = [false, true, false];
702
703 handle_remaining_starts(
704 &encoded_seq,
705 &mut nodes,
706 &context,
707 Strand::Reverse,
708 &training,
709 );
710
711 assert_eq!(nodes.len(), 1);
712 assert_eq!(nodes[0].position.strand, Strand::Reverse);
713 }
714
715 #[test]
716 fn test_handle_remaining_starts_no_starts() {
717 let sequence = b"ATGAAA";
718 let encoded_seq = get_encoded_sequence(sequence);
719 let mut nodes = Vec::new();
720 let training = create_test_training();
721
722 let mut context = StrandProcessingContext::new(sequence.len(), 0, false, false);
724 context.last_stop_positions = [0, READING_FRAMES, 6];
725 context.has_start_codon = [false, false, false];
726
727 handle_remaining_starts(
728 &encoded_seq,
729 &mut nodes,
730 &context,
731 Strand::Forward,
732 &training,
733 );
734
735 assert!(nodes.is_empty());
736 }
737
738 #[test]
739 fn test_add_nodes_short_sequence() {
740 let sequence = b"ATG";
741 let encoded_sequence = create_test_encoded_sequence(sequence);
742 let mut nodes = Vec::new();
743 let training = create_test_training();
744
745 let result = add_nodes(&encoded_sequence, &mut nodes, false, false, &training);
746
747 assert!(result.is_ok());
748 }
749
750 #[test]
751 fn test_add_nodes_circular_detects_wrapped_forward_start() {
752 let mut sequence = vec![b'C'; 150];
754 sequence[20..23].copy_from_slice(b"TAA");
755 sequence[80..83].copy_from_slice(b"ATG");
756
757 let encoded_sequence = create_test_encoded_sequence(&sequence);
758 let training = create_test_training();
759
760 let mut linear_nodes = Vec::new();
761 let _ = add_nodes(&encoded_sequence, &mut linear_nodes, true, false, &training).unwrap();
762
763 let mut circular_nodes = Vec::new();
764 let _ = add_nodes(
765 &encoded_sequence,
766 &mut circular_nodes,
767 false,
768 true,
769 &training,
770 )
771 .unwrap();
772
773 let has_wrapped_forward_start_linear = linear_nodes.iter().any(|n| {
774 n.position.strand == Strand::Forward
775 && n.position.codon_type != CodonType::Stop
776 && n.position.index == 80
777 && n.position.stop_value == 20
778 });
779 let has_wrapped_forward_start_circular = circular_nodes.iter().any(|n| {
780 n.position.strand == Strand::Forward
781 && n.position.codon_type != CodonType::Stop
782 && n.position.index == 80
783 && n.position.stop_value == 20
784 });
785
786 assert!(!has_wrapped_forward_start_linear);
787 assert!(has_wrapped_forward_start_circular);
788 }
789}