1use crate::error::Result;
38use crate::primitives::{Matrix, Vector};
39use serde::{Deserialize, Serialize};
40use std::path::Path;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct LogisticRegression {
48 coefficients: Option<Vector<f32>>,
50 intercept: f32,
52 learning_rate: f32,
54 max_iter: usize,
56 tol: f32,
58}
59
60impl LogisticRegression {
61 pub fn new() -> Self {
71 Self {
72 coefficients: None,
73 intercept: 0.0,
74 learning_rate: 0.01,
75 max_iter: 1000,
76 tol: 1e-4,
77 }
78 }
79
80 pub fn with_learning_rate(mut self, lr: f32) -> Self {
82 self.learning_rate = lr;
83 self
84 }
85
86 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
88 self.max_iter = max_iter;
89 self
90 }
91
92 pub fn with_tolerance(mut self, tol: f32) -> Self {
94 self.tol = tol;
95 self
96 }
97
98 fn sigmoid(z: f32) -> f32 {
100 1.0 / (1.0 + (-z).exp())
101 }
102
103 pub fn predict_proba(&self, x: &Matrix<f32>) -> Vector<f32> {
107 let coef = self.coefficients.as_ref().expect("Model not fitted yet");
108 let (n_samples, _) = x.shape();
109
110 let mut probas = Vec::with_capacity(n_samples);
111 for row in 0..n_samples {
112 let mut z = self.intercept;
113 for col in 0..coef.len() {
114 z += coef[col] * x.get(row, col);
115 }
116 probas.push(Self::sigmoid(z));
117 }
118
119 Vector::from_vec(probas)
120 }
121
122 pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
133 let (n_samples, n_features) = x.shape();
134
135 if n_samples != y.len() {
136 return Err("Number of samples in X and y must match".into());
137 }
138 if n_samples == 0 {
139 return Err("Cannot fit with zero samples".into());
140 }
141
142 for &label in y {
144 if label != 0 && label != 1 {
145 return Err("Labels must be 0 or 1 for binary classification".into());
146 }
147 }
148
149 self.coefficients = Some(Vector::from_vec(vec![0.0; n_features]));
151 self.intercept = 0.0;
152
153 for _ in 0..self.max_iter {
155 let probas = self.predict_proba(x);
157
158 let mut coef_grad = vec![0.0; n_features];
160 let mut intercept_grad = 0.0;
161
162 for i in 0..n_samples {
163 let error = probas[i] - y[i] as f32;
164 intercept_grad += error;
165 for (j, grad) in coef_grad.iter_mut().enumerate() {
166 *grad += error * x.get(i, j);
167 }
168 }
169
170 let n = n_samples as f32;
172 intercept_grad /= n;
173 for grad in &mut coef_grad {
174 *grad /= n;
175 }
176
177 self.intercept -= self.learning_rate * intercept_grad;
179 if let Some(ref mut coef) = self.coefficients {
180 for j in 0..n_features {
181 coef[j] -= self.learning_rate * coef_grad[j];
182 }
183 }
184
185 if intercept_grad.abs() < self.tol && coef_grad.iter().all(|&g| g.abs() < self.tol) {
187 break;
188 }
189 }
190
191 Ok(())
192 }
193
194 pub fn predict(&self, x: &Matrix<f32>) -> Vec<usize> {
198 let probas = self.predict_proba(x);
199 probas
200 .as_slice()
201 .iter()
202 .map(|&p| usize::from(p >= 0.5))
203 .collect()
204 }
205
206 pub fn score(&self, x: &Matrix<f32>, y: &[usize]) -> f32 {
210 let predictions = self.predict(x);
211 let correct = predictions
212 .iter()
213 .zip(y.iter())
214 .filter(|(pred, true_label)| pred == true_label)
215 .count();
216 correct as f32 / y.len() as f32
217 }
218
219 pub fn coefficients(&self) -> &Vector<f32> {
225 self.coefficients.as_ref().expect("Model not fitted")
226 }
227
228 pub fn intercept(&self) -> f32 {
230 self.intercept
231 }
232
233 pub fn save_safetensors<P: AsRef<Path>>(&self, path: P) -> std::result::Result<(), String> {
263 use crate::serialization::safetensors;
264 use std::collections::BTreeMap;
265
266 let coefficients = self
268 .coefficients
269 .as_ref()
270 .ok_or("Cannot save unfitted model. Call fit() first.")?;
271
272 let mut tensors = BTreeMap::new();
274
275 let coef_data: Vec<f32> = (0..coefficients.len()).map(|i| coefficients[i]).collect();
277 let coef_shape = vec![coefficients.len()];
278 tensors.insert("coefficients".to_string(), (coef_data, coef_shape));
279
280 let intercept_data = vec![self.intercept];
282 let intercept_shape = vec![1];
283 tensors.insert("intercept".to_string(), (intercept_data, intercept_shape));
284
285 safetensors::save_safetensors(path, &tensors)?;
287 Ok(())
288 }
289
290 pub fn load_safetensors<P: AsRef<Path>>(path: P) -> std::result::Result<Self, String> {
318 use crate::serialization::safetensors;
319
320 let (metadata, raw_data) = safetensors::load_safetensors(path)?;
322
323 let coef_meta = metadata
325 .get("coefficients")
326 .ok_or("Missing 'coefficients' tensor in SafeTensors file")?;
327 let coef_data = safetensors::extract_tensor(&raw_data, coef_meta)?;
328
329 let intercept_meta = metadata
331 .get("intercept")
332 .ok_or("Missing 'intercept' tensor in SafeTensors file")?;
333 let intercept_data = safetensors::extract_tensor(&raw_data, intercept_meta)?;
334
335 if intercept_data.len() != 1 {
337 return Err(format!(
338 "Invalid intercept tensor: expected 1 value, got {}",
339 intercept_data.len()
340 ));
341 }
342
343 Ok(Self {
346 coefficients: Some(Vector::from_vec(coef_data)),
347 intercept: intercept_data[0],
348 learning_rate: 0.01, max_iter: 1000, tol: 1e-4, })
352 }
353}
354
355impl Default for LogisticRegression {
356 fn default() -> Self {
357 Self::new()
358 }
359}
360
361#[derive(Debug, Clone, Copy, PartialEq)]
363pub enum DistanceMetric {
364 Euclidean,
366 Manhattan,
368 Minkowski(f32),
370}
371
372#[derive(Debug, Clone)]
401pub struct KNearestNeighbors {
402 k: usize,
404 metric: DistanceMetric,
406 weights: bool,
408 x_train: Option<Matrix<f32>>,
410 y_train: Option<Vec<usize>>,
412}
413
414impl KNearestNeighbors {
415 #[must_use]
429 pub fn new(k: usize) -> Self {
430 Self {
431 k,
432 metric: DistanceMetric::Euclidean,
433 weights: false,
434 x_train: None,
435 y_train: None,
436 }
437 }
438
439 #[must_use]
441 pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
442 self.metric = metric;
443 self
444 }
445
446 #[must_use]
448 pub fn with_weights(mut self, weights: bool) -> Self {
449 self.weights = weights;
450 self
451 }
452
453 pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
462 let (n_samples, _n_features) = x.shape();
463
464 if n_samples == 0 {
465 return Err("Cannot fit with zero samples".into());
466 }
467
468 if y.len() != n_samples {
469 return Err("Number of samples in X and y must match".into());
470 }
471
472 if self.k > n_samples {
473 return Err("k cannot be larger than number of training samples".into());
474 }
475
476 self.x_train = Some(x.clone());
478 self.y_train = Some(y.to_vec());
479
480 Ok(())
481 }
482
483 pub fn predict(&self, x: &Matrix<f32>) -> Result<Vec<usize>> {
492 let x_train = self.x_train.as_ref().ok_or("Model not fitted")?;
493 let y_train = self.y_train.as_ref().ok_or("Model not fitted")?;
494
495 let (n_samples, n_features) = x.shape();
496 let (_n_train, n_train_features) = x_train.shape();
497
498 if n_features != n_train_features {
499 return Err("Feature dimension mismatch".into());
500 }
501
502 let mut predictions = Vec::with_capacity(n_samples);
503
504 for i in 0..n_samples {
505 let mut distances: Vec<(f32, usize)> = Vec::with_capacity(y_train.len());
507
508 for (j, &label) in y_train.iter().enumerate() {
509 let dist = self.compute_distance(x, i, x_train, j, n_features);
510 distances.push((dist, label));
511 }
512
513 distances.sort_by(|a, b| {
515 a.0.partial_cmp(&b.0)
516 .expect("Distance values are valid f32 (not NaN)")
517 });
518 let k_nearest = &distances[..self.k];
519
520 let predicted_class = if self.weights {
522 self.weighted_vote(k_nearest)
523 } else {
524 self.majority_vote(k_nearest)
525 };
526
527 predictions.push(predicted_class);
528 }
529
530 Ok(predictions)
531 }
532
533 pub fn predict_proba(&self, x: &Matrix<f32>) -> Result<Vec<Vec<f32>>> {
542 let x_train = self.x_train.as_ref().ok_or("Model not fitted")?;
543 let y_train = self.y_train.as_ref().ok_or("Model not fitted")?;
544
545 let (n_samples, n_features) = x.shape();
546 let (_n_train, n_train_features) = x_train.shape();
547
548 if n_features != n_train_features {
549 return Err("Feature dimension mismatch".into());
550 }
551
552 let n_classes = *y_train
554 .iter()
555 .max()
556 .expect("Training labels are non-empty (verified in fit())")
557 + 1;
558
559 let mut probabilities = Vec::with_capacity(n_samples);
560
561 for i in 0..n_samples {
562 let mut distances: Vec<(f32, usize)> = Vec::with_capacity(y_train.len());
564
565 for (j, &label) in y_train.iter().enumerate() {
566 let dist = self.compute_distance(x, i, x_train, j, n_features);
567 distances.push((dist, label));
568 }
569
570 distances.sort_by(|a, b| {
572 a.0.partial_cmp(&b.0)
573 .expect("Distance values are valid f32 (not NaN)")
574 });
575 let k_nearest = &distances[..self.k];
576
577 let mut class_counts = vec![0.0; n_classes];
579
580 if self.weights {
581 for (dist, label) in k_nearest {
583 let weight = if *dist < 1e-10 { 1.0 } else { 1.0 / dist };
584 class_counts[*label] += weight;
585 }
586 } else {
587 for (_dist, label) in k_nearest {
589 class_counts[*label] += 1.0;
590 }
591 }
592
593 let total: f32 = class_counts.iter().sum();
595 for count in &mut class_counts {
596 *count /= total;
597 }
598
599 probabilities.push(class_counts);
600 }
601
602 Ok(probabilities)
603 }
604
605 fn compute_distance(
607 &self,
608 x1: &Matrix<f32>,
609 i1: usize,
610 x2: &Matrix<f32>,
611 i2: usize,
612 n_features: usize,
613 ) -> f32 {
614 match self.metric {
615 DistanceMetric::Euclidean => {
616 let mut sum = 0.0;
617 for k in 0..n_features {
618 let diff = x1.get(i1, k) - x2.get(i2, k);
619 sum += diff * diff;
620 }
621 sum.sqrt()
622 }
623 DistanceMetric::Manhattan => {
624 let mut sum = 0.0;
625 for k in 0..n_features {
626 sum += (x1.get(i1, k) - x2.get(i2, k)).abs();
627 }
628 sum
629 }
630 DistanceMetric::Minkowski(p) => {
631 let mut sum = 0.0;
632 for k in 0..n_features {
633 sum += (x1.get(i1, k) - x2.get(i2, k)).abs().powf(p);
634 }
635 sum.powf(1.0 / p)
636 }
637 }
638 }
639
640 #[allow(clippy::unused_self)]
642 fn majority_vote(&self, neighbors: &[(f32, usize)]) -> usize {
643 let mut class_counts = std::collections::HashMap::new();
644
645 for (_dist, label) in neighbors {
646 *class_counts.entry(*label).or_insert(0) += 1;
647 }
648
649 *class_counts
650 .iter()
651 .max_by_key(|(_, count)| *count)
652 .map(|(label, _)| label)
653 .expect("Neighbors slice is non-empty (k >= 1)")
654 }
655
656 #[allow(clippy::unused_self)]
658 fn weighted_vote(&self, neighbors: &[(f32, usize)]) -> usize {
659 let mut class_weights = std::collections::HashMap::new();
660
661 for (dist, label) in neighbors {
662 let weight = if *dist < 1e-10 { 1.0 } else { 1.0 / dist };
663 *class_weights.entry(*label).or_insert(0.0) += weight;
664 }
665
666 *class_weights
667 .iter()
668 .max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("Weights are valid f32 (not NaN)"))
669 .map(|(label, _)| label)
670 .expect("Neighbors slice is non-empty (k >= 1)")
671 }
672}
673
674#[derive(Debug, Clone)]
698pub struct GaussianNB {
699 class_priors: Option<Vec<f32>>,
701 means: Option<Vec<Vec<f32>>>,
703 variances: Option<Vec<Vec<f32>>>,
705 classes: Option<Vec<usize>>,
707 var_smoothing: f32,
709}
710
711impl GaussianNB {
712 pub fn new() -> Self {
722 Self {
723 class_priors: None,
724 means: None,
725 variances: None,
726 classes: None,
727 var_smoothing: 1e-9,
728 }
729 }
730
731 pub fn with_var_smoothing(mut self, var_smoothing: f32) -> Self {
743 self.var_smoothing = var_smoothing;
744 self
745 }
746
747 pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
758 let (n_samples, n_features) = x.shape();
759
760 if n_samples == 0 {
761 return Err("Cannot fit with empty data".into());
762 }
763
764 if y.len() != n_samples {
765 return Err("Number of samples in X and y must match".into());
766 }
767
768 let mut classes: Vec<usize> = y.to_vec();
770 classes.sort_unstable();
771 classes.dedup();
772
773 if classes.len() < 2 {
774 return Err("Need at least 2 classes".into());
775 }
776
777 let n_classes = classes.len();
778
779 let mut class_priors = vec![0.0; n_classes];
781 let mut means = vec![vec![0.0; n_features]; n_classes];
782 let mut variances = vec![vec![0.0; n_features]; n_classes];
783
784 for (class_idx, &class_label) in classes.iter().enumerate() {
786 let class_samples: Vec<usize> = y
788 .iter()
789 .enumerate()
790 .filter_map(|(i, &label)| if label == class_label { Some(i) } else { None })
791 .collect();
792
793 let n_class_samples = class_samples.len() as f32;
794 class_priors[class_idx] = n_class_samples / n_samples as f32;
795
796 for (feature_idx, mean_val) in means[class_idx].iter_mut().enumerate() {
798 let sum: f32 = class_samples
799 .iter()
800 .map(|&sample_idx| x.get(sample_idx, feature_idx))
801 .sum();
802 *mean_val = sum / n_class_samples;
803 }
804
805 for (feature_idx, variance_val) in variances[class_idx].iter_mut().enumerate() {
807 let mean = means[class_idx][feature_idx];
808 let sum_sq_diff: f32 = class_samples
809 .iter()
810 .map(|&sample_idx| {
811 let diff = x.get(sample_idx, feature_idx) - mean;
812 diff * diff
813 })
814 .sum();
815 *variance_val = sum_sq_diff / n_class_samples + self.var_smoothing;
816 }
817 }
818
819 self.class_priors = Some(class_priors);
820 self.means = Some(means);
821 self.variances = Some(variances);
822 self.classes = Some(classes);
823
824 Ok(())
825 }
826
827 pub fn predict(&self, x: &Matrix<f32>) -> Result<Vec<usize>> {
835 let probabilities = self.predict_proba(x)?;
836 let classes = self.classes.as_ref().ok_or("Model not fitted")?;
837
838 let predictions: Vec<usize> = probabilities
839 .iter()
840 .map(|probs| {
841 let max_idx = probs
842 .iter()
843 .enumerate()
844 .max_by(|(_, a), (_, b)| {
845 a.partial_cmp(b)
846 .expect("Probabilities are valid f32 (not NaN)")
847 })
848 .map(|(idx, _)| idx)
849 .expect("Probabilities vector is non-empty (n_classes >= 2)");
850 classes[max_idx]
851 })
852 .collect();
853
854 Ok(predictions)
855 }
856
857 pub fn predict_proba(&self, x: &Matrix<f32>) -> Result<Vec<Vec<f32>>> {
866 let means = self.means.as_ref().ok_or("Model not fitted")?;
867 let variances = self.variances.as_ref().ok_or("Model not fitted")?;
868 let class_priors = self.class_priors.as_ref().ok_or("Model not fitted")?;
869
870 let (n_samples, n_features) = x.shape();
871 let n_classes = means.len();
872
873 if n_features != means[0].len() {
874 return Err("Feature dimension mismatch".into());
875 }
876
877 let mut probabilities = Vec::with_capacity(n_samples);
878
879 for sample_idx in 0..n_samples {
880 let mut log_probs = vec![0.0; n_classes];
881
882 for class_idx in 0..n_classes {
884 log_probs[class_idx] = class_priors[class_idx].ln();
886
887 for feature_idx in 0..n_features {
889 let x_val = x.get(sample_idx, feature_idx);
890 let mean = means[class_idx][feature_idx];
891 let variance = variances[class_idx][feature_idx];
892
893 let diff = x_val - mean;
895 let log_likelihood = -0.5 * (2.0 * std::f32::consts::PI * variance).ln()
896 - (diff * diff) / (2.0 * variance);
897
898 log_probs[class_idx] += log_likelihood;
899 }
900 }
901
902 let max_log_prob = log_probs.iter().copied().fold(f32::NEG_INFINITY, f32::max);
904 let exp_probs: Vec<f32> = log_probs
905 .iter()
906 .map(|&log_p| (log_p - max_log_prob).exp())
907 .collect();
908 let sum: f32 = exp_probs.iter().sum();
909 let normalized: Vec<f32> = exp_probs.iter().map(|p| p / sum).collect();
910
911 probabilities.push(normalized);
912 }
913
914 Ok(probabilities)
915 }
916}
917
918impl Default for GaussianNB {
919 fn default() -> Self {
920 Self::new()
921 }
922}
923
924#[derive(Debug, Clone)]
957pub struct LinearSVM {
958 weights: Option<Vec<f32>>,
960 bias: f32,
962 c: f32,
965 learning_rate: f32,
967 max_iter: usize,
969 tol: f32,
971}
972
973impl LinearSVM {
974 pub fn new() -> Self {
983 Self {
984 weights: None,
985 bias: 0.0,
986 c: 1.0,
987 learning_rate: 0.01,
988 max_iter: 1000,
989 tol: 1e-4,
990 }
991 }
992
993 pub fn with_c(mut self, c: f32) -> Self {
998 self.c = c;
999 self
1000 }
1001
1002 pub fn with_learning_rate(mut self, learning_rate: f32) -> Self {
1004 self.learning_rate = learning_rate;
1005 self
1006 }
1007
1008 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1010 self.max_iter = max_iter;
1011 self
1012 }
1013
1014 pub fn with_tolerance(mut self, tol: f32) -> Self {
1016 self.tol = tol;
1017 self
1018 }
1019
1020 pub fn fit(&mut self, x: &Matrix<f32>, y: &[usize]) -> Result<()> {
1031 if x.n_rows() != y.len() {
1032 return Err("x and y must have the same number of samples".into());
1033 }
1034
1035 if x.n_rows() == 0 {
1036 return Err("Cannot fit with 0 samples".into());
1037 }
1038
1039 let y_signed: Vec<f32> = y
1041 .iter()
1042 .map(|&label| if label == 0 { -1.0 } else { 1.0 })
1043 .collect();
1044
1045 let n_samples = x.n_rows();
1046 let n_features = x.n_cols();
1047
1048 let mut w = vec![0.0; n_features];
1050 let mut b = 0.0;
1051
1052 let lambda = 1.0 / (2.0 * n_samples as f32 * self.c);
1053
1054 for epoch in 0..self.max_iter {
1056 let eta = self.learning_rate / (1.0 + epoch as f32 * 0.01);
1057 let prev_w = w.clone();
1058 let prev_b = b;
1059
1060 for (i, &y_i) in y_signed.iter().enumerate() {
1062 let mut decision = b;
1064 for (j, &w_j) in w.iter().enumerate() {
1065 decision += w_j * x.get(i, j);
1066 }
1067
1068 let margin = y_i * decision;
1070
1071 if margin < 1.0 {
1073 for (j, w_j) in w.iter_mut().enumerate() {
1075 let gradient = 2.0 * lambda * *w_j - y_i * x.get(i, j);
1076 *w_j -= eta * gradient;
1077 }
1078 b += eta * y_i;
1079 } else {
1080 for w_j in &mut w {
1082 let gradient = 2.0 * lambda * *w_j;
1083 *w_j -= eta * gradient;
1084 }
1085 }
1086 }
1087
1088 let mut weight_change = 0.0;
1090 for j in 0..n_features {
1091 weight_change += (w[j] - prev_w[j]).powi(2);
1092 }
1093 weight_change += (b - prev_b).powi(2);
1094 weight_change = weight_change.sqrt();
1095
1096 if weight_change < self.tol {
1097 break;
1098 }
1099 }
1100
1101 self.weights = Some(w);
1102 self.bias = b;
1103
1104 Ok(())
1105 }
1106
1107 pub fn decision_function(&self, x: &Matrix<f32>) -> Result<Vec<f32>> {
1120 let weights = self.weights.as_ref().ok_or("Model not trained yet")?;
1121
1122 if x.n_cols() != weights.len() {
1123 return Err("Feature dimension mismatch".into());
1124 }
1125
1126 let mut decisions = Vec::with_capacity(x.n_rows());
1127
1128 for i in 0..x.n_rows() {
1129 let mut decision = self.bias;
1130 for (j, &w_j) in weights.iter().enumerate() {
1131 decision += w_j * x.get(i, j);
1132 }
1133 decisions.push(decision);
1134 }
1135
1136 Ok(decisions)
1137 }
1138
1139 pub fn predict(&self, x: &Matrix<f32>) -> Result<Vec<usize>> {
1149 let decisions = self.decision_function(x)?;
1150
1151 Ok(decisions.iter().map(|&d| usize::from(d >= 0.0)).collect())
1152 }
1153}
1154
1155impl Default for LinearSVM {
1156 fn default() -> Self {
1157 Self::new()
1158 }
1159}
1160
1161#[cfg(test)]
1162mod tests {
1163 use super::*;
1164
1165 #[test]
1166 fn test_sigmoid() {
1167 assert!((LogisticRegression::sigmoid(0.0) - 0.5).abs() < 1e-6);
1168 assert!(LogisticRegression::sigmoid(10.0) > 0.99);
1169 assert!(LogisticRegression::sigmoid(-10.0) < 0.01);
1170 }
1171
1172 #[test]
1173 fn test_logistic_regression_new() {
1174 let model = LogisticRegression::new();
1175 assert!(model.coefficients.is_none());
1176 assert_eq!(model.intercept, 0.0);
1177 }
1178
1179 #[test]
1180 fn test_logistic_regression_builder() {
1181 let model = LogisticRegression::new()
1182 .with_learning_rate(0.1)
1183 .with_max_iter(500)
1184 .with_tolerance(1e-3);
1185
1186 assert_eq!(model.learning_rate, 0.1);
1187 assert_eq!(model.max_iter, 500);
1188 assert_eq!(model.tol, 1e-3);
1189 }
1190
1191 #[test]
1192 fn test_logistic_regression_fit_simple() {
1193 let x = Matrix::from_vec(
1195 4,
1196 2,
1197 vec![
1198 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
1203 )
1204 .expect("4x2 matrix with 8 values");
1205 let y = vec![0, 0, 1, 1];
1206
1207 let mut model = LogisticRegression::new()
1208 .with_learning_rate(0.1)
1209 .with_max_iter(1000);
1210
1211 let result = model.fit(&x, &y);
1212 assert!(result.is_ok());
1213 assert!(model.coefficients.is_some());
1214 }
1215
1216 #[test]
1217 fn test_logistic_regression_predict() {
1218 let x = Matrix::from_vec(
1219 4,
1220 2,
1221 vec![
1222 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
1227 )
1228 .expect("4x2 matrix with 8 values");
1229 let y = vec![0, 0, 1, 1];
1230
1231 let mut model = LogisticRegression::new()
1232 .with_learning_rate(0.1)
1233 .with_max_iter(1000);
1234
1235 model
1236 .fit(&x, &y)
1237 .expect("Training should succeed with valid data");
1238 let predictions = model.predict(&x);
1239
1240 assert_eq!(predictions.len(), 4);
1242 for pred in predictions {
1243 assert!(pred == 0 || pred == 1);
1244 }
1245 }
1246
1247 #[test]
1248 fn test_logistic_regression_score() {
1249 let x = Matrix::from_vec(
1250 4,
1251 2,
1252 vec![
1253 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
1258 )
1259 .expect("4x2 matrix with 8 values");
1260 let y = vec![0, 0, 1, 1];
1261
1262 let mut model = LogisticRegression::new()
1263 .with_learning_rate(0.1)
1264 .with_max_iter(1000);
1265
1266 model
1267 .fit(&x, &y)
1268 .expect("Training should succeed with valid data");
1269 let accuracy = model.score(&x, &y);
1270
1271 assert!(accuracy >= 0.75); }
1274
1275 #[test]
1276 fn test_logistic_regression_invalid_labels() {
1277 let x = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 matrix with 4 values");
1278 let y = vec![0, 2]; let mut model = LogisticRegression::new();
1281 let result = model.fit(&x, &y);
1282
1283 assert!(result.is_err());
1284 assert_eq!(
1285 result.expect_err("Should fail with invalid label value"),
1286 "Labels must be 0 or 1 for binary classification"
1287 );
1288 }
1289
1290 #[test]
1291 fn test_logistic_regression_mismatched_samples() {
1292 let x = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 matrix with 4 values");
1293 let y = vec![0]; let mut model = LogisticRegression::new();
1296 let result = model.fit(&x, &y);
1297
1298 assert!(result.is_err());
1299 assert_eq!(
1300 result.expect_err("Should fail with mismatched sample counts"),
1301 "Number of samples in X and y must match"
1302 );
1303 }
1304
1305 #[test]
1306 fn test_logistic_regression_zero_samples() {
1307 let x = Matrix::from_vec(0, 2, vec![]).expect("0x2 empty matrix");
1308 let y = vec![];
1309
1310 let mut model = LogisticRegression::new();
1311 let result = model.fit(&x, &y);
1312
1313 assert!(result.is_err());
1314 assert_eq!(
1315 result.expect_err("Should fail with zero samples"),
1316 "Cannot fit with zero samples"
1317 );
1318 }
1319
1320 #[test]
1321 fn test_predict_proba() {
1322 let x = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 matrix with 4 values");
1323 let y = vec![0, 1];
1324
1325 let mut model = LogisticRegression::new()
1326 .with_learning_rate(0.1)
1327 .with_max_iter(1000);
1328
1329 model
1330 .fit(&x, &y)
1331 .expect("Training should succeed with valid data");
1332 let probas = model.predict_proba(&x);
1333
1334 assert_eq!(probas.len(), 2);
1335 for &p in probas.as_slice() {
1336 assert!((0.0..=1.0).contains(&p));
1337 }
1338 }
1339
1340 #[test]
1344 fn test_save_safetensors_unfitted_model() {
1345 let model = LogisticRegression::new();
1347 let result = model.save_safetensors("/tmp/test_unfitted_logistic.safetensors");
1348
1349 assert!(result.is_err());
1350 assert!(result
1351 .expect_err("Should fail when saving unfitted model")
1352 .contains("unfitted"));
1353 }
1354
1355 #[test]
1356 fn test_save_load_safetensors_roundtrip() {
1357 let x = Matrix::from_vec(
1359 4,
1360 2,
1361 vec![
1362 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
1367 )
1368 .expect("4x2 matrix with 8 values");
1369 let y = vec![0, 0, 1, 1];
1370
1371 let mut model = LogisticRegression::new()
1373 .with_learning_rate(0.1)
1374 .with_max_iter(1000);
1375 model
1376 .fit(&x, &y)
1377 .expect("Training should succeed with valid data");
1378
1379 let path = "/tmp/test_logistic_roundtrip.safetensors";
1381 model
1382 .save_safetensors(path)
1383 .expect("Should save fitted model to valid path");
1384
1385 let loaded =
1387 LogisticRegression::load_safetensors(path).expect("Should load valid SafeTensors file");
1388
1389 assert_eq!(
1391 model
1392 .coefficients
1393 .as_ref()
1394 .expect("Model is fitted and has coefficients")
1395 .len(),
1396 loaded
1397 .coefficients
1398 .as_ref()
1399 .expect("Loaded model has coefficients")
1400 .len()
1401 );
1402 for i in 0..model
1403 .coefficients
1404 .as_ref()
1405 .expect("Model has coefficients")
1406 .len()
1407 {
1408 assert_eq!(
1409 model.coefficients.as_ref().expect("Model has coefficients")[i],
1410 loaded
1411 .coefficients
1412 .as_ref()
1413 .expect("Loaded model has coefficients")[i]
1414 );
1415 }
1416 assert_eq!(model.intercept, loaded.intercept);
1417
1418 let predictions_original = model.predict(&x);
1420 let predictions_loaded = loaded.predict(&x);
1421 assert_eq!(predictions_original, predictions_loaded);
1422
1423 std::fs::remove_file(path).ok();
1425 }
1426
1427 #[test]
1428 fn test_load_safetensors_corrupted_file() {
1429 let path = "/tmp/test_corrupted_logistic.safetensors";
1431 std::fs::write(path, b"CORRUPTED DATA").expect("Should write test file");
1432
1433 let result = LogisticRegression::load_safetensors(path);
1434 assert!(result.is_err());
1435
1436 std::fs::remove_file(path).ok();
1437 }
1438
1439 #[test]
1440 fn test_load_safetensors_missing_file() {
1441 let result =
1443 LogisticRegression::load_safetensors("/tmp/nonexistent_logistic_xyz.safetensors");
1444 assert!(result.is_err());
1445 let err = result.expect_err("Should fail when loading nonexistent file");
1446 assert!(
1447 err.contains("No such file") || err.contains("not found"),
1448 "Error should mention file not found: {err}"
1449 );
1450 }
1451
1452 #[test]
1453 fn test_safetensors_preserves_probabilities() {
1454 let x = Matrix::from_vec(
1456 4,
1457 2,
1458 vec![
1459 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
1464 )
1465 .expect("4x2 matrix with 8 values");
1466 let y = vec![0, 0, 1, 1];
1467
1468 let mut model = LogisticRegression::new()
1469 .with_learning_rate(0.1)
1470 .with_max_iter(1000);
1471 model
1472 .fit(&x, &y)
1473 .expect("Training should succeed with valid data");
1474
1475 let probas_before = model.predict_proba(&x);
1476
1477 let path = "/tmp/test_logistic_probas.safetensors";
1479 model
1480 .save_safetensors(path)
1481 .expect("Should save fitted model to valid path");
1482 let loaded =
1483 LogisticRegression::load_safetensors(path).expect("Should load valid SafeTensors file");
1484
1485 let probas_after = loaded.predict_proba(&x);
1486
1487 assert_eq!(probas_before.len(), probas_after.len());
1489 for i in 0..probas_before.len() {
1490 assert_eq!(probas_before[i], probas_after[i]);
1491 }
1492
1493 std::fs::remove_file(path).ok();
1494 }
1495
1496 #[test]
1498 fn test_knn_new() {
1499 let knn = KNearestNeighbors::new(3);
1500 assert_eq!(knn.k, 3);
1501 assert_eq!(knn.metric, DistanceMetric::Euclidean);
1502 assert!(!knn.weights);
1503 }
1504
1505 #[test]
1506 fn test_knn_basic_fit_predict() {
1507 let x = Matrix::from_vec(
1509 6,
1510 2,
1511 vec![
1512 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, ],
1519 )
1520 .expect("6x2 matrix with 12 values");
1521 let y = vec![0, 0, 0, 1, 1, 1];
1522
1523 let mut knn = KNearestNeighbors::new(3);
1524 knn.fit(&x, &y)
1525 .expect("Training should succeed with valid data");
1526
1527 let test1 = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
1529 let pred1 = knn.predict(&test1).expect("Prediction should succeed");
1530 assert_eq!(pred1[0], 0);
1531
1532 let test2 = Matrix::from_vec(1, 2, vec![5.5, 5.5]).expect("1x2 test matrix");
1534 let pred2 = knn.predict(&test2).expect("Prediction should succeed");
1535 assert_eq!(pred2[0], 1);
1536 }
1537
1538 #[test]
1539 fn test_knn_k_equals_one() {
1540 let x = Matrix::from_vec(
1542 4,
1543 2,
1544 vec![
1545 0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, ],
1550 )
1551 .expect("4x2 matrix with 8 values");
1552 let y = vec![0, 1, 0, 1];
1553
1554 let mut knn = KNearestNeighbors::new(1);
1555 knn.fit(&x, &y)
1556 .expect("Training should succeed with valid data");
1557
1558 let predictions = knn.predict(&x).expect("Prediction should succeed");
1560 assert_eq!(predictions, y);
1561 }
1562
1563 #[test]
1564 fn test_knn_euclidean_distance() {
1565 let x = Matrix::from_vec(
1566 3,
1567 2,
1568 vec![
1569 0.0, 0.0, 3.0, 4.0, 1.0, 1.0, ],
1573 )
1574 .expect("3x2 matrix with 6 values");
1575 let y = vec![0, 1, 0];
1576
1577 let mut knn = KNearestNeighbors::new(1).with_metric(DistanceMetric::Euclidean);
1578 knn.fit(&x, &y)
1579 .expect("Training should succeed with valid data");
1580
1581 let test = Matrix::from_vec(1, 2, vec![1.5, 2.0]).expect("1x2 test matrix");
1583 let pred = knn.predict(&test).expect("Prediction should succeed");
1584 assert_eq!(pred[0], 0);
1585 }
1586
1587 #[test]
1588 fn test_knn_manhattan_distance() {
1589 let x = Matrix::from_vec(
1590 3,
1591 2,
1592 vec![
1593 0.0, 0.0, 2.0, 2.0, 1.0, 0.0, ],
1597 )
1598 .expect("3x2 matrix with 6 values");
1599 let y = vec![0, 1, 0];
1600
1601 let mut knn = KNearestNeighbors::new(1).with_metric(DistanceMetric::Manhattan);
1602 knn.fit(&x, &y)
1603 .expect("Training should succeed with valid data");
1604
1605 let test = Matrix::from_vec(1, 2, vec![0.5, 0.0]).expect("1x2 test matrix");
1606 let pred = knn.predict(&test).expect("Prediction should succeed");
1607 assert_eq!(pred[0], 0); }
1609
1610 #[test]
1611 fn test_knn_minkowski_distance() {
1612 let x = Matrix::from_vec(
1613 3,
1614 2,
1615 vec![
1616 0.0, 0.0, 3.0, 4.0, 1.0, 1.0, ],
1620 )
1621 .expect("3x2 matrix with 6 values");
1622 let y = vec![0, 1, 0];
1623
1624 let mut knn = KNearestNeighbors::new(1).with_metric(DistanceMetric::Minkowski(3.0));
1626 knn.fit(&x, &y)
1627 .expect("Training should succeed with valid data");
1628
1629 let test = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
1630 let pred = knn.predict(&test).expect("Prediction should succeed");
1631 assert_eq!(pred[0], 0);
1632 }
1633
1634 #[test]
1635 fn test_knn_weighted_voting() {
1636 let x = Matrix::from_vec(
1638 5,
1639 1,
1640 vec![
1641 0.0, 0.1, 5.0, 5.5, 6.0, ],
1647 )
1648 .expect("5x1 matrix with 5 values");
1649 let y = vec![0, 0, 1, 1, 1];
1650
1651 let mut knn_weighted = KNearestNeighbors::new(3).with_weights(true);
1652 knn_weighted
1653 .fit(&x, &y)
1654 .expect("Training should succeed with valid data");
1655
1656 let test = Matrix::from_vec(1, 1, vec![0.05]).expect("1x1 test matrix");
1658 let pred = knn_weighted
1659 .predict(&test)
1660 .expect("Prediction should succeed");
1661 assert_eq!(pred[0], 0); }
1663
1664 #[test]
1665 fn test_knn_predict_proba() {
1666 let x = Matrix::from_vec(
1667 6,
1668 2,
1669 vec![
1670 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, ],
1677 )
1678 .expect("6x2 matrix with 12 values");
1679 let y = vec![0, 0, 0, 1, 1, 1];
1680
1681 let mut knn = KNearestNeighbors::new(3);
1682 knn.fit(&x, &y)
1683 .expect("Training should succeed with valid data");
1684
1685 let test = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
1686 let probas = knn
1687 .predict_proba(&test)
1688 .expect("Probability prediction should succeed");
1689
1690 assert_eq!(probas.len(), 1);
1691 assert_eq!(probas[0].len(), 2); let sum: f32 = probas[0].iter().sum();
1695 assert!((sum - 1.0).abs() < 1e-6);
1696
1697 assert!(probas[0][0] > probas[0][1]);
1699 }
1700
1701 #[test]
1702 fn test_knn_multiclass() {
1703 let x = Matrix::from_vec(
1705 9,
1706 2,
1707 vec![
1708 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 5.0, 5.0, 5.0, 6.0, 6.0, 5.0, 10.0, 10.0, 10.0, 11.0, 11.0, 10.0, ],
1718 )
1719 .expect("9x2 matrix with 18 values");
1720 let y = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1721
1722 let mut knn = KNearestNeighbors::new(3);
1723 knn.fit(&x, &y)
1724 .expect("Training should succeed with valid data");
1725
1726 let test1 = Matrix::from_vec(1, 2, vec![0.5, 0.5]).expect("1x2 test matrix");
1728 assert_eq!(
1729 knn.predict(&test1).expect("Prediction should succeed")[0],
1730 0
1731 );
1732
1733 let test2 = Matrix::from_vec(1, 2, vec![5.5, 5.5]).expect("1x2 test matrix");
1734 assert_eq!(
1735 knn.predict(&test2).expect("Prediction should succeed")[0],
1736 1
1737 );
1738
1739 let test3 = Matrix::from_vec(1, 2, vec![10.5, 10.5]).expect("1x2 test matrix");
1740 assert_eq!(
1741 knn.predict(&test3).expect("Prediction should succeed")[0],
1742 2
1743 );
1744 }
1745
1746 #[test]
1747 fn test_knn_not_fitted_error() {
1748 let knn = KNearestNeighbors::new(3);
1749 let test = Matrix::from_vec(1, 2, vec![0.0, 0.0]).expect("1x2 test matrix");
1750
1751 let result = knn.predict(&test);
1752 assert!(result.is_err());
1753 assert_eq!(
1754 result.expect_err("Should fail when predicting with unfitted model"),
1755 "Model not fitted"
1756 );
1757 }
1758
1759 #[test]
1760 fn test_knn_dimension_mismatch() {
1761 let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
1762 .expect("3x2 matrix with 6 values");
1763 let y = vec![0, 1, 0];
1764
1765 let mut knn = KNearestNeighbors::new(1);
1766 knn.fit(&x, &y)
1767 .expect("Training should succeed with valid data");
1768
1769 let test = Matrix::from_vec(1, 3, vec![0.0, 0.0, 0.0]).expect("1x3 test matrix");
1771 let result = knn.predict(&test);
1772
1773 assert!(result.is_err());
1774 assert_eq!(
1775 result.expect_err("Should fail with dimension mismatch"),
1776 "Feature dimension mismatch"
1777 );
1778 }
1779
1780 #[test]
1781 fn test_knn_sample_mismatch() {
1782 let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
1783 .expect("3x2 matrix with 6 values");
1784 let y = vec![0, 1]; let mut knn = KNearestNeighbors::new(1);
1787 let result = knn.fit(&x, &y);
1788
1789 assert!(result.is_err());
1790 assert_eq!(
1791 result.expect_err("Should fail with sample mismatch"),
1792 "Number of samples in X and y must match"
1793 );
1794 }
1795
1796 #[test]
1797 fn test_knn_k_too_large() {
1798 let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
1799 .expect("3x2 matrix with 6 values");
1800 let y = vec![0, 1, 0];
1801
1802 let mut knn = KNearestNeighbors::new(5); let result = knn.fit(&x, &y);
1804
1805 assert!(result.is_err());
1806 assert_eq!(
1807 result.expect_err("Should fail when k exceeds sample count"),
1808 "k cannot be larger than number of training samples"
1809 );
1810 }
1811
1812 #[test]
1813 fn test_knn_empty_data() {
1814 let x = Matrix::from_vec(0, 2, vec![]).expect("0x2 empty matrix");
1815 let y = vec![];
1816
1817 let mut knn = KNearestNeighbors::new(1);
1818 let result = knn.fit(&x, &y);
1819
1820 assert!(result.is_err());
1821 assert_eq!(
1822 result.expect_err("Should fail with empty data"),
1823 "Cannot fit with zero samples"
1824 );
1825 }
1826
1827 #[test]
1828 fn test_knn_builder_pattern() {
1829 let knn = KNearestNeighbors::new(5)
1830 .with_metric(DistanceMetric::Manhattan)
1831 .with_weights(true);
1832
1833 assert_eq!(knn.k, 5);
1834 assert_eq!(knn.metric, DistanceMetric::Manhattan);
1835 assert!(knn.weights);
1836 }
1837
1838 #[test]
1839 fn test_knn_distance_symmetry() {
1840 let x = Matrix::from_vec(
1842 2,
1843 2,
1844 vec![
1845 1.0, 2.0, 3.0, 4.0, ],
1848 )
1849 .expect("2x2 matrix with 4 values");
1850 let y = vec![0, 1];
1851
1852 let mut knn = KNearestNeighbors::new(1);
1853 knn.fit(&x, &y)
1854 .expect("Training should succeed with valid data");
1855
1856 let dist_ab = knn.compute_distance(&x, 0, &x, 1, 2);
1858 let dist_ba = knn.compute_distance(&x, 1, &x, 0, 2);
1859
1860 assert!((dist_ab - dist_ba).abs() < 1e-6);
1861 }
1862
1863 #[test]
1864 fn test_knn_perfect_fit_with_k1() {
1865 let x = Matrix::from_vec(
1867 10,
1868 3,
1869 vec![
1870 1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 4.0, 5.0, 6.0, 5.0, 6.0, 7.0, 6.0,
1871 7.0, 8.0, 7.0, 8.0, 9.0, 8.0, 9.0, 10.0, 9.0, 10.0, 11.0, 10.0, 11.0, 12.0,
1872 ],
1873 )
1874 .expect("10x3 matrix with 30 values");
1875 let y = vec![0, 0, 1, 1, 0, 1, 0, 1, 0, 1];
1876
1877 let mut knn = KNearestNeighbors::new(1);
1878 knn.fit(&x, &y)
1879 .expect("Training should succeed with valid data");
1880
1881 let predictions = knn.predict(&x).expect("Prediction should succeed");
1882 assert_eq!(predictions, y);
1883 }
1884
1885 #[test]
1888 fn test_gaussian_nb_new() {
1889 let model = GaussianNB::new();
1890 assert!(model.class_priors.is_none());
1891 assert!(model.means.is_none());
1892 assert!(model.variances.is_none());
1893 assert_eq!(model.var_smoothing, 1e-9);
1894 }
1895
1896 #[test]
1897 fn test_gaussian_nb_builder() {
1898 let model = GaussianNB::new().with_var_smoothing(1e-8);
1899 assert_eq!(model.var_smoothing, 1e-8);
1900 }
1901
1902 #[test]
1903 fn test_gaussian_nb_basic_fit_predict() {
1904 let x = Matrix::from_vec(
1906 4,
1907 2,
1908 vec![
1909 0.0, 0.0, 0.1, 0.1, 1.0, 1.0, 0.9, 0.9, ],
1914 )
1915 .expect("4x2 matrix with 8 values");
1916 let y = vec![0, 0, 1, 1];
1917
1918 let mut model = GaussianNB::new();
1919 model
1920 .fit(&x, &y)
1921 .expect("Training should succeed with valid data");
1922
1923 let predictions = model.predict(&x).expect("Prediction should succeed");
1924 assert_eq!(predictions, y);
1925 }
1926
1927 #[test]
1928 fn test_gaussian_nb_multiclass() {
1929 let x = Matrix::from_vec(
1931 9,
1932 2,
1933 vec![
1934 0.0, 0.0, 0.1, 0.1, 0.0, 0.1, 5.0, 5.0, 5.1, 5.1, 5.0, 5.1, -5.0, -5.0, -5.1, -5.1, -5.0, -5.1, ],
1944 )
1945 .expect("9x2 matrix with 18 values");
1946 let y = vec![0, 0, 0, 1, 1, 1, 2, 2, 2];
1947
1948 let mut model = GaussianNB::new();
1949 model
1950 .fit(&x, &y)
1951 .expect("Training should succeed with valid data");
1952
1953 let predictions = model.predict(&x).expect("Prediction should succeed");
1954 assert_eq!(predictions, y);
1955 }
1956
1957 #[test]
1958 fn test_gaussian_nb_predict_proba() {
1959 let x = Matrix::from_vec(
1960 4,
1961 2,
1962 vec![
1963 0.0, 0.0, 0.1, 0.1, 1.0, 1.0, 0.9, 0.9, ],
1968 )
1969 .expect("4x2 matrix with 8 values");
1970 let y = vec![0, 0, 1, 1];
1971
1972 let mut model = GaussianNB::new();
1973 model
1974 .fit(&x, &y)
1975 .expect("Training should succeed with valid data");
1976
1977 let probabilities = model
1978 .predict_proba(&x)
1979 .expect("Probability prediction should succeed");
1980
1981 assert_eq!(probabilities.len(), 4);
1983
1984 for probs in &probabilities {
1986 assert_eq!(probs.len(), 2);
1987 let sum: f32 = probs.iter().sum();
1988 assert!((sum - 1.0).abs() < 1e-5);
1989 }
1990
1991 assert!(probabilities[0][0] > 0.5);
1993
1994 assert!(probabilities[3][1] > 0.5);
1996 }
1997
1998 #[test]
1999 fn test_gaussian_nb_not_fitted_error() {
2000 let model = GaussianNB::new();
2001 let x_test = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 test matrix");
2002
2003 let result = model.predict(&x_test);
2004 assert!(result.is_err());
2005 assert_eq!(
2006 result.expect_err("Should fail when predicting with unfitted model"),
2007 "Model not fitted"
2008 );
2009 }
2010
2011 #[test]
2012 fn test_gaussian_nb_empty_data() {
2013 let x = Matrix::from_vec(0, 2, vec![]).expect("0x2 empty matrix");
2014 let y: Vec<usize> = vec![];
2015
2016 let mut model = GaussianNB::new();
2017 let result = model.fit(&x, &y);
2018
2019 assert!(result.is_err());
2020 assert_eq!(
2021 result.expect_err("Should fail with empty data"),
2022 "Cannot fit with empty data"
2023 );
2024 }
2025
2026 #[test]
2027 fn test_gaussian_nb_sample_mismatch() {
2028 let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
2029 .expect("3x2 matrix with 6 values");
2030 let y = vec![0, 1]; let mut model = GaussianNB::new();
2033 let result = model.fit(&x, &y);
2034
2035 assert!(result.is_err());
2036 assert_eq!(
2037 result.expect_err("Should fail with sample mismatch"),
2038 "Number of samples in X and y must match"
2039 );
2040 }
2041
2042 #[test]
2043 fn test_gaussian_nb_single_class() {
2044 let x = Matrix::from_vec(3, 2, vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0])
2045 .expect("3x2 matrix with 6 values");
2046 let y = vec![0, 0, 0]; let mut model = GaussianNB::new();
2049 let result = model.fit(&x, &y);
2050
2051 assert!(result.is_err());
2052 assert_eq!(
2053 result.expect_err("Should fail with single class"),
2054 "Need at least 2 classes"
2055 );
2056 }
2057
2058 #[test]
2059 fn test_gaussian_nb_dimension_mismatch() {
2060 let x_train = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.1, 0.1, 1.0, 1.0, 0.9, 0.9])
2061 .expect("4x2 training matrix");
2062 let y_train = vec![0, 0, 1, 1];
2063
2064 let mut model = GaussianNB::new();
2065 model
2066 .fit(&x_train, &y_train)
2067 .expect("Training should succeed with valid data");
2068
2069 let x_test =
2070 Matrix::from_vec(2, 3, vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]).expect("2x3 test matrix");
2071 let result = model.predict(&x_test);
2072
2073 assert!(result.is_err());
2074 assert_eq!(
2075 result.expect_err("Should fail with dimension mismatch"),
2076 "Feature dimension mismatch"
2077 );
2078 }
2079
2080 #[test]
2081 fn test_gaussian_nb_balanced_classes() {
2082 let x = Matrix::from_vec(
2084 6,
2085 2,
2086 vec![
2087 0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 1.0, 1.0, 1.1, 1.1, 1.2, 1.2, ],
2094 )
2095 .expect("6x2 matrix with 12 values");
2096 let y = vec![0, 0, 0, 1, 1, 1];
2097
2098 let mut model = GaussianNB::new();
2099 model
2100 .fit(&x, &y)
2101 .expect("Training should succeed with valid data");
2102
2103 let priors = model
2105 .class_priors
2106 .expect("Model is fitted and has class priors");
2107 assert!((priors[0] - 0.5).abs() < 1e-5);
2108 assert!((priors[1] - 0.5).abs() < 1e-5);
2109 }
2110
2111 #[test]
2112 fn test_gaussian_nb_imbalanced_classes() {
2113 let x = Matrix::from_vec(
2115 4,
2116 2,
2117 vec![
2118 0.0, 0.0, 1.0, 1.0, 1.1, 1.1, 1.2, 1.2, ],
2123 )
2124 .expect("4x2 matrix with 8 values");
2125 let y = vec![0, 1, 1, 1];
2126
2127 let mut model = GaussianNB::new();
2128 model
2129 .fit(&x, &y)
2130 .expect("Training should succeed with valid data");
2131
2132 let priors = model
2134 .class_priors
2135 .expect("Model is fitted and has class priors");
2136 assert!((priors[0] - 0.25).abs() < 1e-5); assert!((priors[1] - 0.75).abs() < 1e-5); }
2139
2140 #[test]
2141 fn test_gaussian_nb_var_smoothing() {
2142 let x = Matrix::from_vec(
2144 4,
2145 2,
2146 vec![
2147 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, ],
2152 )
2153 .expect("4x2 matrix with 8 values");
2154 let y = vec![0, 0, 1, 1];
2155
2156 let mut model = GaussianNB::new().with_var_smoothing(1e-8);
2157 model
2158 .fit(&x, &y)
2159 .expect("Training should succeed with valid data");
2160
2161 let predictions = model.predict(&x).expect("Prediction should succeed");
2163 assert_eq!(predictions, y);
2164
2165 let probabilities = model
2166 .predict_proba(&x)
2167 .expect("Probability prediction should succeed");
2168 for probs in &probabilities {
2169 for &p in probs {
2170 assert!(p.is_finite());
2171 assert!((0.0..=1.0).contains(&p));
2172 }
2173 }
2174 }
2175
2176 #[test]
2177 fn test_gaussian_nb_probabilities_sum_to_one() {
2178 let x = Matrix::from_vec(
2180 10,
2181 3,
2182 vec![
2183 0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3, 1.0, 1.0, 1.0, 1.1,
2184 1.1, 1.1, 1.2, 1.2, 1.2, 1.3, 1.3, 1.3, 2.0, 2.0, 2.0, 2.1, 2.1, 2.1,
2185 ],
2186 )
2187 .expect("10x3 matrix with 30 values");
2188 let y = vec![0, 0, 0, 0, 1, 1, 1, 1, 2, 2];
2189
2190 let mut model = GaussianNB::new();
2191 model
2192 .fit(&x, &y)
2193 .expect("Training should succeed with valid data");
2194
2195 let probabilities = model
2196 .predict_proba(&x)
2197 .expect("Probability prediction should succeed");
2198
2199 for probs in &probabilities {
2200 let sum: f32 = probs.iter().sum();
2201 assert!((sum - 1.0).abs() < 1e-5);
2202 }
2203 }
2204
2205 #[test]
2206 fn test_gaussian_nb_default() {
2207 let model1 = GaussianNB::new();
2208 let model2 = GaussianNB::default();
2209
2210 assert_eq!(model1.var_smoothing, model2.var_smoothing);
2211 }
2212
2213 #[test]
2214 fn test_gaussian_nb_class_separation() {
2215 let x = Matrix::from_vec(
2217 4,
2218 2,
2219 vec![
2220 0.0, 0.0, 0.1, 0.1, 10.0, 10.0, 10.1, 10.1, ],
2225 )
2226 .expect("4x2 matrix with 8 values");
2227 let y = vec![0, 0, 1, 1];
2228
2229 let mut model = GaussianNB::new();
2230 model
2231 .fit(&x, &y)
2232 .expect("Training should succeed with valid data");
2233
2234 let probabilities = model
2235 .predict_proba(&x)
2236 .expect("Probability prediction should succeed");
2237
2238 assert!(probabilities[0][0] > 0.99);
2240
2241 assert!(probabilities[3][1] > 0.99);
2243 }
2244
2245 #[test]
2248 fn test_linear_svm_new() {
2249 let svm = LinearSVM::new();
2250 assert!(svm.weights.is_none());
2251 assert_eq!(svm.bias, 0.0);
2252 assert_eq!(svm.c, 1.0);
2253 assert_eq!(svm.learning_rate, 0.01);
2254 assert_eq!(svm.max_iter, 1000);
2255 assert_eq!(svm.tol, 1e-4);
2256 }
2257
2258 #[test]
2259 fn test_linear_svm_builder() {
2260 let svm = LinearSVM::new()
2261 .with_c(0.5)
2262 .with_learning_rate(0.001)
2263 .with_max_iter(500)
2264 .with_tolerance(1e-5);
2265
2266 assert_eq!(svm.c, 0.5);
2267 assert_eq!(svm.learning_rate, 0.001);
2268 assert_eq!(svm.max_iter, 500);
2269 assert_eq!(svm.tol, 1e-5);
2270 }
2271
2272 #[test]
2273 fn test_linear_svm_fit_simple() {
2274 let x = Matrix::from_vec(
2276 4,
2277 2,
2278 vec![
2279 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
2284 )
2285 .expect("4x2 matrix with 8 values");
2286 let y = vec![0, 0, 1, 1];
2287
2288 let mut svm = LinearSVM::new().with_max_iter(1000).with_learning_rate(0.1);
2289
2290 let result = svm.fit(&x, &y);
2291 assert!(result.is_ok());
2292 assert!(svm.weights.is_some());
2293 }
2294
2295 #[test]
2296 fn test_linear_svm_predict_simple() {
2297 let x = Matrix::from_vec(
2299 4,
2300 2,
2301 vec![
2302 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
2307 )
2308 .expect("4x2 matrix with 8 values");
2309 let y = vec![0, 0, 1, 1];
2310
2311 let mut svm = LinearSVM::new().with_max_iter(1000).with_learning_rate(0.1);
2312 svm.fit(&x, &y)
2313 .expect("Training should succeed with valid data");
2314
2315 let predictions = svm.predict(&x).expect("Prediction should succeed");
2316 assert_eq!(predictions.len(), 4);
2317
2318 let correct = predictions
2320 .iter()
2321 .zip(y.iter())
2322 .filter(|(pred, true_label)| *pred == *true_label)
2323 .count();
2324
2325 assert!(correct >= 3);
2327 }
2328
2329 #[test]
2330 fn test_linear_svm_decision_function() {
2331 let x = Matrix::from_vec(
2332 4,
2333 2,
2334 vec![
2335 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
2340 )
2341 .expect("4x2 matrix with 8 values");
2342 let y = vec![0, 0, 1, 1];
2343
2344 let mut svm = LinearSVM::new().with_max_iter(1000).with_learning_rate(0.1);
2345 svm.fit(&x, &y)
2346 .expect("Training should succeed with valid data");
2347
2348 let decisions = svm
2349 .decision_function(&x)
2350 .expect("Decision function should succeed");
2351 assert_eq!(decisions.len(), 4);
2352
2353 }
2357
2358 #[test]
2359 fn test_linear_svm_predict_untrained() {
2360 let svm = LinearSVM::new();
2361 let x = Matrix::from_vec(2, 2, vec![0.0, 0.0, 1.0, 1.0]).expect("2x2 matrix with 4 values");
2362
2363 let result = svm.predict(&x);
2364 assert!(result.is_err());
2365 assert_eq!(
2366 result.expect_err("Should fail when predicting with untrained model"),
2367 "Model not trained yet"
2368 );
2369 }
2370
2371 #[test]
2372 fn test_linear_svm_dimension_mismatch() {
2373 let x_train = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
2374 .expect("4x2 training matrix");
2375 let y = vec![0, 0, 1, 1];
2376
2377 let mut svm = LinearSVM::new();
2378 svm.fit(&x_train, &y)
2379 .expect("Training should succeed with valid data");
2380
2381 let x_test =
2383 Matrix::from_vec(2, 3, vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]).expect("2x3 test matrix");
2384 let result = svm.predict(&x_test);
2385 assert!(result.is_err());
2386 assert_eq!(
2387 result.expect_err("Should fail with dimension mismatch"),
2388 "Feature dimension mismatch"
2389 );
2390 }
2391
2392 #[test]
2393 fn test_linear_svm_empty_data() {
2394 let x = Matrix::from_vec(0, 2, vec![]).expect("0x2 empty matrix");
2395 let y = vec![];
2396
2397 let mut svm = LinearSVM::new();
2398 let result = svm.fit(&x, &y);
2399 assert!(result.is_err());
2400 assert_eq!(
2401 result.expect_err("Should fail with empty data"),
2402 "Cannot fit with 0 samples"
2403 );
2404 }
2405
2406 #[test]
2407 fn test_linear_svm_mismatched_samples() {
2408 let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
2409 .expect("4x2 matrix with 8 values");
2410 let y = vec![0, 0, 1]; let mut svm = LinearSVM::new();
2413 let result = svm.fit(&x, &y);
2414 assert!(result.is_err());
2415 assert_eq!(
2416 result.expect_err("Should fail with mismatched sample counts"),
2417 "x and y must have the same number of samples"
2418 );
2419 }
2420
2421 #[test]
2422 fn test_linear_svm_regularization_c() {
2423 let x = Matrix::from_vec(
2424 6,
2425 2,
2426 vec![
2427 0.0, 0.0, 0.1, 0.1, 0.0, 0.2, 1.0, 1.0, 0.9, 0.9, 1.0, 0.8, ],
2434 )
2435 .expect("6x2 matrix with 12 values");
2436 let y = vec![0, 0, 0, 1, 1, 1];
2437
2438 let mut svm_high_c = LinearSVM::new()
2440 .with_c(10.0)
2441 .with_max_iter(1000)
2442 .with_learning_rate(0.1);
2443 svm_high_c
2444 .fit(&x, &y)
2445 .expect("Training should succeed with valid data");
2446 let pred_high_c = svm_high_c.predict(&x).expect("Prediction should succeed");
2447
2448 let mut svm_low_c = LinearSVM::new()
2450 .with_c(0.1)
2451 .with_max_iter(1000)
2452 .with_learning_rate(0.1);
2453 svm_low_c
2454 .fit(&x, &y)
2455 .expect("Training should succeed with valid data");
2456 let pred_low_c = svm_low_c.predict(&x).expect("Prediction should succeed");
2457
2458 assert_eq!(pred_high_c.len(), 6);
2460 assert_eq!(pred_low_c.len(), 6);
2461 }
2462
2463 #[test]
2464 fn test_linear_svm_binary_classification() {
2465 let x = Matrix::from_vec(
2467 10,
2468 2,
2469 vec![
2470 0.0, 0.0, 0.1, 0.1, 0.0, 0.2, 0.2, 0.0, 0.1,
2472 0.2, 1.0, 1.0, 0.9, 0.9, 1.0, 0.8, 0.8, 1.0, 0.9, 1.1,
2474 ],
2475 )
2476 .expect("10x2 matrix with 20 values");
2477 let y = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1];
2478
2479 let mut svm = LinearSVM::new()
2480 .with_c(1.0)
2481 .with_max_iter(2000)
2482 .with_learning_rate(0.1);
2483
2484 svm.fit(&x, &y)
2485 .expect("Training should succeed with valid data");
2486 let predictions = svm.predict(&x).expect("Prediction should succeed");
2487
2488 let correct = predictions
2490 .iter()
2491 .zip(y.iter())
2492 .filter(|(pred, true_label)| *pred == *true_label)
2493 .count();
2494
2495 assert!(
2497 correct >= 8,
2498 "Expected at least 8/10 correct, got {correct}/10"
2499 );
2500 }
2501
2502 #[test]
2503 fn test_linear_svm_convergence() {
2504 let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
2505 .expect("4x2 matrix with 8 values");
2506 let y = vec![0, 0, 1, 1];
2507
2508 let mut svm_few_iter = LinearSVM::new().with_max_iter(10).with_learning_rate(0.01);
2510 svm_few_iter
2511 .fit(&x, &y)
2512 .expect("Training should succeed with valid data");
2513
2514 let mut svm_many_iter = LinearSVM::new().with_max_iter(2000).with_learning_rate(0.1);
2516 svm_many_iter
2517 .fit(&x, &y)
2518 .expect("Training should succeed with valid data");
2519
2520 assert!(svm_few_iter.weights.is_some());
2522 assert!(svm_many_iter.weights.is_some());
2523 }
2524
2525 #[test]
2526 fn test_linear_svm_default() {
2527 let svm1 = LinearSVM::new();
2528 let svm2 = LinearSVM::default();
2529
2530 assert_eq!(svm1.c, svm2.c);
2531 assert_eq!(svm1.learning_rate, svm2.learning_rate);
2532 assert_eq!(svm1.max_iter, svm2.max_iter);
2533 }
2534}