1use crate::error::{StatsError, StatsResult};
4use num_traits::{Float, NumCast};
5use serde::{Deserialize, Serialize};
6use std::fmt::Debug;
7use std::fs::File;
8use std::io::{self};
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct MultipleLinearRegression<T = f64>
14where
15 T: Float + Debug + Default + Serialize,
16{
17 pub coefficients: Vec<T>,
19 pub r_squared: T,
21 pub adjusted_r_squared: T,
23 pub standard_error: T,
25 pub n: usize,
27 pub p: usize,
29}
30
31impl<T> Default for MultipleLinearRegression<T>
32where
33 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
34{
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl<T> MultipleLinearRegression<T>
41where
42 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
43{
44 pub fn new() -> Self {
46 Self {
47 coefficients: Vec::new(),
48 r_squared: T::zero(),
49 adjusted_r_squared: T::zero(),
50 standard_error: T::zero(),
51 n: 0,
52 p: 0,
53 }
54 }
55
56 pub fn fit<U, V>(&mut self, x_values: &[Vec<U>], y_values: &[V]) -> StatsResult<()>
72 where
73 U: NumCast + Copy,
74 V: NumCast + Copy,
75 {
76 if x_values.is_empty() || y_values.is_empty() {
78 return Err(StatsError::empty_data(
79 "Cannot fit regression with empty arrays",
80 ));
81 }
82
83 if x_values.len() != y_values.len() {
84 return Err(StatsError::dimension_mismatch(format!(
85 "Number of observations in X and Y must match (got {} and {})",
86 x_values.len(),
87 y_values.len()
88 )));
89 }
90
91 self.n = x_values.len();
92
93 if x_values.is_empty() {
95 return Err(StatsError::empty_data("X values array is empty"));
96 }
97
98 self.p = x_values[0].len();
99
100 for (i, row) in x_values.iter().enumerate() {
101 if row.len() != self.p {
102 return Err(StatsError::invalid_input(format!(
103 "All rows in X must have the same number of features (row {} has {} features, expected {})",
104 i,
105 row.len(),
106 self.p
107 )));
108 }
109 }
110
111 let mut x_cast: Vec<Vec<T>> = Vec::with_capacity(self.n);
113 for (row_idx, row) in x_values.iter().enumerate() {
114 let row_cast: StatsResult<Vec<T>> = row
115 .iter()
116 .enumerate()
117 .map(|(col_idx, &x)| {
118 T::from(x).ok_or_else(|| {
119 StatsError::conversion_error(format!(
120 "Failed to cast X value at row {}, column {} to type T",
121 row_idx, col_idx
122 ))
123 })
124 })
125 .collect();
126 x_cast.push(row_cast?);
127 }
128
129 let y_cast: Vec<T> = y_values
130 .iter()
131 .enumerate()
132 .map(|(i, &y)| {
133 T::from(y).ok_or_else(|| {
134 StatsError::conversion_error(format!(
135 "Failed to cast Y value at index {} to type T",
136 i
137 ))
138 })
139 })
140 .collect::<StatsResult<Vec<T>>>()?;
141
142 let mut augmented_x = Vec::with_capacity(self.n);
144 for row in &x_cast {
145 let mut augmented_row = Vec::with_capacity(self.p + 1);
146 augmented_row.push(T::one()); augmented_row.extend_from_slice(row);
148 augmented_x.push(augmented_row);
149 }
150
151 let xt_x = self.matrix_multiply_transpose(&augmented_x, &augmented_x);
153
154 let xt_y = self.vector_multiply_transpose(&augmented_x, &y_cast);
156
157 match self.solve_linear_system(&xt_x, &xt_y) {
159 Ok(solution) => {
160 self.coefficients = solution;
161 }
162 Err(e) => return Err(e),
163 }
164
165 let n_as_t = T::from(self.n).ok_or_else(|| {
167 StatsError::conversion_error(format!("Failed to convert {} to type T", self.n))
168 })?;
169 let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / n_as_t;
170
171 let mut ss_total = T::zero();
172 let mut ss_residual = T::zero();
173
174 for i in 0..self.n {
175 let predicted = self.predict_t(&x_cast[i]);
176 let residual = y_cast[i] - predicted;
177
178 ss_residual = ss_residual + (residual * residual);
179 let diff = y_cast[i] - y_mean;
180 ss_total = ss_total + (diff * diff);
181 }
182
183 if ss_total > T::zero() {
185 self.r_squared = T::one() - (ss_residual / ss_total);
186
187 if self.n > self.p + 1 {
189 let n_minus_1 = T::from(self.n - 1).ok_or_else(|| {
190 StatsError::conversion_error(format!(
191 "Failed to convert {} to type T",
192 self.n - 1
193 ))
194 })?;
195 let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
196 StatsError::conversion_error(format!(
197 "Failed to convert {} to type T",
198 self.n - self.p - 1
199 ))
200 })?;
201
202 self.adjusted_r_squared =
203 T::one() - ((T::one() - self.r_squared) * n_minus_1 / n_minus_p_minus_1);
204 }
205 }
206
207 if self.n > self.p + 1 {
209 let n_minus_p_minus_1 = T::from(self.n - self.p - 1).ok_or_else(|| {
210 StatsError::conversion_error(format!(
211 "Failed to convert {} to type T",
212 self.n - self.p - 1
213 ))
214 })?;
215 self.standard_error = (ss_residual / n_minus_p_minus_1).sqrt();
216 }
217
218 Ok(())
219 }
220
221 fn predict_t(&self, x: &[T]) -> T {
223 if x.len() != self.p || self.coefficients.is_empty() {
224 return T::nan();
225 }
226
227 let mut result = self.coefficients[0];
229
230 for (i, &xi) in x.iter().enumerate().take(self.p) {
232 result = result + (self.coefficients[i + 1] * xi);
233 }
234
235 result
236 }
237
238 pub fn predict<U>(&self, x: &[U]) -> StatsResult<T>
268 where
269 U: NumCast + Copy,
270 {
271 if self.coefficients.is_empty() {
272 return Err(StatsError::not_fitted(
273 "Model has not been fitted. Call fit() before predicting.",
274 ));
275 }
276
277 if x.len() != self.p {
278 return Err(StatsError::dimension_mismatch(format!(
279 "Expected {} features, but got {}",
280 self.p,
281 x.len()
282 )));
283 }
284
285 let x_cast: StatsResult<Vec<T>> = x
287 .iter()
288 .enumerate()
289 .map(|(i, &val)| {
290 T::from(val).ok_or_else(|| {
291 StatsError::conversion_error(format!(
292 "Failed to convert feature value at index {} to type T",
293 i
294 ))
295 })
296 })
297 .collect();
298
299 Ok(self.predict_t(&x_cast?))
300 }
301
302 pub fn predict_many<U>(&self, x_values: &[Vec<U>]) -> StatsResult<Vec<T>>
332 where
333 U: NumCast + Copy,
334 {
335 x_values.iter().map(|x| self.predict(x)).collect()
336 }
337
338 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
346 let file = File::create(path)?;
347 serde_json::to_writer(file, self).map_err(io::Error::other)
349 }
350
351 pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
359 let file = File::create(path)?;
360 bincode::serialize_into(file, self).map_err(io::Error::other)
362 }
363
364 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
372 let file = File::open(path)?;
373 serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
375 }
376
377 pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
385 let file = File::open(path)?;
386 bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
388 }
389
390 pub fn to_json(&self) -> Result<String, String> {
395 serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
396 }
397
398 pub fn from_json(json: &str) -> Result<Self, String> {
406 serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
407 }
408
409 fn matrix_multiply_transpose(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Vec<Vec<T>> {
411 let a_rows = a.len();
412 let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
413 let b_rows = b.len();
414 let b_cols = if b_rows > 0 { b[0].len() } else { 0 };
415
416 let mut result = vec![vec![T::zero(); b_cols]; a_cols];
418
419 for (i, result_row) in result.iter_mut().enumerate().take(a_cols) {
420 for (j, result_elem) in result_row.iter_mut().enumerate().take(b_cols) {
421 let mut sum = T::zero();
422 for k in 0..a_rows {
423 sum = sum + (a[k][i] * b[k][j]);
424 }
425 *result_elem = sum;
426 }
427 }
428
429 result
430 }
431
432 fn vector_multiply_transpose(&self, a: &[Vec<T>], y: &[T]) -> Vec<T> {
434 let a_rows = a.len();
435 let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
436
437 let mut result = vec![T::zero(); a_cols];
438
439 for (i, result_item) in result.iter_mut().enumerate().take(a_cols) {
440 let mut sum = T::zero();
441 for j in 0..a_rows {
442 sum = sum + (a[j][i] * y[j]);
443 }
444 *result_item = sum;
445 }
446
447 result
448 }
449
450 fn solve_linear_system(&self, a: &[Vec<T>], b: &[T]) -> StatsResult<Vec<T>> {
452 let n = a.len();
453 if n == 0 || a[0].len() != n || b.len() != n {
454 return Err(StatsError::dimension_mismatch(format!(
455 "Invalid matrix dimensions for linear system solving: A is {}x{}, b has {} elements",
456 n,
457 if n > 0 { a[0].len() } else { 0 },
458 b.len()
459 )));
460 }
461
462 let mut aug = Vec::with_capacity(n);
464 for i in 0..n {
465 let mut row = a[i].clone();
466 row.push(b[i]);
467 aug.push(row);
468 }
469
470 for i in 0..n {
472 let mut max_row = i;
474 let mut max_val = aug[i][i].abs();
475
476 for (j, row) in aug.iter().enumerate().skip(i + 1).take(n - (i + 1)) {
477 let abs_val = row[i].abs();
478 if abs_val > max_val {
479 max_row = j;
480 max_val = abs_val;
481 }
482 }
483
484 let epsilon: T = T::from(1e-10).ok_or_else(|| {
485 StatsError::conversion_error("Failed to convert epsilon (1e-10) to type T")
486 })?;
487 if max_val < epsilon {
488 return Err(StatsError::mathematical_error(
489 "Matrix is singular or near-singular, cannot solve linear system",
490 ));
491 }
492
493 if max_row != i {
495 aug.swap(i, max_row);
496 }
497
498 for j in (i + 1)..n {
500 let factor = aug[j][i] / aug[i][i];
501
502 for k in i..(n + 1) {
503 aug[j][k] = aug[j][k] - (factor * aug[i][k]);
504 }
505 }
506 }
507
508 let mut x = vec![T::zero(); n];
510 for i in (0..n).rev() {
511 let mut sum = aug[i][n];
512
513 for (j, &x_val) in x.iter().enumerate().skip(i + 1).take(n - (i + 1)) {
514 sum = sum - (aug[i][j] * x_val);
515 }
516
517 x[i] = sum / aug[i][i];
518 }
519
520 Ok(x)
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::utils::approx_equal;
528 use tempfile::tempdir;
529
530 #[test]
531 fn test_simple_multi_regression_f64() {
532 let x = vec![
534 vec![1.0, 2.0],
535 vec![2.0, 1.0],
536 vec![3.0, 3.0],
537 vec![4.0, 2.0],
538 ];
539 let y = vec![9.0, 8.0, 16.0, 15.0];
540
541 let mut model = MultipleLinearRegression::<f64>::new();
542 let result = model.fit(&x, &y);
543
544 assert!(result.is_ok());
545 assert!(model.coefficients.len() == 3);
546 assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
550 }
551
552 #[test]
553 fn test_simple_multi_regression_f32() {
554 let x = vec![
556 vec![1.0f32, 2.0f32],
557 vec![2.0f32, 1.0f32],
558 vec![3.0f32, 3.0f32],
559 vec![4.0f32, 2.0f32],
560 ];
561 let y = vec![9.0f32, 8.0f32, 16.0f32, 15.0f32];
562
563 let mut model = MultipleLinearRegression::<f32>::new();
564 let result = model.fit(&x, &y);
565
566 assert!(result.is_ok());
567 assert!(model.coefficients.len() == 3);
568 assert!(approx_equal(model.coefficients[0], 1.0f32, Some(1e-4))); assert!(approx_equal(model.coefficients[1], 2.0f32, Some(1e-4))); assert!(approx_equal(model.coefficients[2], 3.0f32, Some(1e-4))); assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-4)));
572 }
573
574 #[test]
575 fn test_integer_data() {
576 let x = vec![
578 vec![1u32, 2u32],
579 vec![2u32, 1u32],
580 vec![3u32, 3u32],
581 vec![4u32, 2u32],
582 ];
583 let y = vec![9i32, 8i32, 16i32, 15i32];
584
585 let mut model = MultipleLinearRegression::<f64>::new();
586 let result = model.fit(&x, &y);
587
588 assert!(result.is_ok());
589 assert!(model.coefficients.len() == 3);
590 assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
594 }
595
596 #[test]
597 fn test_prediction() {
598 let x = vec![vec![1, 2], vec![2, 1], vec![3, 3], vec![4, 2]];
600 let y = vec![9, 8, 16, 15];
601
602 let mut model = MultipleLinearRegression::<f64>::new();
603 model.fit(&x, &y).unwrap();
604
605 assert!(approx_equal(
607 model.predict(&[5u32, 4u32]).unwrap(),
608 23.0,
609 Some(1e-6)
610 ));
611 }
612
613 #[test]
614 fn test_prediction_many() {
615 let x = vec![vec![1, 2], vec![2, 1], vec![3, 3]];
616 let y = vec![9, 8, 16];
617
618 let mut model = MultipleLinearRegression::<f64>::new();
619 model.fit(&x, &y).unwrap();
620
621 let new_x = vec![vec![1u32, 2u32], vec![5u32, 4u32]];
622
623 let predictions = model.predict_many(&new_x).unwrap();
624 assert_eq!(predictions.len(), 2);
625 assert!(approx_equal(predictions[0], 9.0, Some(1e-6)));
626 assert!(approx_equal(predictions[1], 23.0, Some(1e-6)));
627 }
628
629 #[test]
630 fn test_save_load_json() {
631 let dir = tempdir().unwrap();
633 let file_path = dir.path().join("model.json");
634
635 let x = vec![
637 vec![1.0, 2.0],
638 vec![2.0, 1.0],
639 vec![3.0, 3.0],
640 vec![4.0, 2.0],
641 ];
642 let y = vec![9.0, 8.0, 16.0, 15.0];
643
644 let mut model = MultipleLinearRegression::<f64>::new();
645 model.fit(&x, &y).unwrap();
646
647 let save_result = model.save(&file_path);
649 assert!(save_result.is_ok());
650
651 let loaded_model = MultipleLinearRegression::<f64>::load(&file_path);
653 assert!(loaded_model.is_ok());
654 let loaded = loaded_model.unwrap();
655
656 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
658 for i in 0..model.coefficients.len() {
659 assert!(approx_equal(
660 loaded.coefficients[i],
661 model.coefficients[i],
662 Some(1e-6)
663 ));
664 }
665 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
666 assert_eq!(loaded.n, model.n);
667 assert_eq!(loaded.p, model.p);
668 }
669
670 #[test]
671 fn test_save_load_binary() {
672 let dir = tempdir().unwrap();
674 let file_path = dir.path().join("model.bin");
675
676 let x = vec![
678 vec![1.0, 2.0],
679 vec![2.0, 1.0],
680 vec![3.0, 3.0],
681 vec![4.0, 2.0],
682 ];
683 let y = vec![9.0, 8.0, 16.0, 15.0];
684
685 let mut model = MultipleLinearRegression::<f64>::new();
686 model.fit(&x, &y).unwrap();
687
688 let save_result = model.save_binary(&file_path);
690 assert!(save_result.is_ok());
691
692 let loaded_model = MultipleLinearRegression::<f64>::load_binary(&file_path);
694 assert!(loaded_model.is_ok());
695 let loaded = loaded_model.unwrap();
696
697 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
699 for i in 0..model.coefficients.len() {
700 assert!(approx_equal(
701 loaded.coefficients[i],
702 model.coefficients[i],
703 Some(1e-6)
704 ));
705 }
706 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
707 assert_eq!(loaded.n, model.n);
708 assert_eq!(loaded.p, model.p);
709 }
710
711 #[test]
712 fn test_json_serialization() {
713 let x = vec![
715 vec![1.0, 2.0],
716 vec![2.0, 1.0],
717 vec![3.0, 3.0],
718 vec![4.0, 2.0],
719 ];
720 let y = vec![9.0, 8.0, 16.0, 15.0];
721
722 let mut model = MultipleLinearRegression::<f64>::new();
723 model.fit(&x, &y).unwrap();
724
725 let json_result = model.to_json();
727 assert!(json_result.is_ok());
728 let json_str = json_result.unwrap();
729
730 let loaded_model = MultipleLinearRegression::<f64>::from_json(&json_str);
732 assert!(loaded_model.is_ok());
733 let loaded = loaded_model.unwrap();
734
735 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
737 for i in 0..model.coefficients.len() {
738 assert!(approx_equal(
739 loaded.coefficients[i],
740 model.coefficients[i],
741 Some(1e-6)
742 ));
743 }
744 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
745 assert_eq!(loaded.n, model.n);
746 assert_eq!(loaded.p, model.p);
747 }
748
749 #[test]
750 fn test_predict_not_fitted() {
751 let model = MultipleLinearRegression::<f64>::new();
753 let features = vec![1.0, 2.0];
757 let result = model.predict(&features);
758 assert!(result.is_err());
759 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
760 }
761
762 #[test]
763 fn test_predict_dimension_mismatch() {
764 let mut model = MultipleLinearRegression::<f64>::new();
766 let x = vec![
768 vec![1.0, 2.0],
769 vec![2.0, 1.0],
770 vec![3.0, 3.0],
771 vec![4.0, 2.0],
772 ];
773 let y = vec![3.0, 3.0, 6.0, 6.0];
774 model.fit(&x, &y).unwrap();
775
776 let wrong_features = vec![1.0]; let result = model.predict(&wrong_features);
779 assert!(result.is_err());
781 assert!(matches!(
782 result.unwrap_err(),
783 StatsError::DimensionMismatch { .. }
784 ));
785 }
786
787 #[test]
788 fn test_fit_singular_matrix() {
789 let x = vec![
792 vec![1.0, 2.0, 3.0], vec![2.0, 4.0, 6.0], vec![3.0, 6.0, 9.0], ];
796 let y = vec![1.0, 2.0, 3.0];
797
798 let mut model = MultipleLinearRegression::<f64>::new();
799 let result = model.fit(&x, &y);
800 match result {
803 Ok(_) => {
804 assert!(!model.coefficients.is_empty());
806 }
807 Err(e) => {
808 assert!(matches!(e, StatsError::MathematicalError { .. }));
810 }
811 }
812 }
813
814 #[test]
815 fn test_save_invalid_path() {
816 let mut model = MultipleLinearRegression::<f64>::new();
818 let x = vec![vec![1.0], vec![2.0]];
819 let y = vec![2.0, 4.0];
820 model.fit(&x, &y).unwrap();
821
822 let invalid_path = std::path::Path::new("/nonexistent/directory/model.json");
823 let result = model.save(invalid_path);
824 assert!(
825 result.is_err(),
826 "Saving to invalid path should return error"
827 );
828 }
829
830 #[test]
831 fn test_load_nonexistent_file() {
832 let nonexistent_path = std::path::Path::new("/nonexistent/file.json");
834 let result = MultipleLinearRegression::<f64>::load(nonexistent_path);
835 assert!(
836 result.is_err(),
837 "Loading non-existent file should return error"
838 );
839 }
840
841 #[test]
842 fn test_from_json_invalid() {
843 let invalid_json = "not valid json";
845 let result = MultipleLinearRegression::<f64>::from_json(invalid_json);
846 assert!(
847 result.is_err(),
848 "Deserializing invalid JSON should return error"
849 );
850 }
851
852 #[test]
853 fn test_predict_t_coefficients_empty() {
854 let model = MultipleLinearRegression::<f64>::new();
856 let features = vec![1.0, 2.0];
857 let result = model.predict(&features);
859 assert!(result.is_err());
860 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
861 }
862
863 #[test]
864 fn test_fit_x_values_empty_after_check() {
865 let mut model = MultipleLinearRegression::<f64>::new();
868 let x: Vec<Vec<f64>> = vec![];
870 let y: Vec<f64> = vec![];
871 let result = model.fit(&x, &y);
872 assert!(result.is_err());
873 }
874
875 #[test]
876 fn test_predict_many_not_fitted() {
877 let model = MultipleLinearRegression::<f64>::new();
879 let result = model.predict_many(&[vec![1.0, 2.0]]);
880 assert!(result.is_err());
881 assert!(matches!(result.unwrap_err(), StatsError::NotFitted { .. }));
882 }
883
884 #[test]
885 fn test_predict_many_dimension_mismatch() {
886 let mut model = MultipleLinearRegression::<f64>::new();
888 let x = vec![vec![1.0, 2.0], vec![2.0, 1.0], vec![3.0, 3.0]];
889 let y = vec![3.0, 3.0, 6.0];
890 model.fit(&x, &y).unwrap();
891
892 let wrong_features = vec![vec![1.0]]; let result = model.predict_many(&wrong_features);
895 assert!(result.is_err());
896 assert!(matches!(
897 result.unwrap_err(),
898 StatsError::DimensionMismatch { .. }
899 ));
900 }
901
902 #[test]
903 fn test_predict_many_success() {
904 let mut model = MultipleLinearRegression::<f64>::new();
906 let x = vec![
907 vec![1.0, 2.0],
908 vec![2.0, 1.0],
909 vec![3.0, 3.0],
910 vec![4.0, 2.0],
911 ];
912 let y = vec![3.0, 3.0, 6.0, 6.0];
913 model.fit(&x, &y).unwrap();
914
915 let predictions = model
916 .predict_many(&[vec![3.0, 4.0], vec![5.0, 6.0]])
917 .unwrap();
918 assert_eq!(predictions.len(), 2);
919 }
920}