Skip to main content

prodigal_rs/api/
meta_predictor.rs

1//! Reusable metagenomic gene predictor with parallel model evaluation.
2//!
3//! Caches the 50 metagenomic models and uses a rayon thread pool
4//! to score qualifying models in parallel.
5
6use std::os::raw::c_int;
7use std::sync::Arc;
8
9use rayon::prelude::*;
10
11use crate::types::{Gene, Node, Training, MAX_GENES, MAX_SEQ, NUM_META};
12use super::convert::gene_to_predicted;
13use super::encode::SequenceBuffer;
14use super::types::{PredictedGene, ProdigalConfig, ProdigalError};
15
16use super::predict::{sort_nodes, validate_config};
17
18const STACK_SIZE: usize = 32 * 1024 * 1024; // 32 MB
19
20use crate::node::{add_nodes, reset_node_scores, score_nodes, record_overlapping_starts};
21use crate::dprog::{dprog, eliminate_bad_genes};
22use crate::gene::{add_genes, tweak_final_starts, record_gene_data};
23
24/// Reusable metagenomic gene predictor.
25///
26/// Pre-loads 50 metagenomic models and evaluates qualifying models
27/// in parallel using a rayon thread pool with large stacks.
28pub struct MetaPredictor {
29    pool: rayon::ThreadPool,
30    models: Arc<Vec<Training>>,
31    config: ProdigalConfig,
32}
33
34impl MetaPredictor {
35    /// Create a new predictor with default config.
36    pub fn new() -> Result<Self, ProdigalError> {
37        Self::with_config(ProdigalConfig::default())
38    }
39
40    /// Create a new predictor with custom config.
41    pub fn with_config(config: ProdigalConfig) -> Result<Self, ProdigalError> {
42        validate_config(&config)?;
43
44        let pool = rayon::ThreadPoolBuilder::new()
45            .stack_size(STACK_SIZE)
46            .build()
47            .map_err(|e| ProdigalError::Io(std::io::Error::new(
48                std::io::ErrorKind::Other, e.to_string(),
49            )))?;
50
51        let models = Arc::new(load_meta_models());
52
53        Ok(MetaPredictor { pool, models, config })
54    }
55
56    /// Predict genes in the given sequence.
57    pub fn predict(&self, seq: &[u8]) -> Result<Vec<PredictedGene>, ProdigalError> {
58        if seq.is_empty() {
59            return Err(ProdigalError::EmptySequence);
60        }
61        if seq.len() > MAX_SEQ {
62            return Err(ProdigalError::SequenceTooLong {
63                length: seq.len(),
64                max: MAX_SEQ,
65            });
66        }
67
68        let models = &self.models;
69        let config = &self.config;
70
71        self.pool.install(|| predict_parallel(seq, models, config))
72    }
73}
74
75fn load_meta_models() -> Vec<Training> {
76    let mut models: Vec<Training> = Vec::with_capacity(NUM_META);
77    for i in 0..NUM_META {
78        let mut tinf: Training = unsafe { std::mem::zeroed() };
79        unsafe {
80            crate::training_data::load_metagenome(i, &mut tinf as *mut Training);
81        }
82        models.push(tinf);
83    }
84    models
85}
86
87/// Nodes built for one translation table, shared across models in that group.
88struct TransTableGroup {
89    /// Indices into the models array for qualifying (GC-filtered) models.
90    model_indices: Vec<usize>,
91    /// Template node array (built once by add_nodes + sort).
92    nodes: Vec<Node>,
93    /// Number of valid nodes.
94    nn: c_int,
95}
96
97fn predict_parallel(
98    seq: &[u8],
99    models: &[Training],
100    config: &ProdigalConfig,
101) -> Result<Vec<PredictedGene>, ProdigalError> {
102    let closed = if config.closed_ends { 1 } else { 0 };
103
104    let mut buf = SequenceBuffer::new();
105    let (slen, gc) = unsafe { buf.encode(seq, config.mask_n_runs) };
106    if slen == 0 {
107        return Err(ProdigalError::EmptySequence);
108    }
109    buf.ensure_node_capacity(slen);
110
111    // GC window for model selection
112    let mut low = 0.88495 * gc - 0.0102337;
113    if low > 0.65 { low = 0.65; }
114    let mut high = 0.86596 * gc + 0.1131991;
115    if high < 0.35 { high = 0.35; }
116
117    // Phase 1: Build node arrays per translation table group,
118    // filtering models by GC range.
119    let mut groups: Vec<TransTableGroup> = Vec::new();
120    let mut nn: c_int = 0;
121
122    for i in 0..NUM_META {
123        let need_rebuild = i == 0 || models[i].trans_table != models[i - 1].trans_table;
124
125        if need_rebuild {
126            // Build nodes for this translation table
127            // We need a mutable Training pointer but add_nodes only reads from it
128            let mut tinf_copy = models[i].clone();
129            unsafe {
130                buf.clear_nodes(nn);
131                nn = add_nodes(
132                    buf.seq.as_mut_ptr(), buf.rseq.as_mut_ptr(), slen,
133                    buf.nodes.as_mut_ptr(), closed,
134                    buf.masks.as_mut_ptr(), buf.nmask,
135                    &mut tinf_copy,
136                );
137            }
138            sort_nodes(&mut buf.nodes[..nn as usize]);
139
140            groups.push(TransTableGroup {
141                model_indices: Vec::new(),
142                nodes: buf.nodes[..nn as usize].to_vec(),
143                nn,
144            });
145        }
146
147        // GC filter
148        if models[i].gc >= low && models[i].gc <= high {
149            groups.last_mut().unwrap().model_indices.push(i);
150        }
151    }
152
153    // Phase 2: Score all qualifying models in parallel.
154    // Each task gets its own copy of the node array.
155    // Safety: seq/rseq buffers are immutable during parallel scoring,
156    // and the usize round-trip preserves the pointer value.
157    let seq_addr = buf.seq.as_ptr() as usize;
158    let rseq_addr = buf.rseq.as_ptr() as usize;
159
160    struct ModelScore {
161        phase: usize,
162        score: f64,
163    }
164
165    let best = groups.par_iter().flat_map(|group| {
166        group.model_indices.par_iter().map(|&model_idx| {
167            let mut nodes = group.nodes.clone();
168            let nn = group.nn;
169            // Clone the Training struct for this thread (needed for mutable API)
170            let mut tinf = models[model_idx].clone();
171
172            unsafe {
173                reset_node_scores(nodes.as_mut_ptr(), nn);
174                score_nodes(
175                    seq_addr as *mut u8, rseq_addr as *mut u8, slen,
176                    nodes.as_mut_ptr(), nn, &mut tinf, closed, 1,
177                );
178                record_overlapping_starts(nodes.as_mut_ptr(), nn, &mut tinf, 1);
179                let ipath = dprog(nodes.as_mut_ptr(), nn, &mut tinf, 1);
180                if ipath < 0 || ipath >= nn {
181                    return ModelScore { phase: model_idx, score: f64::NEG_INFINITY };
182                }
183                ModelScore {
184                    phase: model_idx,
185                    score: nodes[ipath as usize].score,
186                }
187            }
188        })
189    })
190    .reduce(
191        || ModelScore { phase: 0, score: f64::NEG_INFINITY },
192        |a, b| if a.score >= b.score { a } else { b },
193    );
194
195    if best.score == f64::NEG_INFINITY {
196        return Ok(Vec::new());
197    }
198
199    // Phase 3: Re-run the best model to extract genes.
200    // Need to rebuild nodes for the best model's translation table.
201    let mut tinf = models[best.phase].clone();
202
203    // Find the group that contains the best model to get its nodes
204    let best_group = groups.iter().find(|g| g.model_indices.contains(&best.phase)).unwrap();
205    let mut nodes = best_group.nodes.clone();
206    let nn = best_group.nn;
207    let mut genes: Vec<Gene> = vec![unsafe { std::mem::zeroed() }; MAX_GENES];
208
209    unsafe {
210        reset_node_scores(nodes.as_mut_ptr(), nn);
211        score_nodes(
212            buf.seq.as_mut_ptr(), buf.rseq.as_mut_ptr(), slen,
213            nodes.as_mut_ptr(), nn, &mut tinf, closed, 1,
214        );
215        record_overlapping_starts(nodes.as_mut_ptr(), nn, &mut tinf, 1);
216        let ipath = dprog(nodes.as_mut_ptr(), nn, &mut tinf, 1);
217        eliminate_bad_genes(nodes.as_mut_ptr(), ipath, &mut tinf);
218        let ng = add_genes(genes.as_mut_ptr(), nodes.as_mut_ptr(), ipath);
219        tweak_final_starts(genes.as_mut_ptr(), ng, nodes.as_mut_ptr(), nn, &mut tinf);
220        record_gene_data(genes.as_mut_ptr(), ng, nodes.as_mut_ptr(), &mut tinf, 1);
221
222        let mut result = Vec::with_capacity(ng as usize);
223        for i in 0..ng {
224            result.push(gene_to_predicted(&genes[i as usize], nodes.as_ptr(), &tinf));
225        }
226        Ok(result)
227    }
228}