Skip to main content

rs_stats/regression/
linear_regression.rs

1// src/regression/linear_regression.rs
2
3use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use rayon::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::fmt::Debug;
8use std::fs::File;
9use std::io::{self};
10use std::path::Path;
11
12/// Linear regression model that fits a line to data points.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct LinearRegression<T = f64>
15where
16    T: Float + Debug + Default + Serialize,
17{
18    /// Slope of the regression line (coefficient of x)
19    pub slope: T,
20    /// Y-intercept of the regression line
21    pub intercept: T,
22    /// Coefficient of determination (R²) - goodness of fit
23    pub r_squared: T,
24    /// Standard error of the estimate
25    pub standard_error: T,
26    /// Number of data points used for regression
27    pub n: usize,
28}
29
30impl<T> Default for LinearRegression<T>
31where
32    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
33{
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl<T> LinearRegression<T>
40where
41    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
42{
43    /// Create a new linear regression model without fitting any data
44    pub fn new() -> Self {
45        Self {
46            slope: T::zero(),
47            intercept: T::zero(),
48            r_squared: T::zero(),
49            standard_error: T::zero(),
50            n: 0,
51        }
52    }
53
54    /// Fit a linear model to the provided x and y data points
55    ///
56    /// # Arguments
57    /// * `x_values` - Independent variable values
58    /// * `y_values` - Dependent variable values (observations)
59    ///
60    /// # Returns
61    /// * `StatsResult<()>` - Ok if successful, Err with StatsError if the inputs are invalid
62    ///
63    /// # Errors
64    /// Returns `StatsError::DimensionMismatch` if X and Y arrays have different lengths.
65    /// Returns `StatsError::EmptyData` if the input arrays are empty.
66    /// Returns `StatsError::ConversionError` if value conversion fails.
67    /// Returns `StatsError::InvalidParameter` if there's no variance in X values.
68    pub fn fit<U, V>(&mut self, x_values: &[U], y_values: &[V]) -> StatsResult<()>
69    where
70        U: NumCast + Copy,
71        V: NumCast + Copy,
72    {
73        // Validate inputs
74        if x_values.len() != y_values.len() {
75            return Err(StatsError::dimension_mismatch(format!(
76                "X and Y arrays must have the same length (got {} and {})",
77                x_values.len(),
78                y_values.len()
79            )));
80        }
81
82        if x_values.is_empty() {
83            return Err(StatsError::empty_data(
84                "Cannot fit regression with empty arrays",
85            ));
86        }
87
88        let n = x_values.len();
89        self.n = n;
90
91        // Convert input arrays to T type
92        let x_cast: Vec<T> = x_values
93            .iter()
94            .enumerate()
95            .map(|(i, &x)| {
96                T::from(x).ok_or_else(|| {
97                    StatsError::conversion_error(format!(
98                        "Failed to cast X value at index {} to type T",
99                        i
100                    ))
101                })
102            })
103            .collect::<StatsResult<Vec<T>>>()?;
104
105        let y_cast: Vec<T> = y_values
106            .iter()
107            .enumerate()
108            .map(|(i, &y)| {
109                T::from(y).ok_or_else(|| {
110                    StatsError::conversion_error(format!(
111                        "Failed to cast Y value at index {} to type T",
112                        i
113                    ))
114                })
115            })
116            .collect::<StatsResult<Vec<T>>>()?;
117
118        // Calculate means
119        let n_as_t = T::from(n).ok_or_else(|| {
120            StatsError::conversion_error(format!("Failed to convert {} to type T", n))
121        })?;
122        let x_mean = x_cast.iter().fold(T::zero(), |acc, &x| acc + x) / n_as_t;
123        let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
124
125        // Calculate variance and covariance
126        let mut sum_xy = T::zero();
127        let mut sum_xx = T::zero();
128        let mut sum_yy = T::zero();
129
130        for i in 0..n {
131            let x_diff = x_cast[i] - x_mean;
132            let y_diff = y_cast[i] - y_mean;
133
134            sum_xy = sum_xy + (x_diff * y_diff);
135            sum_xx = sum_xx + (x_diff * x_diff);
136            sum_yy = sum_yy + (y_diff * y_diff);
137        }
138
139        // Check if there's any variance in x
140        if sum_xx == T::zero() {
141            return Err(StatsError::invalid_parameter(
142                "No variance in X values, cannot fit regression line",
143            ));
144        }
145
146        // Calculate slope and intercept
147        self.slope = sum_xy / sum_xx;
148        self.intercept = y_mean - (self.slope * x_mean);
149
150        // Calculate R²
151        self.r_squared = (sum_xy * sum_xy) / (sum_xx * sum_yy);
152
153        // Calculate residuals and standard error
154        let mut sum_squared_residuals = T::zero();
155        for i in 0..n {
156            let predicted = self.predict_t(x_cast[i]);
157            let residual = y_cast[i] - predicted;
158            sum_squared_residuals = sum_squared_residuals + (residual * residual);
159        }
160
161        // Standard error of the estimate
162        if n > 2 {
163            let two = T::from(2)
164                .ok_or_else(|| StatsError::conversion_error("Failed to convert 2 to type T"))?;
165            let n_minus_two = n_as_t - two;
166            self.standard_error = (sum_squared_residuals / n_minus_two).sqrt();
167        } else {
168            self.standard_error = T::zero();
169        }
170
171        Ok(())
172    }
173
174    /// Predict y value for a given x using the fitted model (internal version with type T)
175    fn predict_t(&self, x: T) -> T {
176        self.intercept + (self.slope * x)
177    }
178
179    /// Predict y value for a given x using the fitted model
180    ///
181    /// # Arguments
182    /// * `x` - The x value to predict for
183    ///
184    /// # Returns
185    /// * `StatsResult<T>` - The predicted y value
186    ///
187    /// # Errors
188    /// Returns `StatsError::NotFitted` if the model has not been fitted (n == 0).
189    /// Returns `StatsError::ConversionError` if type conversion fails.
190    ///
191    /// # Examples
192    /// ```
193    /// use rs_stats::regression::linear_regression::LinearRegression;
194    ///
195    /// let mut model = LinearRegression::<f64>::new();
196    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
197    ///
198    /// let prediction = model.predict(4.0).unwrap();
199    /// assert!((prediction - 8.0).abs() < 1e-10);
200    /// ```
201    pub fn predict<U>(&self, x: U) -> StatsResult<T>
202    where
203        U: NumCast + Copy,
204    {
205        if self.n == 0 {
206            return Err(StatsError::not_fitted(
207                "Model has not been fitted. Call fit() before predicting.",
208            ));
209        }
210
211        let x_cast: T = T::from(x)
212            .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
213
214        Ok(self.predict_t(x_cast))
215    }
216
217    /// Calculate predictions for multiple x values
218    ///
219    /// # Arguments
220    /// * `x_values` - Slice of x values to predict for
221    ///
222    /// # Returns
223    /// * `StatsResult<Vec<T>>` - Vector of predicted y values
224    ///
225    /// # Errors
226    /// Returns `StatsError::NotFitted` if the model has not been fitted.
227    /// Returns `StatsError::ConversionError` if type conversion fails for any value.
228    ///
229    /// # Examples
230    /// ```
231    /// use rs_stats::regression::linear_regression::LinearRegression;
232    ///
233    /// let mut model = LinearRegression::<f64>::new();
234    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
235    ///
236    /// let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
237    /// assert_eq!(predictions.len(), 2);
238    /// ```
239    pub fn predict_many<U>(&self, x_values: &[U]) -> StatsResult<Vec<T>>
240    where
241        U: NumCast + Copy + Send + Sync,
242        T: Send + Sync,
243    {
244        x_values.par_iter().map(|&x| self.predict(x)).collect()
245    }
246
247    /// Calculate confidence intervals for the regression line
248    ///
249    /// # Arguments
250    /// * `x` - The x value to calculate confidence interval for
251    /// * `confidence_level` - Confidence level (0.95 for 95% confidence)
252    ///
253    /// # Returns
254    /// * `StatsResult<(T, T)>` - Tuple of (lower_bound, upper_bound), or an error if invalid
255    ///
256    /// # Errors
257    /// Returns `StatsError::InvalidInput` if there are fewer than 3 data points.
258    /// Returns `StatsError::InvalidParameter` if confidence level is not supported (only 0.90, 0.95, 0.99).
259    /// Returns `StatsError::ConversionError` if value conversion fails.
260    pub fn confidence_interval<U>(&self, x: U, confidence_level: f64) -> StatsResult<(T, T)>
261    where
262        U: NumCast + Copy,
263    {
264        if self.n < 3 {
265            return Err(StatsError::invalid_input(
266                "Need at least 3 data points to calculate confidence interval",
267            ));
268        }
269
270        let x_cast: T = T::from(x)
271            .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
272
273        // Get the t-critical value based on degrees of freedom and confidence level
274        // For simplicity, we'll use a normal approximation with standard errors
275        let z_score: T = match confidence_level {
276            0.90 => T::from(1.645).ok_or_else(|| {
277                StatsError::conversion_error("Failed to convert z-score 1.645 to type T")
278            })?,
279            0.95 => T::from(1.96).ok_or_else(|| {
280                StatsError::conversion_error("Failed to convert z-score 1.96 to type T")
281            })?,
282            0.99 => T::from(2.576).ok_or_else(|| {
283                StatsError::conversion_error("Failed to convert z-score 2.576 to type T")
284            })?,
285            _ => {
286                return Err(StatsError::invalid_parameter(format!(
287                    "Unsupported confidence level: {}. Supported values: 0.90, 0.95, 0.99",
288                    confidence_level
289                )));
290            }
291        };
292
293        let predicted = self.predict_t(x_cast);
294        let margin = z_score * self.standard_error;
295
296        Ok((predicted - margin, predicted + margin))
297    }
298
299    /// Get the correlation coefficient (r)
300    ///
301    /// The correlation coefficient ranges from -1 to 1, indicating the strength
302    /// and direction of the linear relationship between x and y.
303    ///
304    /// # Returns
305    /// * `StatsResult<T>` - The correlation coefficient
306    ///
307    /// # Errors
308    /// Returns `StatsError::NotFitted` if the model has not been fitted (n == 0).
309    ///
310    /// # Examples
311    /// ```
312    /// use rs_stats::regression::linear_regression::LinearRegression;
313    ///
314    /// let mut model = LinearRegression::<f64>::new();
315    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
316    ///
317    /// let r = model.correlation_coefficient().unwrap();
318    /// assert!((r - 1.0).abs() < 1e-10); // Perfect positive correlation
319    /// ```
320    pub fn correlation_coefficient(&self) -> StatsResult<T> {
321        if self.n == 0 {
322            return Err(StatsError::not_fitted(
323                "Model has not been fitted. Call fit() before getting correlation coefficient.",
324            ));
325        }
326        let r = self.r_squared.sqrt();
327        Ok(if self.slope >= T::zero() { r } else { -r })
328    }
329
330    /// Save the model to a file
331    ///
332    /// # Arguments
333    /// * `path` - Path where to save the model
334    ///
335    /// # Returns
336    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
337    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
338        let file = File::create(path)?;
339        // Use JSON format for human-readability
340        serde_json::to_writer(file, self).map_err(io::Error::other)
341    }
342
343    /// Save the model in binary format
344    ///
345    /// # Arguments
346    /// * `path` - Path where to save the model
347    ///
348    /// # Returns
349    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
350    pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
351        let file = File::create(path)?;
352        // Use bincode for more compact binary format
353        bincode::serialize_into(file, self).map_err(io::Error::other)
354    }
355
356    /// Load a model from a file
357    ///
358    /// # Arguments
359    /// * `path` - Path to the saved model file
360    ///
361    /// # Returns
362    /// * `Result<Self, io::Error>` - Loaded model or IO error
363    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
364        let file = File::open(path)?;
365        // Try to load as JSON format
366        serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
367    }
368
369    /// Load a model from a binary file
370    ///
371    /// # Arguments
372    /// * `path` - Path to the saved model file
373    ///
374    /// # Returns
375    /// * `Result<Self, io::Error>` - Loaded model or IO error
376    pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
377        let file = File::open(path)?;
378        // Try to load as bincode format
379        bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
380    }
381
382    /// Save the model to a string in JSON format
383    ///
384    /// # Returns
385    /// * `Result<String, String>` - JSON string representation or error message
386    pub fn to_json(&self) -> Result<String, String> {
387        serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
388    }
389
390    /// Load a model from a JSON string
391    ///
392    /// # Arguments
393    /// * `json` - JSON string containing the model data
394    ///
395    /// # Returns
396    /// * `Result<Self, String>` - Loaded model or error message
397    pub fn from_json(json: &str) -> Result<Self, String> {
398        serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::utils::approx_equal;
406    use tempfile::tempdir;
407
408    #[test]
409    fn test_simple_regression_f64() {
410        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
411        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
412
413        let mut model = LinearRegression::<f64>::new();
414        let result = model.fit(&x, &y);
415
416        assert!(result.is_ok());
417        assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
418        assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
419        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
420    }
421
422    #[test]
423    fn test_simple_regression_f32() {
424        let x = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32];
425        let y = vec![2.0f32, 4.0f32, 6.0f32, 8.0f32, 10.0f32];
426
427        let mut model = LinearRegression::<f32>::new();
428        let result = model.fit(&x, &y);
429
430        assert!(result.is_ok());
431        assert!(approx_equal(model.slope, 2.0f32, Some(1e-6)));
432        assert!(approx_equal(model.intercept, 0.0f32, Some(1e-6)));
433        assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-6)));
434    }
435
436    #[test]
437    fn test_integer_data() {
438        let x = vec![1, 2, 3, 4, 5];
439        let y = vec![2, 4, 6, 8, 10];
440
441        let mut model = LinearRegression::<f64>::new();
442        let result = model.fit(&x, &y);
443
444        assert!(result.is_ok());
445        assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
446        assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
447        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
448    }
449
450    #[test]
451    fn test_mixed_types() {
452        let x = vec![1u32, 2u32, 3u32, 4u32, 5u32];
453        let y = vec![2.1, 3.9, 6.2, 7.8, 10.1];
454
455        let mut model = LinearRegression::<f64>::new();
456        let result = model.fit(&x, &y);
457
458        assert!(result.is_ok());
459        assert!(model.slope > 1.9 && model.slope < 2.1);
460        assert!(model.intercept > -0.1 && model.intercept < 0.1);
461        assert!(model.r_squared > 0.99);
462    }
463
464    #[test]
465    fn test_prediction() {
466        let x = vec![1, 2, 3, 4, 5];
467        let y = vec![2, 4, 6, 8, 10];
468
469        let mut model = LinearRegression::<f64>::new();
470        model.fit(&x, &y).unwrap();
471
472        assert!(approx_equal(model.predict(6u32).unwrap(), 12.0, Some(1e-6)));
473        assert!(approx_equal(model.predict(0i32).unwrap(), 0.0, Some(1e-6)));
474    }
475
476    #[test]
477    fn test_invalid_inputs() {
478        let x = vec![1, 2, 3];
479        let y = vec![2, 4];
480
481        let mut model = LinearRegression::<f64>::new();
482        let result = model.fit(&x, &y);
483
484        assert!(result.is_err());
485    }
486
487    #[test]
488    fn test_constant_x() {
489        let x = vec![1, 1, 1];
490        let y = vec![2, 3, 4];
491
492        let mut model = LinearRegression::<f64>::new();
493        let result = model.fit(&x, &y);
494
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_save_load_json() {
500        // Create a temporary directory
501        let dir = tempdir().unwrap();
502        let file_path = dir.path().join("model.json");
503
504        // Create and fit a model
505        let mut model = LinearRegression::<f64>::new();
506        model
507            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
508            .unwrap();
509
510        // Save the model
511        let save_result = model.save(&file_path);
512        assert!(save_result.is_ok());
513
514        // Load the model
515        let loaded_model = LinearRegression::<f64>::load(&file_path);
516        assert!(loaded_model.is_ok());
517        let loaded = loaded_model.unwrap();
518
519        // Check that the loaded model has the same parameters
520        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
521        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
522        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
523        assert_eq!(loaded.n, model.n);
524    }
525
526    #[test]
527    fn test_save_load_binary() {
528        // Create a temporary directory
529        let dir = tempdir().unwrap();
530        let file_path = dir.path().join("model.bin");
531
532        // Create and fit a model
533        let mut model = LinearRegression::<f64>::new();
534        model
535            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
536            .unwrap();
537
538        // Save the model
539        let save_result = model.save_binary(&file_path);
540        assert!(save_result.is_ok());
541
542        // Load the model
543        let loaded_model = LinearRegression::<f64>::load_binary(&file_path);
544        assert!(loaded_model.is_ok());
545        let loaded = loaded_model.unwrap();
546
547        // Check that the loaded model has the same parameters
548        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
549        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
550        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
551        assert_eq!(loaded.n, model.n);
552    }
553
554    #[test]
555    fn test_json_serialization() {
556        // Create and fit a model
557        let mut model = LinearRegression::<f64>::new();
558        model
559            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
560            .unwrap();
561
562        // Serialize to JSON string
563        let json_result = model.to_json();
564        assert!(json_result.is_ok());
565        let json_str = json_result.unwrap();
566
567        // Deserialize from JSON string
568        let loaded_model = LinearRegression::<f64>::from_json(&json_str);
569        assert!(loaded_model.is_ok());
570        let loaded = loaded_model.unwrap();
571
572        // Check that the loaded model has the same parameters
573        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
574        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
575        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
576        assert_eq!(loaded.n, model.n);
577    }
578
579    #[test]
580    fn test_load_nonexistent_file() {
581        // Test loading from a file that doesn't exist
582        let result = LinearRegression::<f64>::load("/nonexistent/path/model.json");
583        assert!(result.is_err());
584    }
585
586    #[test]
587    fn test_load_binary_nonexistent_file() {
588        // Test loading from a binary file that doesn't exist
589        let result = LinearRegression::<f64>::load_binary("/nonexistent/path/model.bin");
590        assert!(result.is_err());
591    }
592
593    #[test]
594    fn test_from_json_invalid_json() {
595        // Test deserializing from invalid JSON
596        let invalid_json = "{invalid json}";
597        let result = LinearRegression::<f64>::from_json(invalid_json);
598        assert!(result.is_err());
599    }
600
601    #[test]
602    fn test_predict_when_not_fitted() {
603        // Test that predict returns an error when model is not fitted
604        let model = LinearRegression::<f64>::new();
605        let result = model.predict(5.0);
606        assert!(result.is_err());
607        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
608    }
609
610    #[test]
611    fn test_save_invalid_path() {
612        // Test saving to an invalid path (non-existent directory)
613        let mut model = LinearRegression::<f64>::new();
614        model.fit(&[1.0, 2.0], &[2.0, 4.0]).unwrap();
615
616        let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
617        let result = model.save(invalid_path);
618        assert!(
619            result.is_err(),
620            "Saving to invalid path should return error"
621        );
622    }
623
624    #[test]
625    fn test_fit_standard_error_n_less_than_or_equal_two() {
626        // Test the branch where n <= 2 (standard_error = 0)
627        let mut model = LinearRegression::<f64>::new();
628        let x = vec![1.0, 2.0];
629        let y = vec![2.0, 4.0];
630        model.fit(&x, &y).unwrap();
631
632        // When n = 2, standard_error should be 0
633        assert_eq!(model.standard_error, 0.0);
634    }
635
636    #[test]
637    fn test_fit_standard_error_n_greater_than_two() {
638        // Test the branch where n > 2 (standard_error calculated)
639        let mut model = LinearRegression::<f64>::new();
640        let x = vec![1.0, 2.0, 3.0];
641        let y = vec![2.0, 4.0, 6.0];
642        model.fit(&x, &y).unwrap();
643
644        // When n > 2, standard_error should be calculated
645        assert!(model.standard_error >= 0.0);
646    }
647
648    #[test]
649    fn test_confidence_interval_n_less_than_three() {
650        // Test confidence_interval with n < 3
651        let mut model = LinearRegression::<f64>::new();
652        let x = vec![1.0, 2.0];
653        let y = vec![2.0, 4.0];
654        model.fit(&x, &y).unwrap();
655
656        let result = model.confidence_interval(3.0, 0.95);
657        assert!(result.is_err());
658        assert!(matches!(
659            result.unwrap_err(),
660            StatsError::InvalidInput { .. }
661        ));
662    }
663
664    #[test]
665    fn test_confidence_interval_unsupported_level() {
666        // Test confidence_interval with unsupported confidence level
667        let mut model = LinearRegression::<f64>::new();
668        let x = vec![1.0, 2.0, 3.0, 4.0];
669        let y = vec![2.0, 4.0, 6.0, 8.0];
670        model.fit(&x, &y).unwrap();
671
672        let result = model.confidence_interval(3.0, 0.85);
673        assert!(result.is_err());
674        assert!(matches!(
675            result.unwrap_err(),
676            StatsError::InvalidParameter { .. }
677        ));
678    }
679
680    #[test]
681    fn test_confidence_interval_supported_levels() {
682        // Test all supported confidence levels
683        let mut model = LinearRegression::<f64>::new();
684        let x = vec![1.0, 2.0, 3.0, 4.0];
685        let y = vec![2.0, 4.0, 6.0, 8.0];
686        model.fit(&x, &y).unwrap();
687
688        for level in [0.90, 0.95, 0.99] {
689            let result = model.confidence_interval(3.0, level);
690            assert!(
691                result.is_ok(),
692                "Confidence level {} should be supported",
693                level
694            );
695            let (lower, upper) = result.unwrap();
696            assert!(lower <= upper, "Lower bound should be <= upper bound");
697        }
698    }
699
700    #[test]
701    fn test_correlation_coefficient_positive_slope() {
702        // Test correlation_coefficient with positive slope
703        let mut model = LinearRegression::<f64>::new();
704        let x = vec![1.0, 2.0, 3.0];
705        let y = vec![2.0, 4.0, 6.0];
706        model.fit(&x, &y).unwrap();
707
708        let r = model.correlation_coefficient().unwrap();
709        assert!(
710            r >= 0.0,
711            "Correlation should be positive for positive slope"
712        );
713    }
714
715    #[test]
716    fn test_correlation_coefficient_negative_slope() {
717        // Test correlation_coefficient with negative slope
718        let mut model = LinearRegression::<f64>::new();
719        let x = vec![1.0, 2.0, 3.0];
720        let y = vec![6.0, 4.0, 2.0];
721        model.fit(&x, &y).unwrap();
722
723        let r = model.correlation_coefficient().unwrap();
724        assert!(
725            r <= 0.0,
726            "Correlation should be negative for negative slope"
727        );
728    }
729
730    #[test]
731    fn test_correlation_coefficient_not_fitted() {
732        // Test correlation_coefficient when model is not fitted
733        let model = LinearRegression::<f64>::new();
734        let result = model.correlation_coefficient();
735        assert!(result.is_err());
736        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
737    }
738
739    #[test]
740    fn test_predict_many_not_fitted() {
741        // Test predict_many when model is not fitted
742        let model = LinearRegression::<f64>::new();
743        let result = model.predict_many(&[1.0, 2.0, 3.0]);
744        assert!(result.is_err());
745        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
746    }
747
748    #[test]
749    fn test_predict_many_success() {
750        // Test predict_many with valid data
751        let mut model = LinearRegression::<f64>::new();
752        model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
753
754        let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
755        assert_eq!(predictions.len(), 2);
756        assert!((predictions[0] - 8.0).abs() < 1e-10);
757        assert!((predictions[1] - 10.0).abs() < 1e-10);
758    }
759
760    #[test]
761    fn test_load_invalid_json() {
762        // Test loading invalid JSON
763        let dir = tempdir().unwrap();
764        let file_path = dir.path().join("invalid.json");
765
766        // Write invalid JSON
767        std::fs::write(&file_path, "invalid json content").unwrap();
768
769        let result = LinearRegression::<f64>::load(&file_path);
770        assert!(result.is_err(), "Loading invalid JSON should return error");
771    }
772
773    #[test]
774    fn test_from_json_invalid() {
775        // Test deserializing invalid JSON string
776        let invalid_json = "not valid json";
777        let result = LinearRegression::<f64>::from_json(invalid_json);
778        assert!(
779            result.is_err(),
780            "Deserializing invalid JSON should return error"
781        );
782    }
783}