nrps_rs/svm/
vectors.rs

1// License: GNU Affero General Public License v3 or later
2// A copy of GNU AGPL v3 should have been included in this software package in LICENSE.txt.
3
4use crate::errors::NrpsError;
5
6pub trait Vector {
7    fn values(&self) -> &Vec<f64>;
8    fn dim(&self) -> usize {
9        self.values().len()
10    }
11    fn square_dist<T: Vector>(&self, other: &T) -> Result<f64, NrpsError> {
12        let temp = element_subtract(self.values(), other.values())?;
13        dot(&temp, &temp)
14    }
15
16    fn dist<T: Vector>(&self, other: &T) -> Result<f64, NrpsError> {
17        Ok(self.square_dist(other)?.sqrt())
18    }
19
20    fn similarity<T: Vector>(&self, other: &T) -> Result<f64, NrpsError> {
21        dot(self.values(), other.values())
22    }
23}
24
25#[derive(Debug)]
26pub struct FeatureVector {
27    values: Vec<f64>,
28}
29
30impl FeatureVector {
31    pub fn new(values: Vec<f64>) -> FeatureVector {
32        FeatureVector { values }
33    }
34}
35
36impl Vector for FeatureVector {
37    fn values(&self) -> &Vec<f64> {
38        &self.values
39    }
40}
41
42#[derive(Debug)]
43pub struct SupportVector {
44    values: Vec<f64>,
45    pub yalpha: f64,
46}
47
48impl SupportVector {
49    pub fn new(values: Vec<f64>, yalpha: f64) -> Self {
50        SupportVector { values, yalpha }
51    }
52    pub fn from_line(line: String, dimension: usize) -> Result<Self, NrpsError> {
53        let yalpha: f64;
54        let mut values = vec![0.0; dimension];
55        let parts: Vec<&str> = line.split(char::is_whitespace).collect();
56        if parts.len() < 2 {
57            return Err(NrpsError::InvalidFeatureLine(line));
58        }
59        yalpha = parts[0].parse::<f64>()?;
60
61        for token in parts[1..].iter() {
62            if token == &"#" {
63                break;
64            }
65            let value_parts: Vec<&str> = token.splitn(2, ":").collect();
66            let idx = value_parts[0].parse::<usize>()? - 1;
67            if idx > dimension - 1 {
68                return Err(NrpsError::InvalidFeatureLine(line));
69            }
70            let value = value_parts[1].parse::<f64>()?;
71            values[idx] = value;
72        }
73
74        Ok(SupportVector { values, yalpha })
75    }
76}
77
78impl Vector for SupportVector {
79    fn values(&self) -> &Vec<f64> {
80        &self.values
81    }
82}
83
84fn dot(a: &Vec<f64>, b: &Vec<f64>) -> Result<f64, NrpsError> {
85    if a.len() != b.len() {
86        return Err(NrpsError::DimensionMismatch {
87            first: a.len(),
88            second: b.len(),
89        });
90    }
91    Ok(a.iter()
92        .zip(b.iter())
93        .fold(0.0, |sum, (el_a, el_b)| sum + el_a * el_b))
94}
95
96fn element_subtract(a: &Vec<f64>, b: &Vec<f64>) -> Result<Vec<f64>, NrpsError> {
97    if a.len() != b.len() {
98        return Err(NrpsError::DimensionMismatch {
99            first: a.len(),
100            second: b.len(),
101        });
102    }
103    Ok(a.iter()
104        .zip(b.iter())
105        .map(|(el_a, el_b)| el_a - el_b)
106        .collect())
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn test_square_dist() {
115        let v1 = FeatureVector::new(Vec::<f64>::from([1.0, 0.0, 1.0]));
116        let v2 = FeatureVector::new(Vec::<f64>::from([1.0, 2.0, 3.0]));
117        assert_eq!(v1.square_dist(&v2).unwrap(), 8.0);
118    }
119
120    #[test]
121    fn test_dist() {
122        let v1 = FeatureVector::new(Vec::<f64>::from([1.0, 0.0, 1.0]));
123        let v2 = FeatureVector::new(Vec::<f64>::from([1.0, 2.0, 1.0]));
124        assert_eq!(v1.dist(&v2).unwrap(), 2.0);
125    }
126
127    #[test]
128    fn test_similarity() {
129        let v1 = FeatureVector::new(Vec::<f64>::from([1.0, 0.0, 1.0]));
130        let v2 = FeatureVector::new(Vec::<f64>::from([1.0, 2.0, 3.0]));
131        assert_eq!(v1.similarity(&v2).unwrap(), 4.0);
132    }
133
134    #[test]
135    fn test_element_subtract() {
136        let v1 = FeatureVector::new(Vec::<f64>::from([3.0, 2.0]));
137        let v2 = FeatureVector::new(Vec::<f64>::from([1.0, -2.0]));
138        let expected = Vec::from([2.0_f64, 4.0]);
139        assert_eq!(
140            element_subtract(&v1.values(), &v2.values()).unwrap(),
141            expected
142        );
143    }
144
145    #[test]
146    fn test_from_line() {
147        let line = String::from("10 1:-1.6023999 3:-0.55470002 5:-0.63520002 # some junk");
148        let v1 = SupportVector::from_line(line, 5).unwrap();
149        assert_eq!(v1.yalpha, 10.0);
150        assert_eq!(v1.values, [-1.6023999, 0., -0.55470002, 0., -0.63520002]);
151    }
152}