1use ferrolearn_core::error::FerroError;
44use ferrolearn_core::introspection::HasCoefficients;
45use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
46use ferrolearn_core::traits::{Fit, PartialFit, Predict};
47use ndarray::{Array1, Array2, ScalarOperand};
48use num_traits::Float;
49use rand::SeedableRng;
50use rand::seq::SliceRandom;
51
52pub trait Loss<F: Float>: Clone + Send + Sync {
60 fn loss(&self, y_true: F, y_pred: F) -> F;
62
63 fn gradient(&self, y_true: F, y_pred: F) -> F;
65}
66
67#[derive(Debug, Clone, Copy)]
71pub struct Hinge;
72
73impl<F: Float> Loss<F> for Hinge {
74 fn loss(&self, y_true: F, y_pred: F) -> F {
75 let margin = y_true * y_pred;
76 if margin < F::one() {
77 F::one() - margin
78 } else {
79 F::zero()
80 }
81 }
82
83 fn gradient(&self, y_true: F, y_pred: F) -> F {
84 let margin = y_true * y_pred;
85 if margin < F::one() {
86 -y_true
87 } else {
88 F::zero()
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy)]
97pub struct LogLoss;
98
99impl<F: Float> Loss<F> for LogLoss {
100 fn loss(&self, y_true: F, y_pred: F) -> F {
101 let z = y_true * y_pred;
102 if z > F::from(18.0).unwrap() {
103 (-z).exp()
104 } else if z < F::from(-18.0).unwrap() {
105 -z
106 } else {
107 (F::one() + (-z).exp()).ln()
108 }
109 }
110
111 fn gradient(&self, y_true: F, y_pred: F) -> F {
112 let z = y_true * y_pred;
113 let exp_nz = if z > F::from(18.0).unwrap() {
114 (-z).exp()
115 } else if z < F::from(-18.0).unwrap() {
116 F::from(1e18).unwrap()
117 } else {
118 (-z).exp()
119 };
120 -y_true * exp_nz / (F::one() + exp_nz)
121 }
122}
123
124#[derive(Debug, Clone, Copy)]
128pub struct SquaredError;
129
130impl<F: Float> Loss<F> for SquaredError {
131 fn loss(&self, y_true: F, y_pred: F) -> F {
132 let diff = y_true - y_pred;
133 F::from(0.5).unwrap() * diff * diff
134 }
135
136 fn gradient(&self, y_true: F, y_pred: F) -> F {
137 y_pred - y_true
138 }
139}
140
141#[derive(Debug, Clone, Copy)]
150pub struct ModifiedHuber;
151
152impl<F: Float> Loss<F> for ModifiedHuber {
153 fn loss(&self, y_true: F, y_pred: F) -> F {
154 let z = y_true * y_pred;
155 if z >= -F::one() {
156 let margin = F::one() - z;
157 if margin > F::zero() {
158 margin * margin
159 } else {
160 F::zero()
161 }
162 } else {
163 -F::from(4.0).unwrap() * z
164 }
165 }
166
167 fn gradient(&self, y_true: F, y_pred: F) -> F {
168 let z = y_true * y_pred;
169 if z >= -F::one() {
170 if z < F::one() {
171 F::from(-2.0).unwrap() * y_true * (F::one() - z)
172 } else {
173 F::zero()
174 }
175 } else {
176 -F::from(4.0).unwrap() * y_true
177 }
178 }
179}
180
181#[derive(Debug, Clone, Copy)]
186pub struct Huber<F> {
187 pub epsilon: F,
189}
190
191impl<F: Float + Send + Sync> Loss<F> for Huber<F> {
192 fn loss(&self, y_true: F, y_pred: F) -> F {
193 let diff = y_true - y_pred;
194 let abs_diff = diff.abs();
195 if abs_diff <= self.epsilon {
196 F::from(0.5).unwrap() * diff * diff
197 } else {
198 self.epsilon * (abs_diff - F::from(0.5).unwrap() * self.epsilon)
199 }
200 }
201
202 fn gradient(&self, y_true: F, y_pred: F) -> F {
203 let diff = y_pred - y_true;
204 let abs_diff = diff.abs();
205 if abs_diff <= self.epsilon {
206 diff
207 } else if diff > F::zero() {
208 self.epsilon
209 } else {
210 -self.epsilon
211 }
212 }
213}
214
215#[derive(Debug, Clone, Copy)]
219pub struct EpsilonInsensitive<F> {
220 pub epsilon: F,
222}
223
224impl<F: Float + Send + Sync> Loss<F> for EpsilonInsensitive<F> {
225 fn loss(&self, y_true: F, y_pred: F) -> F {
226 let diff = (y_true - y_pred).abs();
227 if diff > self.epsilon {
228 diff - self.epsilon
229 } else {
230 F::zero()
231 }
232 }
233
234 fn gradient(&self, y_true: F, y_pred: F) -> F {
235 let diff = y_pred - y_true;
236 if diff > self.epsilon {
237 F::one()
238 } else if diff < -self.epsilon {
239 -F::one()
240 } else {
241 F::zero()
242 }
243 }
244}
245
246#[derive(Debug, Clone, Copy)]
252pub enum LearningRateSchedule<F> {
253 Constant,
255 Optimal,
257 InvScaling,
259 Adaptive,
262 #[doc(hidden)]
263 _Phantom(std::marker::PhantomData<F>),
264}
265
266fn compute_lr<F: Float>(
268 schedule: &LearningRateSchedule<F>,
269 eta0: F,
270 alpha: F,
271 power_t: F,
272 t: usize,
273) -> F {
274 let t_f = F::from(t.max(1)).unwrap();
275 match schedule {
276 LearningRateSchedule::Constant => eta0,
277 LearningRateSchedule::Optimal => F::one() / (alpha * t_f),
278 LearningRateSchedule::InvScaling => eta0 / t_f.powf(power_t),
279 LearningRateSchedule::Adaptive => eta0,
280 LearningRateSchedule::_Phantom(_) => unreachable!(),
281 }
282}
283
284#[derive(Debug, Clone, Copy)]
290pub enum ClassifierLoss {
291 Hinge,
293 Log,
295 SquaredError,
297 ModifiedHuber,
299}
300
301#[derive(Debug, Clone, Copy)]
303pub enum RegressorLoss<F> {
304 SquaredError,
306 Huber(F),
308 EpsilonInsensitive(F),
310}
311
312#[derive(Debug, Clone)]
343pub struct SGDClassifier<F> {
344 pub loss: ClassifierLoss,
346 pub learning_rate: LearningRateSchedule<F>,
348 pub eta0: F,
350 pub alpha: F,
352 pub max_iter: usize,
354 pub tol: F,
357 pub random_state: Option<u64>,
359 pub power_t: F,
361}
362
363impl<F: Float> SGDClassifier<F> {
364 #[must_use]
370 pub fn new() -> Self {
371 Self {
372 loss: ClassifierLoss::Hinge,
373 learning_rate: LearningRateSchedule::InvScaling,
374 eta0: F::from(0.01).unwrap(),
375 alpha: F::from(0.0001).unwrap(),
376 max_iter: 1000,
377 tol: F::from(1e-3).unwrap(),
378 random_state: None,
379 power_t: F::from(0.25).unwrap(),
380 }
381 }
382
383 #[must_use]
385 pub fn with_loss(mut self, loss: ClassifierLoss) -> Self {
386 self.loss = loss;
387 self
388 }
389
390 #[must_use]
392 pub fn with_learning_rate(mut self, lr: LearningRateSchedule<F>) -> Self {
393 self.learning_rate = lr;
394 self
395 }
396
397 #[must_use]
399 pub fn with_eta0(mut self, eta0: F) -> Self {
400 self.eta0 = eta0;
401 self
402 }
403
404 #[must_use]
406 pub fn with_alpha(mut self, alpha: F) -> Self {
407 self.alpha = alpha;
408 self
409 }
410
411 #[must_use]
413 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
414 self.max_iter = max_iter;
415 self
416 }
417
418 #[must_use]
420 pub fn with_tol(mut self, tol: F) -> Self {
421 self.tol = tol;
422 self
423 }
424
425 #[must_use]
427 pub fn with_random_state(mut self, seed: u64) -> Self {
428 self.random_state = Some(seed);
429 self
430 }
431
432 #[must_use]
434 pub fn with_power_t(mut self, power_t: F) -> Self {
435 self.power_t = power_t;
436 self
437 }
438}
439
440impl<F: Float> Default for SGDClassifier<F> {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446fn clf_hyper<F: Float>(clf: &SGDClassifier<F>) -> SGDHyper<F> {
448 SGDHyper {
449 learning_rate: clf.learning_rate,
450 eta0: clf.eta0,
451 alpha: clf.alpha,
452 max_iter: clf.max_iter,
453 tol: clf.tol,
454 random_state: clf.random_state,
455 power_t: clf.power_t,
456 }
457}
458
459#[derive(Debug, Clone)]
461struct SGDHyper<F> {
462 learning_rate: LearningRateSchedule<F>,
463 eta0: F,
464 alpha: F,
465 max_iter: usize,
466 tol: F,
467 random_state: Option<u64>,
468 power_t: F,
469}
470
471fn train_binary_sgd<F, L>(
476 x: &Array2<F>,
477 y_binary: &Array1<F>,
478 weights: &mut Array1<F>,
479 intercept: &mut F,
480 loss_fn: &L,
481 hyper: &SGDHyper<F>,
482 initial_t: usize,
483) -> (F, usize)
484where
485 F: Float + ScalarOperand + Send + Sync + 'static,
486 L: Loss<F>,
487{
488 let n_samples = x.nrows();
489 let n_features = x.ncols();
490 let mut t = initial_t;
491 let mut prev_loss = F::infinity();
492 let mut current_eta = hyper.eta0;
493 let mut no_improve_count: usize = 0;
494 let mut indices: Vec<usize> = (0..n_samples).collect();
495
496 let mut rng = match hyper.random_state {
498 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
499 None => rand::rngs::StdRng::from_os_rng(),
500 };
501
502 let mut total_loss = F::zero();
503
504 for _epoch in 0..hyper.max_iter {
505 indices.shuffle(&mut rng);
506 let mut epoch_loss = F::zero();
507
508 for &i in &indices {
509 t += 1;
510
511 let eta = match hyper.learning_rate {
512 LearningRateSchedule::Adaptive => current_eta,
513 _ => compute_lr(
514 &hyper.learning_rate,
515 hyper.eta0,
516 hyper.alpha,
517 hyper.power_t,
518 t,
519 ),
520 };
521
522 let mut y_pred = *intercept;
524 let xi = x.row(i);
525 for j in 0..n_features {
526 y_pred = y_pred + weights[j] * xi[j];
527 }
528
529 let grad = loss_fn.gradient(y_binary[i], y_pred);
530 epoch_loss = epoch_loss + loss_fn.loss(y_binary[i], y_pred);
531
532 for j in 0..n_features {
534 weights[j] = weights[j] - eta * (grad * xi[j] + hyper.alpha * weights[j]);
535 }
536 *intercept = *intercept - eta * grad;
537 }
538
539 epoch_loss = epoch_loss / F::from(n_samples).unwrap();
540 total_loss = epoch_loss;
541
542 if (prev_loss - epoch_loss).abs() < hyper.tol {
544 break;
545 }
546
547 if let LearningRateSchedule::Adaptive = hyper.learning_rate {
549 if epoch_loss >= prev_loss {
550 no_improve_count += 1;
551 if no_improve_count >= 5 {
552 current_eta = current_eta / F::from(2.0).unwrap();
553 no_improve_count = 0;
554 if current_eta < F::from(1e-6).unwrap() {
555 break;
556 }
557 }
558 } else {
559 no_improve_count = 0;
560 }
561 }
562
563 prev_loss = epoch_loss;
564 }
565
566 (total_loss, t)
567}
568
569#[derive(Debug, Clone)]
578pub struct FittedSGDClassifier<F> {
579 weight_matrix: Vec<Array1<F>>,
582 intercepts: Vec<F>,
584 classes: Vec<usize>,
586 n_features: usize,
588 loss: ClassifierLoss,
590 hyper: SGDHyper<F>,
592 t: usize,
594}
595
596impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
597 for SGDClassifier<F>
598{
599 type Fitted = FittedSGDClassifier<F>;
600 type Error = FerroError;
601
602 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSGDClassifier<F>, FerroError> {
613 validate_clf_params(x, y, self.eta0, self.alpha)?;
614
615 let n_features = x.ncols();
616 let mut classes: Vec<usize> = y.to_vec();
617 classes.sort_unstable();
618 classes.dedup();
619
620 if classes.len() < 2 {
621 return Err(FerroError::InsufficientSamples {
622 required: 2,
623 actual: classes.len(),
624 context: "SGDClassifier requires at least 2 distinct classes".into(),
625 });
626 }
627
628 let hyper = clf_hyper(self);
629 let loss_enum = self.loss;
630
631 let (weight_matrix, intercepts, t) =
632 fit_ova(x, y, &classes, n_features, &loss_enum, &hyper, 0)?;
633
634 Ok(FittedSGDClassifier {
635 weight_matrix,
636 intercepts,
637 classes,
638 n_features,
639 loss: loss_enum,
640 hyper,
641 t,
642 })
643 }
644}
645
646fn validate_clf_params<F: Float>(
648 x: &Array2<F>,
649 y: &Array1<usize>,
650 eta0: F,
651 alpha: F,
652) -> Result<(), FerroError> {
653 let n_samples = x.nrows();
654 if n_samples != y.len() {
655 return Err(FerroError::ShapeMismatch {
656 expected: vec![n_samples],
657 actual: vec![y.len()],
658 context: "y length must match number of samples in X".into(),
659 });
660 }
661 if n_samples == 0 {
662 return Err(FerroError::InsufficientSamples {
663 required: 1,
664 actual: 0,
665 context: "SGDClassifier requires at least one sample".into(),
666 });
667 }
668 if eta0 <= F::zero() {
669 return Err(FerroError::InvalidParameter {
670 name: "eta0".into(),
671 reason: "must be positive".into(),
672 });
673 }
674 if alpha < F::zero() {
675 return Err(FerroError::InvalidParameter {
676 name: "alpha".into(),
677 reason: "must be non-negative".into(),
678 });
679 }
680 Ok(())
681}
682
683type OvaResult<F> = (Vec<Array1<F>>, Vec<F>, usize);
685
686fn fit_ova<F: Float + Send + Sync + ScalarOperand + 'static>(
689 x: &Array2<F>,
690 y: &Array1<usize>,
691 classes: &[usize],
692 n_features: usize,
693 loss_enum: &ClassifierLoss,
694 hyper: &SGDHyper<F>,
695 initial_t: usize,
696) -> Result<OvaResult<F>, FerroError> {
697 let n_classes = classes.len();
698 let mut weight_matrix: Vec<Array1<F>> = Vec::with_capacity(n_classes);
699 let mut intercepts: Vec<F> = Vec::with_capacity(n_classes);
700 let mut global_t = initial_t;
701
702 if n_classes == 2 {
703 let y_binary: Array1<F> = y.mapv(|label| {
705 if label == classes[1] {
706 F::one()
707 } else {
708 -F::one()
709 }
710 });
711 let mut w = Array1::<F>::zeros(n_features);
712 let mut b = F::zero();
713 let (_, t) =
714 dispatch_train_binary(x, &y_binary, &mut w, &mut b, loss_enum, hyper, global_t);
715 global_t = t;
716 weight_matrix.push(w);
717 intercepts.push(b);
718 } else {
719 for &cls in classes {
721 let y_binary: Array1<F> =
722 y.mapv(|label| if label == cls { F::one() } else { -F::one() });
723 let mut w = Array1::<F>::zeros(n_features);
724 let mut b = F::zero();
725 let (_, t) =
726 dispatch_train_binary(x, &y_binary, &mut w, &mut b, loss_enum, hyper, global_t);
727 global_t = t;
728 weight_matrix.push(w);
729 intercepts.push(b);
730 }
731 }
732
733 Ok((weight_matrix, intercepts, global_t))
734}
735
736#[allow(clippy::too_many_arguments)]
738fn partial_fit_ova<F: Float + Send + Sync + ScalarOperand + 'static>(
739 x: &Array2<F>,
740 y: &Array1<usize>,
741 classes: &[usize],
742 weight_matrix: &mut [Array1<F>],
743 intercepts: &mut [F],
744 loss_enum: &ClassifierLoss,
745 hyper: &SGDHyper<F>,
746 initial_t: usize,
747) -> usize {
748 let n_classes = classes.len();
749 let mut global_t = initial_t;
750
751 if n_classes == 2 {
752 let y_binary: Array1<F> = y.mapv(|label| {
753 if label == classes[1] {
754 F::one()
755 } else {
756 -F::one()
757 }
758 });
759 let (_, t) = dispatch_train_binary(
760 x,
761 &y_binary,
762 &mut weight_matrix[0],
763 &mut intercepts[0],
764 loss_enum,
765 hyper,
766 global_t,
767 );
768 global_t = t;
769 } else {
770 for (idx, &cls) in classes.iter().enumerate() {
771 let y_binary: Array1<F> =
772 y.mapv(|label| if label == cls { F::one() } else { -F::one() });
773 let (_, t) = dispatch_train_binary(
774 x,
775 &y_binary,
776 &mut weight_matrix[idx],
777 &mut intercepts[idx],
778 loss_enum,
779 hyper,
780 global_t,
781 );
782 global_t = t;
783 }
784 }
785
786 global_t
787}
788
789fn dispatch_train_binary<F: Float + Send + Sync + ScalarOperand + 'static>(
791 x: &Array2<F>,
792 y_binary: &Array1<F>,
793 w: &mut Array1<F>,
794 b: &mut F,
795 loss_enum: &ClassifierLoss,
796 hyper: &SGDHyper<F>,
797 initial_t: usize,
798) -> (F, usize) {
799 match loss_enum {
800 ClassifierLoss::Hinge => train_binary_sgd(x, y_binary, w, b, &Hinge, hyper, initial_t),
801 ClassifierLoss::Log => train_binary_sgd(x, y_binary, w, b, &LogLoss, hyper, initial_t),
802 ClassifierLoss::SquaredError => {
803 train_binary_sgd(x, y_binary, w, b, &SquaredError, hyper, initial_t)
804 }
805 ClassifierLoss::ModifiedHuber => {
806 train_binary_sgd(x, y_binary, w, b, &ModifiedHuber, hyper, initial_t)
807 }
808 }
809}
810
811impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
812 for FittedSGDClassifier<F>
813{
814 type Output = Array1<usize>;
815 type Error = FerroError;
816
817 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
827 let n_features = x.ncols();
828 if n_features != self.n_features {
829 return Err(FerroError::ShapeMismatch {
830 expected: vec![self.n_features],
831 actual: vec![n_features],
832 context: "number of features must match fitted model".into(),
833 });
834 }
835
836 let n_samples = x.nrows();
837 let mut predictions = Array1::<usize>::zeros(n_samples);
838
839 if self.classes.len() == 2 {
840 let scores = x.dot(&self.weight_matrix[0]) + self.intercepts[0];
842 for i in 0..n_samples {
843 predictions[i] = if scores[i] >= F::zero() {
844 self.classes[1]
845 } else {
846 self.classes[0]
847 };
848 }
849 } else {
850 for i in 0..n_samples {
852 let xi = x.row(i);
853 let mut best_class = 0;
854 let mut best_score = F::neg_infinity();
855 for (c, w) in self.weight_matrix.iter().enumerate() {
856 let score = xi.dot(w) + self.intercepts[c];
857 if score > best_score {
858 best_score = score;
859 best_class = c;
860 }
861 }
862 predictions[i] = self.classes[best_class];
863 }
864 }
865
866 Ok(predictions)
867 }
868}
869
870impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<usize>>
871 for FittedSGDClassifier<F>
872{
873 type FitResult = FittedSGDClassifier<F>;
874 type Error = FerroError;
875
876 fn partial_fit(
883 mut self,
884 x: &Array2<F>,
885 y: &Array1<usize>,
886 ) -> Result<FittedSGDClassifier<F>, FerroError> {
887 let n_samples = x.nrows();
888 if n_samples != y.len() {
889 return Err(FerroError::ShapeMismatch {
890 expected: vec![n_samples],
891 actual: vec![y.len()],
892 context: "y length must match number of samples in X".into(),
893 });
894 }
895 if x.ncols() != self.n_features {
896 return Err(FerroError::ShapeMismatch {
897 expected: vec![self.n_features],
898 actual: vec![x.ncols()],
899 context: "number of features must match fitted model".into(),
900 });
901 }
902
903 let mut hyper = self.hyper.clone();
905 hyper.max_iter = 1;
906
907 let t = partial_fit_ova(
908 x,
909 y,
910 &self.classes,
911 &mut self.weight_matrix,
912 &mut self.intercepts,
913 &self.loss,
914 &hyper,
915 self.t,
916 );
917 self.t = t;
918
919 Ok(self)
920 }
921}
922
923impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<usize>>
924 for SGDClassifier<F>
925{
926 type FitResult = FittedSGDClassifier<F>;
927 type Error = FerroError;
928
929 fn partial_fit(
938 self,
939 x: &Array2<F>,
940 y: &Array1<usize>,
941 ) -> Result<FittedSGDClassifier<F>, FerroError> {
942 validate_clf_params(x, y, self.eta0, self.alpha)?;
943
944 let n_features = x.ncols();
945 let mut classes: Vec<usize> = y.to_vec();
946 classes.sort_unstable();
947 classes.dedup();
948
949 if classes.len() < 2 {
950 return Err(FerroError::InsufficientSamples {
951 required: 2,
952 actual: classes.len(),
953 context: "SGDClassifier requires at least 2 distinct classes".into(),
954 });
955 }
956
957 let mut hyper = clf_hyper(&self);
958 hyper.max_iter = 1;
959 let loss_enum = self.loss;
960
961 let (weight_matrix, intercepts, t) =
962 fit_ova(x, y, &classes, n_features, &loss_enum, &hyper, 0)?;
963
964 Ok(FittedSGDClassifier {
965 weight_matrix,
966 intercepts,
967 classes,
968 n_features,
969 loss: loss_enum,
970 hyper: clf_hyper(&self),
971 t,
972 })
973 }
974}
975
976impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
977 for FittedSGDClassifier<F>
978{
979 fn coefficients(&self) -> &Array1<F> {
981 &self.weight_matrix[0]
982 }
983
984 fn intercept(&self) -> F {
986 self.intercepts[0]
987 }
988}
989
990impl PipelineEstimator for SGDClassifier<f64> {
992 fn fit_pipeline(
993 &self,
994 x: &Array2<f64>,
995 y: &Array1<f64>,
996 ) -> Result<Box<dyn FittedPipelineEstimator>, FerroError> {
997 let y_usize: Array1<usize> = y.mapv(|v| v as usize);
998 let fitted = self.fit(x, &y_usize)?;
999 Ok(Box::new(FittedSGDClassifierPipeline(fitted)))
1000 }
1001}
1002
1003struct FittedSGDClassifierPipeline(FittedSGDClassifier<f64>);
1005
1006unsafe impl Send for FittedSGDClassifierPipeline {}
1008unsafe impl Sync for FittedSGDClassifierPipeline {}
1009
1010impl FittedPipelineEstimator for FittedSGDClassifierPipeline {
1011 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1012 let preds = self.0.predict(x)?;
1013 Ok(preds.mapv(|v| v as f64))
1014 }
1015}
1016
1017#[derive(Debug, Clone)]
1045pub struct SGDRegressor<F> {
1046 pub loss: RegressorLoss<F>,
1048 pub learning_rate: LearningRateSchedule<F>,
1050 pub eta0: F,
1052 pub alpha: F,
1054 pub max_iter: usize,
1056 pub tol: F,
1058 pub random_state: Option<u64>,
1060 pub power_t: F,
1062}
1063
1064impl<F: Float> SGDRegressor<F> {
1065 #[must_use]
1071 pub fn new() -> Self {
1072 Self {
1073 loss: RegressorLoss::SquaredError,
1074 learning_rate: LearningRateSchedule::InvScaling,
1075 eta0: F::from(0.01).unwrap(),
1076 alpha: F::from(0.0001).unwrap(),
1077 max_iter: 1000,
1078 tol: F::from(1e-3).unwrap(),
1079 random_state: None,
1080 power_t: F::from(0.25).unwrap(),
1081 }
1082 }
1083
1084 #[must_use]
1086 pub fn with_loss(mut self, loss: RegressorLoss<F>) -> Self {
1087 self.loss = loss;
1088 self
1089 }
1090
1091 #[must_use]
1093 pub fn with_learning_rate(mut self, lr: LearningRateSchedule<F>) -> Self {
1094 self.learning_rate = lr;
1095 self
1096 }
1097
1098 #[must_use]
1100 pub fn with_eta0(mut self, eta0: F) -> Self {
1101 self.eta0 = eta0;
1102 self
1103 }
1104
1105 #[must_use]
1107 pub fn with_alpha(mut self, alpha: F) -> Self {
1108 self.alpha = alpha;
1109 self
1110 }
1111
1112 #[must_use]
1114 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1115 self.max_iter = max_iter;
1116 self
1117 }
1118
1119 #[must_use]
1121 pub fn with_tol(mut self, tol: F) -> Self {
1122 self.tol = tol;
1123 self
1124 }
1125
1126 #[must_use]
1128 pub fn with_random_state(mut self, seed: u64) -> Self {
1129 self.random_state = Some(seed);
1130 self
1131 }
1132
1133 #[must_use]
1135 pub fn with_power_t(mut self, power_t: F) -> Self {
1136 self.power_t = power_t;
1137 self
1138 }
1139}
1140
1141impl<F: Float> Default for SGDRegressor<F> {
1142 fn default() -> Self {
1143 Self::new()
1144 }
1145}
1146
1147fn reg_hyper<F: Float>(reg: &SGDRegressor<F>) -> SGDHyper<F> {
1149 SGDHyper {
1150 learning_rate: reg.learning_rate,
1151 eta0: reg.eta0,
1152 alpha: reg.alpha,
1153 max_iter: reg.max_iter,
1154 tol: reg.tol,
1155 random_state: reg.random_state,
1156 power_t: reg.power_t,
1157 }
1158}
1159
1160fn train_regressor_sgd<F, L>(
1163 x: &Array2<F>,
1164 y: &Array1<F>,
1165 weights: &mut Array1<F>,
1166 intercept: &mut F,
1167 loss_fn: &L,
1168 hyper: &SGDHyper<F>,
1169 initial_t: usize,
1170) -> (F, usize)
1171where
1172 F: Float + ScalarOperand + Send + Sync + 'static,
1173 L: Loss<F>,
1174{
1175 let n_samples = x.nrows();
1176 let n_features = x.ncols();
1177 let mut t = initial_t;
1178 let mut prev_loss = F::infinity();
1179 let mut current_eta = hyper.eta0;
1180 let mut no_improve_count: usize = 0;
1181 let mut indices: Vec<usize> = (0..n_samples).collect();
1182
1183 let mut rng = match hyper.random_state {
1184 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
1185 None => rand::rngs::StdRng::from_os_rng(),
1186 };
1187
1188 let mut total_loss = F::zero();
1189
1190 for _epoch in 0..hyper.max_iter {
1191 indices.shuffle(&mut rng);
1192 let mut epoch_loss = F::zero();
1193
1194 for &i in &indices {
1195 t += 1;
1196
1197 let eta = match hyper.learning_rate {
1198 LearningRateSchedule::Adaptive => current_eta,
1199 _ => compute_lr(
1200 &hyper.learning_rate,
1201 hyper.eta0,
1202 hyper.alpha,
1203 hyper.power_t,
1204 t,
1205 ),
1206 };
1207
1208 let xi = x.row(i);
1209 let mut y_pred = *intercept;
1210 for j in 0..n_features {
1211 y_pred = y_pred + weights[j] * xi[j];
1212 }
1213
1214 let grad = loss_fn.gradient(y[i], y_pred);
1215 epoch_loss = epoch_loss + loss_fn.loss(y[i], y_pred);
1216
1217 for j in 0..n_features {
1218 weights[j] = weights[j] - eta * (grad * xi[j] + hyper.alpha * weights[j]);
1219 }
1220 *intercept = *intercept - eta * grad;
1221 }
1222
1223 epoch_loss = epoch_loss / F::from(n_samples).unwrap();
1224 total_loss = epoch_loss;
1225
1226 if (prev_loss - epoch_loss).abs() < hyper.tol {
1227 break;
1228 }
1229
1230 if let LearningRateSchedule::Adaptive = hyper.learning_rate {
1231 if epoch_loss >= prev_loss {
1232 no_improve_count += 1;
1233 if no_improve_count >= 5 {
1234 current_eta = current_eta / F::from(2.0).unwrap();
1235 no_improve_count = 0;
1236 if current_eta < F::from(1e-6).unwrap() {
1237 break;
1238 }
1239 }
1240 } else {
1241 no_improve_count = 0;
1242 }
1243 }
1244
1245 prev_loss = epoch_loss;
1246 }
1247
1248 (total_loss, t)
1249}
1250
1251fn dispatch_train_regressor<F: Float + Send + Sync + ScalarOperand + 'static>(
1253 x: &Array2<F>,
1254 y: &Array1<F>,
1255 w: &mut Array1<F>,
1256 b: &mut F,
1257 loss_enum: &RegressorLoss<F>,
1258 hyper: &SGDHyper<F>,
1259 initial_t: usize,
1260) -> (F, usize) {
1261 match loss_enum {
1262 RegressorLoss::SquaredError => {
1263 train_regressor_sgd(x, y, w, b, &SquaredError, hyper, initial_t)
1264 }
1265 RegressorLoss::Huber(eps) => {
1266 train_regressor_sgd(x, y, w, b, &Huber { epsilon: *eps }, hyper, initial_t)
1267 }
1268 RegressorLoss::EpsilonInsensitive(eps) => train_regressor_sgd(
1269 x,
1270 y,
1271 w,
1272 b,
1273 &EpsilonInsensitive { epsilon: *eps },
1274 hyper,
1275 initial_t,
1276 ),
1277 }
1278}
1279
1280#[derive(Debug, Clone)]
1285pub struct FittedSGDRegressor<F> {
1286 weights: Array1<F>,
1288 intercept: F,
1290 n_features: usize,
1292 loss: RegressorLoss<F>,
1294 hyper: SGDHyper<F>,
1296 t: usize,
1298}
1299
1300fn validate_reg_params<F: Float>(
1302 x: &Array2<F>,
1303 y: &Array1<F>,
1304 eta0: F,
1305 alpha: F,
1306) -> Result<(), FerroError> {
1307 let n_samples = x.nrows();
1308 if n_samples != y.len() {
1309 return Err(FerroError::ShapeMismatch {
1310 expected: vec![n_samples],
1311 actual: vec![y.len()],
1312 context: "y length must match number of samples in X".into(),
1313 });
1314 }
1315 if n_samples == 0 {
1316 return Err(FerroError::InsufficientSamples {
1317 required: 1,
1318 actual: 0,
1319 context: "SGDRegressor requires at least one sample".into(),
1320 });
1321 }
1322 if eta0 <= F::zero() {
1323 return Err(FerroError::InvalidParameter {
1324 name: "eta0".into(),
1325 reason: "must be positive".into(),
1326 });
1327 }
1328 if alpha < F::zero() {
1329 return Err(FerroError::InvalidParameter {
1330 name: "alpha".into(),
1331 reason: "must be non-negative".into(),
1332 });
1333 }
1334 Ok(())
1335}
1336
1337impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<F>>
1338 for SGDRegressor<F>
1339{
1340 type Fitted = FittedSGDRegressor<F>;
1341 type Error = FerroError;
1342
1343 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedSGDRegressor<F>, FerroError> {
1352 validate_reg_params(x, y, self.eta0, self.alpha)?;
1353
1354 let n_features = x.ncols();
1355 let hyper = reg_hyper(self);
1356 let mut w = Array1::<F>::zeros(n_features);
1357 let mut b = F::zero();
1358
1359 let (_, t) = dispatch_train_regressor(x, y, &mut w, &mut b, &self.loss, &hyper, 0);
1360
1361 Ok(FittedSGDRegressor {
1362 weights: w,
1363 intercept: b,
1364 n_features,
1365 loss: self.loss,
1366 hyper,
1367 t,
1368 })
1369 }
1370}
1371
1372impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
1373 for FittedSGDRegressor<F>
1374{
1375 type Output = Array1<F>;
1376 type Error = FerroError;
1377
1378 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1387 let n_features = x.ncols();
1388 if n_features != self.n_features {
1389 return Err(FerroError::ShapeMismatch {
1390 expected: vec![self.n_features],
1391 actual: vec![n_features],
1392 context: "number of features must match fitted model".into(),
1393 });
1394 }
1395
1396 let preds = x.dot(&self.weights) + self.intercept;
1397 Ok(preds)
1398 }
1399}
1400
1401impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<F>>
1402 for FittedSGDRegressor<F>
1403{
1404 type FitResult = FittedSGDRegressor<F>;
1405 type Error = FerroError;
1406
1407 fn partial_fit(
1414 mut self,
1415 x: &Array2<F>,
1416 y: &Array1<F>,
1417 ) -> Result<FittedSGDRegressor<F>, FerroError> {
1418 let n_samples = x.nrows();
1419 if n_samples != y.len() {
1420 return Err(FerroError::ShapeMismatch {
1421 expected: vec![n_samples],
1422 actual: vec![y.len()],
1423 context: "y length must match number of samples in X".into(),
1424 });
1425 }
1426 if x.ncols() != self.n_features {
1427 return Err(FerroError::ShapeMismatch {
1428 expected: vec![self.n_features],
1429 actual: vec![x.ncols()],
1430 context: "number of features must match fitted model".into(),
1431 });
1432 }
1433
1434 let mut hyper = self.hyper.clone();
1435 hyper.max_iter = 1;
1436
1437 let (_, t) = dispatch_train_regressor(
1438 x,
1439 y,
1440 &mut self.weights,
1441 &mut self.intercept,
1442 &self.loss,
1443 &hyper,
1444 self.t,
1445 );
1446 self.t = t;
1447
1448 Ok(self)
1449 }
1450}
1451
1452impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<F>>
1453 for SGDRegressor<F>
1454{
1455 type FitResult = FittedSGDRegressor<F>;
1456 type Error = FerroError;
1457
1458 fn partial_fit(
1466 self,
1467 x: &Array2<F>,
1468 y: &Array1<F>,
1469 ) -> Result<FittedSGDRegressor<F>, FerroError> {
1470 validate_reg_params(x, y, self.eta0, self.alpha)?;
1471
1472 let n_features = x.ncols();
1473 let mut hyper = reg_hyper(&self);
1474 hyper.max_iter = 1;
1475 let mut w = Array1::<F>::zeros(n_features);
1476 let mut b = F::zero();
1477
1478 let (_, t) = dispatch_train_regressor(x, y, &mut w, &mut b, &self.loss, &hyper, 0);
1479
1480 Ok(FittedSGDRegressor {
1481 weights: w,
1482 intercept: b,
1483 n_features,
1484 loss: self.loss,
1485 hyper: reg_hyper(&self),
1486 t,
1487 })
1488 }
1489}
1490
1491impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
1492 for FittedSGDRegressor<F>
1493{
1494 fn coefficients(&self) -> &Array1<F> {
1495 &self.weights
1496 }
1497
1498 fn intercept(&self) -> F {
1499 self.intercept
1500 }
1501}
1502
1503impl PipelineEstimator for SGDRegressor<f64> {
1505 fn fit_pipeline(
1506 &self,
1507 x: &Array2<f64>,
1508 y: &Array1<f64>,
1509 ) -> Result<Box<dyn FittedPipelineEstimator>, FerroError> {
1510 let fitted = self.fit(x, y)?;
1511 Ok(Box::new(fitted))
1512 }
1513}
1514
1515impl FittedPipelineEstimator for FittedSGDRegressor<f64> {
1516 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1517 self.predict(x)
1518 }
1519}
1520
1521#[cfg(test)]
1526mod tests {
1527 use super::*;
1528 use ndarray::array;
1529
1530 #[test]
1535 fn test_hinge_loss_correct_side() {
1536 let h = Hinge;
1537 assert!((Loss::<f64>::loss(&h, 1.0, 2.0) - 0.0).abs() < 1e-10);
1539 assert!((Loss::<f64>::gradient(&h, 1.0, 2.0) - 0.0).abs() < 1e-10);
1540 }
1541
1542 #[test]
1543 fn test_hinge_loss_wrong_side() {
1544 let h = Hinge;
1545 assert!((Loss::<f64>::loss(&h, 1.0, -0.5) - 1.5).abs() < 1e-10);
1547 assert!((Loss::<f64>::gradient(&h, 1.0, -0.5) - (-1.0)).abs() < 1e-10);
1548 }
1549
1550 #[test]
1551 fn test_log_loss_zero_pred() {
1552 let l = LogLoss;
1553 let loss = Loss::<f64>::loss(&l, 1.0, 0.0);
1555 assert!((loss - 2.0_f64.ln()).abs() < 1e-10);
1556 }
1557
1558 #[test]
1559 fn test_log_loss_large_correct() {
1560 let l = LogLoss;
1561 let loss = Loss::<f64>::loss(&l, 1.0, 20.0);
1563 assert!(loss < 1e-5);
1564 }
1565
1566 #[test]
1567 fn test_squared_error_loss() {
1568 let s = SquaredError;
1569 assert!((Loss::<f64>::loss(&s, 3.0, 1.0) - 2.0).abs() < 1e-10);
1570 assert!((Loss::<f64>::gradient(&s, 3.0, 1.0) - (-2.0)).abs() < 1e-10);
1571 }
1572
1573 #[test]
1574 fn test_modified_huber_loss() {
1575 let mh = ModifiedHuber;
1576 assert!((Loss::<f64>::loss(&mh, 1.0, 2.0)).abs() < 1e-10);
1578 assert!((Loss::<f64>::loss(&mh, 1.0, 0.5) - 0.25).abs() < 1e-10);
1580 assert!((Loss::<f64>::loss(&mh, 1.0, -2.0) - 8.0).abs() < 1e-10);
1582 }
1583
1584 #[test]
1585 fn test_huber_loss_quadratic_region() {
1586 let h = Huber { epsilon: 1.0_f64 };
1587 assert!((Loss::<f64>::loss(&h, 1.0, 0.5) - 0.125).abs() < 1e-10);
1589 }
1590
1591 #[test]
1592 fn test_huber_loss_linear_region() {
1593 let h = Huber { epsilon: 1.0_f64 };
1594 assert!((Loss::<f64>::loss(&h, 3.0, 0.0) - 2.5).abs() < 1e-10);
1596 }
1597
1598 #[test]
1599 fn test_epsilon_insensitive_inside() {
1600 let ei = EpsilonInsensitive { epsilon: 0.1_f64 };
1601 assert!((Loss::<f64>::loss(&ei, 1.0, 0.95)).abs() < 1e-10);
1603 }
1604
1605 #[test]
1606 fn test_epsilon_insensitive_outside() {
1607 let ei = EpsilonInsensitive { epsilon: 0.1_f64 };
1608 assert!((Loss::<f64>::loss(&ei, 1.0, 0.5) - 0.4).abs() < 1e-10);
1610 }
1611
1612 #[test]
1617 fn test_constant_lr() {
1618 let lr: LearningRateSchedule<f64> = LearningRateSchedule::Constant;
1619 assert!((compute_lr(&lr, 0.1, 0.01, 0.25, 100) - 0.1).abs() < 1e-10);
1620 }
1621
1622 #[test]
1623 fn test_optimal_lr() {
1624 let lr: LearningRateSchedule<f64> = LearningRateSchedule::Optimal;
1625 assert!((compute_lr(&lr, 0.1, 0.01, 0.25, 10) - 10.0).abs() < 1e-10);
1627 }
1628
1629 #[test]
1630 fn test_invscaling_lr() {
1631 let lr: LearningRateSchedule<f64> = LearningRateSchedule::InvScaling;
1632 let result = compute_lr(&lr, 0.1, 0.01, 0.5, 10);
1634 let expected = 0.1 / 10.0_f64.sqrt();
1635 assert!((result - expected).abs() < 1e-10);
1636 }
1637
1638 #[test]
1639 fn test_adaptive_lr_returns_eta0() {
1640 let lr: LearningRateSchedule<f64> = LearningRateSchedule::Adaptive;
1641 assert!((compute_lr(&lr, 0.05, 0.01, 0.25, 100) - 0.05).abs() < 1e-10);
1642 }
1643
1644 #[test]
1649 fn test_sgd_classifier_binary() {
1650 let x = Array2::from_shape_vec(
1652 (8, 2),
1653 vec![
1654 -2.0, -2.0, -1.5, -2.0, -2.0, -1.5, -1.5, -1.5, 2.0, 2.0, 1.5, 2.0, 2.0, 1.5, 1.5,
1655 1.5,
1656 ],
1657 )
1658 .unwrap();
1659 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1660
1661 let clf = SGDClassifier::<f64>::new()
1662 .with_loss(ClassifierLoss::Log)
1663 .with_random_state(42)
1664 .with_max_iter(1000)
1665 .with_eta0(0.01);
1666 let fitted = clf.fit(&x, &y).unwrap();
1667 let preds = fitted.predict(&x).unwrap();
1668
1669 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1670 assert!(correct >= 6, "expected >= 6 correct, got {correct}");
1671 }
1672
1673 #[test]
1674 fn test_sgd_classifier_log_loss() {
1675 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1676 let y = array![0, 0, 0, 1, 1, 1];
1677
1678 let clf = SGDClassifier::<f64>::new()
1679 .with_loss(ClassifierLoss::Log)
1680 .with_random_state(42)
1681 .with_max_iter(500);
1682 let fitted = clf.fit(&x, &y).unwrap();
1683 let preds = fitted.predict(&x).unwrap();
1684
1685 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1686 assert!(correct >= 4, "expected >= 4 correct, got {correct}");
1687 }
1688
1689 #[test]
1690 fn test_sgd_classifier_multiclass() {
1691 let x = Array2::from_shape_vec(
1692 (9, 2),
1693 vec![
1694 0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 0.0, 5.5, 0.0, 5.0, 0.5, 0.0, 5.0, 0.5, 5.0,
1695 0.0, 5.5,
1696 ],
1697 )
1698 .unwrap();
1699 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1700
1701 let clf = SGDClassifier::<f64>::new()
1702 .with_random_state(42)
1703 .with_max_iter(1000)
1704 .with_eta0(0.01);
1705 let fitted = clf.fit(&x, &y).unwrap();
1706 let preds = fitted.predict(&x).unwrap();
1707
1708 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1709 assert!(
1710 correct >= 6,
1711 "expected >= 6 correct for multiclass, got {correct}"
1712 );
1713 }
1714
1715 #[test]
1716 fn test_sgd_classifier_shape_mismatch_fit() {
1717 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1718 let y = array![0, 1]; let clf = SGDClassifier::<f64>::new();
1720 assert!(clf.fit(&x, &y).is_err());
1721 }
1722
1723 #[test]
1724 fn test_sgd_classifier_shape_mismatch_predict() {
1725 let x =
1726 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1727 let y = array![0, 0, 1, 1];
1728 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1729 let fitted = clf.fit(&x, &y).unwrap();
1730
1731 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1732 assert!(fitted.predict(&x_bad).is_err());
1733 }
1734
1735 #[test]
1736 fn test_sgd_classifier_single_class_error() {
1737 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1738 let y = array![0, 0, 0];
1739 let clf = SGDClassifier::<f64>::new();
1740 assert!(clf.fit(&x, &y).is_err());
1741 }
1742
1743 #[test]
1744 fn test_sgd_classifier_invalid_eta0() {
1745 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1746 let y = array![0, 0, 1, 1];
1747 let clf = SGDClassifier::<f64>::new().with_eta0(0.0);
1748 assert!(clf.fit(&x, &y).is_err());
1749 }
1750
1751 #[test]
1752 fn test_sgd_classifier_invalid_alpha() {
1753 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1754 let y = array![0, 0, 1, 1];
1755 let clf = SGDClassifier::<f64>::new().with_alpha(-1.0);
1756 assert!(clf.fit(&x, &y).is_err());
1757 }
1758
1759 #[test]
1760 fn test_sgd_classifier_has_coefficients() {
1761 let x =
1762 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1763 let y = array![0, 0, 1, 1];
1764 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1765 let fitted = clf.fit(&x, &y).unwrap();
1766 assert_eq!(fitted.coefficients().len(), 2);
1767 }
1768
1769 #[test]
1770 fn test_sgd_classifier_partial_fit() {
1771 let x =
1772 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1773 let y = array![0, 0, 1, 1];
1774
1775 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1776 let fitted = clf.partial_fit(&x, &y).unwrap();
1777 let fitted = fitted.partial_fit(&x, &y).unwrap();
1778 let preds = fitted.predict(&x).unwrap();
1779 assert_eq!(preds.len(), 4);
1780 }
1781
1782 #[test]
1783 fn test_sgd_classifier_partial_fit_chain() {
1784 let x1 =
1787 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1788 let y1 = array![0, 0, 1, 1];
1789 let x2 =
1790 Array2::from_shape_vec((4, 2), vec![0.5, 0.5, 1.5, 1.5, 7.5, 7.5, 8.5, 8.5]).unwrap();
1791 let y2 = array![0, 0, 1, 1];
1792
1793 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1794 let preds = clf
1795 .partial_fit(&x1, &y1)
1796 .unwrap()
1797 .partial_fit(&x2, &y2)
1798 .unwrap()
1799 .predict(&x1)
1800 .unwrap();
1801 assert_eq!(preds.len(), 4);
1802 }
1803
1804 #[test]
1805 fn test_sgd_classifier_partial_fit_shape_mismatch() {
1806 let x =
1807 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1808 let y = array![0, 0, 1, 1];
1809 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1810 let fitted = clf.partial_fit(&x, &y).unwrap();
1811
1812 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1813 let y_bad = array![0, 1];
1814 assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
1815 }
1816
1817 #[test]
1818 fn test_sgd_classifier_modified_huber() {
1819 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1820 let y = array![0, 0, 0, 1, 1, 1];
1821
1822 let clf = SGDClassifier::<f64>::new()
1823 .with_loss(ClassifierLoss::ModifiedHuber)
1824 .with_random_state(42)
1825 .with_max_iter(500);
1826 let fitted = clf.fit(&x, &y).unwrap();
1827 let preds = fitted.predict(&x).unwrap();
1828 assert_eq!(preds.len(), 6);
1829 }
1830
1831 #[test]
1832 fn test_sgd_classifier_squared_error_loss() {
1833 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1834 let y = array![0, 0, 0, 1, 1, 1];
1835
1836 let clf = SGDClassifier::<f64>::new()
1837 .with_loss(ClassifierLoss::SquaredError)
1838 .with_random_state(42)
1839 .with_max_iter(500);
1840 let fitted = clf.fit(&x, &y).unwrap();
1841 let preds = fitted.predict(&x).unwrap();
1842 assert_eq!(preds.len(), 6);
1843 }
1844
1845 #[test]
1846 fn test_sgd_classifier_pipeline() {
1847 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1848 let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1849
1850 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1851 let fitted = clf.fit_pipeline(&x, &y).unwrap();
1852 let preds = fitted.predict_pipeline(&x).unwrap();
1853 assert_eq!(preds.len(), 6);
1854 }
1855
1856 #[test]
1857 fn test_sgd_classifier_constant_lr() {
1858 let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
1859 let y = array![0, 0, 1, 1];
1860
1861 let clf = SGDClassifier::<f64>::new()
1862 .with_learning_rate(LearningRateSchedule::Constant)
1863 .with_random_state(42)
1864 .with_max_iter(200);
1865 let fitted = clf.fit(&x, &y).unwrap();
1866 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1867 }
1868
1869 #[test]
1870 fn test_sgd_classifier_f32() {
1871 let x = Array2::from_shape_vec((4, 1), vec![-2.0f32, -1.0, 1.0, 2.0]).unwrap();
1872 let y = array![0_usize, 0, 1, 1];
1873
1874 let clf = SGDClassifier::<f32>::new()
1875 .with_random_state(42)
1876 .with_max_iter(200);
1877 let fitted = clf.fit(&x, &y).unwrap();
1878 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1879 }
1880
1881 #[test]
1886 fn test_sgd_regressor_basic() {
1887 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1889 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
1890
1891 let model = SGDRegressor::<f64>::new()
1892 .with_random_state(42)
1893 .with_max_iter(2000)
1894 .with_eta0(0.01)
1895 .with_alpha(0.0);
1896 let fitted = model.fit(&x, &y).unwrap();
1897 let preds = fitted.predict(&x).unwrap();
1898
1899 for (p, &actual) in preds.iter().zip(y.iter()) {
1901 assert!(
1902 (*p - actual).abs() < 2.0,
1903 "prediction {p} too far from {actual}"
1904 );
1905 }
1906 }
1907
1908 #[test]
1909 fn test_sgd_regressor_shape_mismatch() {
1910 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1911 let y = array![1.0, 2.0]; let model = SGDRegressor::<f64>::new();
1913 assert!(model.fit(&x, &y).is_err());
1914 }
1915
1916 #[test]
1917 fn test_sgd_regressor_predict_shape_mismatch() {
1918 let x =
1919 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1920 let y = array![1.0, 2.0, 3.0, 4.0];
1921 let model = SGDRegressor::<f64>::new().with_random_state(42);
1922 let fitted = model.fit(&x, &y).unwrap();
1923
1924 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1925 assert!(fitted.predict(&x_bad).is_err());
1926 }
1927
1928 #[test]
1929 fn test_sgd_regressor_invalid_eta0() {
1930 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1931 let y = array![1.0, 2.0, 3.0];
1932 let model = SGDRegressor::<f64>::new().with_eta0(-0.1);
1933 assert!(model.fit(&x, &y).is_err());
1934 }
1935
1936 #[test]
1937 fn test_sgd_regressor_has_coefficients() {
1938 let x =
1939 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1940 let y = array![1.0, 2.0, 3.0, 4.0];
1941 let model = SGDRegressor::<f64>::new().with_random_state(42);
1942 let fitted = model.fit(&x, &y).unwrap();
1943 assert_eq!(fitted.coefficients().len(), 2);
1944 }
1945
1946 #[test]
1947 fn test_sgd_regressor_partial_fit() {
1948 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1949 let y = array![2.0, 4.0, 6.0, 8.0];
1950
1951 let model = SGDRegressor::<f64>::new().with_random_state(42);
1952 let fitted = model.partial_fit(&x, &y).unwrap();
1953 let fitted = fitted.partial_fit(&x, &y).unwrap();
1954 let preds = fitted.predict(&x).unwrap();
1955 assert_eq!(preds.len(), 4);
1956 }
1957
1958 #[test]
1959 fn test_sgd_regressor_partial_fit_chain() {
1960 let x1 = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1961 let y1 = array![2.0, 4.0, 6.0];
1962 let x2 = Array2::from_shape_vec((3, 1), vec![4.0, 5.0, 6.0]).unwrap();
1963 let y2 = array![8.0, 10.0, 12.0];
1964
1965 let model = SGDRegressor::<f64>::new().with_random_state(42);
1966 let preds = model
1967 .partial_fit(&x1, &y1)
1968 .unwrap()
1969 .partial_fit(&x2, &y2)
1970 .unwrap()
1971 .predict(&x1)
1972 .unwrap();
1973 assert_eq!(preds.len(), 3);
1974 }
1975
1976 #[test]
1977 fn test_sgd_regressor_partial_fit_shape_mismatch() {
1978 let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
1979 let y = array![1.0, 2.0, 3.0];
1980 let model = SGDRegressor::<f64>::new().with_random_state(42);
1981 let fitted = model.partial_fit(&x, &y).unwrap();
1982
1983 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1984 let y_bad = array![1.0, 2.0];
1985 assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
1986 }
1987
1988 #[test]
1989 fn test_sgd_regressor_huber_loss() {
1990 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1991 let y = array![2.0, 4.0, 6.0, 8.0];
1992
1993 let model = SGDRegressor::<f64>::new()
1994 .with_loss(RegressorLoss::Huber(1.35))
1995 .with_random_state(42)
1996 .with_max_iter(500);
1997 let fitted = model.fit(&x, &y).unwrap();
1998 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1999 }
2000
2001 #[test]
2002 fn test_sgd_regressor_epsilon_insensitive() {
2003 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2004 let y = array![2.0, 4.0, 6.0, 8.0];
2005
2006 let model = SGDRegressor::<f64>::new()
2007 .with_loss(RegressorLoss::EpsilonInsensitive(0.1))
2008 .with_random_state(42)
2009 .with_max_iter(500);
2010 let fitted = model.fit(&x, &y).unwrap();
2011 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2012 }
2013
2014 #[test]
2015 fn test_sgd_regressor_pipeline() {
2016 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2017 let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0]);
2018
2019 let model = SGDRegressor::<f64>::new().with_random_state(42);
2020 let fitted = model.fit_pipeline(&x, &y).unwrap();
2021 let preds = fitted.predict_pipeline(&x).unwrap();
2022 assert_eq!(preds.len(), 4);
2023 }
2024
2025 #[test]
2026 fn test_sgd_regressor_f32() {
2027 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
2028 let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
2029
2030 let model = SGDRegressor::<f32>::new().with_random_state(42);
2031 let fitted = model.fit(&x, &y).unwrap();
2032 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2033 }
2034
2035 #[test]
2036 fn test_sgd_regressor_empty_data() {
2037 let x = Array2::<f64>::zeros((0, 2));
2038 let y = Array1::<f64>::zeros(0);
2039 let model = SGDRegressor::<f64>::new();
2040 assert!(model.fit(&x, &y).is_err());
2041 }
2042
2043 #[test]
2044 fn test_sgd_classifier_empty_data() {
2045 let x = Array2::<f64>::zeros((0, 2));
2046 let y = Array1::<usize>::zeros(0);
2047 let clf = SGDClassifier::<f64>::new();
2048 assert!(clf.fit(&x, &y).is_err());
2049 }
2050
2051 #[test]
2052 fn test_sgd_classifier_default() {
2053 let clf = SGDClassifier::<f64>::default();
2054 assert!(clf.eta0 > 0.0);
2055 assert!(clf.alpha >= 0.0);
2056 }
2057
2058 #[test]
2059 fn test_sgd_regressor_default() {
2060 let model = SGDRegressor::<f64>::default();
2061 assert!(model.eta0 > 0.0);
2062 assert!(model.alpha >= 0.0);
2063 }
2064}