rs_stats/regression/
multiple_linear_regression.rs

1// src/regression/multiple_linear_regression.rs
2
3use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{self};
9use std::path::Path;
10
11/// Multiple linear regression model that fits a hyperplane to multivariate data points.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct MultipleLinearRegression<T = f64>
14where
15    T: Float + Debug + Default + Serialize,
16{
17    /// Regression coefficients, including intercept as the first element
18    pub coefficients: Vec<T>,
19    /// Coefficient of determination (R²) - goodness of fit
20    pub r_squared: T,
21    /// Adjusted R² which accounts for the number of predictors
22    pub adjusted_r_squared: T,
23    /// Standard error of the estimate
24    pub standard_error: T,
25    /// Number of data points used for regression
26    pub n: usize,
27    /// Number of predictor variables (excluding intercept)
28    pub p: usize,
29}
30
31impl<T> Default for MultipleLinearRegression<T>
32where
33    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
34{
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl<T> MultipleLinearRegression<T>
41where
42    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
43{
44    /// Create a new multiple linear regression model without fitting any data
45    pub fn new() -> Self {
46        Self {
47            coefficients: Vec::new(),
48            r_squared: T::zero(),
49            adjusted_r_squared: T::zero(),
50            standard_error: T::zero(),
51            n: 0,
52            p: 0,
53        }
54    }
55
56    /// Fit a multiple linear regression model to the provided data
57    ///
58    /// # Arguments
59    /// * `x_values` - 2D array where each row is an observation and each column is a predictor
60    /// * `y_values` - Dependent variable values (observations)
61    ///
62    /// # Returns
63    /// * `StatsResult<()>` - Ok if successful, Err with StatsError if the inputs are invalid
64    ///
65    /// # Errors
66    /// Returns `StatsError::EmptyData` if input arrays are empty.
67    /// Returns `StatsError::DimensionMismatch` if X and Y arrays have different lengths.
68    /// Returns `StatsError::InvalidInput` if rows in X have inconsistent lengths.
69    /// Returns `StatsError::ConversionError` if value conversion fails.
70    /// Returns `StatsError::MathematicalError` if the linear system cannot be solved.
71    pub fn fit<U, V>(&mut self, x_values: &[Vec<U>], y_values: &[V]) -> StatsResult<()>
72    where
73        U: NumCast + Copy,
74        V: NumCast + Copy,
75    {
76        // Validate inputs
77        if x_values.is_empty() || y_values.is_empty() {
78            return Err(StatsError::empty_data(
79                "Cannot fit regression with empty arrays",
80            ));
81        }
82
83        if x_values.len() != y_values.len() {
84            return Err(StatsError::dimension_mismatch(format!(
85                "Number of observations in X and Y must match (got {} and {})",
86                x_values.len(),
87                y_values.len()
88            )));
89        }
90
91        self.n = x_values.len();
92
93        // Check that all rows in x_values have the same length
94        if x_values.is_empty() {
95            return Err(StatsError::empty_data("X values array is empty"));
96        }
97
98        self.p = x_values[0].len();
99
100        for (i, row) in x_values.iter().enumerate() {
101            if row.len() != self.p {
102                return Err(StatsError::invalid_input(format!(
103                    "All rows in X must have the same number of features (row {} has {} features, expected {})",
104                    i,
105                    row.len(),
106                    self.p
107                )));
108            }
109        }
110
111        // Convert input arrays to T type
112        let mut x_cast: Vec<Vec<T>> = Vec::with_capacity(self.n);
113        for (row_idx, row) in x_values.iter().enumerate() {
114            let row_cast: StatsResult<Vec<T>> = row
115                .iter()
116                .enumerate()
117                .map(|(col_idx, &x)| {
118                    T::from(x).ok_or_else(|| {
119                        StatsError::conversion_error(format!(
120                            "Failed to cast X value at row {}, column {} to type T",
121                            row_idx, col_idx
122                        ))
123                    })
124                })
125                .collect();
126            x_cast.push(row_cast?);
127        }
128
129        let y_cast: Vec<T> = y_values
130            .iter()
131            .enumerate()
132            .map(|(i, &y)| {
133                T::from(y).ok_or_else(|| {
134                    StatsError::conversion_error(format!(
135                        "Failed to cast Y value at index {} to type T",
136                        i
137                    ))
138                })
139            })
140            .collect::<StatsResult<Vec<T>>>()?;
141
142        // Augment the X matrix with a column of 1s for the intercept
143        let mut augmented_x = Vec::with_capacity(self.n);
144        for row in &x_cast {
145            let mut augmented_row = Vec::with_capacity(self.p + 1);
146            augmented_row.push(T::one()); // Intercept term
147            augmented_row.extend_from_slice(row);
148            augmented_x.push(augmented_row);
149        }
150
151        // Compute X^T * X
152        let xt_x = self.matrix_multiply_transpose(&augmented_x, &augmented_x);
153
154        // Compute X^T * y
155        let xt_y = self.vector_multiply_transpose(&augmented_x, &y_cast);
156
157        // Solve the normal equations: (X^T * X) * β = X^T * y for β
158        match self.solve_linear_system(&xt_x, &xt_y) {
159            Ok(solution) => {
160                self.coefficients = solution;
161            }
162            Err(e) => return Err(e),
163        }
164
165        // Calculate fitted values and R²
166        let n_as_t = T::from(self.n).ok_or_else(|| {
167            StatsError::conversion_error(format!("Failed to convert {} to type T", self.n))
168        })?;
169        let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
170
171        let mut ss_total = T::zero();
172        let mut ss_residual = T::zero();
173
174        for i in 0..self.n {
175            let predicted = self.predict_t(&x_cast[i]);
176            let residual = y_cast[i] - predicted;
177
178            ss_residual = ss_residual + (residual * residual);
179            let diff = y_cast[i] - y_mean;
180            ss_total = ss_total + (diff * diff);
181        }
182
183        // Calculate R² and adjusted R²
184        if ss_total > T::zero() {
185            self.r_squared = T::one() - (ss_residual / ss_total);
186
187            // Adjusted R² = 1 - [(1 - R²) * (n - 1) / (n - p - 1)]
188            if self.n > self.p + 1 {
189                let n_minus_1 = T::from(self.n - 1).ok_or_else(|| {
190                    StatsError::conversion_error(format!(
191                        "Failed to convert {} to type T",
192                        self.n - 1
193                    ))
194                })?;
195                let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
196                    StatsError::conversion_error(format!(
197                        "Failed to convert {} to type T",
198                        self.n - self.p - 1
199                    ))
200                })?;
201
202                self.adjusted_r_squared =
203                    T::one() - ((T::one() - self.r_squared) * n_minus_1 / n_minus_p_minus_1);
204            }
205        }
206
207        // Calculate standard error
208        if self.n > self.p + 1 {
209            let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
210                StatsError::conversion_error(format!(
211                    "Failed to convert {} to type T",
212                    self.n - self.p - 1
213                ))
214            })?;
215            self.standard_error = (ss_residual / n_minus_p_minus_1).sqrt();
216        }
217
218        Ok(())
219    }
220
221    /// Predict y value for a given set of x values using the fitted model (internal version with type T)
222    fn predict_t(&self, x: &[T]) -> T {
223        if x.len() != self.p || self.coefficients.is_empty() {
224            return T::nan();
225        }
226
227        // First coefficient is the intercept
228        let mut result = self.coefficients[0];
229
230        // Add the weighted features
231        for (i, &xi) in x.iter().enumerate().take(self.p) {
232            result = result + (self.coefficients[i + 1] * xi);
233        }
234
235        result
236    }
237
238    /// Predict y value for a given set of x values using the fitted model
239    ///
240    /// # Arguments
241    /// * `x` - Vector of x values for prediction (must match the number of features used during fitting)
242    ///
243    /// # Returns
244    /// * `StatsResult<T>` - The predicted y value
245    ///
246    /// # Errors
247    /// Returns `StatsError::NotFitted` if the model has not been fitted (coefficients is empty).
248    /// Returns `StatsError::DimensionMismatch` if the number of features doesn't match the model (x.len() != p).
249    /// Returns `StatsError::ConversionError` if type conversion fails.
250    ///
251    /// # Examples
252    /// ```
253    /// use rs_stats::regression::multiple_linear_regression::MultipleLinearRegression;
254    ///
255    /// let mut model = MultipleLinearRegression::<f64>::new();
256    /// let x = vec![
257    ///     vec![1.0, 2.0],
258    ///     vec![2.0, 1.0],
259    ///     vec![3.0, 3.0],
260    ///     vec![4.0, 2.0],
261    /// ];
262    /// let y = vec![5.0, 4.0, 9.0, 8.0];
263    /// model.fit(&x, &y).unwrap();
264    ///
265    /// let prediction = model.predict(&[3.0, 4.0]).unwrap();
266    /// ```
267    pub fn predict<U>(&self, x: &[U]) -> StatsResult<T>
268    where
269        U: NumCast + Copy,
270    {
271        if self.coefficients.is_empty() {
272            return Err(StatsError::not_fitted(
273                "Model has not been fitted. Call fit() before predicting.",
274            ));
275        }
276
277        if x.len() != self.p {
278            return Err(StatsError::dimension_mismatch(format!(
279                "Expected {} features, but got {}",
280                self.p,
281                x.len()
282            )));
283        }
284
285        // Convert input to T type
286        let x_cast: StatsResult<Vec<T>> = x
287            .iter()
288            .enumerate()
289            .map(|(i, &val)| {
290                T::from(val).ok_or_else(|| {
291                    StatsError::conversion_error(format!(
292                        "Failed to convert feature value at index {} to type T",
293                        i
294                    ))
295                })
296            })
297            .collect();
298
299        Ok(self.predict_t(&x_cast?))
300    }
301
302    /// Calculate predictions for multiple observations
303    ///
304    /// # Arguments
305    /// * `x_values` - 2D array of feature values for prediction
306    ///
307    /// # Returns
308    /// * `StatsResult<Vec<T>>` - Vector of predicted y values
309    ///
310    /// # Errors
311    /// Returns `StatsError::NotFitted` if the model has not been fitted.
312    /// Returns an error if any prediction fails (dimension mismatch or conversion error).
313    ///
314    /// # Examples
315    /// ```
316    /// use rs_stats::regression::multiple_linear_regression::MultipleLinearRegression;
317    ///
318    /// let mut model = MultipleLinearRegression::<f64>::new();
319    /// let x = vec![
320    ///     vec![1.0, 2.0],
321    ///     vec![2.0, 1.0],
322    ///     vec![3.0, 3.0],
323    ///     vec![4.0, 2.0],
324    /// ];
325    /// let y = vec![5.0, 4.0, 9.0, 8.0];
326    /// model.fit(&x, &y).unwrap();
327    ///
328    /// let predictions = model.predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]]).unwrap();
329    /// assert_eq!(predictions.len(), 2);
330    /// ```
331    pub fn predict_many<U>(&self, x_values: &[Vec<U>]) -> StatsResult<Vec<T>>
332    where
333        U: NumCast + Copy,
334    {
335        x_values.iter().map(|x| self.predict(x)).collect()
336    }
337
338    /// Save the model to a file
339    ///
340    /// # Arguments
341    /// * `path` - Path where to save the model
342    ///
343    /// # Returns
344    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
345    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
346        let file = File::create(path)?;
347        // Use JSON format for human-readability
348        serde_json::to_writer(file, self).map_err(io::Error::other)
349    }
350
351    /// Save the model in binary format
352    ///
353    /// # Arguments
354    /// * `path` - Path where to save the model
355    ///
356    /// # Returns
357    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
358    pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
359        let file = File::create(path)?;
360        // Use bincode for more compact binary format
361        bincode::serialize_into(file, self).map_err(io::Error::other)
362    }
363
364    /// Load a model from a file
365    ///
366    /// # Arguments
367    /// * `path` - Path to the saved model file
368    ///
369    /// # Returns
370    /// * `Result<Self, io::Error>` - Loaded model or IO error
371    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
372        let file = File::open(path)?;
373        // Try to load as JSON format
374        serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
375    }
376
377    /// Load a model from a binary file
378    ///
379    /// # Arguments
380    /// * `path` - Path to the saved model file
381    ///
382    /// # Returns
383    /// * `Result<Self, io::Error>` - Loaded model or IO error
384    pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
385        let file = File::open(path)?;
386        // Try to load as bincode format
387        bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
388    }
389
390    /// Save the model to a string in JSON format
391    ///
392    /// # Returns
393    /// * `Result<String, String>` - JSON string representation or error message
394    pub fn to_json(&self) -> Result<String, String> {
395        serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
396    }
397
398    /// Load a model from a JSON string
399    ///
400    /// # Arguments
401    /// * `json` - JSON string containing the model data
402    ///
403    /// # Returns
404    /// * `Result<Self, String>` - Loaded model or error message
405    pub fn from_json(json: &str) -> Result<Self, String> {
406        serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
407    }
408
409    // Helper function: Matrix multiplication where one matrix is transposed: A^T * B
410    fn matrix_multiply_transpose(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Vec<Vec<T>> {
411        let a_rows = a.len();
412        let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
413        let b_rows = b.len();
414        let b_cols = if b_rows > 0 { b[0].len() } else { 0 };
415
416        // Result will be a_cols × b_cols
417        let mut result = vec![vec![T::zero(); b_cols]; a_cols];
418
419        for (i, result_row) in result.iter_mut().enumerate().take(a_cols) {
420            for (j, result_elem) in result_row.iter_mut().enumerate().take(b_cols) {
421                let mut sum = T::zero();
422                for k in 0..a_rows {
423                    sum = sum + (a[k][i] * b[k][j]);
424                }
425                *result_elem = sum;
426            }
427        }
428
429        result
430    }
431
432    // Helper function: Multiply transposed matrix by vector: A^T * y
433    fn vector_multiply_transpose(&self, a: &[Vec<T>], y: &[T]) -> Vec<T> {
434        let a_rows = a.len();
435        let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
436
437        let mut result = vec![T::zero(); a_cols];
438
439        for (i, result_item) in result.iter_mut().enumerate().take(a_cols) {
440            let mut sum = T::zero();
441            for j in 0..a_rows {
442                sum = sum + (a[j][i] * y[j]);
443            }
444            *result_item = sum;
445        }
446
447        result
448    }
449
450    // Helper function: Solve a system of linear equations using Gaussian elimination
451    fn solve_linear_system(&self, a: &[Vec<T>], b: &[T]) -> StatsResult<Vec<T>> {
452        let n = a.len();
453        if n == 0 || a[0].len() != n || b.len() != n {
454            return Err(StatsError::dimension_mismatch(format!(
455                "Invalid matrix dimensions for linear system solving: A is {}x{}, b has {} elements",
456                n,
457                if n > 0 { a[0].len() } else { 0 },
458                b.len()
459            )));
460        }
461
462        // Create augmented matrix [A|b]
463        let mut aug = Vec::with_capacity(n);
464        for i in 0..n {
465            let mut row = a[i].clone();
466            row.push(b[i]);
467            aug.push(row);
468        }
469
470        // Gaussian elimination with partial pivoting
471        for i in 0..n {
472            // Find pivot
473            let mut max_row = i;
474            let mut max_val = aug[i][i].abs();
475
476            for (j, row) in aug.iter().enumerate().skip(i + 1).take(n - (i + 1)) {
477                let abs_val = row[i].abs();
478                if abs_val > max_val {
479                    max_row = j;
480                    max_val = abs_val;
481                }
482            }
483
484            let epsilon: T = T::from(1e-10).ok_or_else(|| {
485                StatsError::conversion_error("Failed to convert epsilon (1e-10) to type T")
486            })?;
487            if max_val < epsilon {
488                return Err(StatsError::mathematical_error(
489                    "Matrix is singular or near-singular, cannot solve linear system",
490                ));
491            }
492
493            // Swap rows if needed
494            if max_row != i {
495                aug.swap(i, max_row);
496            }
497
498            // Eliminate below
499            for j in (i + 1)..n {
500                let factor = aug[j][i] / aug[i][i];
501
502                for k in i..(n + 1) {
503                    aug[j][k] = aug[j][k] - (factor * aug[i][k]);
504                }
505            }
506        }
507
508        // Back substitution
509        let mut x = vec![T::zero(); n];
510        for i in (0..n).rev() {
511            let mut sum = aug[i][n];
512
513            for (j, &x_val) in x.iter().enumerate().skip(i + 1).take(n - (i + 1)) {
514                sum = sum - (aug[i][j] * x_val);
515            }
516
517            x[i] = sum / aug[i][i];
518        }
519
520        Ok(x)
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::utils::approx_equal;
528    use tempfile::tempdir;
529
530    #[test]
531    fn test_simple_multi_regression_f64() {
532        // Simple case: y = 2*x1 + 3*x2 + 1
533        let x = vec![
534            vec![1.0, 2.0],
535            vec![2.0, 1.0],
536            vec![3.0, 3.0],
537            vec![4.0, 2.0],
538        ];
539        let y = vec![9.0, 8.0, 16.0, 15.0];
540
541        let mut model = MultipleLinearRegression::<f64>::new();
542        let result = model.fit(&x, &y);
543
544        assert!(result.is_ok());
545        assert!(model.coefficients.len() == 3);
546        assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); // intercept
547        assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); // x1 coefficient
548        assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); // x2 coefficient
549        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
550    }
551
552    #[test]
553    fn test_simple_multi_regression_f32() {
554        // Simple case: y = 2*x1 + 3*x2 + 1
555        let x = vec![
556            vec![1.0f32, 2.0f32],
557            vec![2.0f32, 1.0f32],
558            vec![3.0f32, 3.0f32],
559            vec![4.0f32, 2.0f32],
560        ];
561        let y = vec![9.0f32, 8.0f32, 16.0f32, 15.0f32];
562
563        let mut model = MultipleLinearRegression::<f32>::new();
564        let result = model.fit(&x, &y);
565
566        assert!(result.is_ok());
567        assert!(model.coefficients.len() == 3);
568        assert!(approx_equal(model.coefficients[0], 1.0f32, Some(1e-4))); // intercept
569        assert!(approx_equal(model.coefficients[1], 2.0f32, Some(1e-4))); // x1 coefficient
570        assert!(approx_equal(model.coefficients[2], 3.0f32, Some(1e-4))); // x2 coefficient
571        assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-4)));
572    }
573
574    #[test]
575    fn test_integer_data() {
576        // Simple case: y = 2*x1 + 3*x2 + 1
577        let x = vec![
578            vec![1u32, 2u32],
579            vec![2u32, 1u32],
580            vec![3u32, 3u32],
581            vec![4u32, 2u32],
582        ];
583        let y = vec![9i32, 8i32, 16i32, 15i32];
584
585        let mut model = MultipleLinearRegression::<f64>::new();
586        let result = model.fit(&x, &y);
587
588        assert!(result.is_ok());
589        assert!(model.coefficients.len() == 3);
590        assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); // intercept
591        assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); // x1 coefficient
592        assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); // x2 coefficient
593        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
594    }
595
596    #[test]
597    fn test_prediction() {
598        // Simple case: y = 2*x1 + 3*x2 + 1
599        let x = vec![vec![1, 2], vec![2, 1], vec![3, 3], vec![4, 2]];
600        let y = vec![9, 8, 16, 15];
601
602        let mut model = MultipleLinearRegression::<f64>::new();
603        model.fit(&x, &y).unwrap();
604
605        // Test prediction: 1 + 2*5 + 3*4 = 1 + 10 + 12 = 23
606        assert!(approx_equal(
607            model.predict(&[5u32, 4u32]).unwrap(),
608            23.0,
609            Some(1e-6)
610        ));
611    }
612
613    #[test]
614    fn test_prediction_many() {
615        let x = vec![vec![1, 2], vec![2, 1], vec![3, 3]];
616        let y = vec![9, 8, 16];
617
618        let mut model = MultipleLinearRegression::<f64>::new();
619        model.fit(&x, &y).unwrap();
620
621        let new_x = vec![vec![1u32, 2u32], vec![5u32, 4u32]];
622
623        let predictions = model.predict_many(&new_x).unwrap();
624        assert_eq!(predictions.len(), 2);
625        assert!(approx_equal(predictions[0], 9.0, Some(1e-6)));
626        assert!(approx_equal(predictions[1], 23.0, Some(1e-6)));
627    }
628
629    #[test]
630    fn test_save_load_json() {
631        // Create a temporary directory
632        let dir = tempdir().unwrap();
633        let file_path = dir.path().join("model.json");
634
635        // Create and fit a model
636        let x = vec![
637            vec![1.0, 2.0],
638            vec![2.0, 1.0],
639            vec![3.0, 3.0],
640            vec![4.0, 2.0],
641        ];
642        let y = vec![9.0, 8.0, 16.0, 15.0];
643
644        let mut model = MultipleLinearRegression::<f64>::new();
645        model.fit(&x, &y).unwrap();
646
647        // Save the model
648        let save_result = model.save(&file_path);
649        assert!(save_result.is_ok());
650
651        // Load the model
652        let loaded_model = MultipleLinearRegression::<f64>::load(&file_path);
653        assert!(loaded_model.is_ok());
654        let loaded = loaded_model.unwrap();
655
656        // Check that the loaded model has the same parameters
657        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
658        for i in 0..model.coefficients.len() {
659            assert!(approx_equal(
660                loaded.coefficients[i],
661                model.coefficients[i],
662                Some(1e-6)
663            ));
664        }
665        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
666        assert_eq!(loaded.n, model.n);
667        assert_eq!(loaded.p, model.p);
668    }
669
670    #[test]
671    fn test_save_load_binary() {
672        // Create a temporary directory
673        let dir = tempdir().unwrap();
674        let file_path = dir.path().join("model.bin");
675
676        // Create and fit a model
677        let x = vec![
678            vec![1.0, 2.0],
679            vec![2.0, 1.0],
680            vec![3.0, 3.0],
681            vec![4.0, 2.0],
682        ];
683        let y = vec![9.0, 8.0, 16.0, 15.0];
684
685        let mut model = MultipleLinearRegression::<f64>::new();
686        model.fit(&x, &y).unwrap();
687
688        // Save the model
689        let save_result = model.save_binary(&file_path);
690        assert!(save_result.is_ok());
691
692        // Load the model
693        let loaded_model = MultipleLinearRegression::<f64>::load_binary(&file_path);
694        assert!(loaded_model.is_ok());
695        let loaded = loaded_model.unwrap();
696
697        // Check that the loaded model has the same parameters
698        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
699        for i in 0..model.coefficients.len() {
700            assert!(approx_equal(
701                loaded.coefficients[i],
702                model.coefficients[i],
703                Some(1e-6)
704            ));
705        }
706        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
707        assert_eq!(loaded.n, model.n);
708        assert_eq!(loaded.p, model.p);
709    }
710
711    #[test]
712    fn test_json_serialization() {
713        // Create and fit a model
714        let x = vec![
715            vec![1.0, 2.0],
716            vec![2.0, 1.0],
717            vec![3.0, 3.0],
718            vec![4.0, 2.0],
719        ];
720        let y = vec![9.0, 8.0, 16.0, 15.0];
721
722        let mut model = MultipleLinearRegression::<f64>::new();
723        model.fit(&x, &y).unwrap();
724
725        // Serialize to JSON string
726        let json_result = model.to_json();
727        assert!(json_result.is_ok());
728        let json_str = json_result.unwrap();
729
730        // Deserialize from JSON string
731        let loaded_model = MultipleLinearRegression::<f64>::from_json(&json_str);
732        assert!(loaded_model.is_ok());
733        let loaded = loaded_model.unwrap();
734
735        // Check that the loaded model has the same parameters
736        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
737        for i in 0..model.coefficients.len() {
738            assert!(approx_equal(
739                loaded.coefficients[i],
740                model.coefficients[i],
741                Some(1e-6)
742            ));
743        }
744        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
745        assert_eq!(loaded.n, model.n);
746        assert_eq!(loaded.p, model.p);
747    }
748
749    #[test]
750    fn test_predict_not_fitted() {
751        // Test that predict() works even when model is not fitted
752        let model = MultipleLinearRegression::<f64>::new();
753        // Don't fit the model
754
755        // Predict should return an error when model is not fitted
756        let features = vec![1.0, 2.0];
757        let result = model.predict(&features);
758        assert!(result.is_err());
759        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
760    }
761
762    #[test]
763    fn test_predict_dimension_mismatch() {
764        // Test predict with wrong number of features
765        let mut model = MultipleLinearRegression::<f64>::new();
766        // Use more data points to avoid singular matrix
767        let x = vec![
768            vec![1.0, 2.0],
769            vec![2.0, 1.0],
770            vec![3.0, 3.0],
771            vec![4.0, 2.0],
772        ];
773        let y = vec![3.0, 3.0, 6.0, 6.0];
774        model.fit(&x, &y).unwrap();
775
776        // Try to predict with wrong number of features
777        let wrong_features = vec![1.0]; // Should be 2 features
778        let result = model.predict(&wrong_features);
779        // predict returns error when dimension mismatch
780        assert!(result.is_err());
781        assert!(matches!(
782            result.unwrap_err(),
783            StatsError::DimensionMismatch { .. }
784        ));
785    }
786
787    #[test]
788    fn test_fit_singular_matrix() {
789        // Test with linearly dependent features (singular matrix)
790        // This should trigger a mathematical error
791        let x = vec![
792            vec![1.0, 2.0, 3.0], // Feature 3 = Feature 1 + Feature 2 (linearly dependent)
793            vec![2.0, 4.0, 6.0], // Feature 3 = 2 * (Feature 1 + Feature 2)
794            vec![3.0, 6.0, 9.0], // Feature 3 = 3 * (Feature 1 + Feature 2)
795        ];
796        let y = vec![1.0, 2.0, 3.0];
797
798        let mut model = MultipleLinearRegression::<f64>::new();
799        let result = model.fit(&x, &y);
800        // This might succeed or fail depending on numerical precision
801        // The important thing is it doesn't panic
802        match result {
803            Ok(_) => {
804                // If it succeeds, verify the model is valid
805                assert!(!model.coefficients.is_empty());
806            }
807            Err(e) => {
808                // If it fails, it should be a mathematical error
809                assert!(matches!(e, StatsError::MathematicalError { .. }));
810            }
811        }
812    }
813
814    #[test]
815    fn test_save_invalid_path() {
816        // Test saving to an invalid path
817        let mut model = MultipleLinearRegression::<f64>::new();
818        let x = vec![vec![1.0], vec![2.0]];
819        let y = vec![2.0, 4.0];
820        model.fit(&x, &y).unwrap();
821
822        let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
823        let result = model.save(invalid_path);
824        assert!(
825            result.is_err(),
826            "Saving to invalid path should return error"
827        );
828    }
829
830    #[test]
831    fn test_load_nonexistent_file() {
832        // Test loading a non-existent file
833        let nonexistent_path = std::path::Path::new("/nonexistent/file.json");
834        let result = MultipleLinearRegression::<f64>::load(nonexistent_path);
835        assert!(
836            result.is_err(),
837            "Loading non-existent file should return error"
838        );
839    }
840
841    #[test]
842    fn test_from_json_invalid() {
843        // Test deserializing invalid JSON string
844        let invalid_json = "not valid json";
845        let result = MultipleLinearRegression::<f64>::from_json(invalid_json);
846        assert!(
847            result.is_err(),
848            "Deserializing invalid JSON should return error"
849        );
850    }
851
852    #[test]
853    fn test_predict_t_coefficients_empty() {
854        // Test predict_t when coefficients are empty
855        let model = MultipleLinearRegression::<f64>::new();
856        let features = vec![1.0, 2.0];
857        // predict_t is private, but we can test through predict
858        let result = model.predict(&features);
859        assert!(result.is_err());
860        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
861    }
862
863    #[test]
864    fn test_fit_x_values_empty_after_check() {
865        // This tests the redundant check at line 94 (though it should never execute)
866        // But we test it to cover the branch
867        let mut model = MultipleLinearRegression::<f64>::new();
868        // This will fail at the first empty check, but tests the code path
869        let x: Vec<Vec<f64>> = vec![];
870        let y: Vec<f64> = vec![];
871        let result = model.fit(&x, &y);
872        assert!(result.is_err());
873    }
874
875    #[test]
876    fn test_predict_many_not_fitted() {
877        // Test predict_many when model is not fitted
878        let model = MultipleLinearRegression::<f64>::new();
879        let result = model.predict_many(&[vec![1.0, 2.0]]);
880        assert!(result.is_err());
881        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
882    }
883
884    #[test]
885    fn test_predict_many_dimension_mismatch() {
886        // Test predict_many with wrong number of features
887        let mut model = MultipleLinearRegression::<f64>::new();
888        let x = vec![vec![1.0, 2.0], vec![2.0, 1.0], vec![3.0, 3.0]];
889        let y = vec![3.0, 3.0, 6.0];
890        model.fit(&x, &y).unwrap();
891
892        // Try to predict with wrong number of features
893        let wrong_features = vec![vec![1.0]]; // Should be 2 features
894        let result = model.predict_many(&wrong_features);
895        assert!(result.is_err());
896        assert!(matches!(
897            result.unwrap_err(),
898            StatsError::DimensionMismatch { .. }
899        ));
900    }
901
902    #[test]
903    fn test_predict_many_success() {
904        // Test predict_many with valid data
905        let mut model = MultipleLinearRegression::<f64>::new();
906        let x = vec![
907            vec![1.0, 2.0],
908            vec![2.0, 1.0],
909            vec![3.0, 3.0],
910            vec![4.0, 2.0],
911        ];
912        let y = vec![3.0, 3.0, 6.0, 6.0];
913        model.fit(&x, &y).unwrap();
914
915        let predictions = model
916            .predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]])
917            .unwrap();
918        assert_eq!(predictions.len(), 2);
919    }
920}