nrps_rs/encodings/
mod.rs

1// License: GNU Affero General Public License v3 or later
2// A copy of GNU AGPL v3 should have been included in this software package in LICENSE.txt.
3
4pub mod blin;
5pub mod rausch;
6pub mod wold;
7
8use crate::predictors::predictions::PredictionCategory;
9
10#[derive(Debug)]
11pub enum FeatureEncoding {
12    Blin,
13    Rausch,
14    Wold,
15}
16
17pub fn encode(
18    sequence: &String,
19    encoding: &FeatureEncoding,
20    category: &PredictionCategory,
21) -> Vec<f64> {
22    let legacy_categories = &[
23        PredictionCategory::LargeClusterV1,
24        PredictionCategory::SmallClusterV1,
25    ];
26    match encoding {
27        FeatureEncoding::Blin => blin::encode(sequence),
28        FeatureEncoding::Rausch => {
29            if legacy_categories.contains(category) {
30                rausch::legacy_encode(sequence)
31            } else {
32                rausch::encode(sequence)
33            }
34        }
35        FeatureEncoding::Wold => wold::encode(sequence),
36    }
37}
38
39pub fn get_value(map: &phf::Map<char, f64>, c: char, mean: f64, stdev: f64, use_mean: bool) -> f64 {
40    if let Some(value) = map.get(&c) {
41        return normalise(value.clone(), mean, stdev);
42    }
43    if use_mean {
44        return mean;
45    }
46    normalise(0.0, mean, stdev)
47}
48
49fn normalise(value: f64, mean: f64, stdev: f64) -> f64 {
50    (value - mean) / stdev
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use assert_approx_eq::assert_approx_eq;
57    use phf::phf_map;
58
59    static TEST_MAP: phf::Map<char, f64> = phf_map! {
60        'A' => 0.00,
61        'R' => 4.00,
62        'K' => 2.00,
63    };
64    const TEST_MEAN: f64 = 2.0;
65    const TEST_STDEV: f64 = 2.0;
66
67    #[test]
68    fn test_get_value() {
69        assert_approx_eq!(get_value(&TEST_MAP, 'A', TEST_MEAN, TEST_STDEV, true), -1.0);
70        assert_approx_eq!(get_value(&TEST_MAP, 'R', TEST_MEAN, TEST_STDEV, true), 1.0);
71        assert_approx_eq!(get_value(&TEST_MAP, 'K', TEST_MEAN, TEST_STDEV, true), 0.0);
72        assert_approx_eq!(get_value(&TEST_MAP, '-', TEST_MEAN, TEST_STDEV, true), 2.0);
73        assert_approx_eq!(
74            get_value(&TEST_MAP, '-', TEST_MEAN, TEST_STDEV, false),
75            -1.0
76        );
77    }
78}