nrps_rs/
lib.rs

1// License: GNU Affero General Public License v3 or later
2// A copy of GNU AGPL v3 should have been included in this software package in LICENSE.txt.
3
4pub mod config;
5pub mod encodings;
6pub mod errors;
7pub mod predictors;
8pub mod svm;
9
10use std::fs::File;
11use std::io::{self, BufRead, BufReader};
12use std::path::PathBuf;
13
14use errors::NrpsError;
15use predictors::predictions::ADomain;
16use predictors::stachelhaus::predict_stachelhaus;
17use predictors::{load_models, Predictor};
18
19pub fn run_on_file(
20    config: &config::Config,
21    signature_file: PathBuf,
22) -> Result<Vec<ADomain>, NrpsError> {
23    let mut domains = parse_domains(signature_file)?;
24    run(config, &mut domains)?;
25    Ok(domains)
26}
27
28pub fn run(config: &config::Config, domains: &mut Vec<ADomain>) -> Result<(), NrpsError> {
29    if !config.skip_stachelhaus {
30        predict_stachelhaus(&config, domains)?;
31    }
32
33    let models = load_models(&config)?;
34    let predictor = Predictor { models };
35    predictor.predict(domains)?;
36    Ok(())
37}
38
39pub fn run_on_strings(
40    config: &config::Config,
41    lines: Vec<String>,
42) -> Result<Vec<ADomain>, NrpsError> {
43    let mut domains = Vec::with_capacity(lines.len());
44
45    for line in lines.iter() {
46        domains.push(parse_domain(line.to_string())?);
47    }
48
49    run(config, &mut domains)?;
50
51    Ok(domains)
52}
53
54pub fn print_results(config: &config::Config, domains: &Vec<ADomain>) -> Result<(), NrpsError> {
55    if config.count < 1 {
56        return Err(NrpsError::CountError(config.count));
57    }
58
59    let categories = config.categories();
60
61    let cat_strings: Vec<String> = categories.iter().map(|c| format!("{c:?}")).collect();
62
63    let mut headers: Vec<String> = Vec::with_capacity(3);
64
65    headers.push("Name\t8A signature\tStachelhaus signature".to_string());
66    if !config.skip_stachelhaus && !config.skip_new_stachelhaus_output {
67        headers.push(
68            [
69                "Full Stachelhaus match",
70                "AA10 score",
71                "AA10 signature matched",
72                "AA34 score",
73            ]
74            .join("\t")
75            .to_string(),
76        );
77    }
78    headers.push(cat_strings.join("\t"));
79    println!("{}", headers.join("\t"));
80
81    for domain in domains.iter() {
82        let mut best_predictions: Vec<String> = Vec::new();
83        for cat in categories.iter() {
84            let mut best = domain
85                .get_best_n(&cat, config.count)
86                .iter()
87                .fold("".to_string(), |acc, new| {
88                    format!("{acc}|{}({:.2})", new.name, new.score)
89                })
90                .trim_matches('|')
91                .to_string();
92            if best == "" {
93                best = "N/A".to_string();
94            }
95            best_predictions.push(best)
96        }
97        let mut line: Vec<String> = Vec::with_capacity(5);
98        line.push(domain.name.to_string());
99        line.push(domain.aa34.to_string());
100        line.push(domain.aa10.to_string());
101        if !config.skip_stachelhaus && !config.skip_new_stachelhaus_output {
102            line.push(domain.stach_predictions.to_table());
103        }
104        line.push(best_predictions.join("\t"));
105        println!("{}", line.join("\t"));
106    }
107
108    Ok(())
109}
110
111pub fn parse_domains(signature_file: PathBuf) -> Result<Vec<ADomain>, NrpsError> {
112    if signature_file == PathBuf::from("-") {
113        let reader = BufReader::new(io::stdin());
114        return parse_domains_from_reader(reader);
115    }
116
117    if !signature_file.exists() {
118        let err = format!("'{}' doesn't exist", signature_file.display());
119        return Err(NrpsError::SignatureFileError(err));
120    }
121
122    let handle = File::open(signature_file)?;
123    let reader = BufReader::new(handle);
124
125    parse_domains_from_reader(reader)
126}
127
128fn parse_domains_from_reader<R>(reader: R) -> Result<Vec<ADomain>, NrpsError>
129where
130    R: BufRead,
131{
132    let mut domains = Vec::new();
133
134    for line_res in reader.lines() {
135        let line = line_res?.trim().to_string();
136        if line == "" {
137            continue;
138        }
139
140        domains.push(parse_domain(line)?);
141    }
142
143    Ok(domains)
144}
145
146pub fn parse_domain(line: String) -> Result<ADomain, NrpsError> {
147    let parts: Vec<&str> = line.split("\t").collect();
148    if parts.len() < 2 {
149        return Err(NrpsError::SignatureError(line));
150    }
151    if parts[0].len() != 34 {
152        return Err(NrpsError::SignatureError(line));
153    }
154
155    let name: String;
156    match parts.len() {
157        2 => name = parts[1].to_string(),
158        _ => name = format!("{}_{}", parts[2], parts[1]),
159    }
160    Ok(ADomain::new(name, parts[0].to_string()))
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_parse_domains() {
169        let two_parts = BufReader::new("LDASFDASLFEMYLLTGGDRNMYGPTEATMCATW\tbpsA_A1".as_bytes());
170        let three_parts =
171            BufReader::new("LEPAFDISLFEVHLLTGGDRHLYGPTEATLCATW\tHpg\tCAC48361.1.A1".as_bytes());
172        let too_short = BufReader::new("LDASFDASLFEMYLLTGGDRNMYGPTEATMCATW".as_bytes());
173
174        let expected_two = Vec::from([ADomain::new(
175            "bpsA_A1".to_string(),
176            "LDASFDASLFEMYLLTGGDRNMYGPTEATMCATW".to_string(),
177        )]);
178
179        let expected_three = Vec::from([ADomain::new(
180            "CAC48361.1.A1_Hpg".to_string(),
181            "LEPAFDISLFEVHLLTGGDRHLYGPTEATLCATW".to_string(),
182        )]);
183
184        let got_two = parse_domains_from_reader(two_parts).unwrap();
185        assert_eq!(expected_two, got_two);
186
187        let got_three = parse_domains_from_reader(three_parts).unwrap();
188        assert_eq!(expected_three, got_three);
189
190        let got_error = parse_domains_from_reader(too_short);
191        assert!(got_error.is_err());
192    }
193}