1use num_traits::{Float, NumCast};
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6use std::fs::File;
7use std::io::{self};
8use std::path::Path;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct MultipleLinearRegression<T = f64>
13where
14 T: Float + Debug + Default + Serialize,
15{
16 pub coefficients: Vec<T>,
18 pub r_squared: T,
20 pub adjusted_r_squared: T,
22 pub standard_error: T,
24 pub n: usize,
26 pub p: usize,
28}
29
30impl<T> Default for MultipleLinearRegression<T>
31where
32 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
33{
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl<T> MultipleLinearRegression<T>
40where
41 T: Float + Debug + Default + NumCast + Serialize + for<'de> Deserialize<'de>,
42{
43 pub fn new() -> Self {
45 Self {
46 coefficients: Vec::new(),
47 r_squared: T::zero(),
48 adjusted_r_squared: T::zero(),
49 standard_error: T::zero(),
50 n: 0,
51 p: 0,
52 }
53 }
54
55 pub fn fit<U, V>(&mut self, x_values: &[Vec<U>], y_values: &[V]) -> Result<(), String>
64 where
65 U: NumCast + Copy,
66 V: NumCast + Copy,
67 {
68 if x_values.is_empty() || y_values.is_empty() {
70 return Err("Cannot fit regression with empty arrays".to_string());
71 }
72
73 if x_values.len() != y_values.len() {
74 return Err("Number of observations in X and Y must match".to_string());
75 }
76
77 self.n = x_values.len();
78
79 if x_values.is_empty() {
81 return Err("X values array is empty".to_string());
82 }
83
84 self.p = x_values[0].len();
85
86 for row in x_values {
87 if row.len() != self.p {
88 return Err("All rows in X must have the same number of features".to_string());
89 }
90 }
91
92 let mut x_cast: Vec<Vec<T>> = Vec::with_capacity(self.n);
94 for row in x_values {
95 let row_cast: Result<Vec<T>, String> = row
96 .iter()
97 .map(|&x| T::from(x).ok_or_else(|| "Failed to cast X value".to_string()))
98 .collect();
99 x_cast.push(row_cast?);
100 }
101
102 let y_cast: Vec<T> = y_values
103 .iter()
104 .map(|&y| T::from(y).ok_or_else(|| "Failed to cast Y value".to_string()))
105 .collect::<Result<Vec<T>, String>>()?;
106
107 let mut augmented_x = Vec::with_capacity(self.n);
109 for row in &x_cast {
110 let mut augmented_row = Vec::with_capacity(self.p + 1);
111 augmented_row.push(T::one()); augmented_row.extend_from_slice(row);
113 augmented_x.push(augmented_row);
114 }
115
116 let xt_x = self.matrix_multiply_transpose(&augmented_x, &augmented_x);
118
119 let xt_y = self.vector_multiply_transpose(&augmented_x, &y_cast);
121
122 match self.solve_linear_system(&xt_x, &xt_y) {
124 Ok(solution) => {
125 self.coefficients = solution;
126 }
127 Err(e) => return Err(e),
128 }
129
130 let y_mean = y_cast.iter().fold(T::zero(), |acc, &y| acc + y) / T::from(self.n).unwrap();
132
133 let mut ss_total = T::zero();
134 let mut ss_residual = T::zero();
135
136 for i in 0..self.n {
137 let predicted = self.predict_t(&x_cast[i]);
138 let residual = y_cast[i] - predicted;
139
140 ss_residual = ss_residual + (residual * residual);
141 let diff = y_cast[i] - y_mean;
142 ss_total = ss_total + (diff * diff);
143 }
144
145 if ss_total > T::zero() {
147 self.r_squared = T::one() - (ss_residual / ss_total);
148
149 if self.n > self.p + 1 {
151 let n_minus_1 = T::from(self.n - 1).unwrap();
152 let n_minus_p_minus_1 = T::from(self.n - self.p - 1).unwrap();
153
154 self.adjusted_r_squared =
155 T::one() - ((T::one() - self.r_squared) * n_minus_1 / n_minus_p_minus_1);
156 }
157 }
158
159 if self.n > self.p + 1 {
161 let n_minus_p_minus_1 = T::from(self.n - self.p - 1).unwrap();
162 self.standard_error = (ss_residual / n_minus_p_minus_1).sqrt();
163 }
164
165 Ok(())
166 }
167
168 fn predict_t(&self, x: &[T]) -> T {
170 if x.len() != self.p || self.coefficients.is_empty() {
171 return T::nan();
172 }
173
174 let mut result = self.coefficients[0];
176
177 for (i, &xi) in x.iter().enumerate().take(self.p) {
179 result = result + (self.coefficients[i + 1] * xi);
180 }
181
182 result
183 }
184
185 pub fn predict<U>(&self, x: &[U]) -> T
193 where
194 U: NumCast + Copy,
195 {
196 if x.len() != self.p {
197 return T::nan();
198 }
199
200 let x_cast: Result<Vec<T>, ()> = x.iter().map(|&val| T::from(val).ok_or(())).collect();
202
203 match x_cast {
204 Ok(x_t) => self.predict_t(&x_t),
205 Err(_) => T::nan(),
206 }
207 }
208
209 pub fn predict_many<U>(&self, x_values: &[Vec<U>]) -> Vec<T>
217 where
218 U: NumCast + Copy,
219 {
220 x_values.iter().map(|x| self.predict(x)).collect()
221 }
222
223 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
231 let file = File::create(path)?;
232 serde_json::to_writer(file, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
234 }
235
236 pub fn save_binary<P: AsRef<Path>>(&self, path: P) -> Result<(), io::Error> {
244 let file = File::create(path)?;
245 bincode::serialize_into(file, self).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
247 }
248
249 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
257 let file = File::open(path)?;
258 serde_json::from_reader(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
260 }
261
262 pub fn load_binary<P: AsRef<Path>>(path: P) -> Result<Self, io::Error> {
270 let file = File::open(path)?;
271 bincode::deserialize_from(file).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
273 }
274
275 pub fn to_json(&self) -> Result<String, String> {
280 serde_json::to_string(self).map_err(|e| format!("Failed to serialize model: {}", e))
281 }
282
283 pub fn from_json(json: &str) -> Result<Self, String> {
291 serde_json::from_str(json).map_err(|e| format!("Failed to deserialize model: {}", e))
292 }
293
294 fn matrix_multiply_transpose(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Vec<Vec<T>> {
296 let a_rows = a.len();
297 let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
298 let b_rows = b.len();
299 let b_cols = if b_rows > 0 { b[0].len() } else { 0 };
300
301 let mut result = vec![vec![T::zero(); b_cols]; a_cols];
303
304 for (i, result_row) in result.iter_mut().enumerate().take(a_cols) {
305 for (j, result_elem) in result_row.iter_mut().enumerate().take(b_cols) {
306 let mut sum = T::zero();
307 for k in 0..a_rows {
308 sum = sum + (a[k][i] * b[k][j]);
309 }
310 *result_elem = sum;
311 }
312 }
313
314 result
315 }
316
317 fn vector_multiply_transpose(&self, a: &[Vec<T>], y: &[T]) -> Vec<T> {
319 let a_rows = a.len();
320 let a_cols = if a_rows > 0 { a[0].len() } else { 0 };
321
322 let mut result = vec![T::zero(); a_cols];
323
324 for (i, result_item) in result.iter_mut().enumerate().take(a_cols) {
325 let mut sum = T::zero();
326 for j in 0..a_rows {
327 sum = sum + (a[j][i] * y[j]);
328 }
329 *result_item = sum;
330 }
331
332 result
333 }
334
335 fn solve_linear_system(&self, a: &[Vec<T>], b: &[T]) -> Result<Vec<T>, String> {
337 let n = a.len();
338 if n == 0 || a[0].len() != n || b.len() != n {
339 return Err("Invalid matrix dimensions for linear system solving".to_string());
340 }
341
342 let mut aug = Vec::with_capacity(n);
344 for i in 0..n {
345 let mut row = a[i].clone();
346 row.push(b[i]);
347 aug.push(row);
348 }
349
350 for i in 0..n {
352 let mut max_row = i;
354 let mut max_val = aug[i][i].abs();
355
356 for (j, row) in aug.iter().enumerate().skip(i + 1).take(n - (i + 1)) {
357 let abs_val = row[i].abs();
358 if abs_val > max_val {
359 max_row = j;
360 max_val = abs_val;
361 }
362 }
363
364 let epsilon: T = T::from(1e-10).unwrap();
365 if max_val < epsilon {
366 return Err("Matrix is singular or near-singular".to_string());
367 }
368
369 if max_row != i {
371 aug.swap(i, max_row);
372 }
373
374 for j in (i + 1)..n {
376 let factor = aug[j][i] / aug[i][i];
377
378 for k in i..(n + 1) {
379 aug[j][k] = aug[j][k] - (factor * aug[i][k]);
380 }
381 }
382 }
383
384 let mut x = vec![T::zero(); n];
386 for i in (0..n).rev() {
387 let mut sum = aug[i][n];
388
389 for (j, &x_val) in x.iter().enumerate().skip(i + 1).take(n - (i + 1)) {
390 sum = sum - (aug[i][j] * x_val);
391 }
392
393 x[i] = sum / aug[i][i];
394 }
395
396 Ok(x)
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use crate::utils::numeric::approx_equal;
404 use tempfile::tempdir;
405
406 #[test]
407 fn test_simple_multi_regression_f64() {
408 let x = vec![
410 vec![1.0, 2.0],
411 vec![2.0, 1.0],
412 vec![3.0, 3.0],
413 vec![4.0, 2.0],
414 ];
415 let y = vec![9.0, 8.0, 16.0, 15.0];
416
417 let mut model = MultipleLinearRegression::<f64>::new();
418 let result = model.fit(&x, &y);
419
420 assert!(result.is_ok());
421 assert!(model.coefficients.len() == 3);
422 assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
426 }
427
428 #[test]
429 fn test_simple_multi_regression_f32() {
430 let x = vec![
432 vec![1.0f32, 2.0f32],
433 vec![2.0f32, 1.0f32],
434 vec![3.0f32, 3.0f32],
435 vec![4.0f32, 2.0f32],
436 ];
437 let y = vec![9.0f32, 8.0f32, 16.0f32, 15.0f32];
438
439 let mut model = MultipleLinearRegression::<f32>::new();
440 let result = model.fit(&x, &y);
441
442 assert!(result.is_ok());
443 assert!(model.coefficients.len() == 3);
444 assert!(approx_equal(model.coefficients[0], 1.0f32, Some(1e-4))); assert!(approx_equal(model.coefficients[1], 2.0f32, Some(1e-4))); assert!(approx_equal(model.coefficients[2], 3.0f32, Some(1e-4))); assert!(approx_equal(model.r_squared, 1.0f32, Some(1e-4)));
448 }
449
450 #[test]
451 fn test_integer_data() {
452 let x = vec![
454 vec![1u32, 2u32],
455 vec![2u32, 1u32],
456 vec![3u32, 3u32],
457 vec![4u32, 2u32],
458 ];
459 let y = vec![9i32, 8i32, 16i32, 15i32];
460
461 let mut model = MultipleLinearRegression::<f64>::new();
462 let result = model.fit(&x, &y);
463
464 assert!(result.is_ok());
465 assert!(model.coefficients.len() == 3);
466 assert!(approx_equal(model.coefficients[0], 1.0, Some(1e-6))); assert!(approx_equal(model.coefficients[1], 2.0, Some(1e-6))); assert!(approx_equal(model.coefficients[2], 3.0, Some(1e-6))); assert!(approx_equal(model.r_squared, 1.0, Some(1e-6)));
470 }
471
472 #[test]
473 fn test_prediction() {
474 let x = vec![vec![1, 2], vec![2, 1], vec![3, 3], vec![4, 2]];
476 let y = vec![9, 8, 16, 15];
477
478 let mut model = MultipleLinearRegression::<f64>::new();
479 model.fit(&x, &y).unwrap();
480
481 assert!(approx_equal(model.predict(&[5u32, 4u32]), 23.0, Some(1e-6)));
483 }
484
485 #[test]
486 fn test_prediction_many() {
487 let x = vec![vec![1, 2], vec![2, 1], vec![3, 3]];
488 let y = vec![9, 8, 16];
489
490 let mut model = MultipleLinearRegression::<f64>::new();
491 model.fit(&x, &y).unwrap();
492
493 let new_x = vec![vec![1u32, 2u32], vec![5u32, 4u32]];
494
495 let predictions = model.predict_many(&new_x);
496 assert_eq!(predictions.len(), 2);
497 assert!(approx_equal(predictions[0], 9.0, Some(1e-6)));
498 assert!(approx_equal(predictions[1], 23.0, Some(1e-6)));
499 }
500
501 #[test]
502 fn test_save_load_json() {
503 let dir = tempdir().unwrap();
505 let file_path = dir.path().join("model.json");
506
507 let x = vec![
509 vec![1.0, 2.0],
510 vec![2.0, 1.0],
511 vec![3.0, 3.0],
512 vec![4.0, 2.0],
513 ];
514 let y = vec![9.0, 8.0, 16.0, 15.0];
515
516 let mut model = MultipleLinearRegression::<f64>::new();
517 model.fit(&x, &y).unwrap();
518
519 let save_result = model.save(&file_path);
521 assert!(save_result.is_ok());
522
523 let loaded_model = MultipleLinearRegression::<f64>::load(&file_path);
525 assert!(loaded_model.is_ok());
526 let loaded = loaded_model.unwrap();
527
528 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
530 for i in 0..model.coefficients.len() {
531 assert!(approx_equal(
532 loaded.coefficients[i],
533 model.coefficients[i],
534 Some(1e-6)
535 ));
536 }
537 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
538 assert_eq!(loaded.n, model.n);
539 assert_eq!(loaded.p, model.p);
540 }
541
542 #[test]
543 fn test_save_load_binary() {
544 let dir = tempdir().unwrap();
546 let file_path = dir.path().join("model.bin");
547
548 let x = vec![
550 vec![1.0, 2.0],
551 vec![2.0, 1.0],
552 vec![3.0, 3.0],
553 vec![4.0, 2.0],
554 ];
555 let y = vec![9.0, 8.0, 16.0, 15.0];
556
557 let mut model = MultipleLinearRegression::<f64>::new();
558 model.fit(&x, &y).unwrap();
559
560 let save_result = model.save_binary(&file_path);
562 assert!(save_result.is_ok());
563
564 let loaded_model = MultipleLinearRegression::<f64>::load_binary(&file_path);
566 assert!(loaded_model.is_ok());
567 let loaded = loaded_model.unwrap();
568
569 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
571 for i in 0..model.coefficients.len() {
572 assert!(approx_equal(
573 loaded.coefficients[i],
574 model.coefficients[i],
575 Some(1e-6)
576 ));
577 }
578 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
579 assert_eq!(loaded.n, model.n);
580 assert_eq!(loaded.p, model.p);
581 }
582
583 #[test]
584 fn test_json_serialization() {
585 let x = vec![
587 vec![1.0, 2.0],
588 vec![2.0, 1.0],
589 vec![3.0, 3.0],
590 vec![4.0, 2.0],
591 ];
592 let y = vec![9.0, 8.0, 16.0, 15.0];
593
594 let mut model = MultipleLinearRegression::<f64>::new();
595 model.fit(&x, &y).unwrap();
596
597 let json_result = model.to_json();
599 assert!(json_result.is_ok());
600 let json_str = json_result.unwrap();
601
602 let loaded_model = MultipleLinearRegression::<f64>::from_json(&json_str);
604 assert!(loaded_model.is_ok());
605 let loaded = loaded_model.unwrap();
606
607 assert_eq!(loaded.coefficients.len(), model.coefficients.len());
609 for i in 0..model.coefficients.len() {
610 assert!(approx_equal(
611 loaded.coefficients[i],
612 model.coefficients[i],
613 Some(1e-6)
614 ));
615 }
616 assert!(approx_equal(loaded.r_squared, model.r_squared, Some(1e-6)));
617 assert_eq!(loaded.n, model.n);
618 assert_eq!(loaded.p, model.p);
619 }
620}