1use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5#[cfg(feature = "parallel")]
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::fmt::Debug;
9use std::fs::File;
10use std::io::{self};
11use std::path::Path;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct LinearRegression<T = f64>
16where
17 T: Float + Debug + Default + Serialize,
18{
19 pub slope: T,
21 pub intercept: T,
23 pub r_squared: T,
25 pub standard_error: T,
27 pub n: usize,
29}
30
31impl<T> Default for LinearRegression<T>
32where
33 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
34{
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl<T> LinearRegression<T>
41where
42 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
43{
44 pub fn new() -> Self {
46 Self {
47 slope: T::zero(),
48 intercept: T::zero(),
49 r_squared: T::zero(),
50 standard_error: T::zero(),
51 n: 0,
52 }
53 }
54
55 pub fn fit<U, V>(&mut self, x_values: &[U], y_values: &[V]) -> StatsResult<()>
70 where
71 U: NumCast + Copy,
72 V: NumCast + Copy,
73 {
74 if x_values.len() != y_values.len() {
76 return Err(StatsError::dimension_mismatch(format!(
77 "X and Y arrays must have the same length (got {} and {})",
78 x_values.len(),
79 y_values.len()
80 )));
81 }
82
83 if x_values.is_empty() {
84 return Err(StatsError::empty_data(
85 "Cannot fit regression with empty arrays",
86 ));
87 }
88
89 let n = x_values.len();
90 self.n = n;
91
92 let x_cast: Vec<T> = x_values
94 .iter()
95 .enumerate()
96 .map(|(i, &x)| {
97 T::from(x).ok_or_else(|| {
98 StatsError::conversion_error(format!(
99 "Failed to cast X value at index {} to type T",
100 i
101 ))
102 })
103 })
104 .collect::<StatsResult<Vec<T>>>()?;
105
106 let y_cast: Vec<T> = y_values
107 .iter()
108 .enumerate()
109 .map(|(i, &y)| {
110 T::from(y).ok_or_else(|| {
111 StatsError::conversion_error(format!(
112 "Failed to cast Y value at index {} to type T",
113 i
114 ))
115 })
116 })
117 .collect::<StatsResult<Vec<T>>>()?;
118
119 let n_as_t = T::from(n).ok_or_else(|| {
121 StatsError::conversion_error(format!("Failed to convert {} to type T", n))
122 })?;
123 let x_mean = x_cast.iter().fold(T::zero(), |acc, &x| acc + x) / n_as_t;
124 let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
125
126 let mut sum_xy = T::zero();
128 let mut sum_xx = T::zero();
129 let mut sum_yy = T::zero();
130
131 for i in 0..n {
132 let x_diff = x_cast[i] - x_mean;
133 let y_diff = y_cast[i] - y_mean;
134
135 sum_xy = sum_xy + (x_diff * y_diff);
136 sum_xx = sum_xx + (x_diff * x_diff);
137 sum_yy = sum_yy + (y_diff * y_diff);
138 }
139
140 if sum_xx == T::zero() {
142 return Err(StatsError::invalid_parameter(
143 "No variance in X values, cannot fit regression line",
144 ));
145 }
146
147 self.slope = sum_xy / sum_xx;
149 self.intercept = y_mean - (self.slope * x_mean);
150
151 self.r_squared = (sum_xy * sum_xy) / (sum_xx * sum_yy);
153
154 let mut sum_squared_residuals = T::zero();
156 for i in 0..n {
157 let predicted = self.predict_t(x_cast[i]);
158 let residual = y_cast[i] - predicted;
159 sum_squared_residuals = sum_squared_residuals + (residual * residual);
160 }
161
162 if n > 2 {
164 let two = T::from(2)
165 .ok_or_else(|| StatsError::conversion_error("Failed to convert 2 to type T"))?;
166 let n_minus_two = n_as_t - two;
167 self.standard_error = (sum_squared_residuals / n_minus_two).sqrt();
168 } else {
169 self.standard_error = T::zero();
170 }
171
172 Ok(())
173 }
174
175 fn predict_t(&self, x: T) -> T {
177 self.intercept + (self.slope * x)
178 }
179
180 pub fn predict<U>(&self, x: U) -> StatsResult<T>
203 where
204 U: NumCast + Copy,
205 {
206 if self.n == 0 {
207 return Err(StatsError::not_fitted(
208 "Model has not been fitted. Call fit() before predicting.",
209 ));
210 }
211
212 let x_cast: T = T::from(x)
213 .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
214
215 Ok(self.predict_t(x_cast))
216 }
217
218 pub fn predict_many<U>(&self, x_values: &[U]) -> StatsResult<Vec<T>>
241 where
242 U: NumCast + Copy + Send + Sync,
243 T: Send + Sync,
244 {
245 #[cfg(feature = "parallel")]
246 {
247 x_values.par_iter().map(|&x| self.predict(x)).collect()
248 }
249 #[cfg(not(feature = "parallel"))]
250 {
251 x_values.iter().map(|&x| self.predict(x)).collect()
252 }
253 }
254
255 pub fn confidence_interval<U>(&self, x: U, confidence_level: f64) -> StatsResult<(T, T)>
269 where
270 U: NumCast + Copy,
271 {
272 if self.n < 3 {
273 return Err(StatsError::invalid_input(
274 "Need at least 3 data points to calculate confidence interval",
275 ));
276 }
277
278 let x_cast: T = T::from(x)
279 .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
280
281 let z_score: T = match confidence_level {
284 0.90 => T::from(1.645).ok_or_else(|| {
285 StatsError::conversion_error("Failed to convert z-score 1.645 to type T")
286 })?,
287 0.95 => T::from(1.96).ok_or_else(|| {
288 StatsError::conversion_error("Failed to convert z-score 1.96 to type T")
289 })?,
290 0.99 => T::from(2.576).ok_or_else(|| {
291 StatsError::conversion_error("Failed to convert z-score 2.576 to type T")
292 })?,
293 _ => {
294 return Err(StatsError::invalid_parameter(format!(
295 "Unsupported confidence level: {}. Supported values: 0.90, 0.95, 0.99",
296 confidence_level
297 )));
298 }
299 };
300
301 let predicted = self.predict_t(x_cast);
302 let margin = z_score * self.standard_error;
303
304 Ok((predicted - margin, predicted + margin))
305 }
306
307 pub fn correlation_coefficient(&self) -> StatsResult<T> {
329 if self.n == 0 {
330 return Err(StatsError::not_fitted(
331 "Model has not been fitted. Call fit() before getting correlation coefficient.",
332 ));
333 }
334 let r = self.r_squared.sqrt();
335 Ok(if self.slope >= T::zero() { r } else { -r })
336 }
337
338 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
346 let file = File::create(path)?;
347 serde_json::to_writer(file, self).map_err(io::Error::other)
349 }
350
351 pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
359 let file = File::create(path)?;
360 bincode::serialize_into(file, self).map_err(io::Error::other)
362 }
363
364 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
372 let file = File::open(path)?;
373 serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
375 }
376
377 pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
385 let file = File::open(path)?;
386 bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
388 }
389
390 pub fn to_json(&self) -> Result<String, String> {
395 serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
396 }
397
398 pub fn from_json(json: &str) -> Result<Self, String> {
406 serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::utils::approx_equal;
414 use tempfile::tempdir;
415
416 #[test]
417 fn test_simple_regression_f64() {
418 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
419 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
420
421 let mut model = LinearRegression::<f64>::new();
422 let result = model.fit(&x, &y);
423
424 assert!(result.is_ok());
425 assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
426 assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
427 assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
428 }
429
430 #[test]
431 fn test_simple_regression_f32() {
432 let x = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32];
433 let y = vec![2.0f32, 4.0f32, 6.0f32, 8.0f32, 10.0f32];
434
435 let mut model = LinearRegression::<f32>::new();
436 let result = model.fit(&x, &y);
437
438 assert!(result.is_ok());
439 assert!(approx_equal(model.slope, 2.0f32, Some(1e-6)));
440 assert!(approx_equal(model.intercept, 0.0f32, Some(1e-6)));
441 assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-6)));
442 }
443
444 #[test]
445 fn test_integer_data() {
446 let x = vec![1, 2, 3, 4, 5];
447 let y = vec![2, 4, 6, 8, 10];
448
449 let mut model = LinearRegression::<f64>::new();
450 let result = model.fit(&x, &y);
451
452 assert!(result.is_ok());
453 assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
454 assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
455 assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
456 }
457
458 #[test]
459 fn test_mixed_types() {
460 let x = vec![1u32, 2u32, 3u32, 4u32, 5u32];
461 let y = vec![2.1, 3.9, 6.2, 7.8, 10.1];
462
463 let mut model = LinearRegression::<f64>::new();
464 let result = model.fit(&x, &y);
465
466 assert!(result.is_ok());
467 assert!(model.slope > 1.9 && model.slope < 2.1);
468 assert!(model.intercept > -0.1 && model.intercept < 0.1);
469 assert!(model.r_squared > 0.99);
470 }
471
472 #[test]
473 fn test_prediction() {
474 let x = vec![1, 2, 3, 4, 5];
475 let y = vec![2, 4, 6, 8, 10];
476
477 let mut model = LinearRegression::<f64>::new();
478 model.fit(&x, &y).unwrap();
479
480 assert!(approx_equal(model.predict(6u32).unwrap(), 12.0, Some(1e-6)));
481 assert!(approx_equal(model.predict(0i32).unwrap(), 0.0, Some(1e-6)));
482 }
483
484 #[test]
485 fn test_invalid_inputs() {
486 let x = vec![1, 2, 3];
487 let y = vec![2, 4];
488
489 let mut model = LinearRegression::<f64>::new();
490 let result = model.fit(&x, &y);
491
492 assert!(result.is_err());
493 }
494
495 #[test]
496 fn test_constant_x() {
497 let x = vec![1, 1, 1];
498 let y = vec![2, 3, 4];
499
500 let mut model = LinearRegression::<f64>::new();
501 let result = model.fit(&x, &y);
502
503 assert!(result.is_err());
504 }
505
506 #[test]
507 fn test_save_load_json() {
508 let dir = tempdir().unwrap();
510 let file_path = dir.path().join("model.json");
511
512 let mut model = LinearRegression::<f64>::new();
514 model
515 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
516 .unwrap();
517
518 let save_result = model.save(&file_path);
520 assert!(save_result.is_ok());
521
522 let loaded_model = LinearRegression::<f64>::load(&file_path);
524 assert!(loaded_model.is_ok());
525 let loaded = loaded_model.unwrap();
526
527 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
529 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
530 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
531 assert_eq!(loaded.n, model.n);
532 }
533
534 #[test]
535 fn test_save_load_binary() {
536 let dir = tempdir().unwrap();
538 let file_path = dir.path().join("model.bin");
539
540 let mut model = LinearRegression::<f64>::new();
542 model
543 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
544 .unwrap();
545
546 let save_result = model.save_binary(&file_path);
548 assert!(save_result.is_ok());
549
550 let loaded_model = LinearRegression::<f64>::load_binary(&file_path);
552 assert!(loaded_model.is_ok());
553 let loaded = loaded_model.unwrap();
554
555 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
557 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
558 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
559 assert_eq!(loaded.n, model.n);
560 }
561
562 #[test]
563 fn test_json_serialization() {
564 let mut model = LinearRegression::<f64>::new();
566 model
567 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
568 .unwrap();
569
570 let json_result = model.to_json();
572 assert!(json_result.is_ok());
573 let json_str = json_result.unwrap();
574
575 let loaded_model = LinearRegression::<f64>::from_json(&json_str);
577 assert!(loaded_model.is_ok());
578 let loaded = loaded_model.unwrap();
579
580 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
582 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
583 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
584 assert_eq!(loaded.n, model.n);
585 }
586
587 #[test]
588 fn test_load_nonexistent_file() {
589 let result = LinearRegression::<f64>::load("/nonexistent/path/model.json");
591 assert!(result.is_err());
592 }
593
594 #[test]
595 fn test_load_binary_nonexistent_file() {
596 let result = LinearRegression::<f64>::load_binary("/nonexistent/path/model.bin");
598 assert!(result.is_err());
599 }
600
601 #[test]
602 fn test_from_json_invalid_json() {
603 let invalid_json = "{invalid json}";
605 let result = LinearRegression::<f64>::from_json(invalid_json);
606 assert!(result.is_err());
607 }
608
609 #[test]
610 fn test_predict_when_not_fitted() {
611 let model = LinearRegression::<f64>::new();
613 let result = model.predict(5.0);
614 assert!(result.is_err());
615 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
616 }
617
618 #[test]
619 fn test_save_invalid_path() {
620 let mut model = LinearRegression::<f64>::new();
622 model.fit(&[1.0, 2.0], &[2.0, 4.0]).unwrap();
623
624 let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
625 let result = model.save(invalid_path);
626 assert!(
627 result.is_err(),
628 "Saving to invalid path should return error"
629 );
630 }
631
632 #[test]
633 fn test_fit_standard_error_n_less_than_or_equal_two() {
634 let mut model = LinearRegression::<f64>::new();
636 let x = vec![1.0, 2.0];
637 let y = vec![2.0, 4.0];
638 model.fit(&x, &y).unwrap();
639
640 assert_eq!(model.standard_error, 0.0);
642 }
643
644 #[test]
645 fn test_fit_standard_error_n_greater_than_two() {
646 let mut model = LinearRegression::<f64>::new();
648 let x = vec![1.0, 2.0, 3.0];
649 let y = vec![2.0, 4.0, 6.0];
650 model.fit(&x, &y).unwrap();
651
652 assert!(model.standard_error >= 0.0);
654 }
655
656 #[test]
657 fn test_confidence_interval_n_less_than_three() {
658 let mut model = LinearRegression::<f64>::new();
660 let x = vec![1.0, 2.0];
661 let y = vec![2.0, 4.0];
662 model.fit(&x, &y).unwrap();
663
664 let result = model.confidence_interval(3.0, 0.95);
665 assert!(result.is_err());
666 assert!(matches!(
667 result.unwrap_err(),
668 StatsError::InvalidInput { .. }
669 ));
670 }
671
672 #[test]
673 fn test_confidence_interval_unsupported_level() {
674 let mut model = LinearRegression::<f64>::new();
676 let x = vec![1.0, 2.0, 3.0, 4.0];
677 let y = vec![2.0, 4.0, 6.0, 8.0];
678 model.fit(&x, &y).unwrap();
679
680 let result = model.confidence_interval(3.0, 0.85);
681 assert!(result.is_err());
682 assert!(matches!(
683 result.unwrap_err(),
684 StatsError::InvalidParameter { .. }
685 ));
686 }
687
688 #[test]
689 fn test_confidence_interval_supported_levels() {
690 let mut model = LinearRegression::<f64>::new();
692 let x = vec![1.0, 2.0, 3.0, 4.0];
693 let y = vec![2.0, 4.0, 6.0, 8.0];
694 model.fit(&x, &y).unwrap();
695
696 for level in [0.90, 0.95, 0.99] {
697 let result = model.confidence_interval(3.0, level);
698 assert!(
699 result.is_ok(),
700 "Confidence level {} should be supported",
701 level
702 );
703 let (lower, upper) = result.unwrap();
704 assert!(lower <= upper, "Lower bound should be <= upper bound");
705 }
706 }
707
708 #[test]
709 fn test_correlation_coefficient_positive_slope() {
710 let mut model = LinearRegression::<f64>::new();
712 let x = vec![1.0, 2.0, 3.0];
713 let y = vec![2.0, 4.0, 6.0];
714 model.fit(&x, &y).unwrap();
715
716 let r = model.correlation_coefficient().unwrap();
717 assert!(
718 r >= 0.0,
719 "Correlation should be positive for positive slope"
720 );
721 }
722
723 #[test]
724 fn test_correlation_coefficient_negative_slope() {
725 let mut model = LinearRegression::<f64>::new();
727 let x = vec![1.0, 2.0, 3.0];
728 let y = vec![6.0, 4.0, 2.0];
729 model.fit(&x, &y).unwrap();
730
731 let r = model.correlation_coefficient().unwrap();
732 assert!(
733 r <= 0.0,
734 "Correlation should be negative for negative slope"
735 );
736 }
737
738 #[test]
739 fn test_correlation_coefficient_not_fitted() {
740 let model = LinearRegression::<f64>::new();
742 let result = model.correlation_coefficient();
743 assert!(result.is_err());
744 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
745 }
746
747 #[test]
748 fn test_predict_many_not_fitted() {
749 let model = LinearRegression::<f64>::new();
751 let result = model.predict_many(&[1.0, 2.0, 3.0]);
752 assert!(result.is_err());
753 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
754 }
755
756 #[test]
757 fn test_predict_many_success() {
758 let mut model = LinearRegression::<f64>::new();
760 model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
761
762 let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
763 assert_eq!(predictions.len(), 2);
764 assert!((predictions[0] - 8.0).abs() < 1e-10);
765 assert!((predictions[1] - 10.0).abs() < 1e-10);
766 }
767
768 #[test]
769 fn test_load_invalid_json() {
770 let dir = tempdir().unwrap();
772 let file_path = dir.path().join("invalid.json");
773
774 std::fs::write(&file_path, "invalid json content").unwrap();
776
777 let result = LinearRegression::<f64>::load(&file_path);
778 assert!(result.is_err(), "Loading invalid JSON should return error");
779 }
780
781 #[test]
782 fn test_from_json_invalid() {
783 let invalid_json = "not valid json";
785 let result = LinearRegression::<f64>::from_json(invalid_json);
786 assert!(
787 result.is_err(),
788 "Deserializing invalid JSON should return error"
789 );
790 }
791}