clear_ml/
models.rs

1use crate::{
2    loss_functions::gradient_mse,
3    util::{validate_not_empty, validate_same_length},
4    Error,
5};
6
7/// Linear model.
8///
9/// # Formula
10///
11/// y = a0(intercept) + a1*x1 + a2*x2 + ... + an*xn
12///
13#[derive(Debug)]
14pub struct LinearModel {
15    pub intercept: f64,
16    pub weights: Vec<f64>,
17}
18
19impl LinearModel {
20    /// Create a new linear model.
21    pub fn new() -> Self {
22        LinearModel {
23            intercept: 0.0,
24            weights: Vec::new(),
25        }
26    }
27
28    /// Fit linear model.
29    ///
30    /// # Arguments
31    ///
32    /// * `x` - Matrix of features
33    /// * `y` - Vector of actual values
34    ///
35    /// # Requirements
36    ///
37    /// * `x` must have the same length as `y`
38    /// * `x`'s rows must all be same length
39    ///
40    /// # Returns
41    ///
42    /// * `Ok(&Self)` if successful
43    /// * `Err(Error)` if unsuccessful
44    pub fn fit(&mut self, x: &Vec<Vec<f64>>, y: &Vec<f64>) -> Result<&Self, Error> {
45        validate_not_empty(x)?;
46        validate_not_empty(y)?;
47
48        validate_same_length(x, y)?;
49
50        // validate all rows have same length
51        let first_row = x.first().unwrap();
52        x.iter()
53            .skip(1)
54            .try_for_each(|row| validate_same_length(first_row, row))?;
55
56        Ok(self._fit(x, y, 1000, 0.01, 0.000001))
57    }
58
59    fn _fit(
60        &mut self,
61        x: &Vec<Vec<f64>>,
62        y: &Vec<f64>,
63        max_iter: u32,
64        learning_rate: f64,
65        tol: f64,
66    ) -> &Self {
67        // initialize coefficients to 0
68        self.weights = vec![0.0; x.first().unwrap().len()];
69
70        for _ in 0..max_iter {
71            let y_hat = self._predict(x);
72
73            let gradient = gradient_mse(x, &y_hat, y).unwrap();
74            if gradient.iter().all(|g| g.abs() < tol) {
75                break;
76            }
77
78            // update coefficients
79            self.weights = self
80                .weights
81                .iter()
82                .zip(&gradient)
83                .map(|(a, g)| a - learning_rate * g)
84                .collect();
85
86            // update intercept
87            self.intercept -= learning_rate * gradient.last().unwrap();
88        }
89
90        self
91    }
92
93    /// Predict using the linear model.
94    ///
95    /// # Arguments
96    ///
97    /// * `x` - Matrix of features
98    ///
99    /// # Requirements
100    ///
101    /// * `x`'s rows must all match the length of the coefficients of the `LinearModel`
102    ///
103    /// # Returns
104    ///
105    /// * `Ok(Vec<f64>)` if successful
106    /// * `Err(Error)` if unsuccessful
107    pub fn predict(&self, x: &Vec<Vec<f64>>) -> Result<Vec<f64>, Error> {
108        validate_not_empty(x)?;
109
110        // validate all rows are same length as coefficients
111        x.iter()
112            .try_for_each(|row| validate_same_length(row, &self.weights))?;
113
114        Ok(self._predict(x))
115    }
116
117    fn _predict(&self, x: &Vec<Vec<f64>>) -> Vec<f64> {
118        x.iter()
119            .map(|row| {
120                row.iter()
121                    .zip(&self.weights)
122                    .map(|(x, a)| x * a)
123                    .sum::<f64>()
124                    + self.intercept
125            })
126            .collect()
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use std::vec;
133
134    use super::*;
135
136    #[test]
137    fn test_fit_empty_vector() {
138        let x: Vec<Vec<f64>> = vec![];
139        let y: Vec<f64> = vec![];
140
141        assert!(LinearModel::new().fit(&x, &y).is_err());
142        assert_eq!(
143            LinearModel::new().fit(&x, &y).unwrap_err(),
144            Error::EmptyVector
145        );
146    }
147
148    #[test]
149    fn test_fit_dimension_mismatch() {
150        let x: Vec<Vec<f64>> = vec![vec![1.0, 2.0, 3.0]];
151        let y: Vec<f64> = vec![0.0; 3];
152
153        assert!(LinearModel::new().fit(&x, &y).is_err());
154        assert_eq!(
155            LinearModel::new().fit(&x, &y).unwrap_err(),
156            Error::DimensionMismatch
157        );
158    }
159
160    #[test]
161    fn test_fit() {
162        let x = vec![vec![1.0], vec![2.0], vec![5.0], vec![6.0]];
163        let y = vec![1.0, 2.0, 5.0, 6.0];
164
165        let mut model = LinearModel::new();
166
167        model.fit(&x, &y).unwrap();
168
169        // TODO: write more descriptive tests        
170    }
171
172    #[test]
173    fn test_fit_with_intercept() {
174        let x = vec![vec![1.0], vec![2.0]];
175        let y = vec![2.0, 4.0];
176
177        let mut model = LinearModel::new();
178
179        model.fit(&x, &y).unwrap();
180    }
181
182    #[test]
183    fn test_predict_empty_vector() {
184        let x: Vec<Vec<f64>> = vec![];
185
186        assert!(LinearModel::new().predict(&x).is_err());
187        assert_eq!(
188            LinearModel::new().predict(&x).unwrap_err(),
189            Error::EmptyVector
190        );
191    }
192
193    #[test]
194    fn test_predict_different_row_length() {
195        let mut model = LinearModel::new();
196
197        model.intercept = 1.0;
198        model.weights = vec![1.0, 1.0];
199
200        let x: Vec<Vec<f64>> = vec![vec![1.0, 1.0], vec![1.0]];
201
202        assert!(model.predict(&x).is_err());
203        assert_eq!(model.predict(&x).unwrap_err(), Error::DimensionMismatch);
204    }
205
206    #[test]
207    fn test_predict() {
208        let mut model = LinearModel::new();
209
210        model.intercept = 1.0;
211        model.weights = vec![1.0];
212
213        let x: Vec<Vec<f64>> = vec![vec![1.0], vec![2.0], vec![3.0]];
214        let expected: Vec<f64> = vec![2.0, 3.0, 4.0];
215
216        assert_eq!(model.predict(&x).unwrap(), expected);
217    }
218
219    #[test]
220    fn test_predict_with_multiple_coefficients() {
221        let mut model = LinearModel::new();
222
223        model.intercept = 1.0;
224        model.weights = vec![1.0, 2.0];
225
226        let x: Vec<Vec<f64>> = vec![vec![1.0, 1.0], vec![2.0, 2.0], vec![3.0, 3.0]];
227        let expected: Vec<f64> = vec![4.0, 7.0, 10.0];
228
229        assert_eq!(model.predict(&x).unwrap(), expected);
230    }
231}