rs_stats/regression/
linear_regression.rs

1// src/regression/linear_regression.rs
2
3use num_traits::{Float, NumCast};
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::fs::File;
7use std::io::{self};
8use std::path::Path;
9
10/// Linear regression model that fits a line to data points.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct LinearRegression<T = f64>
13where
14    T: Float + Debug + Default + Serialize,
15{
16    /// Slope of the regression line (coefficient of x)
17    pub slope: T,
18    /// Y-intercept of the regression line
19    pub intercept: T,
20    /// Coefficient of determination (R²) - goodness of fit
21    pub r_squared: T,
22    /// Standard error of the estimate
23    pub standard_error: T,
24    /// Number of data points used for regression
25    pub n: usize,
26}
27
28impl<T> Default for LinearRegression<T>
29where
30    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
31{
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl<T> LinearRegression<T>
38where
39    T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
40{
41    /// Create a new linear regression model without fitting any data
42    pub fn new() -> Self {
43        Self {
44            slope: T::zero(),
45            intercept: T::zero(),
46            r_squared: T::zero(),
47            standard_error: T::zero(),
48            n: 0,
49        }
50    }
51
52    /// Fit a linear model to the provided x and y data points
53    ///
54    /// # Arguments
55    /// * `x_values` - Independent variable values
56    /// * `y_values` - Dependent variable values (observations)
57    ///
58    /// # Returns
59    /// * `Result<(), String>` - Ok if successful, Err with message if the inputs are invalid
60    pub fn fit<U, V>(&mut self, x_values: &[U], y_values: &[V]) -> Result<(), String>
61    where
62        U: NumCast + Copy,
63        V: NumCast + Copy,
64    {
65        // Validate inputs
66        if x_values.len() != y_values.len() {
67            return Err("X and Y arrays must have the same length".to_string());
68        }
69
70        if x_values.is_empty() {
71            return Err("Cannot fit regression with empty arrays".to_string());
72        }
73
74        let n = x_values.len();
75        self.n = n;
76
77        // Convert input arrays to T type
78        let x_cast: Vec<T> = x_values
79            .iter()
80            .map(|&x| T::from(x).ok_or_else(|| "Failed to cast X value".to_string()))
81            .collect::<Result<Vec<T>, String>>()?;
82
83        let y_cast: Vec<T> = y_values
84            .iter()
85            .map(|&y| T::from(y).ok_or_else(|| "Failed to cast Y value".to_string()))
86            .collect::<Result<Vec<T>, String>>()?;
87
88        // Calculate means
89        let x_mean = x_cast.iter().fold(T::zero(), |acc, &x| acc + x) / T::from(n).unwrap();
90        let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / T::from(n).unwrap();
91
92        // Calculate variance and covariance
93        let mut sum_xy = T::zero();
94        let mut sum_xx = T::zero();
95        let mut sum_yy = T::zero();
96
97        for i in 0..n {
98            let x_diff = x_cast[i] - x_mean;
99            let y_diff = y_cast[i] - y_mean;
100
101            sum_xy = sum_xy + (x_diff * y_diff);
102            sum_xx = sum_xx + (x_diff * x_diff);
103            sum_yy = sum_yy + (y_diff * y_diff);
104        }
105
106        // Check if there's any variance in x
107        if sum_xx == T::zero() {
108            return Err("No variance in X values, cannot fit regression line".to_string());
109        }
110
111        // Calculate slope and intercept
112        self.slope = sum_xy / sum_xx;
113        self.intercept = y_mean - (self.slope * x_mean);
114
115        // Calculate R²
116        self.r_squared = (sum_xy * sum_xy) / (sum_xx * sum_yy);
117
118        // Calculate residuals and standard error
119        let mut sum_squared_residuals = T::zero();
120        for i in 0..n {
121            let predicted = self.predict_t(x_cast[i]);
122            let residual = y_cast[i] - predicted;
123            sum_squared_residuals = sum_squared_residuals + (residual * residual);
124        }
125
126        // Standard error of the estimate
127        if n > 2 {
128            let two = T::from(2).unwrap();
129            self.standard_error = (sum_squared_residuals / (T::from(n).unwrap() - two)).sqrt();
130        } else {
131            self.standard_error = T::zero();
132        }
133
134        Ok(())
135    }
136
137    /// Predict y value for a given x using the fitted model (internal version with type T)
138    fn predict_t(&self, x: T) -> T {
139        self.intercept + (self.slope * x)
140    }
141
142    /// Predict y value for a given x using the fitted model
143    ///
144    /// # Arguments
145    /// * `x` - The x value to predict for
146    ///
147    /// # Returns
148    /// * The predicted y value
149    pub fn predict<U>(&self, x: U) -> T
150    where
151        U: NumCast + Copy,
152    {
153        let x_cast: T = match T::from(x) {
154            Some(val) => val,
155            None => return T::nan(),
156        };
157
158        self.predict_t(x_cast)
159    }
160
161    /// Calculate predictions for multiple x values
162    ///
163    /// # Arguments
164    /// * `x_values` - Slice of x values to predict for
165    ///
166    /// # Returns
167    /// * Vector of predicted y values
168    pub fn predict_many<U>(&self, x_values: &[U]) -> Vec<T>
169    where
170        U: NumCast + Copy,
171    {
172        x_values.iter().map(|&x| self.predict(x)).collect()
173    }
174
175    /// Calculate confidence intervals for the regression line
176    ///
177    /// # Arguments
178    /// * `x` - The x value to calculate confidence interval for
179    /// * `confidence_level` - Confidence level (0.95 for 95% confidence)
180    ///
181    /// # Returns
182    /// * `Option<(T, T)>` - Tuple of (lower_bound, upper_bound) or None if not enough data
183    pub fn confidence_interval<U>(&self, x: U, confidence_level: f64) -> Option<(T, T)>
184    where
185        U: NumCast + Copy,
186    {
187        if self.n < 3 {
188            return None;
189        }
190
191        let x_cast: T = T::from(x)?;
192
193        // Get the t-critical value based on degrees of freedom and confidence level
194        // For simplicity, we'll use a normal approximation with standard errors
195        let z_score: T = match confidence_level {
196            0.90 => T::from(1.645).unwrap(),
197            0.95 => T::from(1.96).unwrap(),
198            0.99 => T::from(2.576).unwrap(),
199            _ => return None, // Only supporting common confidence levels for simplicity
200        };
201
202        let predicted = self.predict_t(x_cast);
203        let margin = z_score * self.standard_error;
204
205        Some((predicted - margin, predicted + margin))
206    }
207
208    /// Get the correlation coefficient (r)
209    pub fn correlation_coefficient(&self) -> T {
210        let r = self.r_squared.sqrt();
211        if self.slope >= T::zero() { r } else { -r }
212    }
213
214    /// Save the model to a file
215    ///
216    /// # Arguments
217    /// * `path` - Path where to save the model
218    ///
219    /// # Returns
220    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
221    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
222        let file = File::create(path)?;
223        // Use JSON format for human-readability
224        serde_json::to_writer(file, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
225    }
226
227    /// Save the model in binary format
228    ///
229    /// # Arguments
230    /// * `path` - Path where to save the model
231    ///
232    /// # Returns
233    /// * `Result<(), io::Error>` - Ok if successful, Err with IO error if saving fails
234    pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
235        let file = File::create(path)?;
236        // Use bincode for more compact binary format
237        bincode::serialize_into(file, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
238    }
239
240    /// Load a model from a file
241    ///
242    /// # Arguments
243    /// * `path` - Path to the saved model file
244    ///
245    /// # Returns
246    /// * `Result<Self, io::Error>` - Loaded model or IO error
247    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
248        let file = File::open(path)?;
249        // Try to load as JSON format
250        serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
251    }
252
253    /// Load a model from a binary file
254    ///
255    /// # Arguments
256    /// * `path` - Path to the saved model file
257    ///
258    /// # Returns
259    /// * `Result<Self, io::Error>` - Loaded model or IO error
260    pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
261        let file = File::open(path)?;
262        // Try to load as bincode format
263        bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
264    }
265
266    /// Save the model to a string in JSON format
267    ///
268    /// # Returns
269    /// * `Result<String, String>` - JSON string representation or error message
270    pub fn to_json(&self) -> Result<String, String> {
271        serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
272    }
273
274    /// Load a model from a JSON string
275    ///
276    /// # Arguments
277    /// * `json` - JSON string containing the model data
278    ///
279    /// # Returns
280    /// * `Result<Self, String>` - Loaded model or error message
281    pub fn from_json(json: &str) -> Result<Self, String> {
282        serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::utils::numeric::approx_equal;
290    use tempfile::tempdir;
291
292    #[test]
293    fn test_simple_regression_f64() {
294        let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
295        let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
296
297        let mut model = LinearRegression::<f64>::new();
298        let result = model.fit(&x, &y);
299
300        assert!(result.is_ok());
301        assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
302        assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
303        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
304    }
305
306    #[test]
307    fn test_simple_regression_f32() {
308        let x = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32];
309        let y = vec![2.0f32, 4.0f32, 6.0f32, 8.0f32, 10.0f32];
310
311        let mut model = LinearRegression::<f32>::new();
312        let result = model.fit(&x, &y);
313
314        assert!(result.is_ok());
315        assert!(approx_equal(model.slope, 2.0f32, Some(1e-6)));
316        assert!(approx_equal(model.intercept, 0.0f32, Some(1e-6)));
317        assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-6)));
318    }
319
320    #[test]
321    fn test_integer_data() {
322        let x = vec![1, 2, 3, 4, 5];
323        let y = vec![2, 4, 6, 8, 10];
324
325        let mut model = LinearRegression::<f64>::new();
326        let result = model.fit(&x, &y);
327
328        assert!(result.is_ok());
329        assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
330        assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
331        assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
332    }
333
334    #[test]
335    fn test_mixed_types() {
336        let x = vec![1u32, 2u32, 3u32, 4u32, 5u32];
337        let y = vec![2.1, 3.9, 6.2, 7.8, 10.1];
338
339        let mut model = LinearRegression::<f64>::new();
340        let result = model.fit(&x, &y);
341
342        assert!(result.is_ok());
343        assert!(model.slope > 1.9 && model.slope < 2.1);
344        assert!(model.intercept > -0.1 && model.intercept < 0.1);
345        assert!(model.r_squared > 0.99);
346    }
347
348    #[test]
349    fn test_prediction() {
350        let x = vec![1, 2, 3, 4, 5];
351        let y = vec![2, 4, 6, 8, 10];
352
353        let mut model = LinearRegression::<f64>::new();
354        model.fit(&x, &y).unwrap();
355
356        assert!(approx_equal(model.predict(6u32), 12.0, Some(1e-6)));
357        assert!(approx_equal(model.predict(0i32), 0.0, Some(1e-6)));
358    }
359
360    #[test]
361    fn test_invalid_inputs() {
362        let x = vec![1, 2, 3];
363        let y = vec![2, 4];
364
365        let mut model = LinearRegression::<f64>::new();
366        let result = model.fit(&x, &y);
367
368        assert!(result.is_err());
369    }
370
371    #[test]
372    fn test_constant_x() {
373        let x = vec![1, 1, 1];
374        let y = vec![2, 3, 4];
375
376        let mut model = LinearRegression::<f64>::new();
377        let result = model.fit(&x, &y);
378
379        assert!(result.is_err());
380    }
381
382    #[test]
383    fn test_save_load_json() {
384        // Create a temporary directory
385        let dir = tempdir().unwrap();
386        let file_path = dir.path().join("model.json");
387
388        // Create and fit a model
389        let mut model = LinearRegression::<f64>::new();
390        model
391            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
392            .unwrap();
393
394        // Save the model
395        let save_result = model.save(&file_path);
396        assert!(save_result.is_ok());
397
398        // Load the model
399        let loaded_model = LinearRegression::<f64>::load(&file_path);
400        assert!(loaded_model.is_ok());
401        let loaded = loaded_model.unwrap();
402
403        // Check that the loaded model has the same parameters
404        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
405        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
406        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
407        assert_eq!(loaded.n, model.n);
408    }
409
410    #[test]
411    fn test_save_load_binary() {
412        // Create a temporary directory
413        let dir = tempdir().unwrap();
414        let file_path = dir.path().join("model.bin");
415
416        // Create and fit a model
417        let mut model = LinearRegression::<f64>::new();
418        model
419            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
420            .unwrap();
421
422        // Save the model
423        let save_result = model.save_binary(&file_path);
424        assert!(save_result.is_ok());
425
426        // Load the model
427        let loaded_model = LinearRegression::<f64>::load_binary(&file_path);
428        assert!(loaded_model.is_ok());
429        let loaded = loaded_model.unwrap();
430
431        // Check that the loaded model has the same parameters
432        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
433        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
434        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
435        assert_eq!(loaded.n, model.n);
436    }
437
438    #[test]
439    fn test_json_serialization() {
440        // Create and fit a model
441        let mut model = LinearRegression::<f64>::new();
442        model
443            .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
444            .unwrap();
445
446        // Serialize to JSON string
447        let json_result = model.to_json();
448        assert!(json_result.is_ok());
449        let json_str = json_result.unwrap();
450
451        // Deserialize from JSON string
452        let loaded_model = LinearRegression::<f64>::from_json(&json_str);
453        assert!(loaded_model.is_ok());
454        let loaded = loaded_model.unwrap();
455
456        // Check that the loaded model has the same parameters
457        assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
458        assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
459        assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
460        assert_eq!(loaded.n, model.n);
461    }
462}