1use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{self};
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct MultipleLinearRegression<T = f64>
14where
15 T: Float + Debug + Default + Serialize,
16{
17 pub coefficients: Vec<T>,
19 pub r_squared: T,
21 pub adjusted_r_squared: T,
23 pub standard_error: T,
25 pub n: usize,
27 pub p: usize,
29}
30
31impl<T> Default for MultipleLinearRegression<T>
32where
33 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
34{
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl<T> MultipleLinearRegression<T>
41where
42 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
43{
44 pub fn new() -> Self {
46 Self {
47 coefficients: Vec::new(),
48 r_squared: T::zero(),
49 adjusted_r_squared: T::zero(),
50 standard_error: T::zero(),
51 n: 0,
52 p: 0,
53 }
54 }
55
56 pub fn fit<U, V>(&mut self, x_values: &[Vec<U>], y_values: &[V]) -> StatsResult<()>
72 where
73 U: NumCast + Copy,
74 V: NumCast + Copy,
75 {
76 if x_values.is_empty() || y_values.is_empty() {
78 return Err(StatsError::empty_data(
79 "Cannot fit regression with empty arrays",
80 ));
81 }
82
83 if x_values.len() != y_values.len() {
84 return Err(StatsError::dimension_mismatch(format!(
85 "Number of observations in X and Y must match (got {} and {})",
86 x_values.len(),
87 y_values.len()
88 )));
89 }
90
91 self.n = x_values.len();
92
93 if x_values.is_empty() {
95 return Err(StatsError::empty_data("X values array is empty"));
96 }
97
98 self.p = x_values[0].len();
99
100 for (i, row) in x_values.iter().enumerate() {
101 if row.len() != self.p {
102 return Err(StatsError::invalid_input(format!(
103 "All rows in X must have the same number of features (row {} has {} features, expected {})",
104 i,
105 row.len(),
106 self.p
107 )));
108 }
109 }
110
111 let m = self.p + 1;
116 let mut augmented_x: Vec<T> = Vec::with_capacity(self.n * m);
117 for (row_idx, row) in x_values.iter().enumerate() {
118 augmented_x.push(T::one()); for (col_idx, &x) in row.iter().enumerate() {
120 let cast = T::from(x).ok_or_else(|| {
121 StatsError::conversion_error(format!(
122 "Failed to cast X value at row {row_idx}, column {col_idx} to type T"
123 ))
124 })?;
125 augmented_x.push(cast);
126 }
127 }
128
129 let y_cast: Vec<T> = y_values
130 .iter()
131 .enumerate()
132 .map(|(i, &y)| {
133 T::from(y).ok_or_else(|| {
134 StatsError::conversion_error(format!(
135 "Failed to cast Y value at index {i} to type T"
136 ))
137 })
138 })
139 .collect::<StatsResult<Vec<T>>>()?;
140
141 let xt_x = matrix_multiply_transpose_flat::<T>(&augmented_x, self.n, m);
143
144 let xt_y = vector_multiply_transpose_flat::<T>(&augmented_x, &y_cast, self.n, m);
146
147 self.coefficients = solve_linear_system_flat::<T>(&xt_x, &xt_y, m)?;
149
150 let n_as_t = T::from(self.n).ok_or_else(|| {
152 StatsError::conversion_error(format!("Failed to convert {} to type T", self.n))
153 })?;
154 let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
155
156 let mut ss_total = T::zero();
157 let mut ss_residual = T::zero();
158
159 for i in 0..self.n {
160 let row_start = i * m + 1;
164 let row_end = i * m + m;
165 let predicted = self.predict_t(&augmented_x[row_start..row_end]);
166 let residual = y_cast[i] - predicted;
167
168 ss_residual = ss_residual + (residual * residual);
169 let diff = y_cast[i] - y_mean;
170 ss_total = ss_total + (diff * diff);
171 }
172
173 if ss_total > T::zero() {
175 self.r_squared = T::one() - (ss_residual / ss_total);
176
177 if self.n > self.p + 1 {
179 let n_minus_1 = T::from(self.n - 1).ok_or_else(|| {
180 StatsError::conversion_error(format!(
181 "Failed to convert {} to type T",
182 self.n - 1
183 ))
184 })?;
185 let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
186 StatsError::conversion_error(format!(
187 "Failed to convert {} to type T",
188 self.n - self.p - 1
189 ))
190 })?;
191
192 self.adjusted_r_squared =
193 T::one() - ((T::one() - self.r_squared) * n_minus_1 / n_minus_p_minus_1);
194 }
195 }
196
197 if self.n > self.p + 1 {
199 let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
200 StatsError::conversion_error(format!(
201 "Failed to convert {} to type T",
202 self.n - self.p - 1
203 ))
204 })?;
205 self.standard_error = (ss_residual / n_minus_p_minus_1).sqrt();
206 }
207
208 Ok(())
209 }
210
211 fn predict_t(&self, x: &[T]) -> T {
213 if x.len() != self.p || self.coefficients.is_empty() {
214 return T::nan();
215 }
216
217 let mut result = self.coefficients[0];
219
220 for (i, &xi) in x.iter().enumerate().take(self.p) {
222 result = result + (self.coefficients[i + 1] * xi);
223 }
224
225 result
226 }
227
228 pub fn predict<U>(&self, x: &[U]) -> StatsResult<T>
258 where
259 U: NumCast + Copy,
260 {
261 if self.coefficients.is_empty() {
262 return Err(StatsError::not_fitted(
263 "Model has not been fitted. Call fit() before predicting.",
264 ));
265 }
266
267 if x.len() != self.p {
268 return Err(StatsError::dimension_mismatch(format!(
269 "Expected {} features, but got {}",
270 self.p,
271 x.len()
272 )));
273 }
274
275 let x_cast: StatsResult<Vec<T>> = x
277 .iter()
278 .enumerate()
279 .map(|(i, &val)| {
280 T::from(val).ok_or_else(|| {
281 StatsError::conversion_error(format!(
282 "Failed to convert feature value at index {} to type T",
283 i
284 ))
285 })
286 })
287 .collect();
288
289 Ok(self.predict_t(&x_cast?))
290 }
291
292 pub fn predict_many<U>(&self, x_values: &[Vec<U>]) -> StatsResult<Vec<T>>
322 where
323 U: NumCast + Copy,
324 {
325 x_values.iter().map(|x| self.predict(x)).collect()
326 }
327
328 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
336 let file = File::create(path)?;
337 serde_json::to_writer(file, self).map_err(io::Error::other)
339 }
340
341 pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
349 let file = File::create(path)?;
350 bincode::serialize_into(file, self).map_err(io::Error::other)
352 }
353
354 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
362 let file = File::open(path)?;
363 serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
365 }
366
367 pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
375 let file = File::open(path)?;
376 bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
378 }
379
380 pub fn to_json(&self) -> Result<String, String> {
385 serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
386 }
387
388 pub fn from_json(json: &str) -> Result<Self, String> {
396 serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
397 }
398}
399
400fn matrix_multiply_transpose_flat<T>(a: &[T], n_rows: usize, n_cols: usize) -> Vec<T>
405where
406 T: Float,
407{
408 let mut result = vec![T::zero(); n_cols * n_cols];
409 for k in 0..n_rows {
412 let row_off = k * n_cols;
413 for i in 0..n_cols {
414 let a_ki = a[row_off + i];
415 if a_ki == T::zero() {
416 continue;
417 }
418 let dst_off = i * n_cols;
419 for j in 0..n_cols {
420 result[dst_off + j] = result[dst_off + j] + a_ki * a[row_off + j];
421 }
422 }
423 }
424 result
425}
426
427fn vector_multiply_transpose_flat<T>(a: &[T], y: &[T], n_rows: usize, n_cols: usize) -> Vec<T>
429where
430 T: Float,
431{
432 let mut result = vec![T::zero(); n_cols];
433 for k in 0..n_rows {
434 let row_off = k * n_cols;
435 let yk = y[k];
436 for i in 0..n_cols {
437 result[i] = result[i] + a[row_off + i] * yk;
438 }
439 }
440 result
441}
442
443fn solve_linear_system_flat<T>(a: &[T], b: &[T], n: usize) -> StatsResult<Vec<T>>
447where
448 T: Float + Debug,
449{
450 if a.len() != n * n || b.len() != n {
451 return Err(StatsError::dimension_mismatch(format!(
452 "Invalid matrix dimensions for linear system solving: A is {n}×{n} ({} elems), b has {} elements",
453 a.len(),
454 b.len()
455 )));
456 }
457 let w = n + 1;
458 let mut aug: Vec<T> = Vec::with_capacity(n * w);
459 for i in 0..n {
460 aug.extend_from_slice(&a[i * n..(i + 1) * n]);
461 aug.push(b[i]);
462 }
463
464 let epsilon: T = T::from(1e-10).ok_or_else(|| {
465 StatsError::conversion_error("Failed to convert epsilon (1e-10) to type T")
466 })?;
467
468 for i in 0..n {
469 let mut max_row = i;
471 let mut max_val = aug[i * w + i].abs();
472 for j in (i + 1)..n {
473 let abs_val = aug[j * w + i].abs();
474 if abs_val > max_val {
475 max_row = j;
476 max_val = abs_val;
477 }
478 }
479 if max_val < epsilon {
480 return Err(StatsError::mathematical_error(
481 "Matrix is singular or near-singular, cannot solve linear system",
482 ));
483 }
484 if max_row != i {
485 for c in 0..w {
486 aug.swap(i * w + c, max_row * w + c);
487 }
488 }
489
490 let pivot = aug[i * w + i];
492 for j in (i + 1)..n {
493 let factor = aug[j * w + i] / pivot;
494 for k in i..w {
495 aug[j * w + k] = aug[j * w + k] - factor * aug[i * w + k];
496 }
497 }
498 }
499
500 let mut x = vec![T::zero(); n];
502 for i in (0..n).rev() {
503 let mut sum = aug[i * w + n];
504 for j in (i + 1)..n {
505 sum = sum - aug[i * w + j] * x[j];
506 }
507 x[i] = sum / aug[i * w + i];
508 }
509 Ok(x)
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use crate::utils::approx_equal;
516 use tempfile::tempdir;
517
518 #[test]
519 fn test_simple_multi_regression_f64() {
520 let x = vec![
522 vec![1.0, 2.0],
523 vec![2.0, 1.0],
524 vec![3.0, 3.0],
525 vec![4.0, 2.0],
526 ];
527 let y = vec![9.0, 8.0, 16.0, 15.0];
528
529 let mut model = MultipleLinearRegression::<f64>::new();
530 let result = model.fit(&x, &y);
531
532 assert!(result.is_ok());
533 assert!(model.coefficients.len() == 3);
534 assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
538 }
539
540 #[test]
541 fn test_simple_multi_regression_f32() {
542 let x = vec![
544 vec![1.0f32, 2.0f32],
545 vec![2.0f32, 1.0f32],
546 vec![3.0f32, 3.0f32],
547 vec![4.0f32, 2.0f32],
548 ];
549 let y = vec![9.0f32, 8.0f32, 16.0f32, 15.0f32];
550
551 let mut model = MultipleLinearRegression::<f32>::new();
552 let result = model.fit(&x, &y);
553
554 assert!(result.is_ok());
555 assert!(model.coefficients.len() == 3);
556 assert!(approx_equal(model.coefficients[0], 1.0f32, Some(1e-4))); assert!(approx_equal(model.coefficients[1], 2.0f32, Some(1e-4))); assert!(approx_equal(model.coefficients[2], 3.0f32, Some(1e-4))); assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-4)));
560 }
561
562 #[test]
563 fn test_integer_data() {
564 let x = vec![
566 vec![1u32, 2u32],
567 vec![2u32, 1u32],
568 vec![3u32, 3u32],
569 vec![4u32, 2u32],
570 ];
571 let y = vec![9i32, 8i32, 16i32, 15i32];
572
573 let mut model = MultipleLinearRegression::<f64>::new();
574 let result = model.fit(&x, &y);
575
576 assert!(result.is_ok());
577 assert!(model.coefficients.len() == 3);
578 assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
582 }
583
584 #[test]
585 fn test_prediction() {
586 let x = vec![vec![1, 2], vec![2, 1], vec![3, 3], vec![4, 2]];
588 let y = vec![9, 8, 16, 15];
589
590 let mut model = MultipleLinearRegression::<f64>::new();
591 model.fit(&x, &y).unwrap();
592
593 assert!(approx_equal(
595 model.predict(&[5u32, 4u32]).unwrap(),
596 23.0,
597 Some(1e-6)
598 ));
599 }
600
601 #[test]
602 fn test_prediction_many() {
603 let x = vec![vec![1, 2], vec![2, 1], vec![3, 3]];
604 let y = vec![9, 8, 16];
605
606 let mut model = MultipleLinearRegression::<f64>::new();
607 model.fit(&x, &y).unwrap();
608
609 let new_x = vec![vec![1u32, 2u32], vec![5u32, 4u32]];
610
611 let predictions = model.predict_many(&new_x).unwrap();
612 assert_eq!(predictions.len(), 2);
613 assert!(approx_equal(predictions[0], 9.0, Some(1e-6)));
614 assert!(approx_equal(predictions[1], 23.0, Some(1e-6)));
615 }
616
617 #[test]
618 fn test_save_load_json() {
619 let dir = tempdir().unwrap();
621 let file_path = dir.path().join("model.json");
622
623 let x = vec![
625 vec![1.0, 2.0],
626 vec![2.0, 1.0],
627 vec![3.0, 3.0],
628 vec![4.0, 2.0],
629 ];
630 let y = vec![9.0, 8.0, 16.0, 15.0];
631
632 let mut model = MultipleLinearRegression::<f64>::new();
633 model.fit(&x, &y).unwrap();
634
635 let save_result = model.save(&file_path);
637 assert!(save_result.is_ok());
638
639 let loaded_model = MultipleLinearRegression::<f64>::load(&file_path);
641 assert!(loaded_model.is_ok());
642 let loaded = loaded_model.unwrap();
643
644 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
646 for i in 0..model.coefficients.len() {
647 assert!(approx_equal(
648 loaded.coefficients[i],
649 model.coefficients[i],
650 Some(1e-6)
651 ));
652 }
653 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
654 assert_eq!(loaded.n, model.n);
655 assert_eq!(loaded.p, model.p);
656 }
657
658 #[test]
659 fn test_save_load_binary() {
660 let dir = tempdir().unwrap();
662 let file_path = dir.path().join("model.bin");
663
664 let x = vec![
666 vec![1.0, 2.0],
667 vec![2.0, 1.0],
668 vec![3.0, 3.0],
669 vec![4.0, 2.0],
670 ];
671 let y = vec![9.0, 8.0, 16.0, 15.0];
672
673 let mut model = MultipleLinearRegression::<f64>::new();
674 model.fit(&x, &y).unwrap();
675
676 let save_result = model.save_binary(&file_path);
678 assert!(save_result.is_ok());
679
680 let loaded_model = MultipleLinearRegression::<f64>::load_binary(&file_path);
682 assert!(loaded_model.is_ok());
683 let loaded = loaded_model.unwrap();
684
685 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
687 for i in 0..model.coefficients.len() {
688 assert!(approx_equal(
689 loaded.coefficients[i],
690 model.coefficients[i],
691 Some(1e-6)
692 ));
693 }
694 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
695 assert_eq!(loaded.n, model.n);
696 assert_eq!(loaded.p, model.p);
697 }
698
699 #[test]
700 fn test_json_serialization() {
701 let x = vec![
703 vec![1.0, 2.0],
704 vec![2.0, 1.0],
705 vec![3.0, 3.0],
706 vec![4.0, 2.0],
707 ];
708 let y = vec![9.0, 8.0, 16.0, 15.0];
709
710 let mut model = MultipleLinearRegression::<f64>::new();
711 model.fit(&x, &y).unwrap();
712
713 let json_result = model.to_json();
715 assert!(json_result.is_ok());
716 let json_str = json_result.unwrap();
717
718 let loaded_model = MultipleLinearRegression::<f64>::from_json(&json_str);
720 assert!(loaded_model.is_ok());
721 let loaded = loaded_model.unwrap();
722
723 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
725 for i in 0..model.coefficients.len() {
726 assert!(approx_equal(
727 loaded.coefficients[i],
728 model.coefficients[i],
729 Some(1e-6)
730 ));
731 }
732 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
733 assert_eq!(loaded.n, model.n);
734 assert_eq!(loaded.p, model.p);
735 }
736
737 #[test]
738 fn test_predict_not_fitted() {
739 let model = MultipleLinearRegression::<f64>::new();
741 let features = vec![1.0, 2.0];
745 let result = model.predict(&features);
746 assert!(result.is_err());
747 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
748 }
749
750 #[test]
751 fn test_predict_dimension_mismatch() {
752 let mut model = MultipleLinearRegression::<f64>::new();
754 let x = vec![
756 vec![1.0, 2.0],
757 vec![2.0, 1.0],
758 vec![3.0, 3.0],
759 vec![4.0, 2.0],
760 ];
761 let y = vec![3.0, 3.0, 6.0, 6.0];
762 model.fit(&x, &y).unwrap();
763
764 let wrong_features = vec![1.0]; let result = model.predict(&wrong_features);
767 assert!(result.is_err());
769 assert!(matches!(
770 result.unwrap_err(),
771 StatsError::DimensionMismatch { .. }
772 ));
773 }
774
775 #[test]
776 fn test_fit_singular_matrix() {
777 let x = vec![
780 vec![1.0, 2.0, 3.0], vec![2.0, 4.0, 6.0], vec![3.0, 6.0, 9.0], ];
784 let y = vec![1.0, 2.0, 3.0];
785
786 let mut model = MultipleLinearRegression::<f64>::new();
787 let result = model.fit(&x, &y);
788 match result {
791 Ok(_) => {
792 assert!(!model.coefficients.is_empty());
794 }
795 Err(e) => {
796 assert!(matches!(e, StatsError::MathematicalError { .. }));
798 }
799 }
800 }
801
802 #[test]
803 fn test_save_invalid_path() {
804 let mut model = MultipleLinearRegression::<f64>::new();
806 let x = vec![vec![1.0], vec![2.0]];
807 let y = vec![2.0, 4.0];
808 model.fit(&x, &y).unwrap();
809
810 let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
811 let result = model.save(invalid_path);
812 assert!(
813 result.is_err(),
814 "Saving to invalid path should return error"
815 );
816 }
817
818 #[test]
819 fn test_load_nonexistent_file() {
820 let nonexistent_path = std::path::Path::new("/nonexistent/file.json");
822 let result = MultipleLinearRegression::<f64>::load(nonexistent_path);
823 assert!(
824 result.is_err(),
825 "Loading non-existent file should return error"
826 );
827 }
828
829 #[test]
830 fn test_from_json_invalid() {
831 let invalid_json = "not valid json";
833 let result = MultipleLinearRegression::<f64>::from_json(invalid_json);
834 assert!(
835 result.is_err(),
836 "Deserializing invalid JSON should return error"
837 );
838 }
839
840 #[test]
841 fn test_predict_t_coefficients_empty() {
842 let model = MultipleLinearRegression::<f64>::new();
844 let features = vec![1.0, 2.0];
845 let result = model.predict(&features);
847 assert!(result.is_err());
848 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
849 }
850
851 #[test]
852 fn test_fit_x_values_empty_after_check() {
853 let mut model = MultipleLinearRegression::<f64>::new();
856 let x: Vec<Vec<f64>> = vec![];
858 let y: Vec<f64> = vec![];
859 let result = model.fit(&x, &y);
860 assert!(result.is_err());
861 }
862
863 #[test]
864 fn test_predict_many_not_fitted() {
865 let model = MultipleLinearRegression::<f64>::new();
867 let result = model.predict_many(&[vec![1.0, 2.0]]);
868 assert!(result.is_err());
869 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
870 }
871
872 #[test]
873 fn test_predict_many_dimension_mismatch() {
874 let mut model = MultipleLinearRegression::<f64>::new();
876 let x = vec![vec![1.0, 2.0], vec![2.0, 1.0], vec![3.0, 3.0]];
877 let y = vec![3.0, 3.0, 6.0];
878 model.fit(&x, &y).unwrap();
879
880 let wrong_features = vec![vec![1.0]]; let result = model.predict_many(&wrong_features);
883 assert!(result.is_err());
884 assert!(matches!(
885 result.unwrap_err(),
886 StatsError::DimensionMismatch { .. }
887 ));
888 }
889
890 #[test]
891 fn test_predict_many_success() {
892 let mut model = MultipleLinearRegression::<f64>::new();
894 let x = vec![
895 vec![1.0, 2.0],
896 vec![2.0, 1.0],
897 vec![3.0, 3.0],
898 vec![4.0, 2.0],
899 ];
900 let y = vec![3.0, 3.0, 6.0, 6.0];
901 model.fit(&x, &y).unwrap();
902
903 let predictions = model
904 .predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]])
905 .unwrap();
906 assert_eq!(predictions.len(), 2);
907 }
908}