1use std::os::raw::c_int;
7use std::sync::Arc;
8
9use rayon::prelude::*;
10
11use crate::types::{Gene, Node, Training, MAX_GENES, MAX_SEQ, NUM_META};
12use super::convert::gene_to_predicted;
13use super::encode::SequenceBuffer;
14use super::types::{PredictedGene, ProdigalConfig, ProdigalError};
15
16use super::predict::{sort_nodes, validate_config};
17
18const STACK_SIZE: usize = 32 * 1024 * 1024; use crate::node::{add_nodes, reset_node_scores, score_nodes, record_overlapping_starts};
21use crate::dprog::{dprog, eliminate_bad_genes};
22use crate::gene::{add_genes, tweak_final_starts, record_gene_data};
23
24pub struct MetaPredictor {
29 pool: rayon::ThreadPool,
30 models: Arc<Vec<Training>>,
31 config: ProdigalConfig,
32}
33
34impl MetaPredictor {
35 pub fn new() -> Result<Self, ProdigalError> {
37 Self::with_config(ProdigalConfig::default())
38 }
39
40 pub fn with_config(config: ProdigalConfig) -> Result<Self, ProdigalError> {
42 validate_config(&config)?;
43
44 let pool = rayon::ThreadPoolBuilder::new()
45 .stack_size(STACK_SIZE)
46 .build()
47 .map_err(|e| ProdigalError::Io(std::io::Error::new(
48 std::io::ErrorKind::Other, e.to_string(),
49 )))?;
50
51 let models = Arc::new(load_meta_models());
52
53 Ok(MetaPredictor { pool, models, config })
54 }
55
56 pub fn predict(&self, seq: &[u8]) -> Result<Vec<PredictedGene>, ProdigalError> {
58 if seq.is_empty() {
59 return Err(ProdigalError::EmptySequence);
60 }
61 if seq.len() > MAX_SEQ {
62 return Err(ProdigalError::SequenceTooLong {
63 length: seq.len(),
64 max: MAX_SEQ,
65 });
66 }
67
68 let models = &self.models;
69 let config = &self.config;
70
71 self.pool.install(|| predict_parallel(seq, models, config))
72 }
73}
74
75fn load_meta_models() -> Vec<Training> {
76 let mut models: Vec<Training> = Vec::with_capacity(NUM_META);
77 for i in 0..NUM_META {
78 let mut tinf: Training = unsafe { std::mem::zeroed() };
79 unsafe {
80 crate::training_data::load_metagenome(i, &mut tinf as *mut Training);
81 }
82 models.push(tinf);
83 }
84 models
85}
86
87struct TransTableGroup {
89 model_indices: Vec<usize>,
91 nodes: Vec<Node>,
93 nn: c_int,
95}
96
97fn predict_parallel(
98 seq: &[u8],
99 models: &[Training],
100 config: &ProdigalConfig,
101) -> Result<Vec<PredictedGene>, ProdigalError> {
102 let closed = if config.closed_ends { 1 } else { 0 };
103
104 let mut buf = SequenceBuffer::new();
105 let (slen, gc) = unsafe { buf.encode(seq, config.mask_n_runs) };
106 if slen == 0 {
107 return Err(ProdigalError::EmptySequence);
108 }
109 buf.ensure_node_capacity(slen);
110
111 let mut low = 0.88495 * gc - 0.0102337;
113 if low > 0.65 { low = 0.65; }
114 let mut high = 0.86596 * gc + 0.1131991;
115 if high < 0.35 { high = 0.35; }
116
117 let mut groups: Vec<TransTableGroup> = Vec::new();
120 let mut nn: c_int = 0;
121
122 for i in 0..NUM_META {
123 let need_rebuild = i == 0 || models[i].trans_table != models[i - 1].trans_table;
124
125 if need_rebuild {
126 let mut tinf_copy = models[i].clone();
129 unsafe {
130 buf.clear_nodes(nn);
131 nn = add_nodes(
132 buf.seq.as_mut_ptr(), buf.rseq.as_mut_ptr(), slen,
133 buf.nodes.as_mut_ptr(), closed,
134 buf.masks.as_mut_ptr(), buf.nmask,
135 &mut tinf_copy,
136 );
137 }
138 sort_nodes(&mut buf.nodes[..nn as usize]);
139
140 groups.push(TransTableGroup {
141 model_indices: Vec::new(),
142 nodes: buf.nodes[..nn as usize].to_vec(),
143 nn,
144 });
145 }
146
147 if models[i].gc >= low && models[i].gc <= high {
149 groups.last_mut().unwrap().model_indices.push(i);
150 }
151 }
152
153 let seq_addr = buf.seq.as_ptr() as usize;
158 let rseq_addr = buf.rseq.as_ptr() as usize;
159
160 struct ModelScore {
161 phase: usize,
162 score: f64,
163 }
164
165 let best = groups.par_iter().flat_map(|group| {
166 group.model_indices.par_iter().map(|&model_idx| {
167 let mut nodes = group.nodes.clone();
168 let nn = group.nn;
169 let mut tinf = models[model_idx].clone();
171
172 unsafe {
173 reset_node_scores(nodes.as_mut_ptr(), nn);
174 score_nodes(
175 seq_addr as *mut u8, rseq_addr as *mut u8, slen,
176 nodes.as_mut_ptr(), nn, &mut tinf, closed, 1,
177 );
178 record_overlapping_starts(nodes.as_mut_ptr(), nn, &mut tinf, 1);
179 let ipath = dprog(nodes.as_mut_ptr(), nn, &mut tinf, 1);
180 if ipath < 0 || ipath >= nn {
181 return ModelScore { phase: model_idx, score: f64::NEG_INFINITY };
182 }
183 ModelScore {
184 phase: model_idx,
185 score: nodes[ipath as usize].score,
186 }
187 }
188 })
189 })
190 .reduce(
191 || ModelScore { phase: 0, score: f64::NEG_INFINITY },
192 |a, b| if a.score >= b.score { a } else { b },
193 );
194
195 if best.score == f64::NEG_INFINITY {
196 return Ok(Vec::new());
197 }
198
199 let mut tinf = models[best.phase].clone();
202
203 let best_group = groups.iter().find(|g| g.model_indices.contains(&best.phase)).unwrap();
205 let mut nodes = best_group.nodes.clone();
206 let nn = best_group.nn;
207 let mut genes: Vec<Gene> = vec![unsafe { std::mem::zeroed() }; MAX_GENES];
208
209 unsafe {
210 reset_node_scores(nodes.as_mut_ptr(), nn);
211 score_nodes(
212 buf.seq.as_mut_ptr(), buf.rseq.as_mut_ptr(), slen,
213 nodes.as_mut_ptr(), nn, &mut tinf, closed, 1,
214 );
215 record_overlapping_starts(nodes.as_mut_ptr(), nn, &mut tinf, 1);
216 let ipath = dprog(nodes.as_mut_ptr(), nn, &mut tinf, 1);
217 eliminate_bad_genes(nodes.as_mut_ptr(), ipath, &mut tinf);
218 let ng = add_genes(genes.as_mut_ptr(), nodes.as_mut_ptr(), ipath);
219 tweak_final_starts(genes.as_mut_ptr(), ng, nodes.as_mut_ptr(), nn, &mut tinf);
220 record_gene_data(genes.as_mut_ptr(), ng, nodes.as_mut_ptr(), &mut tinf, 1);
221
222 let mut result = Vec::with_capacity(ng as usize);
223 for i in 0..ng {
224 result.push(gene_to_predicted(&genes[i as usize], nodes.as_ptr(), &tinf));
225 }
226 Ok(result)
227 }
228}