rs_stats/regression/
linear_regression.rs

1// src/regression/linear_regression.rs
2
3use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{self};
9use std::path::Path;
10
11/// Linear regression model that fits a line to data points.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct LinearRegression<T = f64>
14where
15    T: Float + Debug + Default + Serialize,
16{
17    /// Slope of the regression line (coefficient of x)
18    pub slope: T,
19    /// Y-intercept of the regression line
20    pub intercept: T,
21    /// Coefficient of determination (R²) - goodness of fit
22    pub r_squared: T,
23    /// Standard error of the estimate
24    pub standard_error: T,
25    /// Number of data points used for regression
26    pub n: usize,
27}
28
29impl<T> Default for LinearRegression<T>
30where
31    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
32{
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl<T> LinearRegression<T>
39where
40    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
41{
42    /// Create a new linear regression model without fitting any data
43    pub fn new() -> Self {
44        Self {
45            slope: T::zero(),
46            intercept: T::zero(),
47            r_squared: T::zero(),
48            standard_error: T::zero(),
49            n: 0,
50        }
51    }
52
53    /// Fit a linear model to the provided x and y data points
54    ///
55    /// # Arguments
56    /// * `x_values` - Independent variable values
57    /// * `y_values` - Dependent variable values (observations)
58    ///
59    /// # Returns
60    /// * `StatsResult<()>` - Ok if successful, Err with StatsError if the inputs are invalid
61    ///
62    /// # Errors
63    /// Returns `StatsError::DimensionMismatch` if X and Y arrays have different lengths.
64    /// Returns `StatsError::EmptyData` if the input arrays are empty.
65    /// Returns `StatsError::ConversionError` if value conversion fails.
66    /// Returns `StatsError::InvalidParameter` if there's no variance in X values.
67    pub fn fit<U, V>(&mut self, x_values: &[U], y_values: &[V]) -> StatsResult<()>
68    where
69        U: NumCast + Copy,
70        V: NumCast + Copy,
71    {
72        // Validate inputs
73        if x_values.len() != y_values.len() {
74            return Err(StatsError::dimension_mismatch(format!(
75                "X and Y arrays must have the same length (got {} and {})",
76                x_values.len(),
77                y_values.len()
78            )));
79        }
80
81        if x_values.is_empty() {
82            return Err(StatsError::empty_data(
83                "Cannot fit regression with empty arrays",
84            ));
85        }
86
87        let n = x_values.len();
88        self.n = n;
89
90        // Convert input arrays to T type
91        let x_cast: Vec<T> = x_values
92            .iter()
93            .enumerate()
94            .map(|(i, &x)| {
95                T::from(x).ok_or_else(|| {
96                    StatsError::conversion_error(format!(
97                        "Failed to cast X value at index {} to type T",
98                        i
99                    ))
100                })
101            })
102            .collect::<StatsResult<Vec<T>>>()?;
103
104        let y_cast: Vec<T> = y_values
105            .iter()
106            .enumerate()
107            .map(|(i, &y)| {
108                T::from(y).ok_or_else(|| {
109                    StatsError::conversion_error(format!(
110                        "Failed to cast Y value at index {} to type T",
111                        i
112                    ))
113                })
114            })
115            .collect::<StatsResult<Vec<T>>>()?;
116
117        // Calculate means
118        let n_as_t = T::from(n).ok_or_else(|| {
119            StatsError::conversion_error(format!("Failed to convert {} to type T", n))
120        })?;
121        let x_mean = x_cast.iter().fold(T::zero(), |acc, &x| acc + x) / n_as_t;
122        let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
123
124        // Calculate variance and covariance
125        let mut sum_xy = T::zero();
126        let mut sum_xx = T::zero();
127        let mut sum_yy = T::zero();
128
129        for i in 0..n {
130            let x_diff = x_cast[i] - x_mean;
131            let y_diff = y_cast[i] - y_mean;
132
133            sum_xy = sum_xy + (x_diff * y_diff);
134            sum_xx = sum_xx + (x_diff * x_diff);
135            sum_yy = sum_yy + (y_diff * y_diff);
136        }
137
138        // Check if there's any variance in x
139        if sum_xx == T::zero() {
140            return Err(StatsError::invalid_parameter(
141                "No variance in X values, cannot fit regression line",
142            ));
143        }
144
145        // Calculate slope and intercept
146        self.slope = sum_xy / sum_xx;
147        self.intercept = y_mean - (self.slope * x_mean);
148
149        // Calculate R²
150        self.r_squared = (sum_xy * sum_xy) / (sum_xx * sum_yy);
151
152        // Calculate residuals and standard error
153        let mut sum_squared_residuals = T::zero();
154        for i in 0..n {
155            let predicted = self.predict_t(x_cast[i]);
156            let residual = y_cast[i] - predicted;
157            sum_squared_residuals = sum_squared_residuals + (residual * residual);
158        }
159
160        // Standard error of the estimate
161        if n > 2 {
162            let two = T::from(2)
163                .ok_or_else(|| StatsError::conversion_error("Failed to convert 2 to type T"))?;
164            let n_minus_two = n_as_t - two;
165            self.standard_error = (sum_squared_residuals / n_minus_two).sqrt();
166        } else {
167            self.standard_error = T::zero();
168        }
169
170        Ok(())
171    }
172
173    /// Predict y value for a given x using the fitted model (internal version with type T)
174    fn predict_t(&self, x: T) -> T {
175        self.intercept + (self.slope * x)
176    }
177
178    /// Predict y value for a given x using the fitted model
179    ///
180    /// # Arguments
181    /// * `x` - The x value to predict for
182    ///
183    /// # Returns
184    /// * `StatsResult<T>` - The predicted y value
185    ///
186    /// # Errors
187    /// Returns `StatsError::NotFitted` if the model has not been fitted (n == 0).
188    /// Returns `StatsError::ConversionError` if type conversion fails.
189    ///
190    /// # Examples
191    /// ```
192    /// use rs_stats::regression::linear_regression::LinearRegression;
193    ///
194    /// let mut model = LinearRegression::<f64>::new();
195    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
196    ///
197    /// let prediction = model.predict(4.0).unwrap();
198    /// assert!((prediction - 8.0).abs() < 1e-10);
199    /// ```
200    pub fn predict<U>(&self, x: U) -> StatsResult<T>
201    where
202        U: NumCast + Copy,
203    {
204        if self.n == 0 {
205            return Err(StatsError::not_fitted(
206                "Model has not been fitted. Call fit() before predicting.",
207            ));
208        }
209
210        let x_cast: T = T::from(x)
211            .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
212
213        Ok(self.predict_t(x_cast))
214    }
215
216    /// Calculate predictions for multiple x values
217    ///
218    /// # Arguments
219    /// * `x_values` - Slice of x values to predict for
220    ///
221    /// # Returns
222    /// * `StatsResult<Vec<T>>` - Vector of predicted y values
223    ///
224    /// # Errors
225    /// Returns `StatsError::NotFitted` if the model has not been fitted.
226    /// Returns `StatsError::ConversionError` if type conversion fails for any value.
227    ///
228    /// # Examples
229    /// ```
230    /// use rs_stats::regression::linear_regression::LinearRegression;
231    ///
232    /// let mut model = LinearRegression::<f64>::new();
233    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
234    ///
235    /// let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
236    /// assert_eq!(predictions.len(), 2);
237    /// ```
238    pub fn predict_many<U>(&self, x_values: &[U]) -> StatsResult<Vec<T>>
239    where
240        U: NumCast + Copy,
241    {
242        x_values.iter().map(|&x| self.predict(x)).collect()
243    }
244
245    /// Calculate confidence intervals for the regression line
246    ///
247    /// # Arguments
248    /// * `x` - The x value to calculate confidence interval for
249    /// * `confidence_level` - Confidence level (0.95 for 95% confidence)
250    ///
251    /// # Returns
252    /// * `StatsResult<(T, T)>` - Tuple of (lower_bound, upper_bound), or an error if invalid
253    ///
254    /// # Errors
255    /// Returns `StatsError::InvalidInput` if there are fewer than 3 data points.
256    /// Returns `StatsError::InvalidParameter` if confidence level is not supported (only 0.90, 0.95, 0.99).
257    /// Returns `StatsError::ConversionError` if value conversion fails.
258    pub fn confidence_interval<U>(&self, x: U, confidence_level: f64) -> StatsResult<(T, T)>
259    where
260        U: NumCast + Copy,
261    {
262        if self.n < 3 {
263            return Err(StatsError::invalid_input(
264                "Need at least 3 data points to calculate confidence interval",
265            ));
266        }
267
268        let x_cast: T = T::from(x)
269            .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
270
271        // Get the t-critical value based on degrees of freedom and confidence level
272        // For simplicity, we'll use a normal approximation with standard errors
273        let z_score: T = match confidence_level {
274            0.90 => T::from(1.645).ok_or_else(|| {
275                StatsError::conversion_error("Failed to convert z-score 1.645 to type T")
276            })?,
277            0.95 => T::from(1.96).ok_or_else(|| {
278                StatsError::conversion_error("Failed to convert z-score 1.96 to type T")
279            })?,
280            0.99 => T::from(2.576).ok_or_else(|| {
281                StatsError::conversion_error("Failed to convert z-score 2.576 to type T")
282            })?,
283            _ => {
284                return Err(StatsError::invalid_parameter(format!(
285                    "Unsupported confidence level: {}. Supported values: 0.90, 0.95, 0.99",
286                    confidence_level
287                )));
288            }
289        };
290
291        let predicted = self.predict_t(x_cast);
292        let margin = z_score * self.standard_error;
293
294        Ok((predicted - margin, predicted + margin))
295    }
296
297    /// Get the correlation coefficient (r)
298    ///
299    /// The correlation coefficient ranges from -1 to 1, indicating the strength
300    /// and direction of the linear relationship between x and y.
301    ///
302    /// # Returns
303    /// * `StatsResult<T>` - The correlation coefficient
304    ///
305    /// # Errors
306    /// Returns `StatsError::NotFitted` if the model has not been fitted (n == 0).
307    ///
308    /// # Examples
309    /// ```
310    /// use rs_stats::regression::linear_regression::LinearRegression;
311    ///
312    /// let mut model = LinearRegression::<f64>::new();
313    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
314    ///
315    /// let r = model.correlation_coefficient().unwrap();
316    /// assert!((r - 1.0).abs() < 1e-10); // Perfect positive correlation
317    /// ```
318    pub fn correlation_coefficient(&self) -> StatsResult<T> {
319        if self.n == 0 {
320            return Err(StatsError::not_fitted(
321                "Model has not been fitted. Call fit() before getting correlation coefficient.",
322            ));
323        }
324        let r = self.r_squared.sqrt();
325        Ok(if self.slope >= T::zero() { r } else { -r })
326    }
327
328    /// Save the model to a file
329    ///
330    /// # Arguments
331    /// * `path` - Path where to save the model
332    ///
333    /// # Returns
334    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
335    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
336        let file = File::create(path)?;
337        // Use JSON format for human-readability
338        serde_json::to_writer(file, self).map_err(io::Error::other)
339    }
340
341    /// Save the model in binary format
342    ///
343    /// # Arguments
344    /// * `path` - Path where to save the model
345    ///
346    /// # Returns
347    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
348    pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
349        let file = File::create(path)?;
350        // Use bincode for more compact binary format
351        bincode::serialize_into(file, self).map_err(io::Error::other)
352    }
353
354    /// Load a model from a file
355    ///
356    /// # Arguments
357    /// * `path` - Path to the saved model file
358    ///
359    /// # Returns
360    /// * `Result<Self, io::Error>` - Loaded model or IO error
361    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
362        let file = File::open(path)?;
363        // Try to load as JSON format
364        serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
365    }
366
367    /// Load a model from a binary file
368    ///
369    /// # Arguments
370    /// * `path` - Path to the saved model file
371    ///
372    /// # Returns
373    /// * `Result<Self, io::Error>` - Loaded model or IO error
374    pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
375        let file = File::open(path)?;
376        // Try to load as bincode format
377        bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
378    }
379
380    /// Save the model to a string in JSON format
381    ///
382    /// # Returns
383    /// * `Result<String, String>` - JSON string representation or error message
384    pub fn to_json(&self) -> Result<String, String> {
385        serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
386    }
387
388    /// Load a model from a JSON string
389    ///
390    /// # Arguments
391    /// * `json` - JSON string containing the model data
392    ///
393    /// # Returns
394    /// * `Result<Self, String>` - Loaded model or error message
395    pub fn from_json(json: &str) -> Result<Self, String> {
396        serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use crate::utils::approx_equal;
404    use tempfile::tempdir;
405
406    #[test]
407    fn test_simple_regression_f64() {
408        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
409        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
410
411        let mut model = LinearRegression::<f64>::new();
412        let result = model.fit(&x, &y);
413
414        assert!(result.is_ok());
415        assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
416        assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
417        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
418    }
419
420    #[test]
421    fn test_simple_regression_f32() {
422        let x = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32];
423        let y = vec![2.0f32, 4.0f32, 6.0f32, 8.0f32, 10.0f32];
424
425        let mut model = LinearRegression::<f32>::new();
426        let result = model.fit(&x, &y);
427
428        assert!(result.is_ok());
429        assert!(approx_equal(model.slope, 2.0f32, Some(1e-6)));
430        assert!(approx_equal(model.intercept, 0.0f32, Some(1e-6)));
431        assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-6)));
432    }
433
434    #[test]
435    fn test_integer_data() {
436        let x = vec![1, 2, 3, 4, 5];
437        let y = vec![2, 4, 6, 8, 10];
438
439        let mut model = LinearRegression::<f64>::new();
440        let result = model.fit(&x, &y);
441
442        assert!(result.is_ok());
443        assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
444        assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
445        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
446    }
447
448    #[test]
449    fn test_mixed_types() {
450        let x = vec![1u32, 2u32, 3u32, 4u32, 5u32];
451        let y = vec![2.1, 3.9, 6.2, 7.8, 10.1];
452
453        let mut model = LinearRegression::<f64>::new();
454        let result = model.fit(&x, &y);
455
456        assert!(result.is_ok());
457        assert!(model.slope > 1.9 && model.slope < 2.1);
458        assert!(model.intercept > -0.1 && model.intercept < 0.1);
459        assert!(model.r_squared > 0.99);
460    }
461
462    #[test]
463    fn test_prediction() {
464        let x = vec![1, 2, 3, 4, 5];
465        let y = vec![2, 4, 6, 8, 10];
466
467        let mut model = LinearRegression::<f64>::new();
468        model.fit(&x, &y).unwrap();
469
470        assert!(approx_equal(model.predict(6u32).unwrap(), 12.0, Some(1e-6)));
471        assert!(approx_equal(model.predict(0i32).unwrap(), 0.0, Some(1e-6)));
472    }
473
474    #[test]
475    fn test_invalid_inputs() {
476        let x = vec![1, 2, 3];
477        let y = vec![2, 4];
478
479        let mut model = LinearRegression::<f64>::new();
480        let result = model.fit(&x, &y);
481
482        assert!(result.is_err());
483    }
484
485    #[test]
486    fn test_constant_x() {
487        let x = vec![1, 1, 1];
488        let y = vec![2, 3, 4];
489
490        let mut model = LinearRegression::<f64>::new();
491        let result = model.fit(&x, &y);
492
493        assert!(result.is_err());
494    }
495
496    #[test]
497    fn test_save_load_json() {
498        // Create a temporary directory
499        let dir = tempdir().unwrap();
500        let file_path = dir.path().join("model.json");
501
502        // Create and fit a model
503        let mut model = LinearRegression::<f64>::new();
504        model
505            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
506            .unwrap();
507
508        // Save the model
509        let save_result = model.save(&file_path);
510        assert!(save_result.is_ok());
511
512        // Load the model
513        let loaded_model = LinearRegression::<f64>::load(&file_path);
514        assert!(loaded_model.is_ok());
515        let loaded = loaded_model.unwrap();
516
517        // Check that the loaded model has the same parameters
518        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
519        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
520        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
521        assert_eq!(loaded.n, model.n);
522    }
523
524    #[test]
525    fn test_save_load_binary() {
526        // Create a temporary directory
527        let dir = tempdir().unwrap();
528        let file_path = dir.path().join("model.bin");
529
530        // Create and fit a model
531        let mut model = LinearRegression::<f64>::new();
532        model
533            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
534            .unwrap();
535
536        // Save the model
537        let save_result = model.save_binary(&file_path);
538        assert!(save_result.is_ok());
539
540        // Load the model
541        let loaded_model = LinearRegression::<f64>::load_binary(&file_path);
542        assert!(loaded_model.is_ok());
543        let loaded = loaded_model.unwrap();
544
545        // Check that the loaded model has the same parameters
546        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
547        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
548        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
549        assert_eq!(loaded.n, model.n);
550    }
551
552    #[test]
553    fn test_json_serialization() {
554        // Create and fit a model
555        let mut model = LinearRegression::<f64>::new();
556        model
557            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
558            .unwrap();
559
560        // Serialize to JSON string
561        let json_result = model.to_json();
562        assert!(json_result.is_ok());
563        let json_str = json_result.unwrap();
564
565        // Deserialize from JSON string
566        let loaded_model = LinearRegression::<f64>::from_json(&json_str);
567        assert!(loaded_model.is_ok());
568        let loaded = loaded_model.unwrap();
569
570        // Check that the loaded model has the same parameters
571        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
572        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
573        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
574        assert_eq!(loaded.n, model.n);
575    }
576
577    #[test]
578    fn test_load_nonexistent_file() {
579        // Test loading from a file that doesn't exist
580        let result = LinearRegression::<f64>::load("/nonexistent/path/model.json");
581        assert!(result.is_err());
582    }
583
584    #[test]
585    fn test_load_binary_nonexistent_file() {
586        // Test loading from a binary file that doesn't exist
587        let result = LinearRegression::<f64>::load_binary("/nonexistent/path/model.bin");
588        assert!(result.is_err());
589    }
590
591    #[test]
592    fn test_from_json_invalid_json() {
593        // Test deserializing from invalid JSON
594        let invalid_json = "{invalid json}";
595        let result = LinearRegression::<f64>::from_json(invalid_json);
596        assert!(result.is_err());
597    }
598
599    #[test]
600    fn test_predict_when_not_fitted() {
601        // Test that predict returns an error when model is not fitted
602        let model = LinearRegression::<f64>::new();
603        let result = model.predict(5.0);
604        assert!(result.is_err());
605        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
606    }
607
608    #[test]
609    fn test_save_invalid_path() {
610        // Test saving to an invalid path (non-existent directory)
611        let mut model = LinearRegression::<f64>::new();
612        model.fit(&[1.0, 2.0], &[2.0, 4.0]).unwrap();
613
614        let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
615        let result = model.save(invalid_path);
616        assert!(
617            result.is_err(),
618            "Saving to invalid path should return error"
619        );
620    }
621
622    #[test]
623    fn test_fit_standard_error_n_less_than_or_equal_two() {
624        // Test the branch where n <= 2 (standard_error = 0)
625        let mut model = LinearRegression::<f64>::new();
626        let x = vec![1.0, 2.0];
627        let y = vec![2.0, 4.0];
628        model.fit(&x, &y).unwrap();
629
630        // When n = 2, standard_error should be 0
631        assert_eq!(model.standard_error, 0.0);
632    }
633
634    #[test]
635    fn test_fit_standard_error_n_greater_than_two() {
636        // Test the branch where n > 2 (standard_error calculated)
637        let mut model = LinearRegression::<f64>::new();
638        let x = vec![1.0, 2.0, 3.0];
639        let y = vec![2.0, 4.0, 6.0];
640        model.fit(&x, &y).unwrap();
641
642        // When n > 2, standard_error should be calculated
643        assert!(model.standard_error >= 0.0);
644    }
645
646    #[test]
647    fn test_confidence_interval_n_less_than_three() {
648        // Test confidence_interval with n < 3
649        let mut model = LinearRegression::<f64>::new();
650        let x = vec![1.0, 2.0];
651        let y = vec![2.0, 4.0];
652        model.fit(&x, &y).unwrap();
653
654        let result = model.confidence_interval(3.0, 0.95);
655        assert!(result.is_err());
656        assert!(matches!(
657            result.unwrap_err(),
658            StatsError::InvalidInput { .. }
659        ));
660    }
661
662    #[test]
663    fn test_confidence_interval_unsupported_level() {
664        // Test confidence_interval with unsupported confidence level
665        let mut model = LinearRegression::<f64>::new();
666        let x = vec![1.0, 2.0, 3.0, 4.0];
667        let y = vec![2.0, 4.0, 6.0, 8.0];
668        model.fit(&x, &y).unwrap();
669
670        let result = model.confidence_interval(3.0, 0.85);
671        assert!(result.is_err());
672        assert!(matches!(
673            result.unwrap_err(),
674            StatsError::InvalidParameter { .. }
675        ));
676    }
677
678    #[test]
679    fn test_confidence_interval_supported_levels() {
680        // Test all supported confidence levels
681        let mut model = LinearRegression::<f64>::new();
682        let x = vec![1.0, 2.0, 3.0, 4.0];
683        let y = vec![2.0, 4.0, 6.0, 8.0];
684        model.fit(&x, &y).unwrap();
685
686        for level in [0.90, 0.95, 0.99] {
687            let result = model.confidence_interval(3.0, level);
688            assert!(
689                result.is_ok(),
690                "Confidence level {} should be supported",
691                level
692            );
693            let (lower, upper) = result.unwrap();
694            assert!(lower <= upper, "Lower bound should be <= upper bound");
695        }
696    }
697
698    #[test]
699    fn test_correlation_coefficient_positive_slope() {
700        // Test correlation_coefficient with positive slope
701        let mut model = LinearRegression::<f64>::new();
702        let x = vec![1.0, 2.0, 3.0];
703        let y = vec![2.0, 4.0, 6.0];
704        model.fit(&x, &y).unwrap();
705
706        let r = model.correlation_coefficient().unwrap();
707        assert!(
708            r >= 0.0,
709            "Correlation should be positive for positive slope"
710        );
711    }
712
713    #[test]
714    fn test_correlation_coefficient_negative_slope() {
715        // Test correlation_coefficient with negative slope
716        let mut model = LinearRegression::<f64>::new();
717        let x = vec![1.0, 2.0, 3.0];
718        let y = vec![6.0, 4.0, 2.0];
719        model.fit(&x, &y).unwrap();
720
721        let r = model.correlation_coefficient().unwrap();
722        assert!(
723            r <= 0.0,
724            "Correlation should be negative for negative slope"
725        );
726    }
727
728    #[test]
729    fn test_correlation_coefficient_not_fitted() {
730        // Test correlation_coefficient when model is not fitted
731        let model = LinearRegression::<f64>::new();
732        let result = model.correlation_coefficient();
733        assert!(result.is_err());
734        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
735    }
736
737    #[test]
738    fn test_predict_many_not_fitted() {
739        // Test predict_many when model is not fitted
740        let model = LinearRegression::<f64>::new();
741        let result = model.predict_many(&[1.0, 2.0, 3.0]);
742        assert!(result.is_err());
743        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
744    }
745
746    #[test]
747    fn test_predict_many_success() {
748        // Test predict_many with valid data
749        let mut model = LinearRegression::<f64>::new();
750        model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
751
752        let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
753        assert_eq!(predictions.len(), 2);
754        assert!((predictions[0] - 8.0).abs() < 1e-10);
755        assert!((predictions[1] - 10.0).abs() < 1e-10);
756    }
757
758    #[test]
759    fn test_load_invalid_json() {
760        // Test loading invalid JSON
761        let dir = tempdir().unwrap();
762        let file_path = dir.path().join("invalid.json");
763
764        // Write invalid JSON
765        std::fs::write(&file_path, "invalid json content").unwrap();
766
767        let result = LinearRegression::<f64>::load(&file_path);
768        assert!(result.is_err(), "Loading invalid JSON should return error");
769    }
770
771    #[test]
772    fn test_from_json_invalid() {
773        // Test deserializing invalid JSON string
774        let invalid_json = "not valid json";
775        let result = LinearRegression::<f64>::from_json(invalid_json);
776        assert!(
777            result.is_err(),
778            "Deserializing invalid JSON should return error"
779        );
780    }
781}