nrps_rs/
config.rs

1// License: GNU Affero General Public License v3 or later
2// A copy of GNU AGPL v3 should have been included in this software package in LICENSE.txt.
3
4use std::convert::From;
5use std::env;
6use std::io::Read;
7use std::path::PathBuf;
8
9use clap::Parser;
10use serde::Deserialize;
11use toml;
12
13use crate::errors::NrpsError;
14use crate::predictors::predictions::PredictionCategory;
15
16#[derive(Parser, Debug)]
17#[command(author, version, about, long_about = None)]
18pub struct Cli {
19    /// Signature file to run predictions on
20    pub signatures: PathBuf,
21
22    /// Number of results to return per category
23    #[arg(short, long)]
24    pub count: Option<usize>,
25
26    /// Runs the NRPSPredictor2 fungal models
27    #[arg(short = 'F', long, default_value_t = false)]
28    pub fungal: bool,
29
30    /// Sets a custom config file
31    #[arg(short = 'C', long, value_name = "FILE")]
32    pub config: Option<PathBuf>,
33
34    /// Overrides the config file settings for the Stachelhaus signature file
35    #[arg(short, long, value_name = "FILE")]
36    pub stachelhaus_signatures: Option<PathBuf>,
37
38    /// Overrides the config file settings for the SVM model dir
39    #[arg(short, long, value_name = "DIR")]
40    pub model_dir: Option<PathBuf>,
41
42    /// Disable v3 models
43    #[arg(short = '3', long)]
44    pub skip_v3: bool,
45
46    /// Disable v2 models
47    #[arg(short = '2', long)]
48    pub skip_v2: bool,
49
50    /// Disable v1 models
51    #[arg(short = '1', long)]
52    pub skip_v1: bool,
53
54    /// Disable Stachelhaus lookups
55    #[arg(short = 'S', long)]
56    pub skip_stachelhaus: bool,
57
58    /// Disable printing new-style AA34 Stachelhaus results
59    #[arg(long)]
60    pub skip_new_stachelhaus_output: bool,
61}
62
63#[derive(Debug, Deserialize)]
64struct ParsedConfig {
65    pub model_dir: Option<String>,
66    pub stachelhaus_signatures: Option<String>,
67    pub count: Option<usize>,
68    pub skip_v3: Option<bool>,
69    pub skip_v2: Option<bool>,
70    pub skip_v1: Option<bool>,
71    pub skip_stachelhaus: Option<bool>,
72    pub skip_new_stachelhaus_output: Option<bool>,
73}
74
75#[derive(Debug, PartialEq)]
76pub struct Config {
77    model_dir: PathBuf,
78    stachelhaus_signatures: PathBuf,
79    stach_sig_derived: bool,
80    pub count: usize,
81    pub fungal: bool,
82    pub skip_v3: bool,
83    pub skip_v2: bool,
84    pub skip_v1: bool,
85    pub skip_stachelhaus: bool,
86    pub skip_new_stachelhaus_output: bool,
87}
88
89fn set_stach_from_model_dir(model_dir: &PathBuf) -> PathBuf {
90    let mut stachelhaus_signatures = model_dir.clone();
91    stachelhaus_signatures.push("signatures.tsv");
92    stachelhaus_signatures
93}
94
95impl Config {
96    pub fn new() -> Self {
97        let mut model_dir: PathBuf;
98        model_dir = env::current_dir().unwrap();
99        model_dir.push("data");
100        model_dir.push("models");
101        let stachelhaus_signatures = set_stach_from_model_dir(&model_dir);
102
103        Config {
104            model_dir,
105            stachelhaus_signatures,
106            stach_sig_derived: true,
107            count: 1,
108            fungal: false,
109            skip_v3: false,
110            skip_v2: false,
111            skip_v1: false,
112            skip_stachelhaus: false,
113            skip_new_stachelhaus_output: false,
114        }
115    }
116
117    pub fn model_dir(&self) -> &PathBuf {
118        &self.model_dir
119    }
120
121    pub fn set_model_dir(&mut self, model_dir: PathBuf) {
122        self.model_dir = model_dir;
123        if self.stach_sig_derived {
124            self.stachelhaus_signatures = set_stach_from_model_dir(&self.model_dir);
125        }
126    }
127
128    pub fn stachelhaus_signatures(&self) -> &PathBuf {
129        &self.stachelhaus_signatures
130    }
131
132    pub fn set_stachelhaus_signatures(&mut self, stachelhaus_signatures: PathBuf) {
133        self.stach_sig_derived = false;
134        self.stachelhaus_signatures = stachelhaus_signatures;
135    }
136
137    pub fn categories(&self) -> Vec<PredictionCategory> {
138        let mut categories: Vec<PredictionCategory> = Vec::with_capacity(12);
139        if !self.skip_v3 {
140            categories.extend_from_slice(&[
141                PredictionCategory::ThreeClusterV3,
142                PredictionCategory::LargeClusterV3,
143                PredictionCategory::SmallClusterV3,
144                PredictionCategory::SingleV3,
145            ]);
146        }
147
148        if !self.skip_stachelhaus {
149            categories.push(PredictionCategory::Stachelhaus);
150        }
151
152        if !self.skip_v2 {
153            categories.extend_from_slice(&[
154                PredictionCategory::ThreeClusterV2,
155                PredictionCategory::LargeClusterV2,
156                PredictionCategory::SmallClusterV2,
157                PredictionCategory::SingleV2,
158            ]);
159        }
160
161        if self.fungal && !self.skip_v2 {
162            categories.push(PredictionCategory::ThreeClusterFungalV2);
163        }
164
165        if !self.skip_v1 {
166            categories.extend_from_slice(&[
167                PredictionCategory::LargeClusterV1,
168                PredictionCategory::SmallClusterV1,
169            ]);
170        }
171
172        categories
173    }
174}
175
176impl From<ParsedConfig> for Config {
177    fn from(item: ParsedConfig) -> Self {
178        let mut config = Config::new();
179
180        if let Some(dir_str) = item.model_dir {
181            config.set_model_dir(PathBuf::from(dir_str));
182        }
183
184        if let Some(file_name) = item.stachelhaus_signatures {
185            config.set_stachelhaus_signatures(PathBuf::from(file_name));
186        }
187
188        if let Some(count) = item.count {
189            config.count = count;
190        }
191
192        if let Some(skip_v3) = item.skip_v3 {
193            config.skip_v3 = skip_v3;
194        }
195
196        if let Some(skip_v2) = item.skip_v2 {
197            config.skip_v2 = skip_v2;
198        }
199
200        if let Some(skip_v1) = item.skip_v1 {
201            config.skip_v1 = skip_v1;
202        }
203
204        if let Some(skip_stachelhaus) = item.skip_stachelhaus {
205            config.skip_stachelhaus = skip_stachelhaus;
206        }
207
208        if let Some(skip_new_stach) = item.skip_new_stachelhaus_output {
209            config.skip_new_stachelhaus_output = skip_new_stach;
210        }
211
212        config
213    }
214}
215
216pub fn parse_config<R>(mut reader: R, args: &Cli) -> Result<Config, NrpsError>
217where
218    R: Read,
219{
220    let mut raw_config = String::new();
221    reader.read_to_string(&mut raw_config)?;
222    let parsed_config: ParsedConfig = toml::from_str(&raw_config)?;
223    let mut config = Config::from(parsed_config);
224    if let Some(md) = &args.model_dir {
225        config.model_dir = md.clone();
226        config.stachelhaus_signatures = set_stach_from_model_dir(&config.model_dir);
227    }
228    if let Some(stach) = &args.stachelhaus_signatures {
229        config.stachelhaus_signatures = stach.clone();
230    }
231    if let Some(mut count_val) = args.count {
232        if count_val < 1 {
233            count_val = 1;
234        }
235        config.count = count_val;
236    }
237
238    config.skip_v3 = args.skip_v3;
239    config.skip_v2 = args.skip_v2;
240    config.skip_v1 = args.skip_v1;
241    config.skip_stachelhaus = args.skip_stachelhaus;
242    config.skip_new_stachelhaus_output = args.skip_new_stachelhaus_output;
243
244    Ok(config)
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    use rstest::{fixture, rstest};
252
253    #[fixture]
254    fn args() -> Cli {
255        Cli {
256            signatures: PathBuf::from("foo.sig"),
257            count: None,
258            fungal: false,
259            config: None,
260            stachelhaus_signatures: None,
261            model_dir: None,
262            skip_v3: false,
263            skip_v2: false,
264            skip_v1: false,
265            skip_stachelhaus: false,
266            skip_new_stachelhaus_output: false,
267        }
268    }
269
270    #[rstest]
271    fn test_model_dir_set(args: Cli) {
272        let mut expected = Config::new();
273        expected.set_model_dir(PathBuf::from("/foo"));
274        expected.set_stachelhaus_signatures(PathBuf::from("/foo/signatures.tsv"));
275        expected.stach_sig_derived = true;
276        let got = parse_config("model_dir = '/foo'".as_bytes(), &args).unwrap();
277        assert_eq!(expected, got);
278    }
279
280    #[rstest]
281    fn test_model_dir_default(args: Cli) {
282        let mut model_dir = env::current_dir().unwrap();
283        model_dir.push("data");
284        model_dir.push("models");
285        let mut stach = model_dir.clone();
286        stach.push("signatures.tsv");
287
288        let mut expected = Config::new();
289        expected.set_model_dir(model_dir);
290        expected.set_stachelhaus_signatures(stach);
291        expected.stach_sig_derived = true;
292        let got = parse_config("".as_bytes(), &args).unwrap();
293        assert_eq!(expected, got);
294    }
295
296    #[rstest]
297    fn test_stach_extra(args: Cli) {
298        let mut model_dir = env::current_dir().unwrap();
299        model_dir.push("data");
300        model_dir.push("models");
301        let stach = PathBuf::from("/foo/signatures.tsv");
302
303        let mut expected = Config::new();
304        expected.set_model_dir(model_dir);
305        expected.set_stachelhaus_signatures(stach);
306        expected.stach_sig_derived = false;
307
308        let got = parse_config(
309            "stachelhaus_signatures = '/foo/signatures.tsv'".as_bytes(),
310            &args,
311        )
312        .unwrap();
313        assert_eq!(expected, got);
314    }
315
316    #[rstest]
317    fn test_override_model_dir(mut args: Cli) {
318        let model_dir = PathBuf::from("/foo");
319        args.model_dir = Some(model_dir.clone());
320        let mut stach = model_dir.clone();
321        stach.push("signatures.tsv");
322
323        let mut expected = Config::new();
324        expected.set_model_dir(model_dir.clone());
325        expected.set_stachelhaus_signatures(stach);
326        expected.stach_sig_derived = true;
327
328        let got = parse_config("".as_bytes(), &args).unwrap();
329        assert_eq!(expected, got);
330    }
331
332    #[rstest]
333    fn test_override_stach(mut args: Cli) {
334        let model_dir = PathBuf::from("/foo");
335        let stach = PathBuf::from("/bar/signatures.tsv");
336        args.stachelhaus_signatures = Some(stach.clone());
337
338        let mut expected = Config::new();
339        expected.set_model_dir(model_dir.clone());
340        expected.set_stachelhaus_signatures(stach.clone());
341        expected.stach_sig_derived = true;
342
343        let got = parse_config("model_dir = '/foo'".as_bytes(), &args).unwrap();
344        assert_eq!(expected, got);
345    }
346
347    #[rstest]
348    fn test_override_both(mut args: Cli) {
349        let model_dir = PathBuf::from("/foo");
350        let stach = PathBuf::from("/bar/signatures.tsv");
351        args.model_dir = Some(model_dir.clone());
352        args.stachelhaus_signatures = Some(stach.clone());
353
354        let mut expected = Config::new();
355        expected.set_model_dir(model_dir.clone());
356        expected.set_stachelhaus_signatures(stach.clone());
357        expected.stach_sig_derived = false;
358
359        let got = parse_config(
360            "stachelhaus_signatures = '/baz/signatures.tsv'".as_bytes(),
361            &args,
362        )
363        .unwrap();
364        assert_eq!(expected, got);
365    }
366
367    #[rstest]
368    fn test_skip_v3(mut args: Cli) {
369        args.skip_v3 = true;
370
371        let mut expected = Config::new();
372        expected.skip_v3 = true;
373        let got = parse_config("".as_bytes(), &args).unwrap();
374        assert_eq!(expected, got);
375    }
376
377    #[rstest]
378    fn test_skip_v2(mut args: Cli) {
379        args.skip_v2 = true;
380
381        let mut expected = Config::new();
382        expected.skip_v2 = true;
383        let got = parse_config("".as_bytes(), &args).unwrap();
384        assert_eq!(expected, got);
385    }
386
387    #[rstest]
388    fn test_skip_v1(mut args: Cli) {
389        args.skip_v1 = true;
390
391        let mut expected = Config::new();
392        expected.skip_v1 = true;
393        let got = parse_config("".as_bytes(), &args).unwrap();
394        assert_eq!(expected, got);
395    }
396
397    #[rstest]
398    fn test_skip_stachelhaus(mut args: Cli) {
399        args.skip_stachelhaus = true;
400
401        let mut expected = Config::new();
402        expected.skip_stachelhaus = true;
403        let got = parse_config("".as_bytes(), &args).unwrap();
404        assert_eq!(expected, got);
405    }
406}