1use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{self};
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct MultipleLinearRegression<T = f64>
14where
15 T: Float + Debug + Default + Serialize,
16{
17 pub coefficients: Vec<T>,
19 pub r_squared: T,
21 pub adjusted_r_squared: T,
23 pub standard_error: T,
25 pub n: usize,
27 pub p: usize,
29}
30
31impl<T> Default for MultipleLinearRegression<T>
32where
33 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
34{
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl<T> MultipleLinearRegression<T>
41where
42 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
43{
44 pub fn new() -> Self {
46 Self {
47 coefficients: Vec::new(),
48 r_squared: T::zero(),
49 adjusted_r_squared: T::zero(),
50 standard_error: T::zero(),
51 n: 0,
52 p: 0,
53 }
54 }
55
56 pub fn fit<U, V>(&mut self, x_values: &[Vec<U>], y_values: &[V]) -> StatsResult<()>
72 where
73 U: NumCast + Copy,
74 V: NumCast + Copy,
75 {
76 if x_values.is_empty() || y_values.is_empty() {
78 return Err(StatsError::empty_data(
79 "Cannot fit regression with empty arrays",
80 ));
81 }
82
83 if x_values.len() != y_values.len() {
84 return Err(StatsError::dimension_mismatch(format!(
85 "Number of observations in X and Y must match (got {} and {})",
86 x_values.len(),
87 y_values.len()
88 )));
89 }
90
91 self.n = x_values.len();
92
93 if x_values.is_empty() {
95 return Err(StatsError::empty_data("X values array is empty"));
96 }
97
98 self.p = x_values[0].len();
99
100 for (i, row) in x_values.iter().enumerate() {
101 if row.len() != self.p {
102 return Err(StatsError::invalid_input(format!(
103 "All rows in X must have the same number of features (row {} has {} features, expected {})",
104 i,
105 row.len(),
106 self.p
107 )));
108 }
109 }
110
111 let mut x_cast: Vec<Vec<T>> = Vec::with_capacity(self.n);
113 for (row_idx, row) in x_values.iter().enumerate() {
114 let row_cast: StatsResult<Vec<T>> = row
115 .iter()
116 .enumerate()
117 .map(|(col_idx, &x)| {
118 T::from(x).ok_or_else(|| {
119 StatsError::conversion_error(format!(
120 "Failed to cast X value at row {}, column {} to type T",
121 row_idx, col_idx
122 ))
123 })
124 })
125 .collect();
126 x_cast.push(row_cast?);
127 }
128
129 let y_cast: Vec<T> = y_values
130 .iter()
131 .enumerate()
132 .map(|(i, &y)| {
133 T::from(y).ok_or_else(|| {
134 StatsError::conversion_error(format!(
135 "Failed to cast Y value at index {} to type T",
136 i
137 ))
138 })
139 })
140 .collect::<StatsResult<Vec<T>>>()?;
141
142 let mut augmented_x = Vec::with_capacity(self.n);
144 for row in &x_cast {
145 let mut augmented_row = Vec::with_capacity(self.p + 1);
146 augmented_row.push(T::one()); augmented_row.extend_from_slice(row);
148 augmented_x.push(augmented_row);
149 }
150
151 let xt_x = self.matrix_multiply_transpose(&augmented_x, &augmented_x);
153
154 let xt_y = self.vector_multiply_transpose(&augmented_x, &y_cast);
156
157 match self.solve_linear_system(&xt_x, &xt_y) {
159 Ok(solution) => {
160 self.coefficients = solution;
161 }
162 Err(e) => return Err(e),
163 }
164
165 let n_as_t = T::from(self.n).ok_or_else(|| {
167 StatsError::conversion_error(format!("Failed to convert {} to type T", self.n))
168 })?;
169 let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
170
171 let mut ss_total = T::zero();
172 let mut ss_residual = T::zero();
173
174 for i in 0..self.n {
175 let predicted = self.predict_t(&x_cast[i]);
176 let residual = y_cast[i] - predicted;
177
178 ss_residual = ss_residual + (residual * residual);
179 let diff = y_cast[i] - y_mean;
180 ss_total = ss_total + (diff * diff);
181 }
182
183 if ss_total > T::zero() {
185 self.r_squared = T::one() - (ss_residual / ss_total);
186
187 if self.n > self.p + 1 {
189 let n_minus_1 = T::from(self.n - 1).ok_or_else(|| {
190 StatsError::conversion_error(format!(
191 "Failed to convert {} to type T",
192 self.n - 1
193 ))
194 })?;
195 let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
196 StatsError::conversion_error(format!(
197 "Failed to convert {} to type T",
198 self.n - self.p - 1
199 ))
200 })?;
201
202 self.adjusted_r_squared =
203 T::one() - ((T::one() - self.r_squared) * n_minus_1 / n_minus_p_minus_1);
204 }
205 }
206
207 if self.n > self.p + 1 {
209 let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
210 StatsError::conversion_error(format!(
211 "Failed to convert {} to type T",
212 self.n - self.p - 1
213 ))
214 })?;
215 self.standard_error = (ss_residual / n_minus_p_minus_1).sqrt();
216 }
217
218 Ok(())
219 }
220
221 fn predict_t(&self, x: &[T]) -> T {
223 if x.len() != self.p || self.coefficients.is_empty() {
224 return T::nan();
225 }
226
227 let mut result = self.coefficients[0];
229
230 for (i, &xi) in x.iter().enumerate().take(self.p) {
232 result = result + (self.coefficients[i + 1] * xi);
233 }
234
235 result
236 }
237
238 pub fn predict<U>(&self, x: &[U]) -> StatsResult<T>
268 where
269 U: NumCast + Copy,
270 {
271 if self.coefficients.is_empty() {
272 return Err(StatsError::not_fitted(
273 "Model has not been fitted. Call fit() before predicting.",
274 ));
275 }
276
277 if x.len() != self.p {
278 return Err(StatsError::dimension_mismatch(format!(
279 "Expected {} features, but got {}",
280 self.p,
281 x.len()
282 )));
283 }
284
285 let x_cast: StatsResult<Vec<T>> = x
287 .iter()
288 .enumerate()
289 .map(|(i, &val)| {
290 T::from(val).ok_or_else(|| {
291 StatsError::conversion_error(format!(
292 "Failed to convert feature value at index {} to type T",
293 i
294 ))
295 })
296 })
297 .collect();
298
299 Ok(self.predict_t(&x_cast?))
300 }
301
302 pub fn predict_many<U>(&self, x_values: &[Vec<U>]) -> StatsResult<Vec<T>>
332 where
333 U: NumCast + Copy,
334 {
335 x_values.iter().map(|x| self.predict(x)).collect()
336 }
337
338 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
346 let file = File::create(path)?;
347 serde_json::to_writer(file, self).map_err(io::Error::other)
349 }
350
351 pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
359 let file = File::create(path)?;
360 bincode::serialize_into(file, self).map_err(io::Error::other)
362 }
363
364 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
372 let file = File::open(path)?;
373 serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
375 }
376
377 pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
385 let file = File::open(path)?;
386 bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
388 }
389
390 pub fn to_json(&self) -> Result<String, String> {
395 serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
396 }
397
398 pub fn from_json(json: &str) -> Result<Self, String> {
406 serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
407 }
408
409 fn matrix_multiply_transpose(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Vec<Vec<T>> {
413 let a_rows = a.len();
414 let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
415 let b_cols = if !b.is_empty() { b[0].len() } else { 0 };
416
417 let mut result = vec![vec![T::zero(); b_cols]; a_cols];
419
420 for k in 0..a_rows {
422 let a_row = &a[k];
423 let b_row = &b[k];
424 for i in 0..a_cols {
425 let a_ki = a_row[i];
426 let result_row = &mut result[i];
427 for j in 0..b_cols {
428 result_row[j] = result_row[j] + (a_ki * b_row[j]);
429 }
430 }
431 }
432
433 result
434 }
435
436 fn vector_multiply_transpose(&self, a: &[Vec<T>], y: &[T]) -> Vec<T> {
438 let a_rows = a.len();
439 let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
440
441 let mut result = vec![T::zero(); a_cols];
442
443 for (i, result_item) in result.iter_mut().enumerate().take(a_cols) {
444 let mut sum = T::zero();
445 for j in 0..a_rows {
446 sum = sum + (a[j][i] * y[j]);
447 }
448 *result_item = sum;
449 }
450
451 result
452 }
453
454 fn solve_linear_system(&self, a: &[Vec<T>], b: &[T]) -> StatsResult<Vec<T>> {
456 let n = a.len();
457 if n == 0 || a[0].len() != n || b.len() != n {
458 return Err(StatsError::dimension_mismatch(format!(
459 "Invalid matrix dimensions for linear system solving: A is {}x{}, b has {} elements",
460 n,
461 if n > 0 { a[0].len() } else { 0 },
462 b.len()
463 )));
464 }
465
466 let mut aug: Vec<Vec<T>> = Vec::with_capacity(n);
468 for i in 0..n {
469 let mut row = Vec::with_capacity(n + 1);
470 row.extend_from_slice(&a[i]);
471 row.push(b[i]);
472 aug.push(row);
473 }
474
475 for i in 0..n {
477 let mut max_row = i;
479 let mut max_val = aug[i][i].abs();
480
481 #[allow(clippy::needless_range_loop)]
482 for j in (i + 1)..n {
483 let abs_val = aug[j][i].abs();
484 if abs_val > max_val {
485 max_row = j;
486 max_val = abs_val;
487 }
488 }
489
490 let epsilon: T = T::from(1e-10).ok_or_else(|| {
491 StatsError::conversion_error("Failed to convert epsilon (1e-10) to type T")
492 })?;
493 if max_val < epsilon {
494 return Err(StatsError::mathematical_error(
495 "Matrix is singular or near-singular, cannot solve linear system",
496 ));
497 }
498
499 if max_row != i {
501 aug.swap(i, max_row);
502 }
503
504 for j in (i + 1)..n {
506 let factor = aug[j][i] / aug[i][i];
507
508 for k in i..(n + 1) {
509 aug[j][k] = aug[j][k] - (factor * aug[i][k]);
510 }
511 }
512 }
513
514 let mut x = vec![T::zero(); n];
516 for i in (0..n).rev() {
517 let mut sum = aug[i][n];
518
519 #[allow(clippy::needless_range_loop)]
520 for j in (i + 1)..n {
521 sum = sum - (aug[i][j] * x[j]);
522 }
523
524 x[i] = sum / aug[i][i];
525 }
526
527 Ok(x)
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use crate::utils::approx_equal;
535 use tempfile::tempdir;
536
537 #[test]
538 fn test_simple_multi_regression_f64() {
539 let x = vec![
541 vec![1.0, 2.0],
542 vec![2.0, 1.0],
543 vec![3.0, 3.0],
544 vec![4.0, 2.0],
545 ];
546 let y = vec![9.0, 8.0, 16.0, 15.0];
547
548 let mut model = MultipleLinearRegression::<f64>::new();
549 let result = model.fit(&x, &y);
550
551 assert!(result.is_ok());
552 assert!(model.coefficients.len() == 3);
553 assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
557 }
558
559 #[test]
560 fn test_simple_multi_regression_f32() {
561 let x = vec![
563 vec![1.0f32, 2.0f32],
564 vec![2.0f32, 1.0f32],
565 vec![3.0f32, 3.0f32],
566 vec![4.0f32, 2.0f32],
567 ];
568 let y = vec![9.0f32, 8.0f32, 16.0f32, 15.0f32];
569
570 let mut model = MultipleLinearRegression::<f32>::new();
571 let result = model.fit(&x, &y);
572
573 assert!(result.is_ok());
574 assert!(model.coefficients.len() == 3);
575 assert!(approx_equal(model.coefficients[0], 1.0f32, Some(1e-4))); assert!(approx_equal(model.coefficients[1], 2.0f32, Some(1e-4))); assert!(approx_equal(model.coefficients[2], 3.0f32, Some(1e-4))); assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-4)));
579 }
580
581 #[test]
582 fn test_integer_data() {
583 let x = vec![
585 vec![1u32, 2u32],
586 vec![2u32, 1u32],
587 vec![3u32, 3u32],
588 vec![4u32, 2u32],
589 ];
590 let y = vec![9i32, 8i32, 16i32, 15i32];
591
592 let mut model = MultipleLinearRegression::<f64>::new();
593 let result = model.fit(&x, &y);
594
595 assert!(result.is_ok());
596 assert!(model.coefficients.len() == 3);
597 assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
601 }
602
603 #[test]
604 fn test_prediction() {
605 let x = vec![vec![1, 2], vec![2, 1], vec![3, 3], vec![4, 2]];
607 let y = vec![9, 8, 16, 15];
608
609 let mut model = MultipleLinearRegression::<f64>::new();
610 model.fit(&x, &y).unwrap();
611
612 assert!(approx_equal(
614 model.predict(&[5u32, 4u32]).unwrap(),
615 23.0,
616 Some(1e-6)
617 ));
618 }
619
620 #[test]
621 fn test_prediction_many() {
622 let x = vec![vec![1, 2], vec![2, 1], vec![3, 3]];
623 let y = vec![9, 8, 16];
624
625 let mut model = MultipleLinearRegression::<f64>::new();
626 model.fit(&x, &y).unwrap();
627
628 let new_x = vec![vec![1u32, 2u32], vec![5u32, 4u32]];
629
630 let predictions = model.predict_many(&new_x).unwrap();
631 assert_eq!(predictions.len(), 2);
632 assert!(approx_equal(predictions[0], 9.0, Some(1e-6)));
633 assert!(approx_equal(predictions[1], 23.0, Some(1e-6)));
634 }
635
636 #[test]
637 fn test_save_load_json() {
638 let dir = tempdir().unwrap();
640 let file_path = dir.path().join("model.json");
641
642 let x = vec![
644 vec![1.0, 2.0],
645 vec![2.0, 1.0],
646 vec![3.0, 3.0],
647 vec![4.0, 2.0],
648 ];
649 let y = vec![9.0, 8.0, 16.0, 15.0];
650
651 let mut model = MultipleLinearRegression::<f64>::new();
652 model.fit(&x, &y).unwrap();
653
654 let save_result = model.save(&file_path);
656 assert!(save_result.is_ok());
657
658 let loaded_model = MultipleLinearRegression::<f64>::load(&file_path);
660 assert!(loaded_model.is_ok());
661 let loaded = loaded_model.unwrap();
662
663 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
665 for i in 0..model.coefficients.len() {
666 assert!(approx_equal(
667 loaded.coefficients[i],
668 model.coefficients[i],
669 Some(1e-6)
670 ));
671 }
672 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
673 assert_eq!(loaded.n, model.n);
674 assert_eq!(loaded.p, model.p);
675 }
676
677 #[test]
678 fn test_save_load_binary() {
679 let dir = tempdir().unwrap();
681 let file_path = dir.path().join("model.bin");
682
683 let x = vec![
685 vec![1.0, 2.0],
686 vec![2.0, 1.0],
687 vec![3.0, 3.0],
688 vec![4.0, 2.0],
689 ];
690 let y = vec![9.0, 8.0, 16.0, 15.0];
691
692 let mut model = MultipleLinearRegression::<f64>::new();
693 model.fit(&x, &y).unwrap();
694
695 let save_result = model.save_binary(&file_path);
697 assert!(save_result.is_ok());
698
699 let loaded_model = MultipleLinearRegression::<f64>::load_binary(&file_path);
701 assert!(loaded_model.is_ok());
702 let loaded = loaded_model.unwrap();
703
704 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
706 for i in 0..model.coefficients.len() {
707 assert!(approx_equal(
708 loaded.coefficients[i],
709 model.coefficients[i],
710 Some(1e-6)
711 ));
712 }
713 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
714 assert_eq!(loaded.n, model.n);
715 assert_eq!(loaded.p, model.p);
716 }
717
718 #[test]
719 fn test_json_serialization() {
720 let x = vec![
722 vec![1.0, 2.0],
723 vec![2.0, 1.0],
724 vec![3.0, 3.0],
725 vec![4.0, 2.0],
726 ];
727 let y = vec![9.0, 8.0, 16.0, 15.0];
728
729 let mut model = MultipleLinearRegression::<f64>::new();
730 model.fit(&x, &y).unwrap();
731
732 let json_result = model.to_json();
734 assert!(json_result.is_ok());
735 let json_str = json_result.unwrap();
736
737 let loaded_model = MultipleLinearRegression::<f64>::from_json(&json_str);
739 assert!(loaded_model.is_ok());
740 let loaded = loaded_model.unwrap();
741
742 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
744 for i in 0..model.coefficients.len() {
745 assert!(approx_equal(
746 loaded.coefficients[i],
747 model.coefficients[i],
748 Some(1e-6)
749 ));
750 }
751 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
752 assert_eq!(loaded.n, model.n);
753 assert_eq!(loaded.p, model.p);
754 }
755
756 #[test]
757 fn test_predict_not_fitted() {
758 let model = MultipleLinearRegression::<f64>::new();
760 let features = vec![1.0, 2.0];
764 let result = model.predict(&features);
765 assert!(result.is_err());
766 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
767 }
768
769 #[test]
770 fn test_predict_dimension_mismatch() {
771 let mut model = MultipleLinearRegression::<f64>::new();
773 let x = vec![
775 vec![1.0, 2.0],
776 vec![2.0, 1.0],
777 vec![3.0, 3.0],
778 vec![4.0, 2.0],
779 ];
780 let y = vec![3.0, 3.0, 6.0, 6.0];
781 model.fit(&x, &y).unwrap();
782
783 let wrong_features = vec![1.0]; let result = model.predict(&wrong_features);
786 assert!(result.is_err());
788 assert!(matches!(
789 result.unwrap_err(),
790 StatsError::DimensionMismatch { .. }
791 ));
792 }
793
794 #[test]
795 fn test_fit_singular_matrix() {
796 let x = vec![
799 vec![1.0, 2.0, 3.0], vec![2.0, 4.0, 6.0], vec![3.0, 6.0, 9.0], ];
803 let y = vec![1.0, 2.0, 3.0];
804
805 let mut model = MultipleLinearRegression::<f64>::new();
806 let result = model.fit(&x, &y);
807 match result {
810 Ok(_) => {
811 assert!(!model.coefficients.is_empty());
813 }
814 Err(e) => {
815 assert!(matches!(e, StatsError::MathematicalError { .. }));
817 }
818 }
819 }
820
821 #[test]
822 fn test_save_invalid_path() {
823 let mut model = MultipleLinearRegression::<f64>::new();
825 let x = vec![vec![1.0], vec![2.0]];
826 let y = vec![2.0, 4.0];
827 model.fit(&x, &y).unwrap();
828
829 let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
830 let result = model.save(invalid_path);
831 assert!(
832 result.is_err(),
833 "Saving to invalid path should return error"
834 );
835 }
836
837 #[test]
838 fn test_load_nonexistent_file() {
839 let nonexistent_path = std::path::Path::new("/nonexistent/file.json");
841 let result = MultipleLinearRegression::<f64>::load(nonexistent_path);
842 assert!(
843 result.is_err(),
844 "Loading non-existent file should return error"
845 );
846 }
847
848 #[test]
849 fn test_from_json_invalid() {
850 let invalid_json = "not valid json";
852 let result = MultipleLinearRegression::<f64>::from_json(invalid_json);
853 assert!(
854 result.is_err(),
855 "Deserializing invalid JSON should return error"
856 );
857 }
858
859 #[test]
860 fn test_predict_t_coefficients_empty() {
861 let model = MultipleLinearRegression::<f64>::new();
863 let features = vec![1.0, 2.0];
864 let result = model.predict(&features);
866 assert!(result.is_err());
867 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
868 }
869
870 #[test]
871 fn test_fit_x_values_empty_after_check() {
872 let mut model = MultipleLinearRegression::<f64>::new();
875 let x: Vec<Vec<f64>> = vec![];
877 let y: Vec<f64> = vec![];
878 let result = model.fit(&x, &y);
879 assert!(result.is_err());
880 }
881
882 #[test]
883 fn test_predict_many_not_fitted() {
884 let model = MultipleLinearRegression::<f64>::new();
886 let result = model.predict_many(&[vec![1.0, 2.0]]);
887 assert!(result.is_err());
888 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
889 }
890
891 #[test]
892 fn test_predict_many_dimension_mismatch() {
893 let mut model = MultipleLinearRegression::<f64>::new();
895 let x = vec![vec![1.0, 2.0], vec![2.0, 1.0], vec![3.0, 3.0]];
896 let y = vec![3.0, 3.0, 6.0];
897 model.fit(&x, &y).unwrap();
898
899 let wrong_features = vec![vec![1.0]]; let result = model.predict_many(&wrong_features);
902 assert!(result.is_err());
903 assert!(matches!(
904 result.unwrap_err(),
905 StatsError::DimensionMismatch { .. }
906 ));
907 }
908
909 #[test]
910 fn test_predict_many_success() {
911 let mut model = MultipleLinearRegression::<f64>::new();
913 let x = vec![
914 vec![1.0, 2.0],
915 vec![2.0, 1.0],
916 vec![3.0, 3.0],
917 vec![4.0, 2.0],
918 ];
919 let y = vec![3.0, 3.0, 6.0, 6.0];
920 model.fit(&x, &y).unwrap();
921
922 let predictions = model
923 .predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]])
924 .unwrap();
925 assert_eq!(predictions.len(), 2);
926 }
927}