rs_stats/regression/
multiple_linear_regression.rs

1// src/regression/multiple_linear_regression.rs
2
3use num_traits::{Float, NumCast};
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::fs::File;
7use std::io::{self};
8use std::path::Path;
9
10/// Multiple linear regression model that fits a hyperplane to multivariate data points.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MultipleLinearRegression<T = f64>
13where
14    T: Float + Debug + Default + Serialize,
15{
16    /// Regression coefficients, including intercept as the first element
17    pub coefficients: Vec<T>,
18    /// Coefficient of determination (R²) - goodness of fit
19    pub r_squared: T,
20    /// Adjusted R² which accounts for the number of predictors
21    pub adjusted_r_squared: T,
22    /// Standard error of the estimate
23    pub standard_error: T,
24    /// Number of data points used for regression
25    pub n: usize,
26    /// Number of predictor variables (excluding intercept)
27    pub p: usize,
28}
29
30impl<T> Default for MultipleLinearRegression<T>
31where
32    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
33{
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl<T> MultipleLinearRegression<T>
40where
41    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
42{
43    /// Create a new multiple linear regression model without fitting any data
44    pub fn new() -> Self {
45        Self {
46            coefficients: Vec::new(),
47            r_squared: T::zero(),
48            adjusted_r_squared: T::zero(),
49            standard_error: T::zero(),
50            n: 0,
51            p: 0,
52        }
53    }
54
55    /// Fit a multiple linear regression model to the provided data
56    ///
57    /// # Arguments
58    /// * `x_values` - 2D array where each row is an observation and each column is a predictor
59    /// * `y_values` - Dependent variable values (observations)
60    ///
61    /// # Returns
62    /// * `Result<(), String>` - Ok if successful, Err with message if the inputs are invalid
63    pub fn fit<U, V>(&mut self, x_values: &[Vec<U>], y_values: &[V]) -> Result<(), String>
64    where
65        U: NumCast + Copy,
66        V: NumCast + Copy,
67    {
68        // Validate inputs
69        if x_values.is_empty() || y_values.is_empty() {
70            return Err("Cannot fit regression with empty arrays".to_string());
71        }
72
73        if x_values.len() != y_values.len() {
74            return Err("Number of observations in X and Y must match".to_string());
75        }
76
77        self.n = x_values.len();
78
79        // Check that all rows in x_values have the same length
80        if x_values.is_empty() {
81            return Err("X values array is empty".to_string());
82        }
83
84        self.p = x_values[0].len();
85
86        for row in x_values {
87            if row.len() != self.p {
88                return Err("All rows in X must have the same number of features".to_string());
89            }
90        }
91
92        // Convert input arrays to T type
93        let mut x_cast: Vec<Vec<T>> = Vec::with_capacity(self.n);
94        for row in x_values {
95            let row_cast: Result<Vec<T>, String> = row
96                .iter()
97                .map(|&x| T::from(x).ok_or_else(|| "Failed to cast X value".to_string()))
98                .collect();
99            x_cast.push(row_cast?);
100        }
101
102        let y_cast: Vec<T> = y_values
103            .iter()
104            .map(|&y| T::from(y).ok_or_else(|| "Failed to cast Y value".to_string()))
105            .collect::<Result<Vec<T>, String>>()?;
106
107        // Augment the X matrix with a column of 1s for the intercept
108        let mut augmented_x = Vec::with_capacity(self.n);
109        for row in &x_cast {
110            let mut augmented_row = Vec::with_capacity(self.p + 1);
111            augmented_row.push(T::one()); // Intercept term
112            augmented_row.extend_from_slice(row);
113            augmented_x.push(augmented_row);
114        }
115
116        // Compute X^T * X
117        let xt_x = self.matrix_multiply_transpose(&augmented_x, &augmented_x);
118
119        // Compute X^T * y
120        let xt_y = self.vector_multiply_transpose(&augmented_x, &y_cast);
121
122        // Solve the normal equations: (X^T * X) * β = X^T * y for β
123        match self.solve_linear_system(&xt_x, &xt_y) {
124            Ok(solution) => {
125                self.coefficients = solution;
126            }
127            Err(e) => return Err(e),
128        }
129
130        // Calculate fitted values and R²
131        let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / T::from(self.n).unwrap();
132
133        let mut ss_total = T::zero();
134        let mut ss_residual = T::zero();
135
136        for i in 0..self.n {
137            let predicted = self.predict_t(&x_cast[i]);
138            let residual = y_cast[i] - predicted;
139
140            ss_residual = ss_residual + (residual * residual);
141            let diff = y_cast[i] - y_mean;
142            ss_total = ss_total + (diff * diff);
143        }
144
145        // Calculate R² and adjusted R²
146        if ss_total > T::zero() {
147            self.r_squared = T::one() - (ss_residual / ss_total);
148
149            // Adjusted R² = 1 - [(1 - R²) * (n - 1) / (n - p - 1)]
150            if self.n > self.p + 1 {
151                let n_minus_1 = T::from(self.n - 1).unwrap();
152                let n_minus_p_minus_1 = T::from(self.n - self.p - 1).unwrap();
153
154                self.adjusted_r_squared =
155                    T::one() - ((T::one() - self.r_squared) * n_minus_1 / n_minus_p_minus_1);
156            }
157        }
158
159        // Calculate standard error
160        if self.n > self.p + 1 {
161            let n_minus_p_minus_1 = T::from(self.n - self.p - 1).unwrap();
162            self.standard_error = (ss_residual / n_minus_p_minus_1).sqrt();
163        }
164
165        Ok(())
166    }
167
168    /// Predict y value for a given set of x values using the fitted model (internal version with type T)
169    fn predict_t(&self, x: &[T]) -> T {
170        if x.len() != self.p || self.coefficients.is_empty() {
171            return T::nan();
172        }
173
174        // First coefficient is the intercept
175        let mut result = self.coefficients[0];
176
177        // Add the weighted features
178        for (i, &xi) in x.iter().enumerate().take(self.p) {
179            result = result + (self.coefficients[i + 1] * xi);
180        }
181
182        result
183    }
184
185    /// Predict y value for a given set of x values using the fitted model
186    ///
187    /// # Arguments
188    /// * `x` - Vector of x values for prediction
189    ///
190    /// # Returns
191    /// * The predicted y value
192    pub fn predict<U>(&self, x: &[U]) -> T
193    where
194        U: NumCast + Copy,
195    {
196        if x.len() != self.p {
197            return T::nan();
198        }
199
200        // Convert input to T type
201        let x_cast: Result<Vec<T>, ()> = x.iter().map(|&val| T::from(val).ok_or(())).collect();
202
203        match x_cast {
204            Ok(x_t) => self.predict_t(&x_t),
205            Err(_) => T::nan(),
206        }
207    }
208
209    /// Calculate predictions for multiple observations
210    ///
211    /// # Arguments
212    /// * `x_values` - 2D array of feature values for prediction
213    ///
214    /// # Returns
215    /// * Vector of predicted y values
216    pub fn predict_many<U>(&self, x_values: &[Vec<U>]) -> Vec<T>
217    where
218        U: NumCast + Copy,
219    {
220        x_values.iter().map(|x| self.predict(x)).collect()
221    }
222
223    /// Save the model to a file
224    ///
225    /// # Arguments
226    /// * `path` - Path where to save the model
227    ///
228    /// # Returns
229    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
230    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
231        let file = File::create(path)?;
232        // Use JSON format for human-readability
233        serde_json::to_writer(file, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
234    }
235
236    /// Save the model in binary format
237    ///
238    /// # Arguments
239    /// * `path` - Path where to save the model
240    ///
241    /// # Returns
242    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
243    pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
244        let file = File::create(path)?;
245        // Use bincode for more compact binary format
246        bincode::serialize_into(file, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
247    }
248
249    /// Load a model from a file
250    ///
251    /// # Arguments
252    /// * `path` - Path to the saved model file
253    ///
254    /// # Returns
255    /// * `Result<Self, io::Error>` - Loaded model or IO error
256    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
257        let file = File::open(path)?;
258        // Try to load as JSON format
259        serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
260    }
261
262    /// Load a model from a binary file
263    ///
264    /// # Arguments
265    /// * `path` - Path to the saved model file
266    ///
267    /// # Returns
268    /// * `Result<Self, io::Error>` - Loaded model or IO error
269    pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
270        let file = File::open(path)?;
271        // Try to load as bincode format
272        bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
273    }
274
275    /// Save the model to a string in JSON format
276    ///
277    /// # Returns
278    /// * `Result<String, String>` - JSON string representation or error message
279    pub fn to_json(&self) -> Result<String, String> {
280        serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
281    }
282
283    /// Load a model from a JSON string
284    ///
285    /// # Arguments
286    /// * `json` - JSON string containing the model data
287    ///
288    /// # Returns
289    /// * `Result<Self, String>` - Loaded model or error message
290    pub fn from_json(json: &str) -> Result<Self, String> {
291        serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
292    }
293
294    // Helper function: Matrix multiplication where one matrix is transposed: A^T * B
295    fn matrix_multiply_transpose(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Vec<Vec<T>> {
296        let a_rows = a.len();
297        let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
298        let b_rows = b.len();
299        let b_cols = if b_rows > 0 { b[0].len() } else { 0 };
300
301        // Result will be a_cols × b_cols
302        let mut result = vec![vec![T::zero(); b_cols]; a_cols];
303
304        for (i, result_row) in result.iter_mut().enumerate().take(a_cols) {
305            for (j, result_elem) in result_row.iter_mut().enumerate().take(b_cols) {
306                let mut sum = T::zero();
307                for k in 0..a_rows {
308                    sum = sum + (a[k][i] * b[k][j]);
309                }
310                *result_elem = sum;
311            }
312        }
313
314        result
315    }
316
317    // Helper function: Multiply transposed matrix by vector: A^T * y
318    fn vector_multiply_transpose(&self, a: &[Vec<T>], y: &[T]) -> Vec<T> {
319        let a_rows = a.len();
320        let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
321
322        let mut result = vec![T::zero(); a_cols];
323
324        for (i, result_item) in result.iter_mut().enumerate().take(a_cols) {
325            let mut sum = T::zero();
326            for j in 0..a_rows {
327                sum = sum + (a[j][i] * y[j]);
328            }
329            *result_item = sum;
330        }
331
332        result
333    }
334
335    // Helper function: Solve a system of linear equations using Gaussian elimination
336    fn solve_linear_system(&self, a: &[Vec<T>], b: &[T]) -> Result<Vec<T>, String> {
337        let n = a.len();
338        if n == 0 || a[0].len() != n || b.len() != n {
339            return Err("Invalid matrix dimensions for linear system solving".to_string());
340        }
341
342        // Create augmented matrix [A|b]
343        let mut aug = Vec::with_capacity(n);
344        for i in 0..n {
345            let mut row = a[i].clone();
346            row.push(b[i]);
347            aug.push(row);
348        }
349
350        // Gaussian elimination with partial pivoting
351        for i in 0..n {
352            // Find pivot
353            let mut max_row = i;
354            let mut max_val = aug[i][i].abs();
355
356            for (j, row) in aug.iter().enumerate().skip(i + 1).take(n - (i + 1)) {
357                let abs_val = row[i].abs();
358                if abs_val > max_val {
359                    max_row = j;
360                    max_val = abs_val;
361                }
362            }
363
364            let epsilon: T = T::from(1e-10).unwrap();
365            if max_val < epsilon {
366                return Err("Matrix is singular or near-singular".to_string());
367            }
368
369            // Swap rows if needed
370            if max_row != i {
371                aug.swap(i, max_row);
372            }
373
374            // Eliminate below
375            for j in (i + 1)..n {
376                let factor = aug[j][i] / aug[i][i];
377
378                for k in i..(n + 1) {
379                    aug[j][k] = aug[j][k] - (factor * aug[i][k]);
380                }
381            }
382        }
383
384        // Back substitution
385        let mut x = vec![T::zero(); n];
386        for i in (0..n).rev() {
387            let mut sum = aug[i][n];
388
389            for (j, &x_val) in x.iter().enumerate().skip(i + 1).take(n - (i + 1)) {
390                sum = sum - (aug[i][j] * x_val);
391            }
392
393            x[i] = sum / aug[i][i];
394        }
395
396        Ok(x)
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::utils::numeric::approx_equal;
404    use tempfile::tempdir;
405
406    #[test]
407    fn test_simple_multi_regression_f64() {
408        // Simple case: y = 2*x1 + 3*x2 + 1
409        let x = vec![
410            vec![1.0, 2.0],
411            vec![2.0, 1.0],
412            vec![3.0, 3.0],
413            vec![4.0, 2.0],
414        ];
415        let y = vec![9.0, 8.0, 16.0, 15.0];
416
417        let mut model = MultipleLinearRegression::<f64>::new();
418        let result = model.fit(&x, &y);
419
420        assert!(result.is_ok());
421        assert!(model.coefficients.len() == 3);
422        assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); // intercept
423        assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); // x1 coefficient
424        assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); // x2 coefficient
425        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
426    }
427
428    #[test]
429    fn test_simple_multi_regression_f32() {
430        // Simple case: y = 2*x1 + 3*x2 + 1
431        let x = vec![
432            vec![1.0f32, 2.0f32],
433            vec![2.0f32, 1.0f32],
434            vec![3.0f32, 3.0f32],
435            vec![4.0f32, 2.0f32],
436        ];
437        let y = vec![9.0f32, 8.0f32, 16.0f32, 15.0f32];
438
439        let mut model = MultipleLinearRegression::<f32>::new();
440        let result = model.fit(&x, &y);
441
442        assert!(result.is_ok());
443        assert!(model.coefficients.len() == 3);
444        assert!(approx_equal(model.coefficients[0], 1.0f32, Some(1e-4))); // intercept
445        assert!(approx_equal(model.coefficients[1], 2.0f32, Some(1e-4))); // x1 coefficient
446        assert!(approx_equal(model.coefficients[2], 3.0f32, Some(1e-4))); // x2 coefficient
447        assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-4)));
448    }
449
450    #[test]
451    fn test_integer_data() {
452        // Simple case: y = 2*x1 + 3*x2 + 1
453        let x = vec![
454            vec![1u32, 2u32],
455            vec![2u32, 1u32],
456            vec![3u32, 3u32],
457            vec![4u32, 2u32],
458        ];
459        let y = vec![9i32, 8i32, 16i32, 15i32];
460
461        let mut model = MultipleLinearRegression::<f64>::new();
462        let result = model.fit(&x, &y);
463
464        assert!(result.is_ok());
465        assert!(model.coefficients.len() == 3);
466        assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); // intercept
467        assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); // x1 coefficient
468        assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); // x2 coefficient
469        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
470    }
471
472    #[test]
473    fn test_prediction() {
474        // Simple case: y = 2*x1 + 3*x2 + 1
475        let x = vec![vec![1, 2], vec![2, 1], vec![3, 3], vec![4, 2]];
476        let y = vec![9, 8, 16, 15];
477
478        let mut model = MultipleLinearRegression::<f64>::new();
479        model.fit(&x, &y).unwrap();
480
481        // Test prediction: 1 + 2*5 + 3*4 = 1 + 10 + 12 = 23
482        assert!(approx_equal(model.predict(&[5u32, 4u32]), 23.0, Some(1e-6)));
483    }
484
485    #[test]
486    fn test_prediction_many() {
487        let x = vec![vec![1, 2], vec![2, 1], vec![3, 3]];
488        let y = vec![9, 8, 16];
489
490        let mut model = MultipleLinearRegression::<f64>::new();
491        model.fit(&x, &y).unwrap();
492
493        let new_x = vec![vec![1u32, 2u32], vec![5u32, 4u32]];
494
495        let predictions = model.predict_many(&new_x);
496        assert_eq!(predictions.len(), 2);
497        assert!(approx_equal(predictions[0], 9.0, Some(1e-6)));
498        assert!(approx_equal(predictions[1], 23.0, Some(1e-6)));
499    }
500
501    #[test]
502    fn test_save_load_json() {
503        // Create a temporary directory
504        let dir = tempdir().unwrap();
505        let file_path = dir.path().join("model.json");
506
507        // Create and fit a model
508        let x = vec![
509            vec![1.0, 2.0],
510            vec![2.0, 1.0],
511            vec![3.0, 3.0],
512            vec![4.0, 2.0],
513        ];
514        let y = vec![9.0, 8.0, 16.0, 15.0];
515
516        let mut model = MultipleLinearRegression::<f64>::new();
517        model.fit(&x, &y).unwrap();
518
519        // Save the model
520        let save_result = model.save(&file_path);
521        assert!(save_result.is_ok());
522
523        // Load the model
524        let loaded_model = MultipleLinearRegression::<f64>::load(&file_path);
525        assert!(loaded_model.is_ok());
526        let loaded = loaded_model.unwrap();
527
528        // Check that the loaded model has the same parameters
529        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
530        for i in 0..model.coefficients.len() {
531            assert!(approx_equal(
532                loaded.coefficients[i],
533                model.coefficients[i],
534                Some(1e-6)
535            ));
536        }
537        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
538        assert_eq!(loaded.n, model.n);
539        assert_eq!(loaded.p, model.p);
540    }
541
542    #[test]
543    fn test_save_load_binary() {
544        // Create a temporary directory
545        let dir = tempdir().unwrap();
546        let file_path = dir.path().join("model.bin");
547
548        // Create and fit a model
549        let x = vec![
550            vec![1.0, 2.0],
551            vec![2.0, 1.0],
552            vec![3.0, 3.0],
553            vec![4.0, 2.0],
554        ];
555        let y = vec![9.0, 8.0, 16.0, 15.0];
556
557        let mut model = MultipleLinearRegression::<f64>::new();
558        model.fit(&x, &y).unwrap();
559
560        // Save the model
561        let save_result = model.save_binary(&file_path);
562        assert!(save_result.is_ok());
563
564        // Load the model
565        let loaded_model = MultipleLinearRegression::<f64>::load_binary(&file_path);
566        assert!(loaded_model.is_ok());
567        let loaded = loaded_model.unwrap();
568
569        // Check that the loaded model has the same parameters
570        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
571        for i in 0..model.coefficients.len() {
572            assert!(approx_equal(
573                loaded.coefficients[i],
574                model.coefficients[i],
575                Some(1e-6)
576            ));
577        }
578        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
579        assert_eq!(loaded.n, model.n);
580        assert_eq!(loaded.p, model.p);
581    }
582
583    #[test]
584    fn test_json_serialization() {
585        // Create and fit a model
586        let x = vec![
587            vec![1.0, 2.0],
588            vec![2.0, 1.0],
589            vec![3.0, 3.0],
590            vec![4.0, 2.0],
591        ];
592        let y = vec![9.0, 8.0, 16.0, 15.0];
593
594        let mut model = MultipleLinearRegression::<f64>::new();
595        model.fit(&x, &y).unwrap();
596
597        // Serialize to JSON string
598        let json_result = model.to_json();
599        assert!(json_result.is_ok());
600        let json_str = json_result.unwrap();
601
602        // Deserialize from JSON string
603        let loaded_model = MultipleLinearRegression::<f64>::from_json(&json_str);
604        assert!(loaded_model.is_ok());
605        let loaded = loaded_model.unwrap();
606
607        // Check that the loaded model has the same parameters
608        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
609        for i in 0..model.coefficients.len() {
610            assert!(approx_equal(
611                loaded.coefficients[i],
612                model.coefficients[i],
613                Some(1e-6)
614            ));
615        }
616        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
617        assert_eq!(loaded.n, model.n);
618        assert_eq!(loaded.p, model.p);
619    }
620}