1pub mod blin;
5pub mod rausch;
6pub mod wold;
7
8use crate::predictors::predictions::PredictionCategory;
9
10#[derive(Debug)]
11pub enum FeatureEncoding {
12 Blin,
13 Rausch,
14 Wold,
15}
16
17pub fn encode(
18 sequence: &String,
19 encoding: &FeatureEncoding,
20 category: &PredictionCategory,
21) -> Vec<f64> {
22 let legacy_categories = &[
23 PredictionCategory::LargeClusterV1,
24 PredictionCategory::SmallClusterV1,
25 ];
26 match encoding {
27 FeatureEncoding::Blin => blin::encode(sequence),
28 FeatureEncoding::Rausch => {
29 if legacy_categories.contains(category) {
30 rausch::legacy_encode(sequence)
31 } else {
32 rausch::encode(sequence)
33 }
34 }
35 FeatureEncoding::Wold => wold::encode(sequence),
36 }
37}
38
39pub fn get_value(map: &phf::Map<char, f64>, c: char, mean: f64, stdev: f64, use_mean: bool) -> f64 {
40 if let Some(value) = map.get(&c) {
41 return normalise(value.clone(), mean, stdev);
42 }
43 if use_mean {
44 return mean;
45 }
46 normalise(0.0, mean, stdev)
47}
48
49fn normalise(value: f64, mean: f64, stdev: f64) -> f64 {
50 (value - mean) / stdev
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use assert_approx_eq::assert_approx_eq;
57 use phf::phf_map;
58
59 static TEST_MAP: phf::Map<char, f64> = phf_map! {
60 'A' => 0.00,
61 'R' => 4.00,
62 'K' => 2.00,
63 };
64 const TEST_MEAN: f64 = 2.0;
65 const TEST_STDEV: f64 = 2.0;
66
67 #[test]
68 fn test_get_value() {
69 assert_approx_eq!(get_value(&TEST_MAP, 'A', TEST_MEAN, TEST_STDEV, true), -1.0);
70 assert_approx_eq!(get_value(&TEST_MAP, 'R', TEST_MEAN, TEST_STDEV, true), 1.0);
71 assert_approx_eq!(get_value(&TEST_MAP, 'K', TEST_MEAN, TEST_STDEV, true), 0.0);
72 assert_approx_eq!(get_value(&TEST_MAP, '-', TEST_MEAN, TEST_STDEV, true), 2.0);
73 assert_approx_eq!(
74 get_value(&TEST_MAP, '-', TEST_MEAN, TEST_STDEV, false),
75 -1.0
76 );
77 }
78}