1use crate::error::Result;
6use crate::metrics::r_squared;
7use crate::primitives::{Matrix, Vector};
8use crate::traits::Estimator;
9use serde::{Deserialize, Serialize};
10use std::fs;
11use std::path::Path;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct LinearRegression {
56 coefficients: Option<Vector<f32>>,
58 intercept: f32,
60 fit_intercept: bool,
62}
63
64impl Default for LinearRegression {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl LinearRegression {
71 #[must_use]
73 pub fn new() -> Self {
74 Self {
75 coefficients: None,
76 intercept: 0.0,
77 fit_intercept: true,
78 }
79 }
80
81 #[must_use]
83 pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
84 self.fit_intercept = fit_intercept;
85 self
86 }
87
88 #[must_use]
94 pub fn coefficients(&self) -> &Vector<f32> {
95 self.coefficients
96 .as_ref()
97 .expect("Model not fitted. Call fit() first.")
98 }
99
100 #[must_use]
102 pub fn intercept(&self) -> f32 {
103 self.intercept
104 }
105
106 #[must_use]
108 pub fn is_fitted(&self) -> bool {
109 self.coefficients.is_some()
110 }
111
112 pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
118 let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
119 fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
120 Ok(())
121 }
122
123 pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
129 let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
130 let model =
131 bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
132 Ok(model)
133 }
134
135 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
150 use crate::serialization::safetensors;
151 use std::collections::BTreeMap;
152
153 let coefficients = self
155 .coefficients
156 .as_ref()
157 .ok_or("Cannot save unfitted model. Call fit() first.")?;
158
159 let mut tensors = BTreeMap::new();
161
162 let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
164 let coef_shape = vec![coefficients.len()];
165 tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
166
167 let intercept_data = vec![self.intercept];
169 let intercept_shape = vec![1];
170 tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
171
172 safetensors::save_safetensors(path, &tensors)?;
174 Ok(())
175 }
176
177 pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
186 use crate::serialization::safetensors;
187
188 let (metadata, raw_data) = safetensors::load_safetensors(path)?;
190
191 let coef_meta = metadata
193 .get("coefficients")
194 .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
195 let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
196
197 let intercept_meta = metadata
199 .get("intercept")
200 .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
201 let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
202
203 if intercept_data.len() != 1 {
205 return Err(format!(
206 "Invalid intercept tensor: expected 1 value, got {}",
207 intercept_data.len()
208 ));
209 }
210
211 Ok(Self {
213 coefficients: Some(Vector::from_vec(coef_data)),
214 intercept: intercept_data[0],
215 fit_intercept: true, })
217 }
218
219 fn add_intercept_column(x: &Matrix<f32>) -> Matrix<f32> {
221 let (n_rows, n_cols) = x.shape();
222 let mut data = Vec::with_capacity(n_rows * (n_cols + 1));
223
224 for i in 0..n_rows {
225 data.push(1.0); for j in 0..n_cols {
227 data.push(x.get(i, j));
228 }
229 }
230
231 Matrix::from_vec(n_rows, n_cols + 1, data)
232 .expect("Internal error: failed to create design matrix")
233 }
234}
235
236impl Estimator for LinearRegression {
237 fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
248 let (n_samples, n_features) = x.shape();
249
250 if n_samples != y.len() {
251 return Err("Number of samples must match target length".into());
252 }
253
254 if n_samples == 0 {
255 return Err("Cannot fit with zero samples".into());
256 }
257
258 let required_samples = if self.fit_intercept {
262 n_features + 1
263 } else {
264 n_features
265 };
266
267 if n_samples < required_samples {
268 return Err(
269 "Insufficient samples: LinearRegression requires at least as many samples as \
270 features (plus 1 if fitting intercept). Consider using Ridge regression or \
271 collecting more training data"
272 .into(),
273 );
274 }
275
276 let x_design = if self.fit_intercept {
278 Self::add_intercept_column(x)
279 } else {
280 x.clone()
281 };
282
283 let xt = x_design.transpose();
285 let xtx = xt.matmul(&x_design)?;
286
287 let xty = xt.matvec(y)?;
289
290 let beta = xtx.cholesky_solve(&xty)?;
292
293 if self.fit_intercept {
295 self.intercept = beta[0];
296 self.coefficients = Some(beta.slice(1, n_features + 1));
297 } else {
298 self.intercept = 0.0;
299 self.coefficients = Some(beta);
300 }
301
302 Ok(())
303 }
304
305 fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
311 let coefficients = self
312 .coefficients
313 .as_ref()
314 .expect("Model not fitted. Call fit() first.");
315
316 let result = x
317 .matvec(coefficients)
318 .expect("Matrix dimensions don't match coefficients");
319
320 result.add_scalar(self.intercept)
321 }
322
323 fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
325 let y_pred = self.predict(x);
326 r_squared(&y_pred, y)
327 }
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
375pub struct Ridge {
376 alpha: f32,
378 coefficients: Option<Vector<f32>>,
380 intercept: f32,
382 fit_intercept: bool,
384}
385
386impl Ridge {
387 #[must_use]
394 pub fn new(alpha: f32) -> Self {
395 Self {
396 alpha,
397 coefficients: None,
398 intercept: 0.0,
399 fit_intercept: true,
400 }
401 }
402
403 #[must_use]
405 pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
406 self.fit_intercept = fit_intercept;
407 self
408 }
409
410 #[must_use]
412 pub fn alpha(&self) -> f32 {
413 self.alpha
414 }
415
416 #[must_use]
422 pub fn coefficients(&self) -> &Vector<f32> {
423 self.coefficients
424 .as_ref()
425 .expect("Model not fitted. Call fit() first.")
426 }
427
428 #[must_use]
430 pub fn intercept(&self) -> f32 {
431 self.intercept
432 }
433
434 #[must_use]
436 pub fn is_fitted(&self) -> bool {
437 self.coefficients.is_some()
438 }
439
440 pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
446 let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
447 fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
448 Ok(())
449 }
450
451 pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
457 let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
458 let model =
459 bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
460 Ok(model)
461 }
462
463 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
478 use crate::serialization::safetensors;
479 use std::collections::BTreeMap;
480
481 let coefficients = self
483 .coefficients
484 .as_ref()
485 .ok_or("Cannot save unfitted model. Call fit() first.")?;
486
487 let mut tensors = BTreeMap::new();
489
490 let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
492 let coef_shape = vec![coefficients.len()];
493 tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
494
495 let intercept_data = vec![self.intercept];
497 let intercept_shape = vec![1];
498 tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
499
500 let alpha_data = vec![self.alpha];
502 let alpha_shape = vec![1];
503 tensors.insert("alpha".to_string(), (alpha_data, alpha_shape));
504
505 safetensors::save_safetensors(path, &tensors)?;
507 Ok(())
508 }
509
510 pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
519 use crate::serialization::safetensors;
520
521 let (metadata, raw_data) = safetensors::load_safetensors(path)?;
523
524 let coef_meta = metadata
526 .get("coefficients")
527 .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
528 let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
529
530 let intercept_meta = metadata
532 .get("intercept")
533 .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
534 let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
535
536 let alpha_meta = metadata
538 .get("alpha")
539 .ok_or("Missing 'alpha' tensor in SafeTensors file")?;
540 let alpha_data = safetensors::extract_tensor(&raw_data, alpha_meta)?;
541
542 if intercept_data.len() != 1 {
544 return Err(format!(
545 "Expected intercept tensor to have 1 element, got {}",
546 intercept_data.len()
547 ));
548 }
549
550 if alpha_data.len() != 1 {
551 return Err(format!(
552 "Expected alpha tensor to have 1 element, got {}",
553 alpha_data.len()
554 ));
555 }
556
557 Ok(Self {
559 alpha: alpha_data[0],
560 coefficients: Some(Vector::from_vec(coef_data)),
561 intercept: intercept_data[0],
562 fit_intercept: true, })
564 }
565}
566
567impl Estimator for Ridge {
568 fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
576 let (n_samples, n_features) = x.shape();
577
578 if n_samples != y.len() {
579 return Err("Number of samples must match target length".into());
580 }
581
582 if n_samples == 0 {
583 return Err("Cannot fit with zero samples".into());
584 }
585
586 let x_design = if self.fit_intercept {
588 LinearRegression::add_intercept_column(x)
589 } else {
590 x.clone()
591 };
592
593 let n_params = if self.fit_intercept {
594 n_features + 1
595 } else {
596 n_features
597 };
598
599 let xt = x_design.transpose();
601 let mut xtx = xt.matmul(&x_design)?;
602
603 for i in 0..n_params {
606 if self.fit_intercept && i == 0 {
608 continue;
609 }
610 let current = xtx.get(i, i);
611 xtx.set(i, i, current + self.alpha);
612 }
613
614 let xty = xt.matvec(y)?;
616
617 let beta = xtx.cholesky_solve(&xty)?;
619
620 if self.fit_intercept {
622 self.intercept = beta[0];
623 self.coefficients = Some(beta.slice(1, n_features + 1));
624 } else {
625 self.intercept = 0.0;
626 self.coefficients = Some(beta);
627 }
628
629 Ok(())
630 }
631
632 fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
638 let coefficients = self
639 .coefficients
640 .as_ref()
641 .expect("Model not fitted. Call fit() first.");
642
643 let result = x
644 .matvec(coefficients)
645 .expect("Matrix dimensions don't match coefficients");
646
647 result.add_scalar(self.intercept)
648 }
649
650 fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
652 let y_pred = self.predict(x);
653 r_squared(&y_pred, y)
654 }
655}
656
657#[derive(Debug, Clone, Serialize, Deserialize)]
702pub struct Lasso {
703 alpha: f32,
705 coefficients: Option<Vector<f32>>,
707 intercept: f32,
709 fit_intercept: bool,
711 max_iter: usize,
713 tol: f32,
715}
716
717impl Lasso {
718 #[must_use]
725 pub fn new(alpha: f32) -> Self {
726 Self {
727 alpha,
728 coefficients: None,
729 intercept: 0.0,
730 fit_intercept: true,
731 max_iter: 1000,
732 tol: 1e-4,
733 }
734 }
735
736 #[must_use]
738 pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
739 self.fit_intercept = fit_intercept;
740 self
741 }
742
743 #[must_use]
745 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
746 self.max_iter = max_iter;
747 self
748 }
749
750 #[must_use]
752 pub fn with_tol(mut self, tol: f32) -> Self {
753 self.tol = tol;
754 self
755 }
756
757 #[must_use]
759 pub fn alpha(&self) -> f32 {
760 self.alpha
761 }
762
763 #[must_use]
769 pub fn coefficients(&self) -> &Vector<f32> {
770 self.coefficients
771 .as_ref()
772 .expect("Model not fitted. Call fit() first.")
773 }
774
775 #[must_use]
777 pub fn intercept(&self) -> f32 {
778 self.intercept
779 }
780
781 #[must_use]
783 pub fn is_fitted(&self) -> bool {
784 self.coefficients.is_some()
785 }
786
787 fn soft_threshold(x: f32, lambda: f32) -> f32 {
789 if x > lambda {
790 x - lambda
791 } else if x < -lambda {
792 x + lambda
793 } else {
794 0.0
795 }
796 }
797
798 pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
804 let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
805 fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
806 Ok(())
807 }
808
809 pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
815 let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
816 let model =
817 bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
818 Ok(model)
819 }
820
821 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
836 use crate::serialization::safetensors;
837 use std::collections::BTreeMap;
838
839 let coefficients = self
841 .coefficients
842 .as_ref()
843 .ok_or("Cannot save unfitted model. Call fit() first.")?;
844
845 let mut tensors = BTreeMap::new();
847
848 let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
850 let coef_shape = vec![coefficients.len()];
851 tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
852
853 let intercept_data = vec![self.intercept];
855 let intercept_shape = vec![1];
856 tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
857
858 let alpha_data = vec![self.alpha];
860 let alpha_shape = vec![1];
861 tensors.insert("alpha".to_string(), (alpha_data, alpha_shape));
862
863 let max_iter_data = vec![self.max_iter as f32];
865 let max_iter_shape = vec![1];
866 tensors.insert("max_iter".to_string(), (max_iter_data, max_iter_shape));
867
868 let tol_data = vec![self.tol];
870 let tol_shape = vec![1];
871 tensors.insert("tol".to_string(), (tol_data, tol_shape));
872
873 safetensors::save_safetensors(path, &tensors)?;
875 Ok(())
876 }
877
878 pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
887 use crate::serialization::safetensors;
888
889 let (metadata, raw_data) = safetensors::load_safetensors(path)?;
891
892 let coef_meta = metadata
894 .get("coefficients")
895 .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
896 let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
897
898 let intercept_meta = metadata
900 .get("intercept")
901 .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
902 let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
903
904 let alpha_meta = metadata
906 .get("alpha")
907 .ok_or("Missing 'alpha' tensor in SafeTensors file")?;
908 let alpha_data = safetensors::extract_tensor(&raw_data, alpha_meta)?;
909
910 let max_iter_meta = metadata
912 .get("max_iter")
913 .ok_or("Missing 'max_iter' tensor in SafeTensors file")?;
914 let max_iter_data = safetensors::extract_tensor(&raw_data, max_iter_meta)?;
915
916 let tol_meta = metadata
918 .get("tol")
919 .ok_or("Missing 'tol' tensor in SafeTensors file")?;
920 let tol_data = safetensors::extract_tensor(&raw_data, tol_meta)?;
921
922 if intercept_data.len() != 1 {
924 return Err(format!(
925 "Expected intercept tensor to have 1 element, got {}",
926 intercept_data.len()
927 ));
928 }
929
930 if alpha_data.len() != 1 {
931 return Err(format!(
932 "Expected alpha tensor to have 1 element, got {}",
933 alpha_data.len()
934 ));
935 }
936
937 if max_iter_data.len() != 1 {
938 return Err(format!(
939 "Expected max_iter tensor to have 1 element, got {}",
940 max_iter_data.len()
941 ));
942 }
943
944 if tol_data.len() != 1 {
945 return Err(format!(
946 "Expected tol tensor to have 1 element, got {}",
947 tol_data.len()
948 ));
949 }
950
951 Ok(Self {
953 alpha: alpha_data[0],
954 coefficients: Some(Vector::from_vec(coef_data)),
955 intercept: intercept_data[0],
956 fit_intercept: true, max_iter: max_iter_data[0] as usize,
958 tol: tol_data[0],
959 })
960 }
961}
962
963impl Estimator for Lasso {
964 fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
970 let (n_samples, n_features) = x.shape();
971
972 if n_samples != y.len() {
973 return Err("Number of samples must match target length".into());
974 }
975
976 if n_samples == 0 {
977 return Err("Cannot fit with zero samples".into());
978 }
979
980 let (x_centered, y_centered, y_mean) = if self.fit_intercept {
982 let mut x_mean = vec![0.0; n_features];
984 let mut y_sum = 0.0;
985
986 for i in 0..n_samples {
987 for (j, mean_j) in x_mean.iter_mut().enumerate() {
988 *mean_j += x.get(i, j);
989 }
990 y_sum += y[i];
991 }
992
993 for mean in &mut x_mean {
994 *mean /= n_samples as f32;
995 }
996 let y_mean = y_sum / n_samples as f32;
997
998 let mut x_data = vec![0.0; n_samples * n_features];
1000 let mut y_data = vec![0.0; n_samples];
1001
1002 for i in 0..n_samples {
1003 for j in 0..n_features {
1004 x_data[i * n_features + j] = x.get(i, j) - x_mean[j];
1005 }
1006 y_data[i] = y[i] - y_mean;
1007 }
1008
1009 (
1010 Matrix::from_vec(n_samples, n_features, x_data)
1011 .expect("Valid matrix dimensions for property test"),
1012 Vector::from_vec(y_data),
1013 y_mean,
1014 )
1015 } else {
1016 (x.clone(), y.clone(), 0.0)
1017 };
1018
1019 let mut beta = vec![0.0; n_features];
1021
1022 let mut col_norms_sq = vec![0.0; n_features];
1024 for (j, norm_sq) in col_norms_sq.iter_mut().enumerate() {
1025 for i in 0..n_samples {
1026 let val = x_centered.get(i, j);
1027 *norm_sq += val * val;
1028 }
1029 }
1030
1031 for _ in 0..self.max_iter {
1033 let mut max_change = 0.0f32;
1034
1035 for j in 0..n_features {
1036 if col_norms_sq[j] < 1e-10 {
1037 continue; }
1039
1040 let mut rho = 0.0;
1042 for i in 0..n_samples {
1043 let mut pred = 0.0;
1044 for (k, &beta_k) in beta.iter().enumerate() {
1045 if k != j {
1046 pred += x_centered.get(i, k) * beta_k;
1047 }
1048 }
1049 let residual = y_centered[i] - pred;
1050 rho += x_centered.get(i, j) * residual;
1051 }
1052
1053 let old_beta = beta[j];
1055 beta[j] = Self::soft_threshold(rho, self.alpha) / col_norms_sq[j];
1056
1057 let change = (beta[j] - old_beta).abs();
1058 if change > max_change {
1059 max_change = change;
1060 }
1061 }
1062
1063 if max_change < self.tol {
1065 break;
1066 }
1067 }
1068
1069 if self.fit_intercept {
1071 let mut intercept = y_mean;
1072 let mut x_mean = vec![0.0; n_features];
1073 for j in 0..n_features {
1074 for i in 0..n_samples {
1075 x_mean[j] += x.get(i, j);
1076 }
1077 x_mean[j] /= n_samples as f32;
1078 intercept -= beta[j] * x_mean[j];
1079 }
1080 self.intercept = intercept;
1081 } else {
1082 self.intercept = 0.0;
1083 }
1084
1085 self.coefficients = Some(Vector::from_vec(beta));
1086 Ok(())
1087 }
1088
1089 fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
1095 let coefficients = self
1096 .coefficients
1097 .as_ref()
1098 .expect("Model not fitted. Call fit() first.");
1099
1100 let result = x
1101 .matvec(coefficients)
1102 .expect("Matrix dimensions don't match coefficients");
1103
1104 result.add_scalar(self.intercept)
1105 }
1106
1107 fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
1109 let y_pred = self.predict(x);
1110 r_squared(&y_pred, y)
1111 }
1112}
1113
1114#[derive(Debug, Clone, Serialize, Deserialize)]
1160pub struct ElasticNet {
1161 alpha: f32,
1163 l1_ratio: f32,
1165 coefficients: Option<Vector<f32>>,
1167 intercept: f32,
1169 fit_intercept: bool,
1171 max_iter: usize,
1173 tol: f32,
1175}
1176
1177impl ElasticNet {
1178 #[must_use]
1185 pub fn new(alpha: f32, l1_ratio: f32) -> Self {
1186 Self {
1187 alpha,
1188 l1_ratio: l1_ratio.clamp(0.0, 1.0),
1189 coefficients: None,
1190 intercept: 0.0,
1191 fit_intercept: true,
1192 max_iter: 1000,
1193 tol: 1e-4,
1194 }
1195 }
1196
1197 #[must_use]
1199 pub fn with_intercept(mut self, fit_intercept: bool) -> Self {
1200 self.fit_intercept = fit_intercept;
1201 self
1202 }
1203
1204 #[must_use]
1206 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1207 self.max_iter = max_iter;
1208 self
1209 }
1210
1211 #[must_use]
1213 pub fn with_tol(mut self, tol: f32) -> Self {
1214 self.tol = tol;
1215 self
1216 }
1217
1218 #[must_use]
1220 pub fn alpha(&self) -> f32 {
1221 self.alpha
1222 }
1223
1224 #[must_use]
1226 pub fn l1_ratio(&self) -> f32 {
1227 self.l1_ratio
1228 }
1229
1230 #[must_use]
1236 pub fn coefficients(&self) -> &Vector<f32> {
1237 self.coefficients
1238 .as_ref()
1239 .expect("Model not fitted. Call fit() first.")
1240 }
1241
1242 #[must_use]
1244 pub fn intercept(&self) -> f32 {
1245 self.intercept
1246 }
1247
1248 #[must_use]
1250 pub fn is_fitted(&self) -> bool {
1251 self.coefficients.is_some()
1252 }
1253
1254 pub fn save<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
1260 let bytes = bincode::serialize(self).map_err(|e| format!("Serialization failed: {e}"))?;
1261 fs::write(path, bytes).map_err(|e| format!("File write failed: {e}"))?;
1262 Ok(())
1263 }
1264
1265 pub fn load<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
1271 let bytes = fs::read(path).map_err(|e| format!("File read failed: {e}"))?;
1272 let model =
1273 bincode::deserialize(&bytes).map_err(|e| format!("Deserialization failed: {e}"))?;
1274 Ok(model)
1275 }
1276
1277 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
1292 use crate::serialization::safetensors;
1293 use std::collections::BTreeMap;
1294
1295 let coefficients = self
1297 .coefficients
1298 .as_ref()
1299 .ok_or("Cannot save unfitted model. Call fit() first.")?;
1300
1301 let mut tensors = BTreeMap::new();
1303
1304 let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
1306 let coef_shape = vec![coefficients.len()];
1307 tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
1308
1309 let intercept_data = vec![self.intercept];
1311 let intercept_shape = vec![1];
1312 tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
1313
1314 let alpha_data = vec![self.alpha];
1316 let alpha_shape = vec![1];
1317 tensors.insert("alpha".to_string(), (alpha_data, alpha_shape));
1318
1319 let l1_ratio_data = vec![self.l1_ratio];
1321 let l1_ratio_shape = vec![1];
1322 tensors.insert("l1_ratio".to_string(), (l1_ratio_data, l1_ratio_shape));
1323
1324 let max_iter_data = vec![self.max_iter as f32];
1326 let max_iter_shape = vec![1];
1327 tensors.insert("max_iter".to_string(), (max_iter_data, max_iter_shape));
1328
1329 let tol_data = vec![self.tol];
1331 let tol_shape = vec![1];
1332 tensors.insert("tol".to_string(), (tol_data, tol_shape));
1333
1334 safetensors::save_safetensors(path, &tensors)?;
1336 Ok(())
1337 }
1338
1339 pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
1348 use crate::serialization::safetensors;
1349
1350 let (metadata, raw_data) = safetensors::load_safetensors(path)?;
1352
1353 let coef_meta = metadata
1355 .get("coefficients")
1356 .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
1357 let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
1358
1359 let intercept_meta = metadata
1361 .get("intercept")
1362 .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
1363 let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
1364
1365 let alpha_meta = metadata
1367 .get("alpha")
1368 .ok_or("Missing 'alpha' tensor in SafeTensors file")?;
1369 let alpha_data = safetensors::extract_tensor(&raw_data, alpha_meta)?;
1370
1371 let l1_ratio_meta = metadata
1373 .get("l1_ratio")
1374 .ok_or("Missing 'l1_ratio' tensor in SafeTensors file")?;
1375 let l1_ratio_data = safetensors::extract_tensor(&raw_data, l1_ratio_meta)?;
1376
1377 let max_iter_meta = metadata
1379 .get("max_iter")
1380 .ok_or("Missing 'max_iter' tensor in SafeTensors file")?;
1381 let max_iter_data = safetensors::extract_tensor(&raw_data, max_iter_meta)?;
1382
1383 let tol_meta = metadata
1385 .get("tol")
1386 .ok_or("Missing 'tol' tensor in SafeTensors file")?;
1387 let tol_data = safetensors::extract_tensor(&raw_data, tol_meta)?;
1388
1389 if intercept_data.len() != 1 {
1391 return Err(format!(
1392 "Expected intercept tensor to have 1 element, got {}",
1393 intercept_data.len()
1394 ));
1395 }
1396
1397 if alpha_data.len() != 1 {
1398 return Err(format!(
1399 "Expected alpha tensor to have 1 element, got {}",
1400 alpha_data.len()
1401 ));
1402 }
1403
1404 if l1_ratio_data.len() != 1 {
1405 return Err(format!(
1406 "Expected l1_ratio tensor to have 1 element, got {}",
1407 l1_ratio_data.len()
1408 ));
1409 }
1410
1411 if max_iter_data.len() != 1 {
1412 return Err(format!(
1413 "Expected max_iter tensor to have 1 element, got {}",
1414 max_iter_data.len()
1415 ));
1416 }
1417
1418 if tol_data.len() != 1 {
1419 return Err(format!(
1420 "Expected tol tensor to have 1 element, got {}",
1421 tol_data.len()
1422 ));
1423 }
1424
1425 Ok(Self {
1427 alpha: alpha_data[0],
1428 l1_ratio: l1_ratio_data[0],
1429 coefficients: Some(Vector::from_vec(coef_data)),
1430 intercept: intercept_data[0],
1431 fit_intercept: true, max_iter: max_iter_data[0] as usize,
1433 tol: tol_data[0],
1434 })
1435 }
1436}
1437
1438impl Estimator for ElasticNet {
1439 fn fit(&mut self, x: &Matrix<f32>, y: &Vector<f32>) -> Result<()> {
1445 let (n_samples, n_features) = x.shape();
1446
1447 if n_samples != y.len() {
1448 return Err("Number of samples must match target length".into());
1449 }
1450
1451 if n_samples == 0 {
1452 return Err("Cannot fit with zero samples".into());
1453 }
1454
1455 let (x_centered, y_centered, y_mean) = if self.fit_intercept {
1457 let mut x_mean = vec![0.0; n_features];
1458 let mut y_sum = 0.0;
1459
1460 for i in 0..n_samples {
1461 for (j, mean_j) in x_mean.iter_mut().enumerate() {
1462 *mean_j += x.get(i, j);
1463 }
1464 y_sum += y[i];
1465 }
1466
1467 for mean in &mut x_mean {
1468 *mean /= n_samples as f32;
1469 }
1470 let y_mean = y_sum / n_samples as f32;
1471
1472 let mut x_data = vec![0.0; n_samples * n_features];
1473 let mut y_data = vec![0.0; n_samples];
1474
1475 for i in 0..n_samples {
1476 for j in 0..n_features {
1477 x_data[i * n_features + j] = x.get(i, j) - x_mean[j];
1478 }
1479 y_data[i] = y[i] - y_mean;
1480 }
1481
1482 (
1483 Matrix::from_vec(n_samples, n_features, x_data)
1484 .expect("Valid matrix dimensions for property test"),
1485 Vector::from_vec(y_data),
1486 y_mean,
1487 )
1488 } else {
1489 (x.clone(), y.clone(), 0.0)
1490 };
1491
1492 let mut beta = vec![0.0; n_features];
1494
1495 let mut col_norms_sq = vec![0.0; n_features];
1497 for (j, norm_sq) in col_norms_sq.iter_mut().enumerate() {
1498 for i in 0..n_samples {
1499 let val = x_centered.get(i, j);
1500 *norm_sq += val * val;
1501 }
1502 }
1503
1504 let l1_penalty = self.alpha * self.l1_ratio;
1506 let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
1507
1508 for _ in 0..self.max_iter {
1510 let mut max_change = 0.0f32;
1511
1512 for j in 0..n_features {
1513 if col_norms_sq[j] < 1e-10 {
1514 continue;
1515 }
1516
1517 let mut rho = 0.0;
1519 for i in 0..n_samples {
1520 let mut pred = 0.0;
1521 for (k, &beta_k) in beta.iter().enumerate() {
1522 if k != j {
1523 pred += x_centered.get(i, k) * beta_k;
1524 }
1525 }
1526 let residual = y_centered[i] - pred;
1527 rho += x_centered.get(i, j) * residual;
1528 }
1529
1530 let old_beta = beta[j];
1532 let denom = col_norms_sq[j] + l2_penalty;
1533 beta[j] = Lasso::soft_threshold(rho, l1_penalty) / denom;
1534
1535 let change = (beta[j] - old_beta).abs();
1536 if change > max_change {
1537 max_change = change;
1538 }
1539 }
1540
1541 if max_change < self.tol {
1542 break;
1543 }
1544 }
1545
1546 if self.fit_intercept {
1548 let mut intercept = y_mean;
1549 let mut x_mean = vec![0.0; n_features];
1550 for j in 0..n_features {
1551 for i in 0..n_samples {
1552 x_mean[j] += x.get(i, j);
1553 }
1554 x_mean[j] /= n_samples as f32;
1555 intercept -= beta[j] * x_mean[j];
1556 }
1557 self.intercept = intercept;
1558 } else {
1559 self.intercept = 0.0;
1560 }
1561
1562 self.coefficients = Some(Vector::from_vec(beta));
1563 Ok(())
1564 }
1565
1566 fn predict(&self, x: &Matrix<f32>) -> Vector<f32> {
1572 let coefficients = self
1573 .coefficients
1574 .as_ref()
1575 .expect("Model not fitted. Call fit() first.");
1576
1577 let result = x
1578 .matvec(coefficients)
1579 .expect("Matrix dimensions don't match coefficients");
1580
1581 result.add_scalar(self.intercept)
1582 }
1583
1584 fn score(&self, x: &Matrix<f32>, y: &Vector<f32>) -> f32 {
1586 let y_pred = self.predict(x);
1587 r_squared(&y_pred, y)
1588 }
1589}
1590
1591#[cfg(test)]
1592mod tests {
1593 use super::*;
1594
1595 #[test]
1596 fn test_new() {
1597 let model = LinearRegression::new();
1598 assert!(!model.is_fitted());
1599 assert!(model.fit_intercept);
1600 }
1601
1602 #[test]
1603 fn test_simple_regression() {
1604 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
1606 .expect("Valid matrix dimensions for test");
1607 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
1608
1609 let mut model = LinearRegression::new();
1610 model
1611 .fit(&x, &y)
1612 .expect("Fit should succeed with valid test data");
1613
1614 assert!(model.is_fitted());
1615
1616 let coef = model.coefficients();
1618 assert!((coef[0] - 2.0).abs() < 1e-4);
1619 assert!((model.intercept() - 1.0).abs() < 1e-4);
1620
1621 let predictions = model.predict(&x);
1623 for i in 0..4 {
1624 assert!((predictions[i] - y[i]).abs() < 1e-4);
1625 }
1626
1627 let r2 = model.score(&x, &y);
1629 assert!((r2 - 1.0).abs() < 1e-4);
1630 }
1631
1632 #[test]
1633 fn test_multivariate_regression() {
1634 let x = Matrix::from_vec(4, 2, vec![1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0])
1636 .expect("Valid matrix dimensions for test");
1637 let y = Vector::from_slice(&[6.0, 8.0, 9.0, 11.0]);
1638
1639 let mut model = LinearRegression::new();
1640 model
1641 .fit(&x, &y)
1642 .expect("Fit should succeed with valid test data");
1643
1644 let coef = model.coefficients();
1645 assert!((coef[0] - 2.0).abs() < 1e-4);
1646 assert!((coef[1] - 3.0).abs() < 1e-4);
1647 assert!((model.intercept() - 1.0).abs() < 1e-4);
1648
1649 let r2 = model.score(&x, &y);
1650 assert!((r2 - 1.0).abs() < 1e-4);
1651 }
1652
1653 #[test]
1654 fn test_no_intercept() {
1655 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
1657 .expect("Valid matrix dimensions for test");
1658 let y = Vector::from_slice(&[2.0, 4.0, 6.0, 8.0]);
1659
1660 let mut model = LinearRegression::new().with_intercept(false);
1661 model
1662 .fit(&x, &y)
1663 .expect("Fit should succeed with valid test data");
1664
1665 let coef = model.coefficients();
1666 assert!((coef[0] - 2.0).abs() < 1e-4);
1667 assert!((model.intercept() - 0.0).abs() < 1e-4);
1668 }
1669
1670 #[test]
1671 fn test_predict_new_data() {
1672 let x_train =
1674 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1675 let y_train = Vector::from_slice(&[2.0, 3.0, 4.0]);
1676
1677 let mut model = LinearRegression::new();
1678 model
1679 .fit(&x_train, &y_train)
1680 .expect("Fit should succeed with valid test data");
1681
1682 let x_test =
1683 Matrix::from_vec(2, 1, vec![4.0, 5.0]).expect("Valid matrix dimensions for test");
1684 let predictions = model.predict(&x_test);
1685
1686 assert!((predictions[0] - 5.0).abs() < 1e-4);
1687 assert!((predictions[1] - 6.0).abs() < 1e-4);
1688 }
1689
1690 #[test]
1691 fn test_dimension_mismatch_error() {
1692 let x = Matrix::from_vec(3, 2, vec![1.0; 6]).expect("Valid matrix dimensions for test");
1693 let y = Vector::from_slice(&[1.0, 2.0]); let mut model = LinearRegression::new();
1696 let result = model.fit(&x, &y);
1697 assert!(result.is_err());
1698 }
1699
1700 #[test]
1701 fn test_empty_data_error() {
1702 let x = Matrix::from_vec(0, 2, vec![]).expect("Valid matrix dimensions for test");
1703 let y = Vector::from_vec(vec![]);
1704
1705 let mut model = LinearRegression::new();
1706 let result = model.fit(&x, &y);
1707 assert!(result.is_err());
1708 }
1709
1710 #[test]
1711 fn test_with_noise() {
1712 let x = Matrix::from_vec(5, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0])
1714 .expect("Valid matrix dimensions for test");
1715 let y = Vector::from_slice(&[3.1, 4.9, 7.2, 8.8, 11.1]);
1716
1717 let mut model = LinearRegression::new();
1718 model
1719 .fit(&x, &y)
1720 .expect("Fit should succeed with valid test data");
1721
1722 let coef = model.coefficients();
1724 assert!((coef[0] - 2.0).abs() < 0.2);
1725 assert!((model.intercept() - 1.0).abs() < 0.5);
1726
1727 let r2 = model.score(&x, &y);
1729 assert!(r2 > 0.95);
1730 assert!(r2 < 1.0);
1731 }
1732
1733 #[test]
1734 fn test_default() {
1735 let model = LinearRegression::default();
1736 assert!(!model.is_fitted());
1737 }
1738
1739 #[test]
1740 fn test_clone() {
1741 let x =
1742 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1743 let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
1744
1745 let mut model = LinearRegression::new();
1746 model
1747 .fit(&x, &y)
1748 .expect("Fit should succeed with valid test data");
1749
1750 let cloned = model.clone();
1751 assert!(cloned.is_fitted());
1752 assert!((cloned.intercept() - model.intercept()).abs() < 1e-6);
1753 }
1754
1755 #[test]
1756 fn test_score_range() {
1757 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
1759 .expect("Valid matrix dimensions for test");
1760 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
1761
1762 let mut model = LinearRegression::new();
1763 model
1764 .fit(&x, &y)
1765 .expect("Fit should succeed with valid test data");
1766
1767 let r2 = model.score(&x, &y);
1768 assert!(r2 <= 1.0);
1769 }
1770
1771 #[test]
1772 fn test_prediction_invariant() {
1773 let x = Matrix::from_vec(5, 2, vec![1.0, 1.0, 2.0, 3.0, 3.0, 2.0, 4.0, 5.0, 5.0, 4.0])
1776 .expect("Valid matrix dimensions for test");
1777 let y = Vector::from_slice(&[6.0, 14.0, 13.0, 24.0, 23.0]);
1779
1780 let mut model = LinearRegression::new();
1781 model
1782 .fit(&x, &y)
1783 .expect("Fit should succeed with valid test data");
1784
1785 let predictions = model.predict(&x);
1786
1787 for i in 0..y.len() {
1788 assert!((predictions[i] - y[i]).abs() < 1e-3);
1789 }
1790 }
1791
1792 #[test]
1793 fn test_coefficients_length_invariant() {
1794 let x = Matrix::from_vec(
1797 6,
1798 3,
1799 vec![
1800 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0,
1801 0.0, 1.0,
1802 ],
1803 )
1804 .expect("Valid matrix dimensions for test");
1805 let y = Vector::from_slice(&[1.0, 2.0, 3.0, 3.0, 5.0, 4.0]);
1806
1807 let mut model = LinearRegression::new();
1808 model
1809 .fit(&x, &y)
1810 .expect("Fit should succeed with valid test data");
1811
1812 assert_eq!(model.coefficients().len(), 3);
1813 }
1814
1815 #[test]
1816 fn test_larger_dataset() {
1817 let n = 100;
1819 let mut x_data = Vec::with_capacity(n);
1820 let mut y_data = Vec::with_capacity(n);
1821
1822 for i in 0..n {
1823 let x_val = i as f32;
1824 x_data.push(x_val);
1825 y_data.push(2.0 * x_val + 3.0); }
1827
1828 let x = Matrix::from_vec(n, 1, x_data).expect("Valid matrix dimensions for test");
1829 let y = Vector::from_vec(y_data);
1830
1831 let mut model = LinearRegression::new();
1832 model
1833 .fit(&x, &y)
1834 .expect("Fit should succeed with valid test data");
1835
1836 let coef = model.coefficients();
1837 assert!((coef[0] - 2.0).abs() < 1e-3);
1838 assert!((model.intercept() - 3.0).abs() < 1e-3);
1839 }
1840
1841 #[test]
1842 fn test_single_sample_single_feature() {
1843 let x = Matrix::from_vec(2, 1, vec![1.0, 2.0]).expect("Valid matrix dimensions for test");
1845 let y = Vector::from_slice(&[3.0, 5.0]);
1846
1847 let mut model = LinearRegression::new();
1848 model
1849 .fit(&x, &y)
1850 .expect("Fit should succeed with valid test data");
1851
1852 let coef = model.coefficients();
1854 assert!((coef[0] - 2.0).abs() < 1e-4);
1855 assert!((model.intercept() - 1.0).abs() < 1e-4);
1856 }
1857
1858 #[test]
1859 fn test_negative_values() {
1860 let x = Matrix::from_vec(4, 1, vec![-2.0, -1.0, 0.0, 1.0])
1862 .expect("Valid matrix dimensions for test");
1863 let y = Vector::from_slice(&[5.0, 3.0, 1.0, -1.0]); let mut model = LinearRegression::new();
1866 model
1867 .fit(&x, &y)
1868 .expect("Fit should succeed with valid test data");
1869
1870 let coef = model.coefficients();
1871 assert!((coef[0] - (-2.0)).abs() < 1e-4);
1872 assert!((model.intercept() - 1.0).abs() < 1e-4);
1873 }
1874
1875 #[test]
1876 fn test_large_values() {
1877 let x = Matrix::from_vec(3, 1, vec![1000.0, 2000.0, 3000.0])
1879 .expect("Valid matrix dimensions for test");
1880 let y = Vector::from_slice(&[2001.0, 4001.0, 6001.0]); let mut model = LinearRegression::new();
1883 model
1884 .fit(&x, &y)
1885 .expect("Fit should succeed with valid test data");
1886
1887 let coef = model.coefficients();
1888 assert!((coef[0] - 2.0).abs() < 1e-2);
1889 assert!((model.intercept() - 1.0).abs() < 10.0); }
1891
1892 #[test]
1893 fn test_small_values() {
1894 let x = Matrix::from_vec(3, 1, vec![0.001, 0.002, 0.003])
1896 .expect("Valid matrix dimensions for test");
1897 let y = Vector::from_slice(&[0.003, 0.005, 0.007]); let mut model = LinearRegression::new();
1900 model
1901 .fit(&x, &y)
1902 .expect("Fit should succeed with valid test data");
1903
1904 let coef = model.coefficients();
1905 assert!((coef[0] - 2.0).abs() < 1e-2);
1906 }
1907
1908 #[test]
1909 fn test_zero_intercept_data() {
1910 let x =
1912 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1913 let y = Vector::from_slice(&[2.0, 4.0, 6.0]); let mut model = LinearRegression::new();
1916 model
1917 .fit(&x, &y)
1918 .expect("Fit should succeed with valid test data");
1919
1920 let coef = model.coefficients();
1921 assert!((coef[0] - 2.0).abs() < 1e-4);
1922 assert!(model.intercept().abs() < 1e-4);
1923 }
1924
1925 #[test]
1926 fn test_constant_target() {
1927 let x =
1929 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1930 let y = Vector::from_slice(&[5.0, 5.0, 5.0]);
1931
1932 let mut model = LinearRegression::new();
1933 model
1934 .fit(&x, &y)
1935 .expect("Fit should succeed with valid test data");
1936
1937 let coef = model.coefficients();
1939 assert!(coef[0].abs() < 1e-4);
1940 assert!((model.intercept() - 5.0).abs() < 1e-4);
1941 }
1942
1943 #[test]
1944 fn test_r2_score_bounds() {
1945 let x = Matrix::from_vec(5, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0])
1947 .expect("Valid matrix dimensions for test");
1948 let y = Vector::from_slice(&[2.1, 3.9, 6.1, 7.9, 10.1]);
1949
1950 let mut model = LinearRegression::new();
1951 model
1952 .fit(&x, &y)
1953 .expect("Fit should succeed with valid test data");
1954
1955 let r2 = model.score(&x, &y);
1956 assert!(r2 > 0.0);
1957 assert!(r2 <= 1.0);
1958 }
1959
1960 #[test]
1961 fn test_extrapolation() {
1962 let x_train =
1964 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
1965 let y_train = Vector::from_slice(&[2.0, 4.0, 6.0]); let mut model = LinearRegression::new();
1968 model
1969 .fit(&x_train, &y_train)
1970 .expect("Fit should succeed with valid test data");
1971
1972 let x_test = Matrix::from_vec(1, 1, vec![10.0]).expect("Valid matrix dimensions for test");
1974 let predictions = model.predict(&x_test);
1975
1976 assert!((predictions[0] - 20.0).abs() < 1e-4);
1977 }
1978
1979 #[test]
1980 fn test_underdetermined_system_with_intercept() {
1981 let x = Matrix::from_vec(
1984 3,
1985 5,
1986 vec![
1987 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0, 7.0,
1988 ],
1989 )
1990 .expect("Valid matrix dimensions for test");
1991 let y = Vector::from_vec(vec![10.0, 20.0, 30.0]);
1992
1993 let mut model = LinearRegression::new();
1994 let result = model.fit(&x, &y);
1995
1996 assert!(result.is_err());
1997 let error_msg = result.expect_err("Should fail when underdetermined system with intercept");
1998 let error_str = error_msg.to_string();
1999 assert!(
2001 error_str.contains("samples") || error_str.contains("features"),
2002 "Error message should mention samples or features: {error_str}"
2003 );
2004 }
2005
2006 #[test]
2007 fn test_underdetermined_system_without_intercept() {
2008 let x = Matrix::from_vec(
2011 3,
2012 5,
2013 vec![
2014 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0, 7.0,
2015 ],
2016 )
2017 .expect("Valid matrix dimensions for test");
2018 let y = Vector::from_vec(vec![10.0, 20.0, 30.0]);
2019
2020 let mut model = LinearRegression::new().with_intercept(false);
2021 let result = model.fit(&x, &y);
2022
2023 assert!(result.is_err());
2024 let error_msg =
2025 result.expect_err("Should fail when underdetermined system without intercept");
2026 let error_str = error_msg.to_string();
2027 assert!(
2028 error_str.contains("samples") || error_str.contains("features"),
2029 "Error message should be helpful: {error_str}"
2030 );
2031 }
2032
2033 #[test]
2034 fn test_exactly_determined_system() {
2035 let x = Matrix::from_vec(
2038 4,
2039 3,
2040 vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
2041 )
2042 .expect("Valid matrix dimensions for test");
2043 let y = Vector::from_vec(vec![1.0, 2.0, 3.0, 6.0]);
2044
2045 let mut model = LinearRegression::new();
2046 let result = model.fit(&x, &y);
2047
2048 assert!(result.is_ok(), "Exactly determined system should work");
2050 }
2051
2052 #[test]
2053 fn test_save_load_binary() {
2054 use std::fs;
2055 use std::path::Path;
2056
2057 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2059 .expect("Valid matrix dimensions for test");
2060 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]); let mut model = LinearRegression::new();
2063 model
2064 .fit(&x, &y)
2065 .expect("Fit should succeed with valid test data");
2066
2067 let path = Path::new("/tmp/test_linear_regression.bin");
2069 model.save(path).expect("Failed to save model");
2070
2071 let loaded_model = LinearRegression::load(path).expect("Failed to load model");
2073
2074 let original_pred = model.predict(&x);
2076 let loaded_pred = loaded_model.predict(&x);
2077
2078 for i in 0..original_pred.len() {
2079 assert!(
2080 (original_pred[i] - loaded_pred[i]).abs() < 1e-6,
2081 "Loaded model predictions don't match original"
2082 );
2083 }
2084
2085 assert_eq!(
2087 model.coefficients().len(),
2088 loaded_model.coefficients().len()
2089 );
2090 for i in 0..model.coefficients().len() {
2091 assert!((model.coefficients()[i] - loaded_model.coefficients()[i]).abs() < 1e-6);
2092 }
2093 assert!((model.intercept() - loaded_model.intercept()).abs() < 1e-6);
2094
2095 fs::remove_file(path).ok();
2097 }
2098
2099 #[test]
2100 fn test_with_intercept_returns_self() {
2101 let model = LinearRegression::new().with_intercept(false);
2104
2105 let x =
2111 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2112 let y = Vector::from_slice(&[2.0, 4.0, 6.0]); let mut model = model;
2115 model
2116 .fit(&x, &y)
2117 .expect("Fit should succeed with valid test data");
2118
2119 let x_zero = Matrix::from_vec(1, 1, vec![0.0]).expect("Valid matrix dimensions for test");
2122 let pred = model.predict(&x_zero);
2123
2124 assert!(
2125 pred[0].abs() < 1e-6,
2126 "Model without intercept should predict 0 at x=0, got {}",
2127 pred[0]
2128 );
2129 }
2130
2131 #[test]
2132 fn test_with_intercept_builder_chain() {
2133 let x = Matrix::from_vec(2, 1, vec![1.0, 2.0]).expect("Valid matrix dimensions for test");
2136 let y = Vector::from_slice(&[3.0, 5.0]); let mut with_int = LinearRegression::new().with_intercept(true);
2140 with_int
2141 .fit(&x, &y)
2142 .expect("Fit should succeed with valid test data");
2143
2144 let mut without_int = LinearRegression::new().with_intercept(false);
2146 without_int
2147 .fit(&x, &y)
2148 .expect("Fit should succeed with valid test data");
2149
2150 assert!(
2154 with_int.intercept().abs() > 0.1,
2155 "Model with intercept should have non-zero intercept"
2156 );
2157 assert!(
2158 without_int.intercept().abs() < 1e-6,
2159 "Model without intercept should have zero intercept, got {}",
2160 without_int.intercept()
2161 );
2162 }
2163
2164 #[test]
2166 fn test_ridge_new() {
2167 let model = Ridge::new(1.0);
2168 assert!(!model.is_fitted());
2169 assert!((model.alpha() - 1.0).abs() < 1e-6);
2170 }
2171
2172 #[test]
2173 fn test_ridge_simple_regression() {
2174 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2176 .expect("Valid matrix dimensions for test");
2177 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2178
2179 let mut model = Ridge::new(0.0); model
2181 .fit(&x, &y)
2182 .expect("Fit should succeed with valid test data");
2183
2184 assert!(model.is_fitted());
2185
2186 let r2 = model.score(&x, &y);
2188 assert!(r2 > 0.99);
2189 }
2190
2191 #[test]
2192 fn test_ridge_regularization_shrinks_coefficients() {
2193 let x = Matrix::from_vec(5, 2, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0])
2195 .expect("Valid matrix dimensions for test");
2196 let y = Vector::from_slice(&[4.0, 8.0, 12.0, 16.0, 20.0]);
2197
2198 let mut low_reg = Ridge::new(0.01);
2200 low_reg
2201 .fit(&x, &y)
2202 .expect("Fit should succeed with valid test data");
2203
2204 let mut high_reg = Ridge::new(100.0);
2206 high_reg
2207 .fit(&x, &y)
2208 .expect("Fit should succeed with valid test data");
2209
2210 let low_coef = low_reg.coefficients();
2212 let high_coef = high_reg.coefficients();
2213 let low_norm: f32 = (0..low_coef.len()).map(|i| low_coef[i] * low_coef[i]).sum();
2214 let high_norm: f32 = (0..high_coef.len())
2215 .map(|i| high_coef[i] * high_coef[i])
2216 .sum();
2217
2218 assert!(
2219 high_norm < low_norm,
2220 "High regularization should shrink coefficients: {high_norm} < {low_norm}"
2221 );
2222 }
2223
2224 #[test]
2225 fn test_ridge_multivariate() {
2226 let x = Matrix::from_vec(5, 2, vec![1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0])
2228 .expect("Valid matrix dimensions for test");
2229 let y = Vector::from_slice(&[6.0, 8.0, 9.0, 11.0, 16.0]);
2230
2231 let mut model = Ridge::new(0.1);
2232 model
2233 .fit(&x, &y)
2234 .expect("Fit should succeed with valid test data");
2235
2236 let r2 = model.score(&x, &y);
2237 assert!(r2 > 0.95);
2238 }
2239
2240 #[test]
2241 fn test_ridge_no_intercept() {
2242 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2244 .expect("Valid matrix dimensions for test");
2245 let y = Vector::from_slice(&[2.0, 4.0, 6.0, 8.0]);
2246
2247 let mut model = Ridge::new(0.1).with_intercept(false);
2248 model
2249 .fit(&x, &y)
2250 .expect("Fit should succeed with valid test data");
2251
2252 assert!((model.intercept() - 0.0).abs() < 1e-6);
2253 }
2254
2255 #[test]
2256 fn test_ridge_dimension_mismatch_error() {
2257 let x = Matrix::from_vec(3, 2, vec![1.0; 6]).expect("Valid matrix dimensions for test");
2258 let y = Vector::from_slice(&[1.0, 2.0]); let mut model = Ridge::new(1.0);
2261 let result = model.fit(&x, &y);
2262 assert!(result.is_err());
2263 }
2264
2265 #[test]
2266 fn test_ridge_empty_data_error() {
2267 let x = Matrix::from_vec(0, 2, vec![]).expect("Valid matrix dimensions for test");
2268 let y = Vector::from_vec(vec![]);
2269
2270 let mut model = Ridge::new(1.0);
2271 let result = model.fit(&x, &y);
2272 assert!(result.is_err());
2273 }
2274
2275 #[test]
2276 fn test_ridge_underdetermined_system() {
2277 let x = Matrix::from_vec(
2280 3,
2281 5,
2282 vec![
2283 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0, 7.0,
2284 ],
2285 )
2286 .expect("Valid matrix dimensions for test");
2287 let y = Vector::from_vec(vec![10.0, 20.0, 30.0]);
2288
2289 let mut model = Ridge::new(10.0);
2291 let result = model.fit(&x, &y);
2292 assert!(
2293 result.is_ok(),
2294 "Ridge should handle underdetermined systems"
2295 );
2296 }
2297
2298 #[test]
2299 fn test_ridge_clone() {
2300 let x =
2301 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2302 let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2303
2304 let mut model = Ridge::new(0.5);
2305 model
2306 .fit(&x, &y)
2307 .expect("Fit should succeed with valid test data");
2308
2309 let cloned = model.clone();
2310 assert!(cloned.is_fitted());
2311 assert!((cloned.alpha() - model.alpha()).abs() < 1e-6);
2312 assert!((cloned.intercept() - model.intercept()).abs() < 1e-6);
2313 }
2314
2315 #[test]
2316 fn test_ridge_alpha_zero_equals_ols() {
2317 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2319 .expect("Valid matrix dimensions for test");
2320 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2321
2322 let mut ridge = Ridge::new(0.0);
2323 ridge
2324 .fit(&x, &y)
2325 .expect("Fit should succeed with valid test data");
2326
2327 let mut ols = LinearRegression::new();
2328 ols.fit(&x, &y)
2329 .expect("Fit should succeed with valid test data");
2330
2331 assert!(
2333 (ridge.coefficients()[0] - ols.coefficients()[0]).abs() < 1e-4,
2334 "Ridge with alpha=0 should equal OLS"
2335 );
2336 assert!((ridge.intercept() - ols.intercept()).abs() < 1e-4);
2337 }
2338
2339 #[test]
2340 fn test_ridge_save_load() {
2341 use std::fs;
2342 use std::path::Path;
2343
2344 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2345 .expect("Valid matrix dimensions for test");
2346 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2347
2348 let mut model = Ridge::new(0.5);
2349 model
2350 .fit(&x, &y)
2351 .expect("Fit should succeed with valid test data");
2352
2353 let path = Path::new("/tmp/test_ridge.bin");
2354 model.save(path).expect("Failed to save model");
2355
2356 let loaded = Ridge::load(path).expect("Failed to load model");
2357
2358 assert!((loaded.alpha() - model.alpha()).abs() < 1e-6);
2360 let original_pred = model.predict(&x);
2361 let loaded_pred = loaded.predict(&x);
2362
2363 for i in 0..original_pred.len() {
2364 assert!((original_pred[i] - loaded_pred[i]).abs() < 1e-6);
2365 }
2366
2367 fs::remove_file(path).ok();
2368 }
2369
2370 #[test]
2371 fn test_ridge_with_intercept_builder() {
2372 let model = Ridge::new(1.0).with_intercept(false);
2373
2374 let x =
2375 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2376 let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2377
2378 let mut model = model;
2379 model
2380 .fit(&x, &y)
2381 .expect("Fit should succeed with valid test data");
2382
2383 let x_zero = Matrix::from_vec(1, 1, vec![0.0]).expect("Valid matrix dimensions for test");
2385 let pred = model.predict(&x_zero);
2386
2387 assert!(
2388 pred[0].abs() < 1e-6,
2389 "Ridge without intercept should predict 0 at x=0"
2390 );
2391 }
2392
2393 #[test]
2394 fn test_ridge_coefficients_length() {
2395 let x = Matrix::from_vec(5, 3, vec![1.0; 15]).expect("Valid matrix dimensions for test");
2396 let y = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
2397
2398 let mut model = Ridge::new(1.0);
2399 model
2400 .fit(&x, &y)
2401 .expect("Fit should succeed with valid test data");
2402
2403 assert_eq!(model.coefficients().len(), 3);
2404 }
2405
2406 #[test]
2408 fn test_lasso_new() {
2409 let model = Lasso::new(1.0);
2410 assert!(!model.is_fitted());
2411 assert!((model.alpha() - 1.0).abs() < 1e-6);
2412 }
2413
2414 #[test]
2415 fn test_lasso_simple_regression() {
2416 let x = Matrix::from_vec(5, 1, vec![1.0, 2.0, 3.0, 4.0, 5.0])
2418 .expect("Valid matrix dimensions for test");
2419 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0]);
2420
2421 let mut model = Lasso::new(0.01); model
2423 .fit(&x, &y)
2424 .expect("Fit should succeed with valid test data");
2425
2426 assert!(model.is_fitted());
2427
2428 let r2 = model.score(&x, &y);
2429 assert!(r2 > 0.98, "R² should be > 0.98, got {r2}");
2430 }
2431
2432 #[test]
2433 fn test_lasso_produces_sparsity() {
2434 let x = Matrix::from_vec(
2437 6,
2438 3,
2439 vec![
2440 1.0, 0.1, 0.2, 2.0, 0.2, 0.1, 3.0, 0.1, 0.3, 4.0, 0.3, 0.1, 5.0, 0.2, 0.2, 6.0,
2441 0.1, 0.1,
2442 ],
2443 )
2444 .expect("Valid matrix dimensions for test");
2445 let y = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
2446
2447 let mut model = Lasso::new(1.0); model
2449 .fit(&x, &y)
2450 .expect("Fit should succeed with valid test data");
2451
2452 let coef = model.coefficients();
2454 let mut non_zero = 0;
2455 for i in 0..coef.len() {
2456 if coef[i].abs() > 1e-4 {
2457 non_zero += 1;
2458 }
2459 }
2460
2461 assert!(
2463 non_zero < coef.len(),
2464 "Lasso should produce sparse solution, got {} non-zero out of {}",
2465 non_zero,
2466 coef.len()
2467 );
2468 }
2469
2470 #[test]
2471 fn test_lasso_multivariate() {
2472 let x = Matrix::from_vec(
2474 6,
2475 2,
2476 vec![1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0],
2477 )
2478 .expect("Valid matrix dimensions for test");
2479 let y = Vector::from_slice(&[6.0, 8.0, 9.0, 11.0, 16.0, 21.0]);
2480
2481 let mut model = Lasso::new(0.01);
2482 model
2483 .fit(&x, &y)
2484 .expect("Fit should succeed with valid test data");
2485
2486 let r2 = model.score(&x, &y);
2487 assert!(r2 > 0.95, "R² should be > 0.95, got {r2}");
2488 }
2489
2490 #[test]
2491 fn test_lasso_no_intercept() {
2492 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2494 .expect("Valid matrix dimensions for test");
2495 let y = Vector::from_slice(&[2.0, 4.0, 6.0, 8.0]);
2496
2497 let mut model = Lasso::new(0.01).with_intercept(false);
2498 model
2499 .fit(&x, &y)
2500 .expect("Fit should succeed with valid test data");
2501
2502 assert!((model.intercept() - 0.0).abs() < 1e-6);
2503 }
2504
2505 #[test]
2506 fn test_lasso_dimension_mismatch_error() {
2507 let x = Matrix::from_vec(3, 2, vec![1.0; 6]).expect("Valid matrix dimensions for test");
2508 let y = Vector::from_slice(&[1.0, 2.0]); let mut model = Lasso::new(1.0);
2511 let result = model.fit(&x, &y);
2512 assert!(result.is_err());
2513 }
2514
2515 #[test]
2516 fn test_lasso_empty_data_error() {
2517 let x = Matrix::from_vec(0, 2, vec![]).expect("Valid matrix dimensions for test");
2518 let y = Vector::from_vec(vec![]);
2519
2520 let mut model = Lasso::new(1.0);
2521 let result = model.fit(&x, &y);
2522 assert!(result.is_err());
2523 }
2524
2525 #[test]
2526 fn test_lasso_clone() {
2527 let x =
2528 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2529 let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2530
2531 let mut model = Lasso::new(0.5);
2532 model
2533 .fit(&x, &y)
2534 .expect("Fit should succeed with valid test data");
2535
2536 let cloned = model.clone();
2537 assert!(cloned.is_fitted());
2538 assert!((cloned.alpha() - model.alpha()).abs() < 1e-6);
2539 assert!((cloned.intercept() - model.intercept()).abs() < 1e-6);
2540 }
2541
2542 #[test]
2543 fn test_lasso_save_load() {
2544 use std::fs;
2545 use std::path::Path;
2546
2547 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2548 .expect("Valid matrix dimensions for test");
2549 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2550
2551 let mut model = Lasso::new(0.1);
2552 model
2553 .fit(&x, &y)
2554 .expect("Fit should succeed with valid test data");
2555
2556 let path = Path::new("/tmp/test_lasso.bin");
2557 model.save(path).expect("Failed to save model");
2558
2559 let loaded = Lasso::load(path).expect("Failed to load model");
2560
2561 assert!((loaded.alpha() - model.alpha()).abs() < 1e-6);
2562 let original_pred = model.predict(&x);
2563 let loaded_pred = loaded.predict(&x);
2564
2565 for i in 0..original_pred.len() {
2566 assert!((original_pred[i] - loaded_pred[i]).abs() < 1e-6);
2567 }
2568
2569 fs::remove_file(path).ok();
2570 }
2571
2572 #[test]
2573 fn test_lasso_with_intercept_builder() {
2574 let model = Lasso::new(1.0).with_intercept(false);
2575
2576 let x =
2577 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2578 let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2579
2580 let mut model = model;
2581 model
2582 .fit(&x, &y)
2583 .expect("Fit should succeed with valid test data");
2584
2585 let x_zero = Matrix::from_vec(1, 1, vec![0.0]).expect("Valid matrix dimensions for test");
2586 let pred = model.predict(&x_zero);
2587
2588 assert!(
2589 pred[0].abs() < 1e-6,
2590 "Lasso without intercept should predict 0 at x=0"
2591 );
2592 }
2593
2594 #[test]
2595 fn test_lasso_coefficients_length() {
2596 let x = Matrix::from_vec(5, 3, vec![1.0; 15]).expect("Valid matrix dimensions for test");
2597 let y = Vector::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0]);
2598
2599 let mut model = Lasso::new(0.1);
2600 model
2601 .fit(&x, &y)
2602 .expect("Fit should succeed with valid test data");
2603
2604 assert_eq!(model.coefficients().len(), 3);
2605 }
2606
2607 #[test]
2608 fn test_lasso_with_max_iter() {
2609 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2610 .expect("Valid matrix dimensions for test");
2611 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2612
2613 let mut model = Lasso::new(0.1).with_max_iter(100);
2614 model
2615 .fit(&x, &y)
2616 .expect("Fit should succeed with valid test data");
2617
2618 assert!(model.is_fitted());
2619 }
2620
2621 #[test]
2622 fn test_lasso_with_tol() {
2623 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2624 .expect("Valid matrix dimensions for test");
2625 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2626
2627 let mut model = Lasso::new(0.1).with_tol(1e-6);
2628 model
2629 .fit(&x, &y)
2630 .expect("Fit should succeed with valid test data");
2631
2632 assert!(model.is_fitted());
2633 }
2634
2635 #[test]
2636 fn test_lasso_soft_threshold() {
2637 assert!((Lasso::soft_threshold(5.0, 2.0) - 3.0).abs() < 1e-6);
2639 assert!((Lasso::soft_threshold(-5.0, 2.0) - (-3.0)).abs() < 1e-6);
2640 assert!((Lasso::soft_threshold(1.0, 2.0) - 0.0).abs() < 1e-6);
2641 assert!((Lasso::soft_threshold(-1.0, 2.0) - 0.0).abs() < 1e-6);
2642 }
2643
2644 #[test]
2647 fn test_elastic_net_new() {
2648 let model = ElasticNet::new(1.0, 0.5);
2649 assert!(!model.is_fitted());
2650 assert!((model.alpha() - 1.0).abs() < 1e-6);
2651 assert!((model.l1_ratio() - 0.5).abs() < 1e-6);
2652 }
2653
2654 #[test]
2655 fn test_elastic_net_simple() {
2656 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2658 .expect("Valid matrix dimensions for test");
2659 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2660
2661 let mut model = ElasticNet::new(0.01, 0.5);
2662 model
2663 .fit(&x, &y)
2664 .expect("Fit should succeed with valid test data");
2665
2666 assert!(model.is_fitted());
2667
2668 let coef = model.coefficients();
2670 assert!((coef[0] - 2.0).abs() < 0.5); assert!((model.intercept() - 1.0).abs() < 1.0);
2672 }
2673
2674 #[test]
2675 fn test_elastic_net_multivariate() {
2676 let x = Matrix::from_vec(4, 2, vec![1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0])
2678 .expect("Valid matrix dimensions for test");
2679 let y = Vector::from_slice(&[5.0, 7.0, 8.0, 10.0]);
2680
2681 let mut model = ElasticNet::new(0.01, 0.5);
2682 model
2683 .fit(&x, &y)
2684 .expect("Fit should succeed with valid test data");
2685
2686 let predictions = model.predict(&x);
2687 for i in 0..4 {
2688 assert!((predictions[i] - y[i]).abs() < 1.0);
2689 }
2690 }
2691
2692 #[test]
2693 fn test_elastic_net_l1_ratio_pure_l1() {
2694 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2696 .expect("Valid matrix dimensions for test");
2697 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2698
2699 let mut elastic = ElasticNet::new(0.1, 1.0);
2700 elastic
2701 .fit(&x, &y)
2702 .expect("Fit should succeed with valid test data");
2703
2704 let mut lasso = Lasso::new(0.1);
2705 lasso
2706 .fit(&x, &y)
2707 .expect("Fit should succeed with valid test data");
2708
2709 let elastic_coef = elastic.coefficients();
2711 let lasso_coef = lasso.coefficients();
2712 assert!((elastic_coef[0] - lasso_coef[0]).abs() < 0.1);
2713 }
2714
2715 #[test]
2716 fn test_elastic_net_l1_ratio_pure_l2() {
2717 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2719 .expect("Valid matrix dimensions for test");
2720 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2721
2722 let mut elastic = ElasticNet::new(0.1, 0.0);
2723 elastic
2724 .fit(&x, &y)
2725 .expect("Fit should succeed with valid test data");
2726
2727 let mut ridge = Ridge::new(0.1);
2728 ridge
2729 .fit(&x, &y)
2730 .expect("Fit should succeed with valid test data");
2731
2732 let elastic_coef = elastic.coefficients();
2734 let ridge_coef = ridge.coefficients();
2735 assert!((elastic_coef[0] - ridge_coef[0]).abs() < 0.5);
2736 }
2737
2738 #[test]
2739 fn test_elastic_net_dimension_mismatch() {
2740 let x = Matrix::from_vec(3, 2, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
2741 .expect("Valid matrix dimensions for test");
2742 let y = Vector::from_slice(&[1.0, 2.0]); let mut model = ElasticNet::new(0.1, 0.5);
2745 let result = model.fit(&x, &y);
2746 assert!(result.is_err());
2747 }
2748
2749 #[test]
2750 fn test_elastic_net_empty_data() {
2751 let x = Matrix::from_vec(0, 2, vec![]).expect("Valid matrix dimensions for test");
2752 let y = Vector::from_vec(vec![]);
2753
2754 let mut model = ElasticNet::new(0.1, 0.5);
2755 let result = model.fit(&x, &y);
2756 assert!(result.is_err());
2757 }
2758
2759 #[test]
2760 #[should_panic(expected = "Model not fitted")]
2761 fn test_elastic_net_predict_not_fitted() {
2762 let model = ElasticNet::new(0.1, 0.5);
2763 let x = Matrix::from_vec(1, 1, vec![1.0]).expect("Valid matrix dimensions for test");
2764 let _ = model.predict(&x);
2765 }
2766
2767 #[test]
2768 fn test_elastic_net_score() {
2769 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2770 .expect("Valid matrix dimensions for test");
2771 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2772
2773 let mut model = ElasticNet::new(0.01, 0.5);
2774 model
2775 .fit(&x, &y)
2776 .expect("Fit should succeed with valid test data");
2777
2778 let r2 = model.score(&x, &y);
2779 assert!(r2 > 0.9); }
2781
2782 #[test]
2783 fn test_elastic_net_clone() {
2784 let x =
2785 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2786 let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2787
2788 let mut model = ElasticNet::new(0.5, 0.5);
2789 model
2790 .fit(&x, &y)
2791 .expect("Fit should succeed with valid test data");
2792
2793 let cloned = model.clone();
2794 assert!(cloned.is_fitted());
2795 assert!((cloned.alpha() - model.alpha()).abs() < 1e-6);
2796 assert!((cloned.l1_ratio() - model.l1_ratio()).abs() < 1e-6);
2797 assert!((cloned.intercept() - model.intercept()).abs() < 1e-6);
2798 }
2799
2800 #[test]
2801 fn test_elastic_net_save_load() {
2802 use std::fs;
2803 use std::path::Path;
2804
2805 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2806 .expect("Valid matrix dimensions for test");
2807 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2808
2809 let mut model = ElasticNet::new(0.1, 0.5);
2810 model
2811 .fit(&x, &y)
2812 .expect("Fit should succeed with valid test data");
2813
2814 let path = Path::new("/tmp/test_elastic_net.bin");
2815 model.save(path).expect("Failed to save model");
2816
2817 let loaded = ElasticNet::load(path).expect("Failed to load model");
2818
2819 assert!((loaded.alpha() - model.alpha()).abs() < 1e-6);
2820 assert!((loaded.l1_ratio() - model.l1_ratio()).abs() < 1e-6);
2821 let original_pred = model.predict(&x);
2822 let loaded_pred = loaded.predict(&x);
2823
2824 for i in 0..original_pred.len() {
2825 assert!((original_pred[i] - loaded_pred[i]).abs() < 1e-6);
2826 }
2827
2828 fs::remove_file(path).ok();
2829 }
2830
2831 #[test]
2832 fn test_elastic_net_with_intercept_builder() {
2833 let model = ElasticNet::new(1.0, 0.5).with_intercept(false);
2834
2835 let x =
2836 Matrix::from_vec(3, 1, vec![1.0, 2.0, 3.0]).expect("Valid matrix dimensions for test");
2837 let y = Vector::from_slice(&[2.0, 4.0, 6.0]);
2838
2839 let mut model = model;
2840 model
2841 .fit(&x, &y)
2842 .expect("Fit should succeed with valid test data");
2843
2844 let x_zero = Matrix::from_vec(1, 1, vec![0.0]).expect("Valid matrix dimensions for test");
2845 let pred = model.predict(&x_zero);
2846 assert!((pred[0] - 0.0).abs() < 1e-6); }
2848
2849 #[test]
2850 fn test_elastic_net_multivariate_coefficients() {
2851 let x = Matrix::from_vec(3, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
2852 .expect("Valid matrix dimensions for test");
2853 let y = Vector::from_slice(&[6.0, 15.0, 24.0]);
2854
2855 let mut model = ElasticNet::new(0.1, 0.5);
2856 model
2857 .fit(&x, &y)
2858 .expect("Fit should succeed with valid test data");
2859
2860 assert_eq!(model.coefficients().len(), 3);
2861 }
2862
2863 #[test]
2864 fn test_elastic_net_with_max_iter() {
2865 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2866 .expect("Valid matrix dimensions for test");
2867 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2868
2869 let mut model = ElasticNet::new(0.1, 0.5).with_max_iter(100);
2870 model
2871 .fit(&x, &y)
2872 .expect("Fit should succeed with valid test data");
2873
2874 assert!(model.is_fitted());
2875 }
2876
2877 #[test]
2878 fn test_elastic_net_with_tol() {
2879 let x = Matrix::from_vec(4, 1, vec![1.0, 2.0, 3.0, 4.0])
2880 .expect("Valid matrix dimensions for test");
2881 let y = Vector::from_slice(&[3.0, 5.0, 7.0, 9.0]);
2882
2883 let mut model = ElasticNet::new(0.1, 0.5).with_tol(1e-6);
2884 model
2885 .fit(&x, &y)
2886 .expect("Fit should succeed with valid test data");
2887
2888 assert!(model.is_fitted());
2889 }
2890}