Skip to main content

rs_stats/regression/
linear_regression.rs

1// src/regression/linear_regression.rs
2
3use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5#[cfg(feature = "parallel")]
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use std::fs::File;
10use std::io::{self};
11use std::path::Path;
12
13/// Linear regression model that fits a line to data points.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LinearRegression<T = f64>
16where
17    T: Float + Debug + Default + Serialize,
18{
19    /// Slope of the regression line (coefficient of x)
20    pub slope: T,
21    /// Y-intercept of the regression line
22    pub intercept: T,
23    /// Coefficient of determination (R²) - goodness of fit
24    pub r_squared: T,
25    /// Standard error of the estimate
26    pub standard_error: T,
27    /// Number of data points used for regression
28    pub n: usize,
29}
30
31impl<T> Default for LinearRegression<T>
32where
33    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
34{
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl<T> LinearRegression<T>
41where
42    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
43{
44    /// Create a new linear regression model without fitting any data
45    pub fn new() -> Self {
46        Self {
47            slope: T::zero(),
48            intercept: T::zero(),
49            r_squared: T::zero(),
50            standard_error: T::zero(),
51            n: 0,
52        }
53    }
54
55    /// Fit a linear model to the provided x and y data points
56    ///
57    /// # Arguments
58    /// * `x_values` - Independent variable values
59    /// * `y_values` - Dependent variable values (observations)
60    ///
61    /// # Returns
62    /// * `StatsResult<()>` - Ok if successful, Err with StatsError if the inputs are invalid
63    ///
64    /// # Errors
65    /// Returns `StatsError::DimensionMismatch` if X and Y arrays have different lengths.
66    /// Returns `StatsError::EmptyData` if the input arrays are empty.
67    /// Returns `StatsError::ConversionError` if value conversion fails.
68    /// Returns `StatsError::InvalidParameter` if there's no variance in X values.
69    pub fn fit<U, V>(&mut self, x_values: &[U], y_values: &[V]) -> StatsResult<()>
70    where
71        U: NumCast + Copy,
72        V: NumCast + Copy,
73    {
74        // Validate inputs
75        if x_values.len() != y_values.len() {
76            return Err(StatsError::dimension_mismatch(format!(
77                "X and Y arrays must have the same length (got {} and {})",
78                x_values.len(),
79                y_values.len()
80            )));
81        }
82
83        if x_values.is_empty() {
84            return Err(StatsError::empty_data(
85                "Cannot fit regression with empty arrays",
86            ));
87        }
88
89        let n = x_values.len();
90        self.n = n;
91
92        // Convert input arrays to T type
93        let x_cast: Vec<T> = x_values
94            .iter()
95            .enumerate()
96            .map(|(i, &x)| {
97                T::from(x).ok_or_else(|| {
98                    StatsError::conversion_error(format!(
99                        "Failed to cast X value at index {} to type T",
100                        i
101                    ))
102                })
103            })
104            .collect::<StatsResult<Vec<T>>>()?;
105
106        let y_cast: Vec<T> = y_values
107            .iter()
108            .enumerate()
109            .map(|(i, &y)| {
110                T::from(y).ok_or_else(|| {
111                    StatsError::conversion_error(format!(
112                        "Failed to cast Y value at index {} to type T",
113                        i
114                    ))
115                })
116            })
117            .collect::<StatsResult<Vec<T>>>()?;
118
119        // Calculate means
120        let n_as_t = T::from(n).ok_or_else(|| {
121            StatsError::conversion_error(format!("Failed to convert {} to type T", n))
122        })?;
123        let x_mean = x_cast.iter().fold(T::zero(), |acc, &x| acc + x) / n_as_t;
124        let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
125
126        // Calculate variance and covariance
127        let mut sum_xy = T::zero();
128        let mut sum_xx = T::zero();
129        let mut sum_yy = T::zero();
130
131        for i in 0..n {
132            let x_diff = x_cast[i] - x_mean;
133            let y_diff = y_cast[i] - y_mean;
134
135            sum_xy = sum_xy + (x_diff * y_diff);
136            sum_xx = sum_xx + (x_diff * x_diff);
137            sum_yy = sum_yy + (y_diff * y_diff);
138        }
139
140        // Check if there's any variance in x
141        if sum_xx == T::zero() {
142            return Err(StatsError::invalid_parameter(
143                "No variance in X values, cannot fit regression line",
144            ));
145        }
146
147        // Calculate slope and intercept
148        self.slope = sum_xy / sum_xx;
149        self.intercept = y_mean - (self.slope * x_mean);
150
151        // Calculate R²
152        self.r_squared = (sum_xy * sum_xy) / (sum_xx * sum_yy);
153
154        // Calculate residuals and standard error
155        let mut sum_squared_residuals = T::zero();
156        for i in 0..n {
157            let predicted = self.predict_t(x_cast[i]);
158            let residual = y_cast[i] - predicted;
159            sum_squared_residuals = sum_squared_residuals + (residual * residual);
160        }
161
162        // Standard error of the estimate
163        if n > 2 {
164            let two = T::from(2)
165                .ok_or_else(|| StatsError::conversion_error("Failed to convert 2 to type T"))?;
166            let n_minus_two = n_as_t - two;
167            self.standard_error = (sum_squared_residuals / n_minus_two).sqrt();
168        } else {
169            self.standard_error = T::zero();
170        }
171
172        Ok(())
173    }
174
175    /// Predict y value for a given x using the fitted model (internal version with type T)
176    fn predict_t(&self, x: T) -> T {
177        self.intercept + (self.slope * x)
178    }
179
180    /// Predict y value for a given x using the fitted model
181    ///
182    /// # Arguments
183    /// * `x` - The x value to predict for
184    ///
185    /// # Returns
186    /// * `StatsResult<T>` - The predicted y value
187    ///
188    /// # Errors
189    /// Returns `StatsError::NotFitted` if the model has not been fitted (n == 0).
190    /// Returns `StatsError::ConversionError` if type conversion fails.
191    ///
192    /// # Examples
193    /// ```
194    /// use rs_stats::regression::linear_regression::LinearRegression;
195    ///
196    /// let mut model = LinearRegression::<f64>::new();
197    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
198    ///
199    /// let prediction = model.predict(4.0).unwrap();
200    /// assert!((prediction - 8.0).abs() < 1e-10);
201    /// ```
202    pub fn predict<U>(&self, x: U) -> StatsResult<T>
203    where
204        U: NumCast + Copy,
205    {
206        if self.n == 0 {
207            return Err(StatsError::not_fitted(
208                "Model has not been fitted. Call fit() before predicting.",
209            ));
210        }
211
212        let x_cast: T = T::from(x)
213            .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
214
215        Ok(self.predict_t(x_cast))
216    }
217
218    /// Calculate predictions for multiple x values
219    ///
220    /// # Arguments
221    /// * `x_values` - Slice of x values to predict for
222    ///
223    /// # Returns
224    /// * `StatsResult<Vec<T>>` - Vector of predicted y values
225    ///
226    /// # Errors
227    /// Returns `StatsError::NotFitted` if the model has not been fitted.
228    /// Returns `StatsError::ConversionError` if type conversion fails for any value.
229    ///
230    /// # Examples
231    /// ```
232    /// use rs_stats::regression::linear_regression::LinearRegression;
233    ///
234    /// let mut model = LinearRegression::<f64>::new();
235    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
236    ///
237    /// let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
238    /// assert_eq!(predictions.len(), 2);
239    /// ```
240    pub fn predict_many<U>(&self, x_values: &[U]) -> StatsResult<Vec<T>>
241    where
242        U: NumCast + Copy + Send + Sync,
243        T: Send + Sync,
244    {
245        #[cfg(feature = "parallel")]
246        {
247            x_values.par_iter().map(|&x| self.predict(x)).collect()
248        }
249        #[cfg(not(feature = "parallel"))]
250        {
251            x_values.iter().map(|&x| self.predict(x)).collect()
252        }
253    }
254
255    /// Calculate confidence intervals for the regression line
256    ///
257    /// # Arguments
258    /// * `x` - The x value to calculate confidence interval for
259    /// * `confidence_level` - Confidence level (0.95 for 95% confidence)
260    ///
261    /// # Returns
262    /// * `StatsResult<(T, T)>` - Tuple of (lower_bound, upper_bound), or an error if invalid
263    ///
264    /// # Errors
265    /// Returns `StatsError::InvalidInput` if there are fewer than 3 data points.
266    /// Returns `StatsError::InvalidParameter` if confidence level is not supported (only 0.90, 0.95, 0.99).
267    /// Returns `StatsError::ConversionError` if value conversion fails.
268    pub fn confidence_interval<U>(&self, x: U, confidence_level: f64) -> StatsResult<(T, T)>
269    where
270        U: NumCast + Copy,
271    {
272        if self.n < 3 {
273            return Err(StatsError::invalid_input(
274                "Need at least 3 data points to calculate confidence interval",
275            ));
276        }
277
278        let x_cast: T = T::from(x)
279            .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
280
281        // Get the t-critical value based on degrees of freedom and confidence level
282        // For simplicity, we'll use a normal approximation with standard errors
283        let z_score: T = match confidence_level {
284            0.90 => T::from(1.645).ok_or_else(|| {
285                StatsError::conversion_error("Failed to convert z-score 1.645 to type T")
286            })?,
287            0.95 => T::from(1.96).ok_or_else(|| {
288                StatsError::conversion_error("Failed to convert z-score 1.96 to type T")
289            })?,
290            0.99 => T::from(2.576).ok_or_else(|| {
291                StatsError::conversion_error("Failed to convert z-score 2.576 to type T")
292            })?,
293            _ => {
294                return Err(StatsError::invalid_parameter(format!(
295                    "Unsupported confidence level: {}. Supported values: 0.90, 0.95, 0.99",
296                    confidence_level
297                )));
298            }
299        };
300
301        let predicted = self.predict_t(x_cast);
302        let margin = z_score * self.standard_error;
303
304        Ok((predicted - margin, predicted + margin))
305    }
306
307    /// Get the correlation coefficient (r)
308    ///
309    /// The correlation coefficient ranges from -1 to 1, indicating the strength
310    /// and direction of the linear relationship between x and y.
311    ///
312    /// # Returns
313    /// * `StatsResult<T>` - The correlation coefficient
314    ///
315    /// # Errors
316    /// Returns `StatsError::NotFitted` if the model has not been fitted (n == 0).
317    ///
318    /// # Examples
319    /// ```
320    /// use rs_stats::regression::linear_regression::LinearRegression;
321    ///
322    /// let mut model = LinearRegression::<f64>::new();
323    /// model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
324    ///
325    /// let r = model.correlation_coefficient().unwrap();
326    /// assert!((r - 1.0).abs() < 1e-10); // Perfect positive correlation
327    /// ```
328    pub fn correlation_coefficient(&self) -> StatsResult<T> {
329        if self.n == 0 {
330            return Err(StatsError::not_fitted(
331                "Model has not been fitted. Call fit() before getting correlation coefficient.",
332            ));
333        }
334        let r = self.r_squared.sqrt();
335        Ok(if self.slope >= T::zero() { r } else { -r })
336    }
337
338    /// Save the model to a file
339    ///
340    /// # Arguments
341    /// * `path` - Path where to save the model
342    ///
343    /// # Returns
344    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
345    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
346        let file = File::create(path)?;
347        // Use JSON format for human-readability
348        serde_json::to_writer(file, self).map_err(io::Error::other)
349    }
350
351    /// Save the model in binary format
352    ///
353    /// # Arguments
354    /// * `path` - Path where to save the model
355    ///
356    /// # Returns
357    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
358    pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
359        let file = File::create(path)?;
360        // Use bincode for more compact binary format
361        bincode::serialize_into(file, self).map_err(io::Error::other)
362    }
363
364    /// Load a model from a file
365    ///
366    /// # Arguments
367    /// * `path` - Path to the saved model file
368    ///
369    /// # Returns
370    /// * `Result<Self, io::Error>` - Loaded model or IO error
371    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
372        let file = File::open(path)?;
373        // Try to load as JSON format
374        serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
375    }
376
377    /// Load a model from a binary file
378    ///
379    /// # Arguments
380    /// * `path` - Path to the saved model file
381    ///
382    /// # Returns
383    /// * `Result<Self, io::Error>` - Loaded model or IO error
384    pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
385        let file = File::open(path)?;
386        // Try to load as bincode format
387        bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
388    }
389
390    /// Save the model to a string in JSON format
391    ///
392    /// # Returns
393    /// * `Result<String, String>` - JSON string representation or error message
394    pub fn to_json(&self) -> Result<String, String> {
395        serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
396    }
397
398    /// Load a model from a JSON string
399    ///
400    /// # Arguments
401    /// * `json` - JSON string containing the model data
402    ///
403    /// # Returns
404    /// * `Result<Self, String>` - Loaded model or error message
405    pub fn from_json(json: &str) -> Result<Self, String> {
406        serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use crate::utils::approx_equal;
414    use tempfile::tempdir;
415
416    #[test]
417    fn test_simple_regression_f64() {
418        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
419        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
420
421        let mut model = LinearRegression::<f64>::new();
422        let result = model.fit(&x, &y);
423
424        assert!(result.is_ok());
425        assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
426        assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
427        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
428    }
429
430    #[test]
431    fn test_simple_regression_f32() {
432        let x = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32];
433        let y = vec![2.0f32, 4.0f32, 6.0f32, 8.0f32, 10.0f32];
434
435        let mut model = LinearRegression::<f32>::new();
436        let result = model.fit(&x, &y);
437
438        assert!(result.is_ok());
439        assert!(approx_equal(model.slope, 2.0f32, Some(1e-6)));
440        assert!(approx_equal(model.intercept, 0.0f32, Some(1e-6)));
441        assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-6)));
442    }
443
444    #[test]
445    fn test_integer_data() {
446        let x = vec![1, 2, 3, 4, 5];
447        let y = vec![2, 4, 6, 8, 10];
448
449        let mut model = LinearRegression::<f64>::new();
450        let result = model.fit(&x, &y);
451
452        assert!(result.is_ok());
453        assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
454        assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
455        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
456    }
457
458    #[test]
459    fn test_mixed_types() {
460        let x = vec![1u32, 2u32, 3u32, 4u32, 5u32];
461        let y = vec![2.1, 3.9, 6.2, 7.8, 10.1];
462
463        let mut model = LinearRegression::<f64>::new();
464        let result = model.fit(&x, &y);
465
466        assert!(result.is_ok());
467        assert!(model.slope > 1.9 && model.slope < 2.1);
468        assert!(model.intercept > -0.1 && model.intercept < 0.1);
469        assert!(model.r_squared > 0.99);
470    }
471
472    #[test]
473    fn test_prediction() {
474        let x = vec![1, 2, 3, 4, 5];
475        let y = vec![2, 4, 6, 8, 10];
476
477        let mut model = LinearRegression::<f64>::new();
478        model.fit(&x, &y).unwrap();
479
480        assert!(approx_equal(model.predict(6u32).unwrap(), 12.0, Some(1e-6)));
481        assert!(approx_equal(model.predict(0i32).unwrap(), 0.0, Some(1e-6)));
482    }
483
484    #[test]
485    fn test_invalid_inputs() {
486        let x = vec![1, 2, 3];
487        let y = vec![2, 4];
488
489        let mut model = LinearRegression::<f64>::new();
490        let result = model.fit(&x, &y);
491
492        assert!(result.is_err());
493    }
494
495    #[test]
496    fn test_constant_x() {
497        let x = vec![1, 1, 1];
498        let y = vec![2, 3, 4];
499
500        let mut model = LinearRegression::<f64>::new();
501        let result = model.fit(&x, &y);
502
503        assert!(result.is_err());
504    }
505
506    #[test]
507    fn test_save_load_json() {
508        // Create a temporary directory
509        let dir = tempdir().unwrap();
510        let file_path = dir.path().join("model.json");
511
512        // Create and fit a model
513        let mut model = LinearRegression::<f64>::new();
514        model
515            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
516            .unwrap();
517
518        // Save the model
519        let save_result = model.save(&file_path);
520        assert!(save_result.is_ok());
521
522        // Load the model
523        let loaded_model = LinearRegression::<f64>::load(&file_path);
524        assert!(loaded_model.is_ok());
525        let loaded = loaded_model.unwrap();
526
527        // Check that the loaded model has the same parameters
528        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
529        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
530        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
531        assert_eq!(loaded.n, model.n);
532    }
533
534    #[test]
535    fn test_save_load_binary() {
536        // Create a temporary directory
537        let dir = tempdir().unwrap();
538        let file_path = dir.path().join("model.bin");
539
540        // Create and fit a model
541        let mut model = LinearRegression::<f64>::new();
542        model
543            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
544            .unwrap();
545
546        // Save the model
547        let save_result = model.save_binary(&file_path);
548        assert!(save_result.is_ok());
549
550        // Load the model
551        let loaded_model = LinearRegression::<f64>::load_binary(&file_path);
552        assert!(loaded_model.is_ok());
553        let loaded = loaded_model.unwrap();
554
555        // Check that the loaded model has the same parameters
556        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
557        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
558        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
559        assert_eq!(loaded.n, model.n);
560    }
561
562    #[test]
563    fn test_json_serialization() {
564        // Create and fit a model
565        let mut model = LinearRegression::<f64>::new();
566        model
567            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
568            .unwrap();
569
570        // Serialize to JSON string
571        let json_result = model.to_json();
572        assert!(json_result.is_ok());
573        let json_str = json_result.unwrap();
574
575        // Deserialize from JSON string
576        let loaded_model = LinearRegression::<f64>::from_json(&json_str);
577        assert!(loaded_model.is_ok());
578        let loaded = loaded_model.unwrap();
579
580        // Check that the loaded model has the same parameters
581        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
582        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
583        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
584        assert_eq!(loaded.n, model.n);
585    }
586
587    #[test]
588    fn test_load_nonexistent_file() {
589        // Test loading from a file that doesn't exist
590        let result = LinearRegression::<f64>::load("/nonexistent/path/model.json");
591        assert!(result.is_err());
592    }
593
594    #[test]
595    fn test_load_binary_nonexistent_file() {
596        // Test loading from a binary file that doesn't exist
597        let result = LinearRegression::<f64>::load_binary("/nonexistent/path/model.bin");
598        assert!(result.is_err());
599    }
600
601    #[test]
602    fn test_from_json_invalid_json() {
603        // Test deserializing from invalid JSON
604        let invalid_json = "{invalid json}";
605        let result = LinearRegression::<f64>::from_json(invalid_json);
606        assert!(result.is_err());
607    }
608
609    #[test]
610    fn test_predict_when_not_fitted() {
611        // Test that predict returns an error when model is not fitted
612        let model = LinearRegression::<f64>::new();
613        let result = model.predict(5.0);
614        assert!(result.is_err());
615        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
616    }
617
618    #[test]
619    fn test_save_invalid_path() {
620        // Test saving to an invalid path (non-existent directory)
621        let mut model = LinearRegression::<f64>::new();
622        model.fit(&[1.0, 2.0], &[2.0, 4.0]).unwrap();
623
624        let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
625        let result = model.save(invalid_path);
626        assert!(
627            result.is_err(),
628            "Saving to invalid path should return error"
629        );
630    }
631
632    #[test]
633    fn test_fit_standard_error_n_less_than_or_equal_two() {
634        // Test the branch where n <= 2 (standard_error = 0)
635        let mut model = LinearRegression::<f64>::new();
636        let x = vec![1.0, 2.0];
637        let y = vec![2.0, 4.0];
638        model.fit(&x, &y).unwrap();
639
640        // When n = 2, standard_error should be 0
641        assert_eq!(model.standard_error, 0.0);
642    }
643
644    #[test]
645    fn test_fit_standard_error_n_greater_than_two() {
646        // Test the branch where n > 2 (standard_error calculated)
647        let mut model = LinearRegression::<f64>::new();
648        let x = vec![1.0, 2.0, 3.0];
649        let y = vec![2.0, 4.0, 6.0];
650        model.fit(&x, &y).unwrap();
651
652        // When n > 2, standard_error should be calculated
653        assert!(model.standard_error >= 0.0);
654    }
655
656    #[test]
657    fn test_confidence_interval_n_less_than_three() {
658        // Test confidence_interval with n < 3
659        let mut model = LinearRegression::<f64>::new();
660        let x = vec![1.0, 2.0];
661        let y = vec![2.0, 4.0];
662        model.fit(&x, &y).unwrap();
663
664        let result = model.confidence_interval(3.0, 0.95);
665        assert!(result.is_err());
666        assert!(matches!(
667            result.unwrap_err(),
668            StatsError::InvalidInput { .. }
669        ));
670    }
671
672    #[test]
673    fn test_confidence_interval_unsupported_level() {
674        // Test confidence_interval with unsupported confidence level
675        let mut model = LinearRegression::<f64>::new();
676        let x = vec![1.0, 2.0, 3.0, 4.0];
677        let y = vec![2.0, 4.0, 6.0, 8.0];
678        model.fit(&x, &y).unwrap();
679
680        let result = model.confidence_interval(3.0, 0.85);
681        assert!(result.is_err());
682        assert!(matches!(
683            result.unwrap_err(),
684            StatsError::InvalidParameter { .. }
685        ));
686    }
687
688    #[test]
689    fn test_confidence_interval_supported_levels() {
690        // Test all supported confidence levels
691        let mut model = LinearRegression::<f64>::new();
692        let x = vec![1.0, 2.0, 3.0, 4.0];
693        let y = vec![2.0, 4.0, 6.0, 8.0];
694        model.fit(&x, &y).unwrap();
695
696        for level in [0.90, 0.95, 0.99] {
697            let result = model.confidence_interval(3.0, level);
698            assert!(
699                result.is_ok(),
700                "Confidence level {} should be supported",
701                level
702            );
703            let (lower, upper) = result.unwrap();
704            assert!(lower <= upper, "Lower bound should be <= upper bound");
705        }
706    }
707
708    #[test]
709    fn test_correlation_coefficient_positive_slope() {
710        // Test correlation_coefficient with positive slope
711        let mut model = LinearRegression::<f64>::new();
712        let x = vec![1.0, 2.0, 3.0];
713        let y = vec![2.0, 4.0, 6.0];
714        model.fit(&x, &y).unwrap();
715
716        let r = model.correlation_coefficient().unwrap();
717        assert!(
718            r >= 0.0,
719            "Correlation should be positive for positive slope"
720        );
721    }
722
723    #[test]
724    fn test_correlation_coefficient_negative_slope() {
725        // Test correlation_coefficient with negative slope
726        let mut model = LinearRegression::<f64>::new();
727        let x = vec![1.0, 2.0, 3.0];
728        let y = vec![6.0, 4.0, 2.0];
729        model.fit(&x, &y).unwrap();
730
731        let r = model.correlation_coefficient().unwrap();
732        assert!(
733            r <= 0.0,
734            "Correlation should be negative for negative slope"
735        );
736    }
737
738    #[test]
739    fn test_correlation_coefficient_not_fitted() {
740        // Test correlation_coefficient when model is not fitted
741        let model = LinearRegression::<f64>::new();
742        let result = model.correlation_coefficient();
743        assert!(result.is_err());
744        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
745    }
746
747    #[test]
748    fn test_predict_many_not_fitted() {
749        // Test predict_many when model is not fitted
750        let model = LinearRegression::<f64>::new();
751        let result = model.predict_many(&[1.0, 2.0, 3.0]);
752        assert!(result.is_err());
753        assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
754    }
755
756    #[test]
757    fn test_predict_many_success() {
758        // Test predict_many with valid data
759        let mut model = LinearRegression::<f64>::new();
760        model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
761
762        let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
763        assert_eq!(predictions.len(), 2);
764        assert!((predictions[0] - 8.0).abs() < 1e-10);
765        assert!((predictions[1] - 10.0).abs() < 1e-10);
766    }
767
768    #[test]
769    fn test_load_invalid_json() {
770        // Test loading invalid JSON
771        let dir = tempdir().unwrap();
772        let file_path = dir.path().join("invalid.json");
773
774        // Write invalid JSON
775        std::fs::write(&file_path, "invalid json content").unwrap();
776
777        let result = LinearRegression::<f64>::load(&file_path);
778        assert!(result.is_err(), "Loading invalid JSON should return error");
779    }
780
781    #[test]
782    fn test_from_json_invalid() {
783        // Test deserializing invalid JSON string
784        let invalid_json = "not valid json";
785        let result = LinearRegression::<f64>::from_json(invalid_json);
786        assert!(
787            result.is_err(),
788            "Deserializing invalid JSON should return error"
789        );
790    }
791}