1use crate::{
2 loss_functions::gradient_mse,
3 util::{validate_not_empty, validate_same_length},
4 Error,
5};
6
7#[derive(Debug)]
14pub struct LinearModel {
15 pub intercept: f64,
16 pub weights: Vec<f64>,
17}
18
19impl LinearModel {
20 pub fn new() -> Self {
22 LinearModel {
23 intercept: 0.0,
24 weights: Vec::new(),
25 }
26 }
27
28 pub fn fit(&mut self, x: &Vec<Vec<f64>>, y: &Vec<f64>) -> Result<&Self, Error> {
45 validate_not_empty(x)?;
46 validate_not_empty(y)?;
47
48 validate_same_length(x, y)?;
49
50 let first_row = x.first().unwrap();
52 x.iter()
53 .skip(1)
54 .try_for_each(|row| validate_same_length(first_row, row))?;
55
56 Ok(self._fit(x, y, 1000, 0.01, 0.000001))
57 }
58
59 fn _fit(
60 &mut self,
61 x: &Vec<Vec<f64>>,
62 y: &Vec<f64>,
63 max_iter: u32,
64 learning_rate: f64,
65 tol: f64,
66 ) -> &Self {
67 self.weights = vec![0.0; x.first().unwrap().len()];
69
70 for _ in 0..max_iter {
71 let y_hat = self._predict(x);
72
73 let gradient = gradient_mse(x, &y_hat, y).unwrap();
74 if gradient.iter().all(|g| g.abs() < tol) {
75 break;
76 }
77
78 self.weights = self
80 .weights
81 .iter()
82 .zip(&gradient)
83 .map(|(a, g)| a - learning_rate * g)
84 .collect();
85
86 self.intercept -= learning_rate * gradient.last().unwrap();
88 }
89
90 self
91 }
92
93 pub fn predict(&self, x: &Vec<Vec<f64>>) -> Result<Vec<f64>, Error> {
108 validate_not_empty(x)?;
109
110 x.iter()
112 .try_for_each(|row| validate_same_length(row, &self.weights))?;
113
114 Ok(self._predict(x))
115 }
116
117 fn _predict(&self, x: &Vec<Vec<f64>>) -> Vec<f64> {
118 x.iter()
119 .map(|row| {
120 row.iter()
121 .zip(&self.weights)
122 .map(|(x, a)| x * a)
123 .sum::<f64>()
124 + self.intercept
125 })
126 .collect()
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use std::vec;
133
134 use super::*;
135
136 #[test]
137 fn test_fit_empty_vector() {
138 let x: Vec<Vec<f64>> = vec![];
139 let y: Vec<f64> = vec![];
140
141 assert!(LinearModel::new().fit(&x, &y).is_err());
142 assert_eq!(
143 LinearModel::new().fit(&x, &y).unwrap_err(),
144 Error::EmptyVector
145 );
146 }
147
148 #[test]
149 fn test_fit_dimension_mismatch() {
150 let x: Vec<Vec<f64>> = vec![vec![1.0, 2.0, 3.0]];
151 let y: Vec<f64> = vec![0.0; 3];
152
153 assert!(LinearModel::new().fit(&x, &y).is_err());
154 assert_eq!(
155 LinearModel::new().fit(&x, &y).unwrap_err(),
156 Error::DimensionMismatch
157 );
158 }
159
160 #[test]
161 fn test_fit() {
162 let x = vec![vec![1.0], vec![2.0], vec![5.0], vec![6.0]];
163 let y = vec![1.0, 2.0, 5.0, 6.0];
164
165 let mut model = LinearModel::new();
166
167 model.fit(&x, &y).unwrap();
168
169 }
171
172 #[test]
173 fn test_fit_with_intercept() {
174 let x = vec![vec![1.0], vec![2.0]];
175 let y = vec![2.0, 4.0];
176
177 let mut model = LinearModel::new();
178
179 model.fit(&x, &y).unwrap();
180 }
181
182 #[test]
183 fn test_predict_empty_vector() {
184 let x: Vec<Vec<f64>> = vec![];
185
186 assert!(LinearModel::new().predict(&x).is_err());
187 assert_eq!(
188 LinearModel::new().predict(&x).unwrap_err(),
189 Error::EmptyVector
190 );
191 }
192
193 #[test]
194 fn test_predict_different_row_length() {
195 let mut model = LinearModel::new();
196
197 model.intercept = 1.0;
198 model.weights = vec![1.0, 1.0];
199
200 let x: Vec<Vec<f64>> = vec![vec![1.0, 1.0], vec![1.0]];
201
202 assert!(model.predict(&x).is_err());
203 assert_eq!(model.predict(&x).unwrap_err(), Error::DimensionMismatch);
204 }
205
206 #[test]
207 fn test_predict() {
208 let mut model = LinearModel::new();
209
210 model.intercept = 1.0;
211 model.weights = vec![1.0];
212
213 let x: Vec<Vec<f64>> = vec![vec![1.0], vec![2.0], vec![3.0]];
214 let expected: Vec<f64> = vec![2.0, 3.0, 4.0];
215
216 assert_eq!(model.predict(&x).unwrap(), expected);
217 }
218
219 #[test]
220 fn test_predict_with_multiple_coefficients() {
221 let mut model = LinearModel::new();
222
223 model.intercept = 1.0;
224 model.weights = vec![1.0, 2.0];
225
226 let x: Vec<Vec<f64>> = vec![vec![1.0, 1.0], vec![2.0, 2.0], vec![3.0, 3.0]];
227 let expected: Vec<f64> = vec![4.0, 7.0, 10.0];
228
229 assert_eq!(model.predict(&x).unwrap(), expected);
230 }
231}