Skip to main content

rs_stats/regression/
multiple_linear_regression.rs

1// src/regression/multiple_linear_regression.rs
2
3use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{self};
9use std::path::Path;
10
11/// Multiple linear regression model that fits a hyperplane to multivariate data points.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct MultipleLinearRegression<T = f64>
14where
15    T: Float + Debug + Default + Serialize,
16{
17    /// Regression coefficients, including intercept as the first element
18    pub coefficients: Vec<T>,
19    /// Coefficient of determination (R²) - goodness of fit
20    pub r_squared: T,
21    /// Adjusted R² which accounts for the number of predictors
22    pub adjusted_r_squared: T,
23    /// Standard error of the estimate
24    pub standard_error: T,
25    /// Number of data points used for regression
26    pub n: usize,
27    /// Number of predictor variables (excluding intercept)
28    pub p: usize,
29}
30
31impl<T> Default for MultipleLinearRegression<T>
32where
33    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
34{
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl<T> MultipleLinearRegression<T>
41where
42    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
43{
44    /// Create a new multiple linear regression model without fitting any data
45    pub fn new() -> Self {
46        Self {
47            coefficients: Vec::new(),
48            r_squared: T::zero(),
49            adjusted_r_squared: T::zero(),
50            standard_error: T::zero(),
51            n: 0,
52            p: 0,
53        }
54    }
55
56    /// Fit a multiple linear regression model to the provided data
57    ///
58    /// # Arguments
59    /// * `x_values` - 2D array where each row is an observation and each column is a predictor
60    /// * `y_values` - Dependent variable values (observations)
61    ///
62    /// # Returns
63    /// * `StatsResult<()>` - Ok if successful, Err with StatsError if the inputs are invalid
64    ///
65    /// # Errors
66    /// Returns `StatsError::EmptyData` if input arrays are empty.
67    /// Returns `StatsError::DimensionMismatch` if X and Y arrays have different lengths.
68    /// Returns `StatsError::InvalidInput` if rows in X have inconsistent lengths.
69    /// Returns `StatsError::ConversionError` if value conversion fails.
70    /// Returns `StatsError::MathematicalError` if the linear system cannot be solved.
71    pub fn fit<U, V>(&mut self, x_values: &[Vec<U>], y_values: &[V]) -> StatsResult<()>
72    where
73        U: NumCast + Copy,
74        V: NumCast + Copy,
75    {
76        // Validate inputs
77        if x_values.is_empty() || y_values.is_empty() {
78            return Err(StatsError::empty_data(
79                "Cannot fit regression with empty arrays",
80            ));
81        }
82
83        if x_values.len() != y_values.len() {
84            return Err(StatsError::dimension_mismatch(format!(
85                "Number of observations in X and Y must match (got {} and {})",
86                x_values.len(),
87                y_values.len()
88            )));
89        }
90
91        self.n = x_values.len();
92
93        // Check that all rows in x_values have the same length
94        if x_values.is_empty() {
95            return Err(StatsError::empty_data("X values array is empty"));
96        }
97
98        self.p = x_values[0].len();
99
100        for (i, row) in x_values.iter().enumerate() {
101            if row.len() != self.p {
102                return Err(StatsError::invalid_input(format!(
103                    "All rows in X must have the same number of features (row {} has {} features, expected {})",
104                    i,
105                    row.len(),
106                    self.p
107                )));
108            }
109        }
110
111        // Build the augmented design matrix `X' = [1 | X]` directly in
112        // row-major flat storage — n_samples × (p+1) — without going
113        // through an intermediate `Vec<Vec<T>>`. Saves n + 1 allocations
114        // and gives matrix_multiply_transpose contiguous row strides.
115        let m = self.p + 1;
116        let mut augmented_x: Vec<T> = Vec::with_capacity(self.n * m);
117        for (row_idx, row) in x_values.iter().enumerate() {
118            augmented_x.push(T::one()); // intercept column
119            for (col_idx, &x) in row.iter().enumerate() {
120                let cast = T::from(x).ok_or_else(|| {
121                    StatsError::conversion_error(format!(
122                        "Failed to cast X value at row {row_idx}, column {col_idx} to type T"
123                    ))
124                })?;
125                augmented_x.push(cast);
126            }
127        }
128
129        let y_cast: Vec<T> = y_values
130            .iter()
131            .enumerate()
132            .map(|(i, &y)| {
133                T::from(y).ok_or_else(|| {
134                    StatsError::conversion_error(format!(
135                        "Failed to cast Y value at index {i} to type T"
136                    ))
137                })
138            })
139            .collect::<StatsResult<Vec<T>>>()?;
140
141        // X^T · X — m × m matrix in flat row-major storage.
142        let xt_x = matrix_multiply_transpose_flat::<T>(&augmented_x, self.n, m);
143
144        // X^T · y — length-m vector.
145        let xt_y = vector_multiply_transpose_flat::<T>(&augmented_x, &y_cast, self.n, m);
146
147        // Solve the normal equations (X^T·X) · β = X^T·y.
148        self.coefficients = solve_linear_system_flat::<T>(&xt_x, &xt_y, m)?;
149
150        // Compute R² and standard error in a single pass over the rows.
151        let n_as_t = T::from(self.n).ok_or_else(|| {
152            StatsError::conversion_error(format!("Failed to convert {} to type T", self.n))
153        })?;
154        let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
155
156        let mut ss_total = T::zero();
157        let mut ss_residual = T::zero();
158
159        for i in 0..self.n {
160            // The features (excluding the intercept column) live at
161            // augmented_x[i*m + 1 .. i*m + m]; predict_t expects exactly
162            // those p values.
163            let row_start = i * m + 1;
164            let row_end = i * m + m;
165            let predicted = self.predict_t(&augmented_x[row_start..row_end]);
166            let residual = y_cast[i] - predicted;
167
168            ss_residual = ss_residual + (residual * residual);
169            let diff = y_cast[i] - y_mean;
170            ss_total = ss_total + (diff * diff);
171        }
172
173        // Calculate R² and adjusted R²
174        if ss_total > T::zero() {
175            self.r_squared = T::one() - (ss_residual / ss_total);
176
177            // Adjusted R² = 1 - [(1 - R²) * (n - 1) / (n - p - 1)]
178            if self.n > self.p + 1 {
179                let n_minus_1 = T::from(self.n - 1).ok_or_else(|| {
180                    StatsError::conversion_error(format!(
181                        "Failed to convert {} to type T",
182                        self.n - 1
183                    ))
184                })?;
185                let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
186                    StatsError::conversion_error(format!(
187                        "Failed to convert {} to type T",
188                        self.n - self.p - 1
189                    ))
190                })?;
191
192                self.adjusted_r_squared =
193                    T::one() - ((T::one() - self.r_squared) * n_minus_1 / n_minus_p_minus_1);
194            }
195        }
196
197        // Calculate standard error
198        if self.n > self.p + 1 {
199            let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
200                StatsError::conversion_error(format!(
201                    "Failed to convert {} to type T",
202                    self.n - self.p - 1
203                ))
204            })?;
205            self.standard_error = (ss_residual / n_minus_p_minus_1).sqrt();
206        }
207
208        Ok(())
209    }
210
211    /// Predict y value for a given set of x values using the fitted model (internal version with type T)
212    fn predict_t(&self, x: &[T]) -> T {
213        if x.len() != self.p || self.coefficients.is_empty() {
214            return T::nan();
215        }
216
217        // First coefficient is the intercept
218        let mut result = self.coefficients[0];
219
220        // Add the weighted features
221        for (i, &xi) in x.iter().enumerate().take(self.p) {
222            result = result + (self.coefficients[i + 1] * xi);
223        }
224
225        result
226    }
227
228    /// Predict y value for a given set of x values using the fitted model
229    ///
230    /// # Arguments
231    /// * `x` - Vector of x values for prediction (must match the number of features used during fitting)
232    ///
233    /// # Returns
234    /// * `StatsResult<T>` - The predicted y value
235    ///
236    /// # Errors
237    /// Returns `StatsError::NotFitted` if the model has not been fitted (coefficients is empty).
238    /// Returns `StatsError::DimensionMismatch` if the number of features doesn't match the model (x.len() != p).
239    /// Returns `StatsError::ConversionError` if type conversion fails.
240    ///
241    /// # Examples
242    /// ```
243    /// use rs_stats::regression::multiple_linear_regression::MultipleLinearRegression;
244    ///
245    /// let mut model = MultipleLinearRegression::<f64>::new();
246    /// let x = vec![
247    ///     vec![1.0, 2.0],
248    ///     vec![2.0, 1.0],
249    ///     vec![3.0, 3.0],
250    ///     vec![4.0, 2.0],
251    /// ];
252    /// let y = vec![5.0, 4.0, 9.0, 8.0];
253    /// model.fit(&x, &y).unwrap();
254    ///
255    /// let prediction = model.predict(&[3.0, 4.0]).unwrap();
256    /// ```
257    pub fn predict<U>(&self, x: &[U]) -> StatsResult<T>
258    where
259        U: NumCast + Copy,
260    {
261        if self.coefficients.is_empty() {
262            return Err(StatsError::not_fitted(
263                "Model has not been fitted. Call fit() before predicting.",
264            ));
265        }
266
267        if x.len() != self.p {
268            return Err(StatsError::dimension_mismatch(format!(
269                "Expected {} features, but got {}",
270                self.p,
271                x.len()
272            )));
273        }
274
275        // Convert input to T type
276        let x_cast: StatsResult<Vec<T>> = x
277            .iter()
278            .enumerate()
279            .map(|(i, &val)| {
280                T::from(val).ok_or_else(|| {
281                    StatsError::conversion_error(format!(
282                        "Failed to convert feature value at index {} to type T",
283                        i
284                    ))
285                })
286            })
287            .collect();
288
289        Ok(self.predict_t(&x_cast?))
290    }
291
292    /// Calculate predictions for multiple observations
293    ///
294    /// # Arguments
295    /// * `x_values` - 2D array of feature values for prediction
296    ///
297    /// # Returns
298    /// * `StatsResult<Vec<T>>` - Vector of predicted y values
299    ///
300    /// # Errors
301    /// Returns `StatsError::NotFitted` if the model has not been fitted.
302    /// Returns an error if any prediction fails (dimension mismatch or conversion error).
303    ///
304    /// # Examples
305    /// ```
306    /// use rs_stats::regression::multiple_linear_regression::MultipleLinearRegression;
307    ///
308    /// let mut model = MultipleLinearRegression::<f64>::new();
309    /// let x = vec![
310    ///     vec![1.0, 2.0],
311    ///     vec![2.0, 1.0],
312    ///     vec![3.0, 3.0],
313    ///     vec![4.0, 2.0],
314    /// ];
315    /// let y = vec![5.0, 4.0, 9.0, 8.0];
316    /// model.fit(&x, &y).unwrap();
317    ///
318    /// let predictions = model.predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]]).unwrap();
319    /// assert_eq!(predictions.len(), 2);
320    /// ```
321    pub fn predict_many<U>(&self, x_values: &[Vec<U>]) -> StatsResult<Vec<T>>
322    where
323        U: NumCast + Copy,
324    {
325        x_values.iter().map(|x| self.predict(x)).collect()
326    }
327
328    /// Save the model to a file
329    ///
330    /// # Arguments
331    /// * `path` - Path where to save the model
332    ///
333    /// # Returns
334    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
335    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
336        let file = File::create(path)?;
337        // Use JSON format for human-readability
338        serde_json::to_writer(file, self).map_err(io::Error::other)
339    }
340
341    /// Save the model in binary format
342    ///
343    /// # Arguments
344    /// * `path` - Path where to save the model
345    ///
346    /// # Returns
347    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
348    pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
349        let file = File::create(path)?;
350        // Use bincode for more compact binary format
351        bincode::serialize_into(file, self).map_err(io::Error::other)
352    }
353
354    /// Load a model from a file
355    ///
356    /// # Arguments
357    /// * `path` - Path to the saved model file
358    ///
359    /// # Returns
360    /// * `Result<Self, io::Error>` - Loaded model or IO error
361    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
362        let file = File::open(path)?;
363        // Try to load as JSON format
364        serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
365    }
366
367    /// Load a model from a binary file
368    ///
369    /// # Arguments
370    /// * `path` - Path to the saved model file
371    ///
372    /// # Returns
373    /// * `Result<Self, io::Error>` - Loaded model or IO error
374    pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
375        let file = File::open(path)?;
376        // Try to load as bincode format
377        bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
378    }
379
380    /// Save the model to a string in JSON format
381    ///
382    /// # Returns
383    /// * `Result<String, String>` - JSON string representation or error message
384    pub fn to_json(&self) -> Result<String, String> {
385        serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
386    }
387
388    /// Load a model from a JSON string
389    ///
390    /// # Arguments
391    /// * `json` - JSON string containing the model data
392    ///
393    /// # Returns
394    /// * `Result<Self, String>` - Loaded model or error message
395    pub fn from_json(json: &str) -> Result<Self, String> {
396        serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
397    }
398}
399
400// ── Flat row-major linear-algebra helpers ─────────────────────────────────────
401
402/// `Aᵀ · A` where `a` is row-major `n_rows × n_cols`. Returns the
403/// `n_cols × n_cols` Gram matrix in row-major flat storage.
404fn matrix_multiply_transpose_flat<T>(a: &[T], n_rows: usize, n_cols: usize) -> Vec<T>
405where
406    T: Float,
407{
408    let mut result = vec![T::zero(); n_cols * n_cols];
409    // Iterate over rows in the outer loop — both reads (a[k*n_cols + ..])
410    // are contiguous, and the inner write avoids column-stride misses.
411    for k in 0..n_rows {
412        let row_off = k * n_cols;
413        for i in 0..n_cols {
414            let a_ki = a[row_off + i];
415            if a_ki == T::zero() {
416                continue;
417            }
418            let dst_off = i * n_cols;
419            for j in 0..n_cols {
420                result[dst_off + j] = result[dst_off + j] + a_ki * a[row_off + j];
421            }
422        }
423    }
424    result
425}
426
427/// `Aᵀ · y` where `a` is row-major `n_rows × n_cols`.
428fn vector_multiply_transpose_flat<T>(a: &[T], y: &[T], n_rows: usize, n_cols: usize) -> Vec<T>
429where
430    T: Float,
431{
432    let mut result = vec![T::zero(); n_cols];
433    for k in 0..n_rows {
434        let row_off = k * n_cols;
435        let yk = y[k];
436        for i in 0..n_cols {
437            result[i] = result[i] + a[row_off + i] * yk;
438        }
439    }
440    result
441}
442
443/// Gaussian elimination with partial pivoting on a flat `n × n` matrix
444/// `a` and right-hand-side vector `b` of length `n`. Returns the
445/// solution `x` of `a · x = b`. The augmented buffer is built once.
446fn solve_linear_system_flat<T>(a: &[T], b: &[T], n: usize) -> StatsResult<Vec<T>>
447where
448    T: Float + Debug,
449{
450    if a.len() != n * n || b.len() != n {
451        return Err(StatsError::dimension_mismatch(format!(
452            "Invalid matrix dimensions for linear system solving: A is {n}×{n} ({} elems), b has {} elements",
453            a.len(),
454            b.len()
455        )));
456    }
457    let w = n + 1;
458    let mut aug: Vec<T> = Vec::with_capacity(n * w);
459    for i in 0..n {
460        aug.extend_from_slice(&a[i * n..(i + 1) * n]);
461        aug.push(b[i]);
462    }
463
464    let epsilon: T = T::from(1e-10).ok_or_else(|| {
465        StatsError::conversion_error("Failed to convert epsilon (1e-10) to type T")
466    })?;
467
468    for i in 0..n {
469        // Partial pivot
470        let mut max_row = i;
471        let mut max_val = aug[i * w + i].abs();
472        for j in (i + 1)..n {
473            let abs_val = aug[j * w + i].abs();
474            if abs_val > max_val {
475                max_row = j;
476                max_val = abs_val;
477            }
478        }
479        if max_val < epsilon {
480            return Err(StatsError::mathematical_error(
481                "Matrix is singular or near-singular, cannot solve linear system",
482            ));
483        }
484        if max_row != i {
485            for c in 0..w {
486                aug.swap(i * w + c, max_row * w + c);
487            }
488        }
489
490        // Eliminate below
491        let pivot = aug[i * w + i];
492        for j in (i + 1)..n {
493            let factor = aug[j * w + i] / pivot;
494            for k in i..w {
495                aug[j * w + k] = aug[j * w + k] - factor * aug[i * w + k];
496            }
497        }
498    }
499
500    // Back substitution
501    let mut x = vec![T::zero(); n];
502    for i in (0..n).rev() {
503        let mut sum = aug[i * w + n];
504        for j in (i + 1)..n {
505            sum = sum - aug[i * w + j] * x[j];
506        }
507        x[i] = sum / aug[i * w + i];
508    }
509    Ok(x)
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use crate::utils::approx_equal;
516    use tempfile::tempdir;
517
518    #[test]
519    fn test_simple_multi_regression_f64() {
520        // Simple case: y = 2*x1 + 3*x2 + 1
521        let x = vec![
522            vec![1.0, 2.0],
523            vec![2.0, 1.0],
524            vec![3.0, 3.0],
525            vec![4.0, 2.0],
526        ];
527        let y = vec![9.0, 8.0, 16.0, 15.0];
528
529        let mut model = MultipleLinearRegression::<f64>::new();
530        let result = model.fit(&x, &y);
531
532        assert!(result.is_ok());
533        assert!(model.coefficients.len() == 3);
534        assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); // intercept
535        assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); // x1 coefficient
536        assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); // x2 coefficient
537        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
538    }
539
540    #[test]
541    fn test_simple_multi_regression_f32() {
542        // Simple case: y = 2*x1 + 3*x2 + 1
543        let x = vec![
544            vec![1.0f32, 2.0f32],
545            vec![2.0f32, 1.0f32],
546            vec![3.0f32, 3.0f32],
547            vec![4.0f32, 2.0f32],
548        ];
549        let y = vec![9.0f32, 8.0f32, 16.0f32, 15.0f32];
550
551        let mut model = MultipleLinearRegression::<f32>::new();
552        let result = model.fit(&x, &y);
553
554        assert!(result.is_ok());
555        assert!(model.coefficients.len() == 3);
556        assert!(approx_equal(model.coefficients[0], 1.0f32, Some(1e-4))); // intercept
557        assert!(approx_equal(model.coefficients[1], 2.0f32, Some(1e-4))); // x1 coefficient
558        assert!(approx_equal(model.coefficients[2], 3.0f32, Some(1e-4))); // x2 coefficient
559        assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-4)));
560    }
561
562    #[test]
563    fn test_integer_data() {
564        // Simple case: y = 2*x1 + 3*x2 + 1
565        let x = vec![
566            vec![1u32, 2u32],
567            vec![2u32, 1u32],
568            vec![3u32, 3u32],
569            vec![4u32, 2u32],
570        ];
571        let y = vec![9i32, 8i32, 16i32, 15i32];
572
573        let mut model = MultipleLinearRegression::<f64>::new();
574        let result = model.fit(&x, &y);
575
576        assert!(result.is_ok());
577        assert!(model.coefficients.len() == 3);
578        assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); // intercept
579        assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); // x1 coefficient
580        assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); // x2 coefficient
581        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
582    }
583
584    #[test]
585    fn test_prediction() {
586        // Simple case: y = 2*x1 + 3*x2 + 1
587        let x = vec![vec![1, 2], vec![2, 1], vec![3, 3], vec![4, 2]];
588        let y = vec![9, 8, 16, 15];
589
590        let mut model = MultipleLinearRegression::<f64>::new();
591        model.fit(&x, &y).unwrap();
592
593        // Test prediction: 1 + 2*5 + 3*4 = 1 + 10 + 12 = 23
594        assert!(approx_equal(
595            model.predict(&[5u32, 4u32]).unwrap(),
596            23.0,
597            Some(1e-6)
598        ));
599    }
600
601    #[test]
602    fn test_prediction_many() {
603        let x = vec![vec![1, 2], vec![2, 1], vec![3, 3]];
604        let y = vec![9, 8, 16];
605
606        let mut model = MultipleLinearRegression::<f64>::new();
607        model.fit(&x, &y).unwrap();
608
609        let new_x = vec![vec![1u32, 2u32], vec![5u32, 4u32]];
610
611        let predictions = model.predict_many(&new_x).unwrap();
612        assert_eq!(predictions.len(), 2);
613        assert!(approx_equal(predictions[0], 9.0, Some(1e-6)));
614        assert!(approx_equal(predictions[1], 23.0, Some(1e-6)));
615    }
616
617    #[test]
618    fn test_save_load_json() {
619        // Create a temporary directory
620        let dir = tempdir().unwrap();
621        let file_path = dir.path().join("model.json");
622
623        // Create and fit a model
624        let x = vec![
625            vec![1.0, 2.0],
626            vec![2.0, 1.0],
627            vec![3.0, 3.0],
628            vec![4.0, 2.0],
629        ];
630        let y = vec![9.0, 8.0, 16.0, 15.0];
631
632        let mut model = MultipleLinearRegression::<f64>::new();
633        model.fit(&x, &y).unwrap();
634
635        // Save the model
636        let save_result = model.save(&file_path);
637        assert!(save_result.is_ok());
638
639        // Load the model
640        let loaded_model = MultipleLinearRegression::<f64>::load(&file_path);
641        assert!(loaded_model.is_ok());
642        let loaded = loaded_model.unwrap();
643
644        // Check that the loaded model has the same parameters
645        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
646        for i in 0..model.coefficients.len() {
647            assert!(approx_equal(
648                loaded.coefficients[i],
649                model.coefficients[i],
650                Some(1e-6)
651            ));
652        }
653        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
654        assert_eq!(loaded.n, model.n);
655        assert_eq!(loaded.p, model.p);
656    }
657
658    #[test]
659    fn test_save_load_binary() {
660        // Create a temporary directory
661        let dir = tempdir().unwrap();
662        let file_path = dir.path().join("model.bin");
663
664        // Create and fit a model
665        let x = vec![
666            vec![1.0, 2.0],
667            vec![2.0, 1.0],
668            vec![3.0, 3.0],
669            vec![4.0, 2.0],
670        ];
671        let y = vec![9.0, 8.0, 16.0, 15.0];
672
673        let mut model = MultipleLinearRegression::<f64>::new();
674        model.fit(&x, &y).unwrap();
675
676        // Save the model
677        let save_result = model.save_binary(&file_path);
678        assert!(save_result.is_ok());
679
680        // Load the model
681        let loaded_model = MultipleLinearRegression::<f64>::load_binary(&file_path);
682        assert!(loaded_model.is_ok());
683        let loaded = loaded_model.unwrap();
684
685        // Check that the loaded model has the same parameters
686        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
687        for i in 0..model.coefficients.len() {
688            assert!(approx_equal(
689                loaded.coefficients[i],
690                model.coefficients[i],
691                Some(1e-6)
692            ));
693        }
694        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
695        assert_eq!(loaded.n, model.n);
696        assert_eq!(loaded.p, model.p);
697    }
698
699    #[test]
700    fn test_json_serialization() {
701        // Create and fit a model
702        let x = vec![
703            vec![1.0, 2.0],
704            vec![2.0, 1.0],
705            vec![3.0, 3.0],
706            vec![4.0, 2.0],
707        ];
708        let y = vec![9.0, 8.0, 16.0, 15.0];
709
710        let mut model = MultipleLinearRegression::<f64>::new();
711        model.fit(&x, &y).unwrap();
712
713        // Serialize to JSON string
714        let json_result = model.to_json();
715        assert!(json_result.is_ok());
716        let json_str = json_result.unwrap();
717
718        // Deserialize from JSON string
719        let loaded_model = MultipleLinearRegression::<f64>::from_json(&json_str);
720        assert!(loaded_model.is_ok());
721        let loaded = loaded_model.unwrap();
722
723        // Check that the loaded model has the same parameters
724        assert_eq!(loaded.coefficients.len(), model.coefficients.len());
725        for i in 0..model.coefficients.len() {
726            assert!(approx_equal(
727                loaded.coefficients[i],
728                model.coefficients[i],
729                Some(1e-6)
730            ));
731        }
732        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
733        assert_eq!(loaded.n, model.n);
734        assert_eq!(loaded.p, model.p);
735    }
736
737    #[test]
738    fn test_predict_not_fitted() {
739        // Test that predict() works even when model is not fitted
740        let model = MultipleLinearRegression::<f64>::new();
741        // Don't fit the model
742
743        // Predict should return an error when model is not fitted
744        let features = vec![1.0, 2.0];
745        let result = model.predict(&features);
746        assert!(result.is_err());
747        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
748    }
749
750    #[test]
751    fn test_predict_dimension_mismatch() {
752        // Test predict with wrong number of features
753        let mut model = MultipleLinearRegression::<f64>::new();
754        // Use more data points to avoid singular matrix
755        let x = vec![
756            vec![1.0, 2.0],
757            vec![2.0, 1.0],
758            vec![3.0, 3.0],
759            vec![4.0, 2.0],
760        ];
761        let y = vec![3.0, 3.0, 6.0, 6.0];
762        model.fit(&x, &y).unwrap();
763
764        // Try to predict with wrong number of features
765        let wrong_features = vec![1.0]; // Should be 2 features
766        let result = model.predict(&wrong_features);
767        // predict returns error when dimension mismatch
768        assert!(result.is_err());
769        assert!(matches!(
770            result.unwrap_err(),
771            StatsError::DimensionMismatch { .. }
772        ));
773    }
774
775    #[test]
776    fn test_fit_singular_matrix() {
777        // Test with linearly dependent features (singular matrix)
778        // This should trigger a mathematical error
779        let x = vec![
780            vec![1.0, 2.0, 3.0], // Feature 3 = Feature 1 + Feature 2 (linearly dependent)
781            vec![2.0, 4.0, 6.0], // Feature 3 = 2 * (Feature 1 + Feature 2)
782            vec![3.0, 6.0, 9.0], // Feature 3 = 3 * (Feature 1 + Feature 2)
783        ];
784        let y = vec![1.0, 2.0, 3.0];
785
786        let mut model = MultipleLinearRegression::<f64>::new();
787        let result = model.fit(&x, &y);
788        // This might succeed or fail depending on numerical precision
789        // The important thing is it doesn't panic
790        match result {
791            Ok(_) => {
792                // If it succeeds, verify the model is valid
793                assert!(!model.coefficients.is_empty());
794            }
795            Err(e) => {
796                // If it fails, it should be a mathematical error
797                assert!(matches!(e, StatsError::MathematicalError { .. }));
798            }
799        }
800    }
801
802    #[test]
803    fn test_save_invalid_path() {
804        // Test saving to an invalid path
805        let mut model = MultipleLinearRegression::<f64>::new();
806        let x = vec![vec![1.0], vec![2.0]];
807        let y = vec![2.0, 4.0];
808        model.fit(&x, &y).unwrap();
809
810        let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
811        let result = model.save(invalid_path);
812        assert!(
813            result.is_err(),
814            "Saving to invalid path should return error"
815        );
816    }
817
818    #[test]
819    fn test_load_nonexistent_file() {
820        // Test loading a non-existent file
821        let nonexistent_path = std::path::Path::new("/nonexistent/file.json");
822        let result = MultipleLinearRegression::<f64>::load(nonexistent_path);
823        assert!(
824            result.is_err(),
825            "Loading non-existent file should return error"
826        );
827    }
828
829    #[test]
830    fn test_from_json_invalid() {
831        // Test deserializing invalid JSON string
832        let invalid_json = "not valid json";
833        let result = MultipleLinearRegression::<f64>::from_json(invalid_json);
834        assert!(
835            result.is_err(),
836            "Deserializing invalid JSON should return error"
837        );
838    }
839
840    #[test]
841    fn test_predict_t_coefficients_empty() {
842        // Test predict_t when coefficients are empty
843        let model = MultipleLinearRegression::<f64>::new();
844        let features = vec![1.0, 2.0];
845        // predict_t is private, but we can test through predict
846        let result = model.predict(&features);
847        assert!(result.is_err());
848        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
849    }
850
851    #[test]
852    fn test_fit_x_values_empty_after_check() {
853        // This tests the redundant check at line 94 (though it should never execute)
854        // But we test it to cover the branch
855        let mut model = MultipleLinearRegression::<f64>::new();
856        // This will fail at the first empty check, but tests the code path
857        let x: Vec<Vec<f64>> = vec![];
858        let y: Vec<f64> = vec![];
859        let result = model.fit(&x, &y);
860        assert!(result.is_err());
861    }
862
863    #[test]
864    fn test_predict_many_not_fitted() {
865        // Test predict_many when model is not fitted
866        let model = MultipleLinearRegression::<f64>::new();
867        let result = model.predict_many(&[vec![1.0, 2.0]]);
868        assert!(result.is_err());
869        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
870    }
871
872    #[test]
873    fn test_predict_many_dimension_mismatch() {
874        // Test predict_many with wrong number of features
875        let mut model = MultipleLinearRegression::<f64>::new();
876        let x = vec![vec![1.0, 2.0], vec![2.0, 1.0], vec![3.0, 3.0]];
877        let y = vec![3.0, 3.0, 6.0];
878        model.fit(&x, &y).unwrap();
879
880        // Try to predict with wrong number of features
881        let wrong_features = vec![vec![1.0]]; // Should be 2 features
882        let result = model.predict_many(&wrong_features);
883        assert!(result.is_err());
884        assert!(matches!(
885            result.unwrap_err(),
886            StatsError::DimensionMismatch { .. }
887        ));
888    }
889
890    #[test]
891    fn test_predict_many_success() {
892        // Test predict_many with valid data
893        let mut model = MultipleLinearRegression::<f64>::new();
894        let x = vec![
895            vec![1.0, 2.0],
896            vec![2.0, 1.0],
897            vec![3.0, 3.0],
898            vec![4.0, 2.0],
899        ];
900        let y = vec![3.0, 3.0, 6.0, 6.0];
901        model.fit(&x, &y).unwrap();
902
903        let predictions = model
904            .predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]])
905            .unwrap();
906        assert_eq!(predictions.len(), 2);
907    }
908}