Skip to main content

rs_stats/regression/
multiple_linear_regression.rs

1// src/regression/multiple_linear_regression.rs
2
3use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{self};
9use std::path::Path;
10
11/// Multiple linear regression model that fits a hyperplane to multivariate data points.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct MultipleLinearRegression<T = f64>
14where
15    T: Float + Debug + Default + Serialize,
16{
17    /// Regression coefficients, including intercept as the first element
18    pub coefficients: Vec<T>,
19    /// Coefficient of determination (R²) - goodness of fit
20    pub r_squared: T,
21    /// Adjusted R² which accounts for the number of predictors
22    pub adjusted_r_squared: T,
23    /// Standard error of the estimate
24    pub standard_error: T,
25    /// Number of data points used for regression
26    pub n: usize,
27    /// Number of predictor variables (excluding intercept)
28    pub p: usize,
29}
30
31impl<T> Default for MultipleLinearRegression<T>
32where
33    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
34{
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl<T> MultipleLinearRegression<T>
41where
42    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
43{
44    /// Create a new multiple linear regression model without fitting any data
45    pub fn new() -> Self {
46        Self {
47            coefficients: Vec::new(),
48            r_squared: T::zero(),
49            adjusted_r_squared: T::zero(),
50            standard_error: T::zero(),
51            n: 0,
52            p: 0,
53        }
54    }
55
56    /// Fit a multiple linear regression model to the provided data
57    ///
58    /// # Arguments
59    /// * `x_values` - 2D array where each row is an observation and each column is a predictor
60    /// * `y_values` - Dependent variable values (observations)
61    ///
62    /// # Returns
63    /// * `StatsResult<()>` - Ok if successful, Err with StatsError if the inputs are invalid
64    ///
65    /// # Errors
66    /// Returns `StatsError::EmptyData` if input arrays are empty.
67    /// Returns `StatsError::DimensionMismatch` if X and Y arrays have different lengths.
68    /// Returns `StatsError::InvalidInput` if rows in X have inconsistent lengths.
69    /// Returns `StatsError::ConversionError` if value conversion fails.
70    /// Returns `StatsError::MathematicalError` if the linear system cannot be solved.
71    pub fn fit<U, V>(&mut self, x_values: &[Vec<U>], y_values: &[V]) -> StatsResult<()>
72    where
73        U: NumCast + Copy,
74        V: NumCast + Copy,
75    {
76        // Validate inputs
77        if x_values.is_empty() || y_values.is_empty() {
78            return Err(StatsError::empty_data(
79                "Cannot fit regression with empty arrays",
80            ));
81        }
82
83        if x_values.len() != y_values.len() {
84            return Err(StatsError::dimension_mismatch(format!(
85                "Number of observations in X and Y must match (got {} and {})",
86                x_values.len(),
87                y_values.len()
88            )));
89        }
90
91        self.n = x_values.len();
92
93        // Check that all rows in x_values have the same length
94        if x_values.is_empty() {
95            return Err(StatsError::empty_data("X values array is empty"));
96        }
97
98        self.p = x_values[0].len();
99
100        for (i, row) in x_values.iter().enumerate() {
101            if row.len() != self.p {
102                return Err(StatsError::invalid_input(format!(
103                    "All rows in X must have the same number of features (row {} has {} features, expected {})",
104                    i,
105                    row.len(),
106                    self.p
107                )));
108            }
109        }
110
111        // Convert input arrays to T type
112        let mut x_cast: Vec<Vec<T>> = Vec::with_capacity(self.n);
113        for (row_idx, row) in x_values.iter().enumerate() {
114            let row_cast: StatsResult<Vec<T>> = row
115                .iter()
116                .enumerate()
117                .map(|(col_idx, &x)| {
118                    T::from(x).ok_or_else(|| {
119                        StatsError::conversion_error(format!(
120                            "Failed to cast X value at row {}, column {} to type T",
121                            row_idx, col_idx
122                        ))
123                    })
124                })
125                .collect();
126            x_cast.push(row_cast?);
127        }
128
129        let y_cast: Vec<T> = y_values
130            .iter()
131            .enumerate()
132            .map(|(i, &y)| {
133                T::from(y).ok_or_else(|| {
134                    StatsError::conversion_error(format!(
135                        "Failed to cast Y value at index {} to type T",
136                        i
137                    ))
138                })
139            })
140            .collect::<StatsResult<Vec<T>>>()?;
141
142        // Augment the X matrix with a column of 1s for the intercept
143        let mut augmented_x = Vec::with_capacity(self.n);
144        for row in &x_cast {
145            let mut augmented_row = Vec::with_capacity(self.p + 1);
146            augmented_row.push(T::one()); // Intercept term
147            augmented_row.extend_from_slice(row);
148            augmented_x.push(augmented_row);
149        }
150
151        // Compute X^T * X
152        let xt_x = self.matrix_multiply_transpose(&augmented_x, &augmented_x);
153
154        // Compute X^T * y
155        let xt_y = self.vector_multiply_transpose(&augmented_x, &y_cast);
156
157        // Solve the normal equations: (X^T * X) * β = X^T * y for β
158        match self.solve_linear_system(&xt_x, &xt_y) {
159            Ok(solution) => {
160                self.coefficients = solution;
161            }
162            Err(e) => return Err(e),
163        }
164
165        // Calculate fitted values and R²
166        let n_as_t = T::from(self.n).ok_or_else(|| {
167            StatsError::conversion_error(format!("Failed to convert {} to type T", self.n))
168        })?;
169        let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
170
171        let mut ss_total = T::zero();
172        let mut ss_residual = T::zero();
173
174        for i in 0..self.n {
175            let predicted = self.predict_t(&x_cast[i]);
176            let residual = y_cast[i] - predicted;
177
178            ss_residual = ss_residual + (residual * residual);
179            let diff = y_cast[i] - y_mean;
180            ss_total = ss_total + (diff * diff);
181        }
182
183        // Calculate R² and adjusted R²
184        if ss_total > T::zero() {
185            self.r_squared = T::one() - (ss_residual / ss_total);
186
187            // Adjusted R² = 1 - [(1 - R²) * (n - 1) / (n - p - 1)]
188            if self.n > self.p + 1 {
189                let n_minus_1 = T::from(self.n - 1).ok_or_else(|| {
190                    StatsError::conversion_error(format!(
191                        "Failed to convert {} to type T",
192                        self.n - 1
193                    ))
194                })?;
195                let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
196                    StatsError::conversion_error(format!(
197                        "Failed to convert {} to type T",
198                        self.n - self.p - 1
199                    ))
200                })?;
201
202                self.adjusted_r_squared =
203                    T::one() - ((T::one() - self.r_squared) * n_minus_1 / n_minus_p_minus_1);
204            }
205        }
206
207        // Calculate standard error
208        if self.n > self.p + 1 {
209            let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
210                StatsError::conversion_error(format!(
211                    "Failed to convert {} to type T",
212                    self.n - self.p - 1
213                ))
214            })?;
215            self.standard_error = (ss_residual / n_minus_p_minus_1).sqrt();
216        }
217
218        Ok(())
219    }
220
221    /// Predict y value for a given set of x values using the fitted model (internal version with type T)
222    fn predict_t(&self, x: &[T]) -> T {
223        if x.len() != self.p || self.coefficients.is_empty() {
224            return T::nan();
225        }
226
227        // First coefficient is the intercept
228        let mut result = self.coefficients[0];
229
230        // Add the weighted features
231        for (i, &xi) in x.iter().enumerate().take(self.p) {
232            result = result + (self.coefficients[i + 1] * xi);
233        }
234
235        result
236    }
237
238    /// Predict y value for a given set of x values using the fitted model
239    ///
240    /// # Arguments
241    /// * `x` - Vector of x values for prediction (must match the number of features used during fitting)
242    ///
243    /// # Returns
244    /// * `StatsResult<T>` - The predicted y value
245    ///
246    /// # Errors
247    /// Returns `StatsError::NotFitted` if the model has not been fitted (coefficients is empty).
248    /// Returns `StatsError::DimensionMismatch` if the number of features doesn't match the model (x.len() != p).
249    /// Returns `StatsError::ConversionError` if type conversion fails.
250    ///
251    /// # Examples
252    /// ```
253    /// use rs_stats::regression::multiple_linear_regression::MultipleLinearRegression;
254    ///
255    /// let mut model = MultipleLinearRegression::<f64>::new();
256    /// let x = vec![
257    ///     vec![1.0, 2.0],
258    ///     vec![2.0, 1.0],
259    ///     vec![3.0, 3.0],
260    ///     vec![4.0, 2.0],
261    /// ];
262    /// let y = vec![5.0, 4.0, 9.0, 8.0];
263    /// model.fit(&x, &y).unwrap();
264    ///
265    /// let prediction = model.predict(&[3.0, 4.0]).unwrap();
266    /// ```
267    pub fn predict<U>(&self, x: &[U]) -> StatsResult<T>
268    where
269        U: NumCast + Copy,
270    {
271        if self.coefficients.is_empty() {
272            return Err(StatsError::not_fitted(
273                "Model has not been fitted. Call fit() before predicting.",
274            ));
275        }
276
277        if x.len() != self.p {
278            return Err(StatsError::dimension_mismatch(format!(
279                "Expected {} features, but got {}",
280                self.p,
281                x.len()
282            )));
283        }
284
285        // Convert input to T type
286        let x_cast: StatsResult<Vec<T>> = x
287            .iter()
288            .enumerate()
289            .map(|(i, &val)| {
290                T::from(val).ok_or_else(|| {
291                    StatsError::conversion_error(format!(
292                        "Failed to convert feature value at index {} to type T",
293                        i
294                    ))
295                })
296            })
297            .collect();
298
299        Ok(self.predict_t(&x_cast?))
300    }
301
302    /// Calculate predictions for multiple observations
303    ///
304    /// # Arguments
305    /// * `x_values` - 2D array of feature values for prediction
306    ///
307    /// # Returns
308    /// * `StatsResult<Vec<T>>` - Vector of predicted y values
309    ///
310    /// # Errors
311    /// Returns `StatsError::NotFitted` if the model has not been fitted.
312    /// Returns an error if any prediction fails (dimension mismatch or conversion error).
313    ///
314    /// # Examples
315    /// ```
316    /// use rs_stats::regression::multiple_linear_regression::MultipleLinearRegression;
317    ///
318    /// let mut model = MultipleLinearRegression::<f64>::new();
319    /// let x = vec![
320    ///     vec![1.0, 2.0],
321    ///     vec![2.0, 1.0],
322    ///     vec![3.0, 3.0],
323    ///     vec![4.0, 2.0],
324    /// ];
325    /// let y = vec![5.0, 4.0, 9.0, 8.0];
326    /// model.fit(&x, &y).unwrap();
327    ///
328    /// let predictions = model.predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]]).unwrap();
329    /// assert_eq!(predictions.len(), 2);
330    /// ```
331    pub fn predict_many<U>(&self, x_values: &[Vec<U>]) -> StatsResult<Vec<T>>
332    where
333        U: NumCast + Copy,
334    {
335        x_values.iter().map(|x| self.predict(x)).collect()
336    }
337
338    /// Save the model to a file
339    ///
340    /// # Arguments
341    /// * `path` - Path where to save the model
342    ///
343    /// # Returns
344    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
345    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
346        let file = File::create(path)?;
347        // Use JSON format for human-readability
348        serde_json::to_writer(file, self).map_err(io::Error::other)
349    }
350
351    /// Save the model in binary format
352    ///
353    /// # Arguments
354    /// * `path` - Path where to save the model
355    ///
356    /// # Returns
357    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
358    pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
359        let file = File::create(path)?;
360        // Use bincode for more compact binary format
361        bincode::serialize_into(file, self).map_err(io::Error::other)
362    }
363
364    /// Load a model from a file
365    ///
366    /// # Arguments
367    /// * `path` - Path to the saved model file
368    ///
369    /// # Returns
370    /// * `Result<Self, io::Error>` - Loaded model or IO error
371    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
372        let file = File::open(path)?;
373        // Try to load as JSON format
374        serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
375    }
376
377    /// Load a model from a binary file
378    ///
379    /// # Arguments
380    /// * `path` - Path to the saved model file
381    ///
382    /// # Returns
383    /// * `Result<Self, io::Error>` - Loaded model or IO error
384    pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
385        let file = File::open(path)?;
386        // Try to load as bincode format
387        bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
388    }
389
390    /// Save the model to a string in JSON format
391    ///
392    /// # Returns
393    /// * `Result<String, String>` - JSON string representation or error message
394    pub fn to_json(&self) -> Result<String, String> {
395        serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
396    }
397
398    /// Load a model from a JSON string
399    ///
400    /// # Arguments
401    /// * `json` - JSON string containing the model data
402    ///
403    /// # Returns
404    /// * `Result<Self, String>` - Loaded model or error message
405    pub fn from_json(json: &str) -> Result<Self, String> {
406        serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
407    }
408
409    // Helper function: Matrix multiplication where one matrix is transposed: A^T * B
410    // Loop order: k (rows of A) in outer loop for cache locality —
411    // both a[k] and b[k] are contiguous row accesses, avoiding column-stride misses.
412    fn matrix_multiply_transpose(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Vec<Vec<T>> {
413        let a_rows = a.len();
414        let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
415        let b_cols = if !b.is_empty() { b[0].len() } else { 0 };
416
417        // Result will be a_cols × b_cols
418        let mut result = vec![vec![T::zero(); b_cols]; a_cols];
419
420        // Cache-friendly loop: iterate over shared dimension (k) in the outer loop
421        for k in 0..a_rows {
422            let a_row = &a[k];
423            let b_row = &b[k];
424            for i in 0..a_cols {
425                let a_ki = a_row[i];
426                let result_row = &mut result[i];
427                for j in 0..b_cols {
428                    result_row[j] = result_row[j] + (a_ki * b_row[j]);
429                }
430            }
431        }
432
433        result
434    }
435
436    // Helper function: Multiply transposed matrix by vector: A^T * y
437    fn vector_multiply_transpose(&self, a: &[Vec<T>], y: &[T]) -> Vec<T> {
438        let a_rows = a.len();
439        let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
440
441        let mut result = vec![T::zero(); a_cols];
442
443        for (i, result_item) in result.iter_mut().enumerate().take(a_cols) {
444            let mut sum = T::zero();
445            for j in 0..a_rows {
446                sum = sum + (a[j][i] * y[j]);
447            }
448            *result_item = sum;
449        }
450
451        result
452    }
453
454    // Helper function: Solve a system of linear equations using Gaussian elimination
455    fn solve_linear_system(&self, a: &[Vec<T>], b: &[T]) -> StatsResult<Vec<T>> {
456        let n = a.len();
457        if n == 0 || a[0].len() != n || b.len() != n {
458            return Err(StatsError::dimension_mismatch(format!(
459                "Invalid matrix dimensions for linear system solving: A is {}x{}, b has {} elements",
460                n,
461                if n > 0 { a[0].len() } else { 0 },
462                b.len()
463            )));
464        }
465
466        // Create augmented matrix [A|b] — allocate once with correct capacity
467        let mut aug: Vec<Vec<T>> = Vec::with_capacity(n);
468        for i in 0..n {
469            let mut row = Vec::with_capacity(n + 1);
470            row.extend_from_slice(&a[i]);
471            row.push(b[i]);
472            aug.push(row);
473        }
474
475        // Gaussian elimination with partial pivoting
476        for i in 0..n {
477            // Find pivot — direct index range, no skip/take overhead
478            let mut max_row = i;
479            let mut max_val = aug[i][i].abs();
480
481            #[allow(clippy::needless_range_loop)]
482            for j in (i + 1)..n {
483                let abs_val = aug[j][i].abs();
484                if abs_val > max_val {
485                    max_row = j;
486                    max_val = abs_val;
487                }
488            }
489
490            let epsilon: T = T::from(1e-10).ok_or_else(|| {
491                StatsError::conversion_error("Failed to convert epsilon (1e-10) to type T")
492            })?;
493            if max_val < epsilon {
494                return Err(StatsError::mathematical_error(
495                    "Matrix is singular or near-singular, cannot solve linear system",
496                ));
497            }
498
499            // Swap rows if needed
500            if max_row != i {
501                aug.swap(i, max_row);
502            }
503
504            // Eliminate below
505            for j in (i + 1)..n {
506                let factor = aug[j][i] / aug[i][i];
507
508                for k in i..(n + 1) {
509                    aug[j][k] = aug[j][k] - (factor * aug[i][k]);
510                }
511            }
512        }
513
514        // Back substitution — direct range indexing
515        let mut x = vec![T::zero(); n];
516        for i in (0..n).rev() {
517            let mut sum = aug[i][n];
518
519            #[allow(clippy::needless_range_loop)]
520            for j in (i + 1)..n {
521                sum = sum - (aug[i][j] * x[j]);
522            }
523
524            x[i] = sum / aug[i][i];
525        }
526
527        Ok(x)
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use crate::utils::approx_equal;
535    use tempfile::tempdir;
536
537    #[test]
538    fn test_simple_multi_regression_f64() {
539        // Simple case: y = 2*x1 + 3*x2 + 1
540        let x = vec![
541            vec![1.0, 2.0],
542            vec![2.0, 1.0],
543            vec![3.0, 3.0],
544            vec![4.0, 2.0],
545        ];
546        let y = vec![9.0, 8.0, 16.0, 15.0];
547
548        let mut model = MultipleLinearRegression::<f64>::new();
549        let result = model.fit(&x, &y);
550
551        assert!(result.is_ok());
552        assert!(model.coefficients.len() == 3);
553        assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); // intercept
554        assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); // x1 coefficient
555        assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); // x2 coefficient
556        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
557    }
558
559    #[test]
560    fn test_simple_multi_regression_f32() {
561        // Simple case: y = 2*x1 + 3*x2 + 1
562        let x = vec![
563            vec![1.0f32, 2.0f32],
564            vec![2.0f32, 1.0f32],
565            vec![3.0f32, 3.0f32],
566            vec![4.0f32, 2.0f32],
567        ];
568        let y = vec![9.0f32, 8.0f32, 16.0f32, 15.0f32];
569
570        let mut model = MultipleLinearRegression::<f32>::new();
571        let result = model.fit(&x, &y);
572
573        assert!(result.is_ok());
574        assert!(model.coefficients.len() == 3);
575        assert!(approx_equal(model.coefficients[0], 1.0f32, Some(1e-4))); // intercept
576        assert!(approx_equal(model.coefficients[1], 2.0f32, Some(1e-4))); // x1 coefficient
577        assert!(approx_equal(model.coefficients[2], 3.0f32, Some(1e-4))); // x2 coefficient
578        assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-4)));
579    }
580
581    #[test]
582    fn test_integer_data() {
583        // Simple case: y = 2*x1 + 3*x2 + 1
584        let x = vec![
585            vec![1u32, 2u32],
586            vec![2u32, 1u32],
587            vec![3u32, 3u32],
588            vec![4u32, 2u32],
589        ];
590        let y = vec![9i32, 8i32, 16i32, 15i32];
591
592        let mut model = MultipleLinearRegression::<f64>::new();
593        let result = model.fit(&x, &y);
594
595        assert!(result.is_ok());
596        assert!(model.coefficients.len() == 3);
597        assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); // intercept
598        assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); // x1 coefficient
599        assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); // x2 coefficient
600        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
601    }
602
603    #[test]
604    fn test_prediction() {
605        // Simple case: y = 2*x1 + 3*x2 + 1
606        let x = vec![vec![1, 2], vec![2, 1], vec![3, 3], vec![4, 2]];
607        let y = vec![9, 8, 16, 15];
608
609        let mut model = MultipleLinearRegression::<f64>::new();
610        model.fit(&x, &y).unwrap();
611
612        // Test prediction: 1 + 2*5 + 3*4 = 1 + 10 + 12 = 23
613        assert!(approx_equal(
614            model.predict(&[5u32, 4u32]).unwrap(),
615            23.0,
616            Some(1e-6)
617        ));
618    }
619
620    #[test]
621    fn test_prediction_many() {
622        let x = vec![vec![1, 2], vec![2, 1], vec![3, 3]];
623        let y = vec![9, 8, 16];
624
625        let mut model = MultipleLinearRegression::<f64>::new();
626        model.fit(&x, &y).unwrap();
627
628        let new_x = vec![vec![1u32, 2u32], vec![5u32, 4u32]];
629
630        let predictions = model.predict_many(&new_x).unwrap();
631        assert_eq!(predictions.len(), 2);
632        assert!(approx_equal(predictions[0], 9.0, Some(1e-6)));
633        assert!(approx_equal(predictions[1], 23.0, Some(1e-6)));
634    }
635
636    #[test]
637    fn test_save_load_json() {
638        // Create a temporary directory
639        let dir = tempdir().unwrap();
640        let file_path = dir.path().join("model.json");
641
642        // Create and fit a model
643        let x = vec![
644            vec![1.0, 2.0],
645            vec![2.0, 1.0],
646            vec![3.0, 3.0],
647            vec![4.0, 2.0],
648        ];
649        let y = vec![9.0, 8.0, 16.0, 15.0];
650
651        let mut model = MultipleLinearRegression::<f64>::new();
652        model.fit(&x, &y).unwrap();
653
654        // Save the model
655        let save_result = model.save(&file_path);
656        assert!(save_result.is_ok());
657
658        // Load the model
659        let loaded_model = MultipleLinearRegression::<f64>::load(&file_path);
660        assert!(loaded_model.is_ok());
661        let loaded = loaded_model.unwrap();
662
663        // Check that the loaded model has the same parameters
664        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
665        for i in 0..model.coefficients.len() {
666            assert!(approx_equal(
667                loaded.coefficients[i],
668                model.coefficients[i],
669                Some(1e-6)
670            ));
671        }
672        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
673        assert_eq!(loaded.n, model.n);
674        assert_eq!(loaded.p, model.p);
675    }
676
677    #[test]
678    fn test_save_load_binary() {
679        // Create a temporary directory
680        let dir = tempdir().unwrap();
681        let file_path = dir.path().join("model.bin");
682
683        // Create and fit a model
684        let x = vec![
685            vec![1.0, 2.0],
686            vec![2.0, 1.0],
687            vec![3.0, 3.0],
688            vec![4.0, 2.0],
689        ];
690        let y = vec![9.0, 8.0, 16.0, 15.0];
691
692        let mut model = MultipleLinearRegression::<f64>::new();
693        model.fit(&x, &y).unwrap();
694
695        // Save the model
696        let save_result = model.save_binary(&file_path);
697        assert!(save_result.is_ok());
698
699        // Load the model
700        let loaded_model = MultipleLinearRegression::<f64>::load_binary(&file_path);
701        assert!(loaded_model.is_ok());
702        let loaded = loaded_model.unwrap();
703
704        // Check that the loaded model has the same parameters
705        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
706        for i in 0..model.coefficients.len() {
707            assert!(approx_equal(
708                loaded.coefficients[i],
709                model.coefficients[i],
710                Some(1e-6)
711            ));
712        }
713        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
714        assert_eq!(loaded.n, model.n);
715        assert_eq!(loaded.p, model.p);
716    }
717
718    #[test]
719    fn test_json_serialization() {
720        // Create and fit a model
721        let x = vec![
722            vec![1.0, 2.0],
723            vec![2.0, 1.0],
724            vec![3.0, 3.0],
725            vec![4.0, 2.0],
726        ];
727        let y = vec![9.0, 8.0, 16.0, 15.0];
728
729        let mut model = MultipleLinearRegression::<f64>::new();
730        model.fit(&x, &y).unwrap();
731
732        // Serialize to JSON string
733        let json_result = model.to_json();
734        assert!(json_result.is_ok());
735        let json_str = json_result.unwrap();
736
737        // Deserialize from JSON string
738        let loaded_model = MultipleLinearRegression::<f64>::from_json(&json_str);
739        assert!(loaded_model.is_ok());
740        let loaded = loaded_model.unwrap();
741
742        // Check that the loaded model has the same parameters
743        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
744        for i in 0..model.coefficients.len() {
745            assert!(approx_equal(
746                loaded.coefficients[i],
747                model.coefficients[i],
748                Some(1e-6)
749            ));
750        }
751        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
752        assert_eq!(loaded.n, model.n);
753        assert_eq!(loaded.p, model.p);
754    }
755
756    #[test]
757    fn test_predict_not_fitted() {
758        // Test that predict() works even when model is not fitted
759        let model = MultipleLinearRegression::<f64>::new();
760        // Don't fit the model
761
762        // Predict should return an error when model is not fitted
763        let features = vec![1.0, 2.0];
764        let result = model.predict(&features);
765        assert!(result.is_err());
766        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
767    }
768
769    #[test]
770    fn test_predict_dimension_mismatch() {
771        // Test predict with wrong number of features
772        let mut model = MultipleLinearRegression::<f64>::new();
773        // Use more data points to avoid singular matrix
774        let x = vec![
775            vec![1.0, 2.0],
776            vec![2.0, 1.0],
777            vec![3.0, 3.0],
778            vec![4.0, 2.0],
779        ];
780        let y = vec![3.0, 3.0, 6.0, 6.0];
781        model.fit(&x, &y).unwrap();
782
783        // Try to predict with wrong number of features
784        let wrong_features = vec![1.0]; // Should be 2 features
785        let result = model.predict(&wrong_features);
786        // predict returns error when dimension mismatch
787        assert!(result.is_err());
788        assert!(matches!(
789            result.unwrap_err(),
790            StatsError::DimensionMismatch { .. }
791        ));
792    }
793
794    #[test]
795    fn test_fit_singular_matrix() {
796        // Test with linearly dependent features (singular matrix)
797        // This should trigger a mathematical error
798        let x = vec![
799            vec![1.0, 2.0, 3.0], // Feature 3 = Feature 1 + Feature 2 (linearly dependent)
800            vec![2.0, 4.0, 6.0], // Feature 3 = 2 * (Feature 1 + Feature 2)
801            vec![3.0, 6.0, 9.0], // Feature 3 = 3 * (Feature 1 + Feature 2)
802        ];
803        let y = vec![1.0, 2.0, 3.0];
804
805        let mut model = MultipleLinearRegression::<f64>::new();
806        let result = model.fit(&x, &y);
807        // This might succeed or fail depending on numerical precision
808        // The important thing is it doesn't panic
809        match result {
810            Ok(_) => {
811                // If it succeeds, verify the model is valid
812                assert!(!model.coefficients.is_empty());
813            }
814            Err(e) => {
815                // If it fails, it should be a mathematical error
816                assert!(matches!(e, StatsError::MathematicalError { .. }));
817            }
818        }
819    }
820
821    #[test]
822    fn test_save_invalid_path() {
823        // Test saving to an invalid path
824        let mut model = MultipleLinearRegression::<f64>::new();
825        let x = vec![vec![1.0], vec![2.0]];
826        let y = vec![2.0, 4.0];
827        model.fit(&x, &y).unwrap();
828
829        let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
830        let result = model.save(invalid_path);
831        assert!(
832            result.is_err(),
833            "Saving to invalid path should return error"
834        );
835    }
836
837    #[test]
838    fn test_load_nonexistent_file() {
839        // Test loading a non-existent file
840        let nonexistent_path = std::path::Path::new("/nonexistent/file.json");
841        let result = MultipleLinearRegression::<f64>::load(nonexistent_path);
842        assert!(
843            result.is_err(),
844            "Loading non-existent file should return error"
845        );
846    }
847
848    #[test]
849    fn test_from_json_invalid() {
850        // Test deserializing invalid JSON string
851        let invalid_json = "not valid json";
852        let result = MultipleLinearRegression::<f64>::from_json(invalid_json);
853        assert!(
854            result.is_err(),
855            "Deserializing invalid JSON should return error"
856        );
857    }
858
859    #[test]
860    fn test_predict_t_coefficients_empty() {
861        // Test predict_t when coefficients are empty
862        let model = MultipleLinearRegression::<f64>::new();
863        let features = vec![1.0, 2.0];
864        // predict_t is private, but we can test through predict
865        let result = model.predict(&features);
866        assert!(result.is_err());
867        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
868    }
869
870    #[test]
871    fn test_fit_x_values_empty_after_check() {
872        // This tests the redundant check at line 94 (though it should never execute)
873        // But we test it to cover the branch
874        let mut model = MultipleLinearRegression::<f64>::new();
875        // This will fail at the first empty check, but tests the code path
876        let x: Vec<Vec<f64>> = vec![];
877        let y: Vec<f64> = vec![];
878        let result = model.fit(&x, &y);
879        assert!(result.is_err());
880    }
881
882    #[test]
883    fn test_predict_many_not_fitted() {
884        // Test predict_many when model is not fitted
885        let model = MultipleLinearRegression::<f64>::new();
886        let result = model.predict_many(&[vec![1.0, 2.0]]);
887        assert!(result.is_err());
888        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
889    }
890
891    #[test]
892    fn test_predict_many_dimension_mismatch() {
893        // Test predict_many with wrong number of features
894        let mut model = MultipleLinearRegression::<f64>::new();
895        let x = vec![vec![1.0, 2.0], vec![2.0, 1.0], vec![3.0, 3.0]];
896        let y = vec![3.0, 3.0, 6.0];
897        model.fit(&x, &y).unwrap();
898
899        // Try to predict with wrong number of features
900        let wrong_features = vec![vec![1.0]]; // Should be 2 features
901        let result = model.predict_many(&wrong_features);
902        assert!(result.is_err());
903        assert!(matches!(
904            result.unwrap_err(),
905            StatsError::DimensionMismatch { .. }
906        ));
907    }
908
909    #[test]
910    fn test_predict_many_success() {
911        // Test predict_many with valid data
912        let mut model = MultipleLinearRegression::<f64>::new();
913        let x = vec![
914            vec![1.0, 2.0],
915            vec![2.0, 1.0],
916            vec![3.0, 3.0],
917            vec![4.0, 2.0],
918        ];
919        let y = vec![3.0, 3.0, 6.0, 6.0];
920        model.fit(&x, &y).unwrap();
921
922        let predictions = model
923            .predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]])
924            .unwrap();
925        assert_eq!(predictions.len(), 2);
926    }
927}