1use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{self};
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct LinearRegression<T = f64>
14where
15 T: Float + Debug + Default + Serialize,
16{
17 pub slope: T,
19 pub intercept: T,
21 pub r_squared: T,
23 pub standard_error: T,
25 pub n: usize,
27}
28
29impl<T> Default for LinearRegression<T>
30where
31 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
32{
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl<T> LinearRegression<T>
39where
40 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
41{
42 pub fn new() -> Self {
44 Self {
45 slope: T::zero(),
46 intercept: T::zero(),
47 r_squared: T::zero(),
48 standard_error: T::zero(),
49 n: 0,
50 }
51 }
52
53 pub fn fit<U, V>(&mut self, x_values: &[U], y_values: &[V]) -> StatsResult<()>
68 where
69 U: NumCast + Copy,
70 V: NumCast + Copy,
71 {
72 if x_values.len() != y_values.len() {
74 return Err(StatsError::dimension_mismatch(format!(
75 "X and Y arrays must have the same length (got {} and {})",
76 x_values.len(),
77 y_values.len()
78 )));
79 }
80
81 if x_values.is_empty() {
82 return Err(StatsError::empty_data(
83 "Cannot fit regression with empty arrays",
84 ));
85 }
86
87 let n = x_values.len();
88 self.n = n;
89
90 let x_cast: Vec<T> = x_values
92 .iter()
93 .enumerate()
94 .map(|(i, &x)| {
95 T::from(x).ok_or_else(|| {
96 StatsError::conversion_error(format!(
97 "Failed to cast X value at index {} to type T",
98 i
99 ))
100 })
101 })
102 .collect::<StatsResult<Vec<T>>>()?;
103
104 let y_cast: Vec<T> = y_values
105 .iter()
106 .enumerate()
107 .map(|(i, &y)| {
108 T::from(y).ok_or_else(|| {
109 StatsError::conversion_error(format!(
110 "Failed to cast Y value at index {} to type T",
111 i
112 ))
113 })
114 })
115 .collect::<StatsResult<Vec<T>>>()?;
116
117 let n_as_t = T::from(n).ok_or_else(|| {
119 StatsError::conversion_error(format!("Failed to convert {} to type T", n))
120 })?;
121 let x_mean = x_cast.iter().fold(T::zero(), |acc, &x| acc + x) / n_as_t;
122 let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
123
124 let mut sum_xy = T::zero();
126 let mut sum_xx = T::zero();
127 let mut sum_yy = T::zero();
128
129 for i in 0..n {
130 let x_diff = x_cast[i] - x_mean;
131 let y_diff = y_cast[i] - y_mean;
132
133 sum_xy = sum_xy + (x_diff * y_diff);
134 sum_xx = sum_xx + (x_diff * x_diff);
135 sum_yy = sum_yy + (y_diff * y_diff);
136 }
137
138 if sum_xx == T::zero() {
140 return Err(StatsError::invalid_parameter(
141 "No variance in X values, cannot fit regression line",
142 ));
143 }
144
145 self.slope = sum_xy / sum_xx;
147 self.intercept = y_mean - (self.slope * x_mean);
148
149 self.r_squared = (sum_xy * sum_xy) / (sum_xx * sum_yy);
151
152 let mut sum_squared_residuals = T::zero();
154 for i in 0..n {
155 let predicted = self.predict_t(x_cast[i]);
156 let residual = y_cast[i] - predicted;
157 sum_squared_residuals = sum_squared_residuals + (residual * residual);
158 }
159
160 if n > 2 {
162 let two = T::from(2)
163 .ok_or_else(|| StatsError::conversion_error("Failed to convert 2 to type T"))?;
164 let n_minus_two = n_as_t - two;
165 self.standard_error = (sum_squared_residuals / n_minus_two).sqrt();
166 } else {
167 self.standard_error = T::zero();
168 }
169
170 Ok(())
171 }
172
173 fn predict_t(&self, x: T) -> T {
175 self.intercept + (self.slope * x)
176 }
177
178 pub fn predict<U>(&self, x: U) -> StatsResult<T>
201 where
202 U: NumCast + Copy,
203 {
204 if self.n == 0 {
205 return Err(StatsError::not_fitted(
206 "Model has not been fitted. Call fit() before predicting.",
207 ));
208 }
209
210 let x_cast: T = T::from(x)
211 .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
212
213 Ok(self.predict_t(x_cast))
214 }
215
216 pub fn predict_many<U>(&self, x_values: &[U]) -> StatsResult<Vec<T>>
239 where
240 U: NumCast + Copy,
241 {
242 x_values.iter().map(|&x| self.predict(x)).collect()
243 }
244
245 pub fn confidence_interval<U>(&self, x: U, confidence_level: f64) -> StatsResult<(T, T)>
259 where
260 U: NumCast + Copy,
261 {
262 if self.n < 3 {
263 return Err(StatsError::invalid_input(
264 "Need at least 3 data points to calculate confidence interval",
265 ));
266 }
267
268 let x_cast: T = T::from(x)
269 .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
270
271 let z_score: T = match confidence_level {
274 0.90 => T::from(1.645).ok_or_else(|| {
275 StatsError::conversion_error("Failed to convert z-score 1.645 to type T")
276 })?,
277 0.95 => T::from(1.96).ok_or_else(|| {
278 StatsError::conversion_error("Failed to convert z-score 1.96 to type T")
279 })?,
280 0.99 => T::from(2.576).ok_or_else(|| {
281 StatsError::conversion_error("Failed to convert z-score 2.576 to type T")
282 })?,
283 _ => {
284 return Err(StatsError::invalid_parameter(format!(
285 "Unsupported confidence level: {}. Supported values: 0.90, 0.95, 0.99",
286 confidence_level
287 )));
288 }
289 };
290
291 let predicted = self.predict_t(x_cast);
292 let margin = z_score * self.standard_error;
293
294 Ok((predicted - margin, predicted + margin))
295 }
296
297 pub fn correlation_coefficient(&self) -> StatsResult<T> {
319 if self.n == 0 {
320 return Err(StatsError::not_fitted(
321 "Model has not been fitted. Call fit() before getting correlation coefficient.",
322 ));
323 }
324 let r = self.r_squared.sqrt();
325 Ok(if self.slope >= T::zero() { r } else { -r })
326 }
327
328 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
336 let file = File::create(path)?;
337 serde_json::to_writer(file, self).map_err(io::Error::other)
339 }
340
341 pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
349 let file = File::create(path)?;
350 bincode::serialize_into(file, self).map_err(io::Error::other)
352 }
353
354 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
362 let file = File::open(path)?;
363 serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
365 }
366
367 pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
375 let file = File::open(path)?;
376 bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
378 }
379
380 pub fn to_json(&self) -> Result<String, String> {
385 serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
386 }
387
388 pub fn from_json(json: &str) -> Result<Self, String> {
396 serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::utils::approx_equal;
404 use tempfile::tempdir;
405
406 #[test]
407 fn test_simple_regression_f64() {
408 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
409 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
410
411 let mut model = LinearRegression::<f64>::new();
412 let result = model.fit(&x, &y);
413
414 assert!(result.is_ok());
415 assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
416 assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
417 assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
418 }
419
420 #[test]
421 fn test_simple_regression_f32() {
422 let x = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32];
423 let y = vec![2.0f32, 4.0f32, 6.0f32, 8.0f32, 10.0f32];
424
425 let mut model = LinearRegression::<f32>::new();
426 let result = model.fit(&x, &y);
427
428 assert!(result.is_ok());
429 assert!(approx_equal(model.slope, 2.0f32, Some(1e-6)));
430 assert!(approx_equal(model.intercept, 0.0f32, Some(1e-6)));
431 assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-6)));
432 }
433
434 #[test]
435 fn test_integer_data() {
436 let x = vec![1, 2, 3, 4, 5];
437 let y = vec![2, 4, 6, 8, 10];
438
439 let mut model = LinearRegression::<f64>::new();
440 let result = model.fit(&x, &y);
441
442 assert!(result.is_ok());
443 assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
444 assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
445 assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
446 }
447
448 #[test]
449 fn test_mixed_types() {
450 let x = vec![1u32, 2u32, 3u32, 4u32, 5u32];
451 let y = vec![2.1, 3.9, 6.2, 7.8, 10.1];
452
453 let mut model = LinearRegression::<f64>::new();
454 let result = model.fit(&x, &y);
455
456 assert!(result.is_ok());
457 assert!(model.slope > 1.9 && model.slope < 2.1);
458 assert!(model.intercept > -0.1 && model.intercept < 0.1);
459 assert!(model.r_squared > 0.99);
460 }
461
462 #[test]
463 fn test_prediction() {
464 let x = vec![1, 2, 3, 4, 5];
465 let y = vec![2, 4, 6, 8, 10];
466
467 let mut model = LinearRegression::<f64>::new();
468 model.fit(&x, &y).unwrap();
469
470 assert!(approx_equal(model.predict(6u32).unwrap(), 12.0, Some(1e-6)));
471 assert!(approx_equal(model.predict(0i32).unwrap(), 0.0, Some(1e-6)));
472 }
473
474 #[test]
475 fn test_invalid_inputs() {
476 let x = vec![1, 2, 3];
477 let y = vec![2, 4];
478
479 let mut model = LinearRegression::<f64>::new();
480 let result = model.fit(&x, &y);
481
482 assert!(result.is_err());
483 }
484
485 #[test]
486 fn test_constant_x() {
487 let x = vec![1, 1, 1];
488 let y = vec![2, 3, 4];
489
490 let mut model = LinearRegression::<f64>::new();
491 let result = model.fit(&x, &y);
492
493 assert!(result.is_err());
494 }
495
496 #[test]
497 fn test_save_load_json() {
498 let dir = tempdir().unwrap();
500 let file_path = dir.path().join("model.json");
501
502 let mut model = LinearRegression::<f64>::new();
504 model
505 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
506 .unwrap();
507
508 let save_result = model.save(&file_path);
510 assert!(save_result.is_ok());
511
512 let loaded_model = LinearRegression::<f64>::load(&file_path);
514 assert!(loaded_model.is_ok());
515 let loaded = loaded_model.unwrap();
516
517 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
519 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
520 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
521 assert_eq!(loaded.n, model.n);
522 }
523
524 #[test]
525 fn test_save_load_binary() {
526 let dir = tempdir().unwrap();
528 let file_path = dir.path().join("model.bin");
529
530 let mut model = LinearRegression::<f64>::new();
532 model
533 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
534 .unwrap();
535
536 let save_result = model.save_binary(&file_path);
538 assert!(save_result.is_ok());
539
540 let loaded_model = LinearRegression::<f64>::load_binary(&file_path);
542 assert!(loaded_model.is_ok());
543 let loaded = loaded_model.unwrap();
544
545 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
547 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
548 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
549 assert_eq!(loaded.n, model.n);
550 }
551
552 #[test]
553 fn test_json_serialization() {
554 let mut model = LinearRegression::<f64>::new();
556 model
557 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
558 .unwrap();
559
560 let json_result = model.to_json();
562 assert!(json_result.is_ok());
563 let json_str = json_result.unwrap();
564
565 let loaded_model = LinearRegression::<f64>::from_json(&json_str);
567 assert!(loaded_model.is_ok());
568 let loaded = loaded_model.unwrap();
569
570 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
572 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
573 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
574 assert_eq!(loaded.n, model.n);
575 }
576
577 #[test]
578 fn test_load_nonexistent_file() {
579 let result = LinearRegression::<f64>::load("/nonexistent/path/model.json");
581 assert!(result.is_err());
582 }
583
584 #[test]
585 fn test_load_binary_nonexistent_file() {
586 let result = LinearRegression::<f64>::load_binary("/nonexistent/path/model.bin");
588 assert!(result.is_err());
589 }
590
591 #[test]
592 fn test_from_json_invalid_json() {
593 let invalid_json = "{invalid json}";
595 let result = LinearRegression::<f64>::from_json(invalid_json);
596 assert!(result.is_err());
597 }
598
599 #[test]
600 fn test_predict_when_not_fitted() {
601 let model = LinearRegression::<f64>::new();
603 let result = model.predict(5.0);
604 assert!(result.is_err());
605 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
606 }
607
608 #[test]
609 fn test_save_invalid_path() {
610 let mut model = LinearRegression::<f64>::new();
612 model.fit(&[1.0, 2.0], &[2.0, 4.0]).unwrap();
613
614 let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
615 let result = model.save(invalid_path);
616 assert!(
617 result.is_err(),
618 "Saving to invalid path should return error"
619 );
620 }
621
622 #[test]
623 fn test_fit_standard_error_n_less_than_or_equal_two() {
624 let mut model = LinearRegression::<f64>::new();
626 let x = vec![1.0, 2.0];
627 let y = vec![2.0, 4.0];
628 model.fit(&x, &y).unwrap();
629
630 assert_eq!(model.standard_error, 0.0);
632 }
633
634 #[test]
635 fn test_fit_standard_error_n_greater_than_two() {
636 let mut model = LinearRegression::<f64>::new();
638 let x = vec![1.0, 2.0, 3.0];
639 let y = vec![2.0, 4.0, 6.0];
640 model.fit(&x, &y).unwrap();
641
642 assert!(model.standard_error >= 0.0);
644 }
645
646 #[test]
647 fn test_confidence_interval_n_less_than_three() {
648 let mut model = LinearRegression::<f64>::new();
650 let x = vec![1.0, 2.0];
651 let y = vec![2.0, 4.0];
652 model.fit(&x, &y).unwrap();
653
654 let result = model.confidence_interval(3.0, 0.95);
655 assert!(result.is_err());
656 assert!(matches!(
657 result.unwrap_err(),
658 StatsError::InvalidInput { .. }
659 ));
660 }
661
662 #[test]
663 fn test_confidence_interval_unsupported_level() {
664 let mut model = LinearRegression::<f64>::new();
666 let x = vec![1.0, 2.0, 3.0, 4.0];
667 let y = vec![2.0, 4.0, 6.0, 8.0];
668 model.fit(&x, &y).unwrap();
669
670 let result = model.confidence_interval(3.0, 0.85);
671 assert!(result.is_err());
672 assert!(matches!(
673 result.unwrap_err(),
674 StatsError::InvalidParameter { .. }
675 ));
676 }
677
678 #[test]
679 fn test_confidence_interval_supported_levels() {
680 let mut model = LinearRegression::<f64>::new();
682 let x = vec![1.0, 2.0, 3.0, 4.0];
683 let y = vec![2.0, 4.0, 6.0, 8.0];
684 model.fit(&x, &y).unwrap();
685
686 for level in [0.90, 0.95, 0.99] {
687 let result = model.confidence_interval(3.0, level);
688 assert!(
689 result.is_ok(),
690 "Confidence level {} should be supported",
691 level
692 );
693 let (lower, upper) = result.unwrap();
694 assert!(lower <= upper, "Lower bound should be <= upper bound");
695 }
696 }
697
698 #[test]
699 fn test_correlation_coefficient_positive_slope() {
700 let mut model = LinearRegression::<f64>::new();
702 let x = vec![1.0, 2.0, 3.0];
703 let y = vec![2.0, 4.0, 6.0];
704 model.fit(&x, &y).unwrap();
705
706 let r = model.correlation_coefficient().unwrap();
707 assert!(
708 r >= 0.0,
709 "Correlation should be positive for positive slope"
710 );
711 }
712
713 #[test]
714 fn test_correlation_coefficient_negative_slope() {
715 let mut model = LinearRegression::<f64>::new();
717 let x = vec![1.0, 2.0, 3.0];
718 let y = vec![6.0, 4.0, 2.0];
719 model.fit(&x, &y).unwrap();
720
721 let r = model.correlation_coefficient().unwrap();
722 assert!(
723 r <= 0.0,
724 "Correlation should be negative for negative slope"
725 );
726 }
727
728 #[test]
729 fn test_correlation_coefficient_not_fitted() {
730 let model = LinearRegression::<f64>::new();
732 let result = model.correlation_coefficient();
733 assert!(result.is_err());
734 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
735 }
736
737 #[test]
738 fn test_predict_many_not_fitted() {
739 let model = LinearRegression::<f64>::new();
741 let result = model.predict_many(&[1.0, 2.0, 3.0]);
742 assert!(result.is_err());
743 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
744 }
745
746 #[test]
747 fn test_predict_many_success() {
748 let mut model = LinearRegression::<f64>::new();
750 model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
751
752 let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
753 assert_eq!(predictions.len(), 2);
754 assert!((predictions[0] - 8.0).abs() < 1e-10);
755 assert!((predictions[1] - 10.0).abs() < 1e-10);
756 }
757
758 #[test]
759 fn test_load_invalid_json() {
760 let dir = tempdir().unwrap();
762 let file_path = dir.path().join("invalid.json");
763
764 std::fs::write(&file_path, "invalid json content").unwrap();
766
767 let result = LinearRegression::<f64>::load(&file_path);
768 assert!(result.is_err(), "Loading invalid JSON should return error");
769 }
770
771 #[test]
772 fn test_from_json_invalid() {
773 let invalid_json = "not valid json";
775 let result = LinearRegression::<f64>::from_json(invalid_json);
776 assert!(
777 result.is_err(),
778 "Deserializing invalid JSON should return error"
779 );
780 }
781}