1use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use rayon::prelude::*;
6use serde::{Deserialize, Serialize};
7use std::fmt::Debug;
8use std::fs::File;
9use std::io::{self};
10use std::path::Path;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct LinearRegression<T = f64>
15where
16 T: Float + Debug + Default + Serialize,
17{
18 pub slope: T,
20 pub intercept: T,
22 pub r_squared: T,
24 pub standard_error: T,
26 pub n: usize,
28}
29
30impl<T> Default for LinearRegression<T>
31where
32 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
33{
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl<T> LinearRegression<T>
40where
41 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
42{
43 pub fn new() -> Self {
45 Self {
46 slope: T::zero(),
47 intercept: T::zero(),
48 r_squared: T::zero(),
49 standard_error: T::zero(),
50 n: 0,
51 }
52 }
53
54 pub fn fit<U, V>(&mut self, x_values: &[U], y_values: &[V]) -> StatsResult<()>
69 where
70 U: NumCast + Copy,
71 V: NumCast + Copy,
72 {
73 if x_values.len() != y_values.len() {
75 return Err(StatsError::dimension_mismatch(format!(
76 "X and Y arrays must have the same length (got {} and {})",
77 x_values.len(),
78 y_values.len()
79 )));
80 }
81
82 if x_values.is_empty() {
83 return Err(StatsError::empty_data(
84 "Cannot fit regression with empty arrays",
85 ));
86 }
87
88 let n = x_values.len();
89 self.n = n;
90
91 let x_cast: Vec<T> = x_values
93 .iter()
94 .enumerate()
95 .map(|(i, &x)| {
96 T::from(x).ok_or_else(|| {
97 StatsError::conversion_error(format!(
98 "Failed to cast X value at index {} to type T",
99 i
100 ))
101 })
102 })
103 .collect::<StatsResult<Vec<T>>>()?;
104
105 let y_cast: Vec<T> = y_values
106 .iter()
107 .enumerate()
108 .map(|(i, &y)| {
109 T::from(y).ok_or_else(|| {
110 StatsError::conversion_error(format!(
111 "Failed to cast Y value at index {} to type T",
112 i
113 ))
114 })
115 })
116 .collect::<StatsResult<Vec<T>>>()?;
117
118 let n_as_t = T::from(n).ok_or_else(|| {
120 StatsError::conversion_error(format!("Failed to convert {} to type T", n))
121 })?;
122 let x_mean = x_cast.iter().fold(T::zero(), |acc, &x| acc + x) / n_as_t;
123 let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
124
125 let mut sum_xy = T::zero();
127 let mut sum_xx = T::zero();
128 let mut sum_yy = T::zero();
129
130 for i in 0..n {
131 let x_diff = x_cast[i] - x_mean;
132 let y_diff = y_cast[i] - y_mean;
133
134 sum_xy = sum_xy + (x_diff * y_diff);
135 sum_xx = sum_xx + (x_diff * x_diff);
136 sum_yy = sum_yy + (y_diff * y_diff);
137 }
138
139 if sum_xx == T::zero() {
141 return Err(StatsError::invalid_parameter(
142 "No variance in X values, cannot fit regression line",
143 ));
144 }
145
146 self.slope = sum_xy / sum_xx;
148 self.intercept = y_mean - (self.slope * x_mean);
149
150 self.r_squared = (sum_xy * sum_xy) / (sum_xx * sum_yy);
152
153 let mut sum_squared_residuals = T::zero();
155 for i in 0..n {
156 let predicted = self.predict_t(x_cast[i]);
157 let residual = y_cast[i] - predicted;
158 sum_squared_residuals = sum_squared_residuals + (residual * residual);
159 }
160
161 if n > 2 {
163 let two = T::from(2)
164 .ok_or_else(|| StatsError::conversion_error("Failed to convert 2 to type T"))?;
165 let n_minus_two = n_as_t - two;
166 self.standard_error = (sum_squared_residuals / n_minus_two).sqrt();
167 } else {
168 self.standard_error = T::zero();
169 }
170
171 Ok(())
172 }
173
174 fn predict_t(&self, x: T) -> T {
176 self.intercept + (self.slope * x)
177 }
178
179 pub fn predict<U>(&self, x: U) -> StatsResult<T>
202 where
203 U: NumCast + Copy,
204 {
205 if self.n == 0 {
206 return Err(StatsError::not_fitted(
207 "Model has not been fitted. Call fit() before predicting.",
208 ));
209 }
210
211 let x_cast: T = T::from(x)
212 .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
213
214 Ok(self.predict_t(x_cast))
215 }
216
217 pub fn predict_many<U>(&self, x_values: &[U]) -> StatsResult<Vec<T>>
240 where
241 U: NumCast + Copy + Send + Sync,
242 T: Send + Sync,
243 {
244 x_values.par_iter().map(|&x| self.predict(x)).collect()
245 }
246
247 pub fn confidence_interval<U>(&self, x: U, confidence_level: f64) -> StatsResult<(T, T)>
261 where
262 U: NumCast + Copy,
263 {
264 if self.n < 3 {
265 return Err(StatsError::invalid_input(
266 "Need at least 3 data points to calculate confidence interval",
267 ));
268 }
269
270 let x_cast: T = T::from(x)
271 .ok_or_else(|| StatsError::conversion_error("Failed to convert x value to type T"))?;
272
273 let z_score: T = match confidence_level {
276 0.90 => T::from(1.645).ok_or_else(|| {
277 StatsError::conversion_error("Failed to convert z-score 1.645 to type T")
278 })?,
279 0.95 => T::from(1.96).ok_or_else(|| {
280 StatsError::conversion_error("Failed to convert z-score 1.96 to type T")
281 })?,
282 0.99 => T::from(2.576).ok_or_else(|| {
283 StatsError::conversion_error("Failed to convert z-score 2.576 to type T")
284 })?,
285 _ => {
286 return Err(StatsError::invalid_parameter(format!(
287 "Unsupported confidence level: {}. Supported values: 0.90, 0.95, 0.99",
288 confidence_level
289 )));
290 }
291 };
292
293 let predicted = self.predict_t(x_cast);
294 let margin = z_score * self.standard_error;
295
296 Ok((predicted - margin, predicted + margin))
297 }
298
299 pub fn correlation_coefficient(&self) -> StatsResult<T> {
321 if self.n == 0 {
322 return Err(StatsError::not_fitted(
323 "Model has not been fitted. Call fit() before getting correlation coefficient.",
324 ));
325 }
326 let r = self.r_squared.sqrt();
327 Ok(if self.slope >= T::zero() { r } else { -r })
328 }
329
330 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
338 let file = File::create(path)?;
339 serde_json::to_writer(file, self).map_err(io::Error::other)
341 }
342
343 pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
351 let file = File::create(path)?;
352 bincode::serialize_into(file, self).map_err(io::Error::other)
354 }
355
356 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
364 let file = File::open(path)?;
365 serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
367 }
368
369 pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
377 let file = File::open(path)?;
378 bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
380 }
381
382 pub fn to_json(&self) -> Result<String, String> {
387 serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
388 }
389
390 pub fn from_json(json: &str) -> Result<Self, String> {
398 serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::utils::approx_equal;
406 use tempfile::tempdir;
407
408 #[test]
409 fn test_simple_regression_f64() {
410 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
411 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
412
413 let mut model = LinearRegression::<f64>::new();
414 let result = model.fit(&x, &y);
415
416 assert!(result.is_ok());
417 assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
418 assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
419 assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
420 }
421
422 #[test]
423 fn test_simple_regression_f32() {
424 let x = vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32];
425 let y = vec![2.0f32, 4.0f32, 6.0f32, 8.0f32, 10.0f32];
426
427 let mut model = LinearRegression::<f32>::new();
428 let result = model.fit(&x, &y);
429
430 assert!(result.is_ok());
431 assert!(approx_equal(model.slope, 2.0f32, Some(1e-6)));
432 assert!(approx_equal(model.intercept, 0.0f32, Some(1e-6)));
433 assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-6)));
434 }
435
436 #[test]
437 fn test_integer_data() {
438 let x = vec![1, 2, 3, 4, 5];
439 let y = vec![2, 4, 6, 8, 10];
440
441 let mut model = LinearRegression::<f64>::new();
442 let result = model.fit(&x, &y);
443
444 assert!(result.is_ok());
445 assert!(approx_equal(model.slope, 2.0, Some(1e-6)));
446 assert!(approx_equal(model.intercept, 0.0, Some(1e-6)));
447 assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
448 }
449
450 #[test]
451 fn test_mixed_types() {
452 let x = vec![1u32, 2u32, 3u32, 4u32, 5u32];
453 let y = vec![2.1, 3.9, 6.2, 7.8, 10.1];
454
455 let mut model = LinearRegression::<f64>::new();
456 let result = model.fit(&x, &y);
457
458 assert!(result.is_ok());
459 assert!(model.slope > 1.9 && model.slope < 2.1);
460 assert!(model.intercept > -0.1 && model.intercept < 0.1);
461 assert!(model.r_squared > 0.99);
462 }
463
464 #[test]
465 fn test_prediction() {
466 let x = vec![1, 2, 3, 4, 5];
467 let y = vec![2, 4, 6, 8, 10];
468
469 let mut model = LinearRegression::<f64>::new();
470 model.fit(&x, &y).unwrap();
471
472 assert!(approx_equal(model.predict(6u32).unwrap(), 12.0, Some(1e-6)));
473 assert!(approx_equal(model.predict(0i32).unwrap(), 0.0, Some(1e-6)));
474 }
475
476 #[test]
477 fn test_invalid_inputs() {
478 let x = vec![1, 2, 3];
479 let y = vec![2, 4];
480
481 let mut model = LinearRegression::<f64>::new();
482 let result = model.fit(&x, &y);
483
484 assert!(result.is_err());
485 }
486
487 #[test]
488 fn test_constant_x() {
489 let x = vec![1, 1, 1];
490 let y = vec![2, 3, 4];
491
492 let mut model = LinearRegression::<f64>::new();
493 let result = model.fit(&x, &y);
494
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_save_load_json() {
500 let dir = tempdir().unwrap();
502 let file_path = dir.path().join("model.json");
503
504 let mut model = LinearRegression::<f64>::new();
506 model
507 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
508 .unwrap();
509
510 let save_result = model.save(&file_path);
512 assert!(save_result.is_ok());
513
514 let loaded_model = LinearRegression::<f64>::load(&file_path);
516 assert!(loaded_model.is_ok());
517 let loaded = loaded_model.unwrap();
518
519 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
521 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
522 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
523 assert_eq!(loaded.n, model.n);
524 }
525
526 #[test]
527 fn test_save_load_binary() {
528 let dir = tempdir().unwrap();
530 let file_path = dir.path().join("model.bin");
531
532 let mut model = LinearRegression::<f64>::new();
534 model
535 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
536 .unwrap();
537
538 let save_result = model.save_binary(&file_path);
540 assert!(save_result.is_ok());
541
542 let loaded_model = LinearRegression::<f64>::load_binary(&file_path);
544 assert!(loaded_model.is_ok());
545 let loaded = loaded_model.unwrap();
546
547 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
549 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
550 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
551 assert_eq!(loaded.n, model.n);
552 }
553
554 #[test]
555 fn test_json_serialization() {
556 let mut model = LinearRegression::<f64>::new();
558 model
559 .fit(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0, 4.0, 6.0, 8.0, 10.0])
560 .unwrap();
561
562 let json_result = model.to_json();
564 assert!(json_result.is_ok());
565 let json_str = json_result.unwrap();
566
567 let loaded_model = LinearRegression::<f64>::from_json(&json_str);
569 assert!(loaded_model.is_ok());
570 let loaded = loaded_model.unwrap();
571
572 assert!(approx_equal(loaded.slope, model.slope, Some(1e-6)));
574 assert!(approx_equal(loaded.intercept, model.intercept, Some(1e-6)));
575 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
576 assert_eq!(loaded.n, model.n);
577 }
578
579 #[test]
580 fn test_load_nonexistent_file() {
581 let result = LinearRegression::<f64>::load("/nonexistent/path/model.json");
583 assert!(result.is_err());
584 }
585
586 #[test]
587 fn test_load_binary_nonexistent_file() {
588 let result = LinearRegression::<f64>::load_binary("/nonexistent/path/model.bin");
590 assert!(result.is_err());
591 }
592
593 #[test]
594 fn test_from_json_invalid_json() {
595 let invalid_json = "{invalid json}";
597 let result = LinearRegression::<f64>::from_json(invalid_json);
598 assert!(result.is_err());
599 }
600
601 #[test]
602 fn test_predict_when_not_fitted() {
603 let model = LinearRegression::<f64>::new();
605 let result = model.predict(5.0);
606 assert!(result.is_err());
607 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
608 }
609
610 #[test]
611 fn test_save_invalid_path() {
612 let mut model = LinearRegression::<f64>::new();
614 model.fit(&[1.0, 2.0], &[2.0, 4.0]).unwrap();
615
616 let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
617 let result = model.save(invalid_path);
618 assert!(
619 result.is_err(),
620 "Saving to invalid path should return error"
621 );
622 }
623
624 #[test]
625 fn test_fit_standard_error_n_less_than_or_equal_two() {
626 let mut model = LinearRegression::<f64>::new();
628 let x = vec![1.0, 2.0];
629 let y = vec![2.0, 4.0];
630 model.fit(&x, &y).unwrap();
631
632 assert_eq!(model.standard_error, 0.0);
634 }
635
636 #[test]
637 fn test_fit_standard_error_n_greater_than_two() {
638 let mut model = LinearRegression::<f64>::new();
640 let x = vec![1.0, 2.0, 3.0];
641 let y = vec![2.0, 4.0, 6.0];
642 model.fit(&x, &y).unwrap();
643
644 assert!(model.standard_error >= 0.0);
646 }
647
648 #[test]
649 fn test_confidence_interval_n_less_than_three() {
650 let mut model = LinearRegression::<f64>::new();
652 let x = vec![1.0, 2.0];
653 let y = vec![2.0, 4.0];
654 model.fit(&x, &y).unwrap();
655
656 let result = model.confidence_interval(3.0, 0.95);
657 assert!(result.is_err());
658 assert!(matches!(
659 result.unwrap_err(),
660 StatsError::InvalidInput { .. }
661 ));
662 }
663
664 #[test]
665 fn test_confidence_interval_unsupported_level() {
666 let mut model = LinearRegression::<f64>::new();
668 let x = vec![1.0, 2.0, 3.0, 4.0];
669 let y = vec![2.0, 4.0, 6.0, 8.0];
670 model.fit(&x, &y).unwrap();
671
672 let result = model.confidence_interval(3.0, 0.85);
673 assert!(result.is_err());
674 assert!(matches!(
675 result.unwrap_err(),
676 StatsError::InvalidParameter { .. }
677 ));
678 }
679
680 #[test]
681 fn test_confidence_interval_supported_levels() {
682 let mut model = LinearRegression::<f64>::new();
684 let x = vec![1.0, 2.0, 3.0, 4.0];
685 let y = vec![2.0, 4.0, 6.0, 8.0];
686 model.fit(&x, &y).unwrap();
687
688 for level in [0.90, 0.95, 0.99] {
689 let result = model.confidence_interval(3.0, level);
690 assert!(
691 result.is_ok(),
692 "Confidence level {} should be supported",
693 level
694 );
695 let (lower, upper) = result.unwrap();
696 assert!(lower <= upper, "Lower bound should be <= upper bound");
697 }
698 }
699
700 #[test]
701 fn test_correlation_coefficient_positive_slope() {
702 let mut model = LinearRegression::<f64>::new();
704 let x = vec![1.0, 2.0, 3.0];
705 let y = vec![2.0, 4.0, 6.0];
706 model.fit(&x, &y).unwrap();
707
708 let r = model.correlation_coefficient().unwrap();
709 assert!(
710 r >= 0.0,
711 "Correlation should be positive for positive slope"
712 );
713 }
714
715 #[test]
716 fn test_correlation_coefficient_negative_slope() {
717 let mut model = LinearRegression::<f64>::new();
719 let x = vec![1.0, 2.0, 3.0];
720 let y = vec![6.0, 4.0, 2.0];
721 model.fit(&x, &y).unwrap();
722
723 let r = model.correlation_coefficient().unwrap();
724 assert!(
725 r <= 0.0,
726 "Correlation should be negative for negative slope"
727 );
728 }
729
730 #[test]
731 fn test_correlation_coefficient_not_fitted() {
732 let model = LinearRegression::<f64>::new();
734 let result = model.correlation_coefficient();
735 assert!(result.is_err());
736 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
737 }
738
739 #[test]
740 fn test_predict_many_not_fitted() {
741 let model = LinearRegression::<f64>::new();
743 let result = model.predict_many(&[1.0, 2.0, 3.0]);
744 assert!(result.is_err());
745 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
746 }
747
748 #[test]
749 fn test_predict_many_success() {
750 let mut model = LinearRegression::<f64>::new();
752 model.fit(&[1.0, 2.0, 3.0], &[2.0, 4.0, 6.0]).unwrap();
753
754 let predictions = model.predict_many(&[4.0, 5.0]).unwrap();
755 assert_eq!(predictions.len(), 2);
756 assert!((predictions[0] - 8.0).abs() < 1e-10);
757 assert!((predictions[1] - 10.0).abs() < 1e-10);
758 }
759
760 #[test]
761 fn test_load_invalid_json() {
762 let dir = tempdir().unwrap();
764 let file_path = dir.path().join("invalid.json");
765
766 std::fs::write(&file_path, "invalid json content").unwrap();
768
769 let result = LinearRegression::<f64>::load(&file_path);
770 assert!(result.is_err(), "Loading invalid JSON should return error");
771 }
772
773 #[test]
774 fn test_from_json_invalid() {
775 let invalid_json = "not valid json";
777 let result = LinearRegression::<f64>::from_json(invalid_json);
778 assert!(
779 result.is_err(),
780 "Deserializing invalid JSON should return error"
781 );
782 }
783}