1use bio::bio_types::strand::Strand;
2
3use crate::{
4 constants::{
5 CODON_LENGTH, MINIMUM_EDGE_GENE_LENGTH, MINIMUM_GENE_LENGTH, READING_FRAMES, STT_NOD,
6 },
7 node::validation::{
8 check_start_codon, is_edge_gene, is_reverse_edge_gene, is_valid_gene, is_valid_reverse_gene,
9 },
10 sequence::{encoded::EncodedSequence, is_stop},
11 types::{CodonType, Mask, Node, NodePosition, OrphosError, Training},
12};
13
14type ReadingFrameArray<T> = [T; READING_FRAMES];
16
17struct StrandProcessingContext {
19 last_stop_positions: ReadingFrameArray<usize>,
20 has_start_codon: ReadingFrameArray<bool>,
21 minimum_distances: ReadingFrameArray<usize>,
22 sequence_length: usize,
23 closed: bool,
24}
25
26impl StrandProcessingContext {
27 fn new(sequence_length: usize, sequence_length_mod: usize, closed: bool) -> Self {
29 let mut context = Self {
30 last_stop_positions: ReadingFrameArray::default(),
31 has_start_codon: ReadingFrameArray::default(),
32 minimum_distances: ReadingFrameArray::default(),
33 sequence_length,
34 closed,
35 };
36
37 context.initialize_arrays(sequence_length_mod);
38 context
39 }
40
41 fn initialize_arrays(&mut self, sequence_length_mod: usize) {
43 for i in 0..READING_FRAMES {
45 self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] =
47 self.sequence_length + i; self.has_start_codon[i % READING_FRAMES] = false; self.minimum_distances[i % READING_FRAMES] = MINIMUM_EDGE_GENE_LENGTH; if !self.closed && self.sequence_length > 0 {
52 while self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] + 2
53 > self.sequence_length - 1
54 {
55 self.last_stop_positions[(i + sequence_length_mod) % READING_FRAMES] = self
56 .last_stop_positions[(i + sequence_length_mod) % READING_FRAMES]
57 .saturating_sub(3);
58 }
59 }
60 }
61 }
62}
63
64pub fn add_nodes(
78 encoded_sequence: &EncodedSequence,
79 nodes: &mut Vec<Node>,
80 closed: bool,
81 training: &Training,
82) -> Result<usize, OrphosError> {
83 let sequence_length = encoded_sequence.sequence_length;
84 nodes.clear();
86 nodes.reserve(STT_NOD);
87
88 let sequence_length_mod = sequence_length % READING_FRAMES;
89
90 let mut context = StrandProcessingContext::new(sequence_length, sequence_length_mod, closed);
91 process_forward_strand(
92 &encoded_sequence.forward_sequence,
93 nodes,
94 &mut context,
95 &encoded_sequence.masks,
96 training,
97 );
98 handle_remaining_starts(
99 &encoded_sequence.forward_sequence,
100 nodes,
101 &context,
102 Strand::Forward,
103 training,
104 );
105
106 let mut context = StrandProcessingContext::new(sequence_length, sequence_length_mod, closed);
107 process_reverse_strand(
108 &encoded_sequence.reverse_complement_sequence,
109 nodes,
110 &mut context,
111 &encoded_sequence.masks,
112 training,
113 );
114 handle_remaining_starts(
115 &encoded_sequence.reverse_complement_sequence,
116 nodes,
117 &context,
118 Strand::Reverse,
119 training,
120 );
121
122 Ok(nodes.len())
123}
124
125fn process_forward_strand(
130 encoded_sequence: &[u8],
131 nodes: &mut Vec<Node>,
132 context: &mut StrandProcessingContext,
133 masks: &[Mask],
134 training: &Training,
135) {
136 let scanning_start_position = context.sequence_length.saturating_sub(CODON_LENGTH);
137
138 for position_index in (0..=scanning_start_position).rev() {
139 let reading_frame_index = position_index % READING_FRAMES;
140
141 if is_stop(encoded_sequence, position_index, training) {
142 if context.has_start_codon[reading_frame_index] {
143 let node = create_stop_node(
144 context.last_stop_positions[reading_frame_index],
145 position_index as isize,
146 Strand::Forward,
147 encoded_sequence,
148 training,
149 );
150 nodes.push(node);
151 }
152
153 context.minimum_distances[reading_frame_index] = MINIMUM_GENE_LENGTH;
154 context.last_stop_positions[reading_frame_index] = position_index;
155 context.has_start_codon[reading_frame_index] = false;
156 continue;
157 }
158
159 if context.last_stop_positions[reading_frame_index] >= context.sequence_length {
160 continue;
161 }
162
163 if let Some(codon_type) = check_start_codon(encoded_sequence, position_index, training) {
165 if is_valid_gene(
166 position_index,
167 context.last_stop_positions[reading_frame_index],
168 context.minimum_distances[reading_frame_index],
169 masks,
170 ) {
171 let node = create_start_node(
172 position_index,
173 codon_type,
174 context.last_stop_positions[reading_frame_index] as isize,
175 Strand::Forward,
176 );
177 context.has_start_codon[reading_frame_index] = true;
178 nodes.push(node);
179 }
180 } else if is_edge_gene(
181 position_index,
182 context.last_stop_positions[reading_frame_index],
183 context.closed,
184 masks,
185 ) {
186 let node = create_edge_node(
187 position_index,
188 context.last_stop_positions[reading_frame_index] as isize,
189 Strand::Forward,
190 );
191 context.has_start_codon[reading_frame_index] = true;
192 nodes.push(node);
193 }
194 }
195}
196
197fn process_reverse_strand(
199 reverse_complement_encoded_sequence: &[u8],
200 nodes: &mut Vec<Node>,
201 context: &mut StrandProcessingContext,
202 masks: &[Mask],
203 training: &Training,
204) {
205 let scanning_start_position = context.sequence_length.saturating_sub(CODON_LENGTH);
206
207 for position_index in (0..=scanning_start_position).rev() {
208 let reading_frame_index = position_index % READING_FRAMES;
209
210 if is_stop(
211 reverse_complement_encoded_sequence,
212 position_index,
213 training,
214 ) {
215 if context.has_start_codon[reading_frame_index] {
216 let node = create_reverse_stop_node(
217 context.last_stop_positions[reading_frame_index],
218 position_index as isize,
219 context.sequence_length,
220 reverse_complement_encoded_sequence,
221 training,
222 );
223 nodes.push(node);
224 }
225
226 context.minimum_distances[reading_frame_index] = MINIMUM_GENE_LENGTH;
227 context.last_stop_positions[reading_frame_index] = position_index;
228 context.has_start_codon[reading_frame_index] = false;
229 continue;
230 }
231
232 if context.last_stop_positions[reading_frame_index] >= context.sequence_length {
233 continue;
234 }
235
236 if let Some(codon_type) = check_start_codon(
238 reverse_complement_encoded_sequence,
239 position_index,
240 training,
241 ) {
242 if is_valid_reverse_gene(
243 position_index,
244 context.last_stop_positions[reading_frame_index],
245 context.minimum_distances[reading_frame_index],
246 context.sequence_length,
247 masks,
248 ) {
249 let node = create_reverse_start_node(
250 position_index,
251 codon_type,
252 context.last_stop_positions[reading_frame_index] as isize,
253 context.sequence_length,
254 );
255
256 context.has_start_codon[reading_frame_index] = true;
257 nodes.push(node);
258 }
259 } else if is_reverse_edge_gene(
260 position_index,
261 context.last_stop_positions[reading_frame_index],
262 context.sequence_length,
263 context.closed,
264 masks,
265 ) {
266 let node = create_reverse_edge_node(
267 position_index,
268 context.last_stop_positions[reading_frame_index] as isize,
269 context.sequence_length,
270 );
271
272 context.has_start_codon[reading_frame_index] = true;
273 nodes.push(node);
274 }
275 }
276}
277
278fn handle_remaining_starts(
280 encoded_sequence: &[u8],
281 nodes: &mut Vec<Node>,
282 context: &StrandProcessingContext,
283 strand: Strand,
284 training: &Training,
285) {
286 for i in 0..READING_FRAMES {
287 if context.has_start_codon[i % READING_FRAMES] {
288 let (position_index, stop_value, is_edge) = match strand {
289 Strand::Forward => {
290 let is_edge = !is_stop(
291 encoded_sequence,
292 context.last_stop_positions[i % READING_FRAMES],
293 training,
294 );
295 let stop_val = (i as isize) - 6;
298 (
299 context.last_stop_positions[i % READING_FRAMES],
300 stop_val,
301 is_edge,
302 )
303 }
304 Strand::Reverse => {
305 let is_edge = !is_stop(
306 encoded_sequence,
307 context.last_stop_positions[i % READING_FRAMES],
308 training,
309 );
310 let position_index = context.sequence_length
311 - context.last_stop_positions[i % READING_FRAMES]
312 - 1;
313 let stop_val = (context.sequence_length + 5 - i) as isize;
314 (position_index, stop_val, is_edge)
315 }
316 Strand::Unknown => unreachable!("Unknown strand should not be processed"),
317 };
318
319 nodes.push(Node {
320 position: NodePosition {
321 index: position_index,
322 codon_type: CodonType::Stop,
323 strand,
324 is_edge,
325 stop_value,
326 },
327 ..Node::default()
328 });
329 }
330 }
331}
332
333fn create_stop_node(
335 index: usize,
336 stop_value: isize,
337 strand: Strand,
338 encoded_sequence: &[u8],
339 training: &Training,
340) -> Node {
341 let is_edge = !is_stop(encoded_sequence, index, training);
342
343 Node {
344 position: NodePosition {
345 index,
346 codon_type: CodonType::Stop,
347 strand,
348 is_edge,
349 stop_value,
350 },
351 ..Node::default()
352 }
353}
354
355fn create_reverse_stop_node(
357 index: usize,
358 stop_value: isize,
359 sequence_length: usize,
360 reverse_complement_encoded_sequence: &[u8],
361 training: &Training,
362) -> Node {
363 let is_edge = !is_stop(reverse_complement_encoded_sequence, index, training);
364
365 Node {
366 position: NodePosition {
367 index: sequence_length - index - 1,
368 codon_type: CodonType::Stop,
369 strand: Strand::Reverse,
370 is_edge,
371 stop_value: sequence_length as isize - stop_value - 1,
372 },
373 ..Node::default()
374 }
375}
376
377fn create_start_node(
379 position_index: usize,
380 codon_type: CodonType,
381 stop_value: isize,
382 strand: Strand,
383) -> Node {
384 Node {
385 position: NodePosition {
386 index: position_index,
387 codon_type,
388 strand,
389 is_edge: false,
390 stop_value,
391 },
392 ..Node::default()
393 }
394}
395
396fn create_reverse_start_node(
398 position: usize,
399 codon_type: CodonType,
400 stop_value: isize,
401 sequence_length: usize,
402) -> Node {
403 Node {
404 position: NodePosition {
405 index: sequence_length - position - 1,
406 codon_type,
407 strand: Strand::Reverse,
408 is_edge: false,
409 stop_value: sequence_length as isize - stop_value - 1,
410 },
411 ..Node::default()
412 }
413}
414
415fn create_edge_node(index: usize, stop_value: isize, strand: Strand) -> Node {
417 Node {
418 position: NodePosition {
419 index,
420 codon_type: CodonType::Atg,
421 strand,
422 is_edge: true,
423 stop_value,
424 },
425 ..Node::default()
426 }
427}
428
429fn create_reverse_edge_node(position: usize, stop_value: isize, sequence_length: usize) -> Node {
431 Node {
432 position: NodePosition {
433 index: sequence_length - position - 1,
434 codon_type: CodonType::Atg,
435 strand: Strand::Reverse,
436 is_edge: true,
437 stop_value: sequence_length as isize - stop_value - 1,
438 },
439 ..Node::default()
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use crate::sequence::encode_sequence;
447
448 fn get_encoded_sequence(input: &[u8]) -> Vec<u8> {
449 let sequence_length = input.len();
450 let mut seq = vec![0u8; (sequence_length * 2).div_ceil(8)];
451 let mut unknown_sequence = vec![0u8; sequence_length.div_ceil(8)];
452 let mut masks = Vec::new();
453 let _ = encode_sequence(input, &mut seq, &mut unknown_sequence, &mut masks, false).unwrap();
454 seq
455 }
456
457 fn create_test_encoded_sequence(input: &[u8]) -> EncodedSequence {
458 EncodedSequence::without_masking(input)
459 }
460
461 fn create_test_training() -> Training {
462 Training {
463 gc_content: 0.5,
464 translation_table: 11,
465 uses_shine_dalgarno: true,
466 start_type_weights: [1.0, 2.0, 3.0],
467 rbs_weights: Box::new([1.0; 28]),
468 upstream_composition: Box::new([[0.25; 4]; 32]),
469 motif_weights: Box::new([[[1.0; 4096]; 4]; 4]),
470 no_motif_weight: 0.5,
471 start_weight_factor: 4.35,
472 gc_bias_factors: [1.0; 3],
473 gene_dicodon_table: Box::new([0.0; 4096]),
474 total_dicodons: 0,
475 }
476 }
477
478 #[test]
479 fn test_add_nodes_closed_sequence() {
480 let sequence = b"ATGAAATAAGTGAAATAG";
481 let encoded_sequence = create_test_encoded_sequence(sequence);
482 let mut nodes = Vec::new();
483 let training = create_test_training();
484
485 let result = add_nodes(&encoded_sequence, &mut nodes, true, &training);
486
487 assert!(result.is_ok());
488 }
489
490 #[test]
491 fn test_add_nodes_with_masks() {
492 let sequence = b"ATGAAATAAGTGAAATAG";
493 let encoded_sequence = create_test_encoded_sequence(sequence);
494 let mut nodes = Vec::new();
495 let training = create_test_training();
496
497 let result = add_nodes(&encoded_sequence, &mut nodes, false, &training);
498
499 assert!(result.is_ok());
500 }
501
502 #[test]
503 fn test_initialize_strand_arrays() {
504 let sequence_length = 100;
505 let sequence_length_mod = 1;
506 let closed = false;
507
508 let context = StrandProcessingContext::new(sequence_length, sequence_length_mod, closed);
509
510 assert_eq!(context.has_start_codon, [false; READING_FRAMES]);
512 for &dist in &context.minimum_distances {
513 assert_eq!(dist, MINIMUM_EDGE_GENE_LENGTH);
514 }
515 for &val in &context.last_stop_positions {
517 assert!(val <= sequence_length + READING_FRAMES);
518 }
519 }
520
521 #[test]
522 fn test_initialize_strand_arrays_closed() {
523 let sequence_length = 10;
524 let sequence_length_mod = 0;
525 let closed = true;
526
527 let context = StrandProcessingContext::new(sequence_length, sequence_length_mod, closed);
528
529 assert_eq!(context.has_start_codon, [false; READING_FRAMES]);
531 for &dist in &context.minimum_distances {
532 assert_eq!(dist, MINIMUM_EDGE_GENE_LENGTH);
533 }
534 }
535
536 #[test]
537 fn test_create_start_node() {
538 let node = create_start_node(100, CodonType::Atg, 200, Strand::Forward);
539
540 assert_eq!(node.position.index, 100);
541 assert_eq!(node.position.codon_type, CodonType::Atg);
542 assert_eq!(node.position.strand, Strand::Forward);
543 assert_eq!(node.position.stop_value, 200);
544 assert!(!node.position.is_edge);
545 }
546
547 #[test]
548 fn test_create_reverse_start_node() {
549 let sequence_length = 1000;
550 let node = create_reverse_start_node(100, CodonType::Gtg, 200, sequence_length);
551
552 assert_eq!(node.position.index, sequence_length - 100 - 1);
553 assert_eq!(node.position.codon_type, CodonType::Gtg);
554 assert_eq!(node.position.strand, Strand::Reverse);
555 assert_eq!(
556 node.position.stop_value,
557 (sequence_length - 200 - 1) as isize
558 );
559 assert!(!node.position.is_edge);
560 }
561
562 #[test]
563 fn test_create_edge_node() {
564 let node = create_edge_node(50, 150, Strand::Forward);
565
566 assert_eq!(node.position.index, 50);
567 assert_eq!(node.position.codon_type, CodonType::Atg);
568 assert_eq!(node.position.strand, Strand::Forward);
569 assert_eq!(node.position.stop_value, 150);
570 assert!(node.position.is_edge);
571 }
572
573 #[test]
574 fn test_create_reverse_edge_node() {
575 let sequence_length = 1000;
576 let node = create_reverse_edge_node(50, 150, sequence_length);
577
578 assert_eq!(node.position.index, sequence_length - 50 - 1);
579 assert_eq!(node.position.codon_type, CodonType::Atg);
580 assert_eq!(node.position.strand, Strand::Reverse);
581 assert_eq!(
582 node.position.stop_value,
583 (sequence_length - 150 - 1) as isize
584 );
585 assert!(node.position.is_edge);
586 }
587
588 #[test]
589 fn test_create_stop_node() {
590 let sequence = b"TAAGGG";
591 let encoded_seq = get_encoded_sequence(sequence);
592 let training = create_test_training();
593
594 let node = create_stop_node(0, 3, Strand::Forward, &encoded_seq, &training);
595
596 assert_eq!(node.position.index, 0);
597 assert_eq!(node.position.codon_type, CodonType::Stop);
598 assert_eq!(node.position.strand, Strand::Forward);
599 assert_eq!(node.position.stop_value, 3);
600 }
601
602 #[test]
603 fn test_create_reverse_stop_node() {
604 let sequence = b"TAAGGG";
605 let encoded_seq = get_encoded_sequence(sequence);
606 let training = create_test_training();
607 let sequence_length = sequence.len();
608
609 let node = create_reverse_stop_node(0, 3, sequence_length, &encoded_seq, &training);
610
611 assert_eq!(node.position.index, sequence_length - 1);
612 assert_eq!(node.position.codon_type, CodonType::Stop);
613 assert_eq!(node.position.strand, Strand::Reverse);
614 assert_eq!(node.position.stop_value, (sequence_length - 3 - 1) as isize);
615 }
616
617 #[test]
618 fn test_handle_remaining_starts_forward() {
619 let sequence = b"ATGAAA";
620 let encoded_seq = get_encoded_sequence(sequence);
621 let mut nodes = Vec::new();
622 let training = create_test_training();
623
624 let mut context = StrandProcessingContext::new(sequence.len(), 0, false);
626 context.last_stop_positions = [0, READING_FRAMES, 6];
627 context.has_start_codon = [true, false, true];
628
629 handle_remaining_starts(
630 &encoded_seq,
631 &mut nodes,
632 &context,
633 Strand::Forward,
634 &training,
635 );
636
637 assert_eq!(nodes.len(), 2);
638 assert_eq!(nodes[0].position.strand, Strand::Forward);
639 assert_eq!(nodes[1].position.strand, Strand::Forward);
640 }
641
642 #[test]
643 fn test_handle_remaining_starts_reverse() {
644 let sequence = b"ATGAAA";
645 let encoded_seq = get_encoded_sequence(sequence);
646 let mut nodes = Vec::new();
647 let training = create_test_training();
648
649 let mut context = StrandProcessingContext::new(sequence.len(), 0, false);
651 context.last_stop_positions = [0, READING_FRAMES, 6];
652 context.has_start_codon = [false, true, false];
653
654 handle_remaining_starts(
655 &encoded_seq,
656 &mut nodes,
657 &context,
658 Strand::Reverse,
659 &training,
660 );
661
662 assert_eq!(nodes.len(), 1);
663 assert_eq!(nodes[0].position.strand, Strand::Reverse);
664 }
665
666 #[test]
667 fn test_handle_remaining_starts_no_starts() {
668 let sequence = b"ATGAAA";
669 let encoded_seq = get_encoded_sequence(sequence);
670 let mut nodes = Vec::new();
671 let training = create_test_training();
672
673 let mut context = StrandProcessingContext::new(sequence.len(), 0, false);
675 context.last_stop_positions = [0, READING_FRAMES, 6];
676 context.has_start_codon = [false, false, false];
677
678 handle_remaining_starts(
679 &encoded_seq,
680 &mut nodes,
681 &context,
682 Strand::Forward,
683 &training,
684 );
685
686 assert!(nodes.is_empty());
687 }
688
689 #[test]
690 fn test_add_nodes_short_sequence() {
691 let sequence = b"ATG";
692 let encoded_sequence = create_test_encoded_sequence(sequence);
693 let mut nodes = Vec::new();
694 let training = create_test_training();
695
696 let result = add_nodes(&encoded_sequence, &mut nodes, false, &training);
697
698 assert!(result.is_ok());
699 }
700}