1use ferrolearn_core::error::FerroError;
44use ferrolearn_core::introspection::HasCoefficients;
45use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
46use ferrolearn_core::traits::{Fit, PartialFit, Predict};
47use ndarray::{Array1, Array2, ScalarOperand};
48use num_traits::{Float, FromPrimitive, ToPrimitive};
49use rand::SeedableRng;
50use rand::seq::SliceRandom;
51
52pub trait Loss<F: Float>: Clone + Send + Sync {
60 fn loss(&self, y_true: F, y_pred: F) -> F;
62
63 fn gradient(&self, y_true: F, y_pred: F) -> F;
65}
66
67#[derive(Debug, Clone, Copy)]
71pub struct Hinge;
72
73impl<F: Float> Loss<F> for Hinge {
74 fn loss(&self, y_true: F, y_pred: F) -> F {
75 let margin = y_true * y_pred;
76 if margin < F::one() {
77 F::one() - margin
78 } else {
79 F::zero()
80 }
81 }
82
83 fn gradient(&self, y_true: F, y_pred: F) -> F {
84 let margin = y_true * y_pred;
85 if margin < F::one() {
86 -y_true
87 } else {
88 F::zero()
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy)]
97pub struct LogLoss;
98
99impl<F: Float> Loss<F> for LogLoss {
100 fn loss(&self, y_true: F, y_pred: F) -> F {
101 let z = y_true * y_pred;
102 if z > F::from(18.0).unwrap() {
103 (-z).exp()
104 } else if z < F::from(-18.0).unwrap() {
105 -z
106 } else {
107 (F::one() + (-z).exp()).ln()
108 }
109 }
110
111 fn gradient(&self, y_true: F, y_pred: F) -> F {
112 let z = y_true * y_pred;
113 let exp_nz = if z > F::from(18.0).unwrap() {
114 (-z).exp()
115 } else if z < F::from(-18.0).unwrap() {
116 F::from(1e18).unwrap()
117 } else {
118 (-z).exp()
119 };
120 -y_true * exp_nz / (F::one() + exp_nz)
121 }
122}
123
124#[derive(Debug, Clone, Copy)]
128pub struct SquaredError;
129
130impl<F: Float> Loss<F> for SquaredError {
131 fn loss(&self, y_true: F, y_pred: F) -> F {
132 let diff = y_true - y_pred;
133 F::from(0.5).unwrap() * diff * diff
134 }
135
136 fn gradient(&self, y_true: F, y_pred: F) -> F {
137 y_pred - y_true
138 }
139}
140
141#[derive(Debug, Clone, Copy)]
150pub struct ModifiedHuber;
151
152impl<F: Float> Loss<F> for ModifiedHuber {
153 fn loss(&self, y_true: F, y_pred: F) -> F {
154 let z = y_true * y_pred;
155 if z >= -F::one() {
156 let margin = F::one() - z;
157 if margin > F::zero() {
158 margin * margin
159 } else {
160 F::zero()
161 }
162 } else {
163 -F::from(4.0).unwrap() * z
164 }
165 }
166
167 fn gradient(&self, y_true: F, y_pred: F) -> F {
168 let z = y_true * y_pred;
169 if z >= -F::one() {
170 if z < F::one() {
171 F::from(-2.0).unwrap() * y_true * (F::one() - z)
172 } else {
173 F::zero()
174 }
175 } else {
176 -F::from(4.0).unwrap() * y_true
177 }
178 }
179}
180
181#[derive(Debug, Clone, Copy)]
186pub struct Huber<F> {
187 pub epsilon: F,
189}
190
191impl<F: Float + Send + Sync> Loss<F> for Huber<F> {
192 fn loss(&self, y_true: F, y_pred: F) -> F {
193 let diff = y_true - y_pred;
194 let abs_diff = diff.abs();
195 if abs_diff <= self.epsilon {
196 F::from(0.5).unwrap() * diff * diff
197 } else {
198 self.epsilon * (abs_diff - F::from(0.5).unwrap() * self.epsilon)
199 }
200 }
201
202 fn gradient(&self, y_true: F, y_pred: F) -> F {
203 let diff = y_pred - y_true;
204 let abs_diff = diff.abs();
205 if abs_diff <= self.epsilon {
206 diff
207 } else if diff > F::zero() {
208 self.epsilon
209 } else {
210 -self.epsilon
211 }
212 }
213}
214
215#[derive(Debug, Clone, Copy)]
219pub struct EpsilonInsensitive<F> {
220 pub epsilon: F,
222}
223
224impl<F: Float + Send + Sync> Loss<F> for EpsilonInsensitive<F> {
225 fn loss(&self, y_true: F, y_pred: F) -> F {
226 let diff = (y_true - y_pred).abs();
227 if diff > self.epsilon {
228 diff - self.epsilon
229 } else {
230 F::zero()
231 }
232 }
233
234 fn gradient(&self, y_true: F, y_pred: F) -> F {
235 let diff = y_pred - y_true;
236 if diff > self.epsilon {
237 F::one()
238 } else if diff < -self.epsilon {
239 -F::one()
240 } else {
241 F::zero()
242 }
243 }
244}
245
246#[derive(Debug, Clone, Copy)]
252pub enum LearningRateSchedule<F> {
253 Constant,
255 Optimal,
257 InvScaling,
259 Adaptive,
262 #[doc(hidden)]
263 _Phantom(std::marker::PhantomData<F>),
264}
265
266fn compute_lr<F: Float>(
268 schedule: &LearningRateSchedule<F>,
269 eta0: F,
270 alpha: F,
271 power_t: F,
272 t: usize,
273) -> F {
274 let t_f = F::from(t.max(1)).unwrap();
275 match schedule {
276 LearningRateSchedule::Constant => eta0,
277 LearningRateSchedule::Optimal => F::one() / (alpha * t_f),
278 LearningRateSchedule::InvScaling => eta0 / t_f.powf(power_t),
279 LearningRateSchedule::Adaptive => eta0,
280 LearningRateSchedule::_Phantom(_) => unreachable!(),
281 }
282}
283
284#[derive(Debug, Clone, Copy)]
290pub enum ClassifierLoss {
291 Hinge,
293 Log,
295 SquaredError,
297 ModifiedHuber,
299}
300
301#[derive(Debug, Clone, Copy)]
303pub enum RegressorLoss<F> {
304 SquaredError,
306 Huber(F),
308 EpsilonInsensitive(F),
310}
311
312#[derive(Debug, Clone)]
343pub struct SGDClassifier<F> {
344 pub loss: ClassifierLoss,
346 pub learning_rate: LearningRateSchedule<F>,
348 pub eta0: F,
350 pub alpha: F,
352 pub max_iter: usize,
354 pub tol: F,
357 pub random_state: Option<u64>,
359 pub power_t: F,
361}
362
363impl<F: Float> SGDClassifier<F> {
364 #[must_use]
370 pub fn new() -> Self {
371 Self {
372 loss: ClassifierLoss::Hinge,
373 learning_rate: LearningRateSchedule::InvScaling,
374 eta0: F::from(0.01).unwrap(),
375 alpha: F::from(0.0001).unwrap(),
376 max_iter: 1000,
377 tol: F::from(1e-3).unwrap(),
378 random_state: None,
379 power_t: F::from(0.25).unwrap(),
380 }
381 }
382
383 #[must_use]
385 pub fn with_loss(mut self, loss: ClassifierLoss) -> Self {
386 self.loss = loss;
387 self
388 }
389
390 #[must_use]
392 pub fn with_learning_rate(mut self, lr: LearningRateSchedule<F>) -> Self {
393 self.learning_rate = lr;
394 self
395 }
396
397 #[must_use]
399 pub fn with_eta0(mut self, eta0: F) -> Self {
400 self.eta0 = eta0;
401 self
402 }
403
404 #[must_use]
406 pub fn with_alpha(mut self, alpha: F) -> Self {
407 self.alpha = alpha;
408 self
409 }
410
411 #[must_use]
413 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
414 self.max_iter = max_iter;
415 self
416 }
417
418 #[must_use]
420 pub fn with_tol(mut self, tol: F) -> Self {
421 self.tol = tol;
422 self
423 }
424
425 #[must_use]
427 pub fn with_random_state(mut self, seed: u64) -> Self {
428 self.random_state = Some(seed);
429 self
430 }
431
432 #[must_use]
434 pub fn with_power_t(mut self, power_t: F) -> Self {
435 self.power_t = power_t;
436 self
437 }
438}
439
440impl<F: Float> Default for SGDClassifier<F> {
441 fn default() -> Self {
442 Self::new()
443 }
444}
445
446fn clf_hyper<F: Float>(clf: &SGDClassifier<F>) -> SGDHyper<F> {
448 SGDHyper {
449 learning_rate: clf.learning_rate,
450 eta0: clf.eta0,
451 alpha: clf.alpha,
452 max_iter: clf.max_iter,
453 tol: clf.tol,
454 random_state: clf.random_state,
455 power_t: clf.power_t,
456 }
457}
458
459#[derive(Debug, Clone)]
461struct SGDHyper<F> {
462 learning_rate: LearningRateSchedule<F>,
463 eta0: F,
464 alpha: F,
465 max_iter: usize,
466 tol: F,
467 random_state: Option<u64>,
468 power_t: F,
469}
470
471fn train_binary_sgd<F, L>(
476 x: &Array2<F>,
477 y_binary: &Array1<F>,
478 weights: &mut Array1<F>,
479 intercept: &mut F,
480 loss_fn: &L,
481 hyper: &SGDHyper<F>,
482 initial_t: usize,
483) -> (F, usize)
484where
485 F: Float + ScalarOperand + Send + Sync + 'static,
486 L: Loss<F>,
487{
488 let n_samples = x.nrows();
489 let n_features = x.ncols();
490 let mut t = initial_t;
491 let mut prev_loss = F::infinity();
492 let mut current_eta = hyper.eta0;
493 let mut no_improve_count: usize = 0;
494 let mut indices: Vec<usize> = (0..n_samples).collect();
495
496 let mut rng = match hyper.random_state {
498 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
499 None => rand::rngs::StdRng::from_os_rng(),
500 };
501
502 let mut total_loss = F::zero();
503
504 for _epoch in 0..hyper.max_iter {
505 indices.shuffle(&mut rng);
506 let mut epoch_loss = F::zero();
507
508 for &i in &indices {
509 t += 1;
510
511 let eta = match hyper.learning_rate {
512 LearningRateSchedule::Adaptive => current_eta,
513 _ => compute_lr(
514 &hyper.learning_rate,
515 hyper.eta0,
516 hyper.alpha,
517 hyper.power_t,
518 t,
519 ),
520 };
521
522 let mut y_pred = *intercept;
524 let xi = x.row(i);
525 for j in 0..n_features {
526 y_pred = y_pred + weights[j] * xi[j];
527 }
528
529 let grad = loss_fn.gradient(y_binary[i], y_pred);
530 epoch_loss = epoch_loss + loss_fn.loss(y_binary[i], y_pred);
531
532 for j in 0..n_features {
534 weights[j] = weights[j] - eta * (grad * xi[j] + hyper.alpha * weights[j]);
535 }
536 *intercept = *intercept - eta * grad;
537 }
538
539 epoch_loss = epoch_loss / F::from(n_samples).unwrap();
540 total_loss = epoch_loss;
541
542 if (prev_loss - epoch_loss).abs() < hyper.tol {
544 break;
545 }
546
547 if let LearningRateSchedule::Adaptive = hyper.learning_rate {
549 if epoch_loss >= prev_loss {
550 no_improve_count += 1;
551 if no_improve_count >= 5 {
552 current_eta = current_eta / F::from(2.0).unwrap();
553 no_improve_count = 0;
554 if current_eta < F::from(1e-6).unwrap() {
555 break;
556 }
557 }
558 } else {
559 no_improve_count = 0;
560 }
561 }
562
563 prev_loss = epoch_loss;
564 }
565
566 (total_loss, t)
567}
568
569#[derive(Debug, Clone)]
578pub struct FittedSGDClassifier<F> {
579 weight_matrix: Vec<Array1<F>>,
582 intercepts: Vec<F>,
584 classes: Vec<usize>,
586 n_features: usize,
588 loss: ClassifierLoss,
590 hyper: SGDHyper<F>,
592 t: usize,
594}
595
596impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
597 for SGDClassifier<F>
598{
599 type Fitted = FittedSGDClassifier<F>;
600 type Error = FerroError;
601
602 fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSGDClassifier<F>, FerroError> {
613 validate_clf_params(x, y, self.eta0, self.alpha)?;
614
615 let n_features = x.ncols();
616 let mut classes: Vec<usize> = y.to_vec();
617 classes.sort_unstable();
618 classes.dedup();
619
620 if classes.len() < 2 {
621 return Err(FerroError::InsufficientSamples {
622 required: 2,
623 actual: classes.len(),
624 context: "SGDClassifier requires at least 2 distinct classes".into(),
625 });
626 }
627
628 let hyper = clf_hyper(self);
629 let loss_enum = self.loss;
630
631 let (weight_matrix, intercepts, t) =
632 fit_ova(x, y, &classes, n_features, &loss_enum, &hyper, 0)?;
633
634 Ok(FittedSGDClassifier {
635 weight_matrix,
636 intercepts,
637 classes,
638 n_features,
639 loss: loss_enum,
640 hyper,
641 t,
642 })
643 }
644}
645
646fn validate_clf_params<F: Float>(
648 x: &Array2<F>,
649 y: &Array1<usize>,
650 eta0: F,
651 alpha: F,
652) -> Result<(), FerroError> {
653 let n_samples = x.nrows();
654 if n_samples != y.len() {
655 return Err(FerroError::ShapeMismatch {
656 expected: vec![n_samples],
657 actual: vec![y.len()],
658 context: "y length must match number of samples in X".into(),
659 });
660 }
661 if n_samples == 0 {
662 return Err(FerroError::InsufficientSamples {
663 required: 1,
664 actual: 0,
665 context: "SGDClassifier requires at least one sample".into(),
666 });
667 }
668 if eta0 <= F::zero() {
669 return Err(FerroError::InvalidParameter {
670 name: "eta0".into(),
671 reason: "must be positive".into(),
672 });
673 }
674 if alpha < F::zero() {
675 return Err(FerroError::InvalidParameter {
676 name: "alpha".into(),
677 reason: "must be non-negative".into(),
678 });
679 }
680 Ok(())
681}
682
683type OvaResult<F> = (Vec<Array1<F>>, Vec<F>, usize);
685
686fn fit_ova<F: Float + Send + Sync + ScalarOperand + 'static>(
689 x: &Array2<F>,
690 y: &Array1<usize>,
691 classes: &[usize],
692 n_features: usize,
693 loss_enum: &ClassifierLoss,
694 hyper: &SGDHyper<F>,
695 initial_t: usize,
696) -> Result<OvaResult<F>, FerroError> {
697 let n_classes = classes.len();
698 let mut weight_matrix: Vec<Array1<F>> = Vec::with_capacity(n_classes);
699 let mut intercepts: Vec<F> = Vec::with_capacity(n_classes);
700 let mut global_t = initial_t;
701
702 if n_classes == 2 {
703 let y_binary: Array1<F> = y.mapv(|label| {
705 if label == classes[1] {
706 F::one()
707 } else {
708 -F::one()
709 }
710 });
711 let mut w = Array1::<F>::zeros(n_features);
712 let mut b = F::zero();
713 let (_, t) =
714 dispatch_train_binary(x, &y_binary, &mut w, &mut b, loss_enum, hyper, global_t);
715 global_t = t;
716 weight_matrix.push(w);
717 intercepts.push(b);
718 } else {
719 for &cls in classes {
721 let y_binary: Array1<F> =
722 y.mapv(|label| if label == cls { F::one() } else { -F::one() });
723 let mut w = Array1::<F>::zeros(n_features);
724 let mut b = F::zero();
725 let (_, t) =
726 dispatch_train_binary(x, &y_binary, &mut w, &mut b, loss_enum, hyper, global_t);
727 global_t = t;
728 weight_matrix.push(w);
729 intercepts.push(b);
730 }
731 }
732
733 Ok((weight_matrix, intercepts, global_t))
734}
735
736#[allow(clippy::too_many_arguments)]
738fn partial_fit_ova<F: Float + Send + Sync + ScalarOperand + 'static>(
739 x: &Array2<F>,
740 y: &Array1<usize>,
741 classes: &[usize],
742 weight_matrix: &mut [Array1<F>],
743 intercepts: &mut [F],
744 loss_enum: &ClassifierLoss,
745 hyper: &SGDHyper<F>,
746 initial_t: usize,
747) -> usize {
748 let n_classes = classes.len();
749 let mut global_t = initial_t;
750
751 if n_classes == 2 {
752 let y_binary: Array1<F> = y.mapv(|label| {
753 if label == classes[1] {
754 F::one()
755 } else {
756 -F::one()
757 }
758 });
759 let (_, t) = dispatch_train_binary(
760 x,
761 &y_binary,
762 &mut weight_matrix[0],
763 &mut intercepts[0],
764 loss_enum,
765 hyper,
766 global_t,
767 );
768 global_t = t;
769 } else {
770 for (idx, &cls) in classes.iter().enumerate() {
771 let y_binary: Array1<F> =
772 y.mapv(|label| if label == cls { F::one() } else { -F::one() });
773 let (_, t) = dispatch_train_binary(
774 x,
775 &y_binary,
776 &mut weight_matrix[idx],
777 &mut intercepts[idx],
778 loss_enum,
779 hyper,
780 global_t,
781 );
782 global_t = t;
783 }
784 }
785
786 global_t
787}
788
789fn dispatch_train_binary<F: Float + Send + Sync + ScalarOperand + 'static>(
791 x: &Array2<F>,
792 y_binary: &Array1<F>,
793 w: &mut Array1<F>,
794 b: &mut F,
795 loss_enum: &ClassifierLoss,
796 hyper: &SGDHyper<F>,
797 initial_t: usize,
798) -> (F, usize) {
799 match loss_enum {
800 ClassifierLoss::Hinge => train_binary_sgd(x, y_binary, w, b, &Hinge, hyper, initial_t),
801 ClassifierLoss::Log => train_binary_sgd(x, y_binary, w, b, &LogLoss, hyper, initial_t),
802 ClassifierLoss::SquaredError => {
803 train_binary_sgd(x, y_binary, w, b, &SquaredError, hyper, initial_t)
804 }
805 ClassifierLoss::ModifiedHuber => {
806 train_binary_sgd(x, y_binary, w, b, &ModifiedHuber, hyper, initial_t)
807 }
808 }
809}
810
811impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
812 for FittedSGDClassifier<F>
813{
814 type Output = Array1<usize>;
815 type Error = FerroError;
816
817 fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
827 let n_features = x.ncols();
828 if n_features != self.n_features {
829 return Err(FerroError::ShapeMismatch {
830 expected: vec![self.n_features],
831 actual: vec![n_features],
832 context: "number of features must match fitted model".into(),
833 });
834 }
835
836 let n_samples = x.nrows();
837 let mut predictions = Array1::<usize>::zeros(n_samples);
838
839 if self.classes.len() == 2 {
840 let scores = x.dot(&self.weight_matrix[0]) + self.intercepts[0];
842 for i in 0..n_samples {
843 predictions[i] = if scores[i] >= F::zero() {
844 self.classes[1]
845 } else {
846 self.classes[0]
847 };
848 }
849 } else {
850 for i in 0..n_samples {
852 let xi = x.row(i);
853 let mut best_class = 0;
854 let mut best_score = F::neg_infinity();
855 for (c, w) in self.weight_matrix.iter().enumerate() {
856 let score = xi.dot(w) + self.intercepts[c];
857 if score > best_score {
858 best_score = score;
859 best_class = c;
860 }
861 }
862 predictions[i] = self.classes[best_class];
863 }
864 }
865
866 Ok(predictions)
867 }
868}
869
870impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<usize>>
871 for FittedSGDClassifier<F>
872{
873 type FitResult = FittedSGDClassifier<F>;
874 type Error = FerroError;
875
876 fn partial_fit(
883 mut self,
884 x: &Array2<F>,
885 y: &Array1<usize>,
886 ) -> Result<FittedSGDClassifier<F>, FerroError> {
887 let n_samples = x.nrows();
888 if n_samples != y.len() {
889 return Err(FerroError::ShapeMismatch {
890 expected: vec![n_samples],
891 actual: vec![y.len()],
892 context: "y length must match number of samples in X".into(),
893 });
894 }
895 if x.ncols() != self.n_features {
896 return Err(FerroError::ShapeMismatch {
897 expected: vec![self.n_features],
898 actual: vec![x.ncols()],
899 context: "number of features must match fitted model".into(),
900 });
901 }
902
903 let mut hyper = self.hyper.clone();
905 hyper.max_iter = 1;
906
907 let t = partial_fit_ova(
908 x,
909 y,
910 &self.classes,
911 &mut self.weight_matrix,
912 &mut self.intercepts,
913 &self.loss,
914 &hyper,
915 self.t,
916 );
917 self.t = t;
918
919 Ok(self)
920 }
921}
922
923impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<usize>>
924 for SGDClassifier<F>
925{
926 type FitResult = FittedSGDClassifier<F>;
927 type Error = FerroError;
928
929 fn partial_fit(
938 self,
939 x: &Array2<F>,
940 y: &Array1<usize>,
941 ) -> Result<FittedSGDClassifier<F>, FerroError> {
942 validate_clf_params(x, y, self.eta0, self.alpha)?;
943
944 let n_features = x.ncols();
945 let mut classes: Vec<usize> = y.to_vec();
946 classes.sort_unstable();
947 classes.dedup();
948
949 if classes.len() < 2 {
950 return Err(FerroError::InsufficientSamples {
951 required: 2,
952 actual: classes.len(),
953 context: "SGDClassifier requires at least 2 distinct classes".into(),
954 });
955 }
956
957 let mut hyper = clf_hyper(&self);
958 hyper.max_iter = 1;
959 let loss_enum = self.loss;
960
961 let (weight_matrix, intercepts, t) =
962 fit_ova(x, y, &classes, n_features, &loss_enum, &hyper, 0)?;
963
964 Ok(FittedSGDClassifier {
965 weight_matrix,
966 intercepts,
967 classes,
968 n_features,
969 loss: loss_enum,
970 hyper: clf_hyper(&self),
971 t,
972 })
973 }
974}
975
976impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
977 for FittedSGDClassifier<F>
978{
979 fn coefficients(&self) -> &Array1<F> {
981 &self.weight_matrix[0]
982 }
983
984 fn intercept(&self) -> F {
986 self.intercepts[0]
987 }
988}
989
990impl<F> PipelineEstimator<F> for SGDClassifier<F>
992where
993 F: Float + ToPrimitive + FromPrimitive + ScalarOperand + Send + Sync + 'static,
994{
995 fn fit_pipeline(
996 &self,
997 x: &Array2<F>,
998 y: &Array1<F>,
999 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1000 let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
1001 let fitted = self.fit(x, &y_usize)?;
1002 Ok(Box::new(FittedSGDClassifierPipeline(fitted)))
1003 }
1004}
1005
1006struct FittedSGDClassifierPipeline<F>(FittedSGDClassifier<F>)
1008where
1009 F: Float + Send + Sync + 'static;
1010
1011unsafe impl<F> Send for FittedSGDClassifierPipeline<F> where F: Float + Send + Sync + 'static {}
1013unsafe impl<F> Sync for FittedSGDClassifierPipeline<F> where F: Float + Send + Sync + 'static {}
1014
1015impl<F> FittedPipelineEstimator<F> for FittedSGDClassifierPipeline<F>
1016where
1017 F: Float + ToPrimitive + FromPrimitive + ScalarOperand + Send + Sync + 'static,
1018{
1019 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1020 let preds = self.0.predict(x)?;
1021 Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
1022 }
1023}
1024
1025#[derive(Debug, Clone)]
1053pub struct SGDRegressor<F> {
1054 pub loss: RegressorLoss<F>,
1056 pub learning_rate: LearningRateSchedule<F>,
1058 pub eta0: F,
1060 pub alpha: F,
1062 pub max_iter: usize,
1064 pub tol: F,
1066 pub random_state: Option<u64>,
1068 pub power_t: F,
1070}
1071
1072impl<F: Float> SGDRegressor<F> {
1073 #[must_use]
1079 pub fn new() -> Self {
1080 Self {
1081 loss: RegressorLoss::SquaredError,
1082 learning_rate: LearningRateSchedule::InvScaling,
1083 eta0: F::from(0.01).unwrap(),
1084 alpha: F::from(0.0001).unwrap(),
1085 max_iter: 1000,
1086 tol: F::from(1e-3).unwrap(),
1087 random_state: None,
1088 power_t: F::from(0.25).unwrap(),
1089 }
1090 }
1091
1092 #[must_use]
1094 pub fn with_loss(mut self, loss: RegressorLoss<F>) -> Self {
1095 self.loss = loss;
1096 self
1097 }
1098
1099 #[must_use]
1101 pub fn with_learning_rate(mut self, lr: LearningRateSchedule<F>) -> Self {
1102 self.learning_rate = lr;
1103 self
1104 }
1105
1106 #[must_use]
1108 pub fn with_eta0(mut self, eta0: F) -> Self {
1109 self.eta0 = eta0;
1110 self
1111 }
1112
1113 #[must_use]
1115 pub fn with_alpha(mut self, alpha: F) -> Self {
1116 self.alpha = alpha;
1117 self
1118 }
1119
1120 #[must_use]
1122 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1123 self.max_iter = max_iter;
1124 self
1125 }
1126
1127 #[must_use]
1129 pub fn with_tol(mut self, tol: F) -> Self {
1130 self.tol = tol;
1131 self
1132 }
1133
1134 #[must_use]
1136 pub fn with_random_state(mut self, seed: u64) -> Self {
1137 self.random_state = Some(seed);
1138 self
1139 }
1140
1141 #[must_use]
1143 pub fn with_power_t(mut self, power_t: F) -> Self {
1144 self.power_t = power_t;
1145 self
1146 }
1147}
1148
1149impl<F: Float> Default for SGDRegressor<F> {
1150 fn default() -> Self {
1151 Self::new()
1152 }
1153}
1154
1155fn reg_hyper<F: Float>(reg: &SGDRegressor<F>) -> SGDHyper<F> {
1157 SGDHyper {
1158 learning_rate: reg.learning_rate,
1159 eta0: reg.eta0,
1160 alpha: reg.alpha,
1161 max_iter: reg.max_iter,
1162 tol: reg.tol,
1163 random_state: reg.random_state,
1164 power_t: reg.power_t,
1165 }
1166}
1167
1168fn train_regressor_sgd<F, L>(
1171 x: &Array2<F>,
1172 y: &Array1<F>,
1173 weights: &mut Array1<F>,
1174 intercept: &mut F,
1175 loss_fn: &L,
1176 hyper: &SGDHyper<F>,
1177 initial_t: usize,
1178) -> (F, usize)
1179where
1180 F: Float + ScalarOperand + Send + Sync + 'static,
1181 L: Loss<F>,
1182{
1183 let n_samples = x.nrows();
1184 let n_features = x.ncols();
1185 let mut t = initial_t;
1186 let mut prev_loss = F::infinity();
1187 let mut current_eta = hyper.eta0;
1188 let mut no_improve_count: usize = 0;
1189 let mut indices: Vec<usize> = (0..n_samples).collect();
1190
1191 let mut rng = match hyper.random_state {
1192 Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
1193 None => rand::rngs::StdRng::from_os_rng(),
1194 };
1195
1196 let mut total_loss = F::zero();
1197
1198 for _epoch in 0..hyper.max_iter {
1199 indices.shuffle(&mut rng);
1200 let mut epoch_loss = F::zero();
1201
1202 for &i in &indices {
1203 t += 1;
1204
1205 let eta = match hyper.learning_rate {
1206 LearningRateSchedule::Adaptive => current_eta,
1207 _ => compute_lr(
1208 &hyper.learning_rate,
1209 hyper.eta0,
1210 hyper.alpha,
1211 hyper.power_t,
1212 t,
1213 ),
1214 };
1215
1216 let xi = x.row(i);
1217 let mut y_pred = *intercept;
1218 for j in 0..n_features {
1219 y_pred = y_pred + weights[j] * xi[j];
1220 }
1221
1222 let grad = loss_fn.gradient(y[i], y_pred);
1223 epoch_loss = epoch_loss + loss_fn.loss(y[i], y_pred);
1224
1225 for j in 0..n_features {
1226 weights[j] = weights[j] - eta * (grad * xi[j] + hyper.alpha * weights[j]);
1227 }
1228 *intercept = *intercept - eta * grad;
1229 }
1230
1231 epoch_loss = epoch_loss / F::from(n_samples).unwrap();
1232 total_loss = epoch_loss;
1233
1234 if (prev_loss - epoch_loss).abs() < hyper.tol {
1235 break;
1236 }
1237
1238 if let LearningRateSchedule::Adaptive = hyper.learning_rate {
1239 if epoch_loss >= prev_loss {
1240 no_improve_count += 1;
1241 if no_improve_count >= 5 {
1242 current_eta = current_eta / F::from(2.0).unwrap();
1243 no_improve_count = 0;
1244 if current_eta < F::from(1e-6).unwrap() {
1245 break;
1246 }
1247 }
1248 } else {
1249 no_improve_count = 0;
1250 }
1251 }
1252
1253 prev_loss = epoch_loss;
1254 }
1255
1256 (total_loss, t)
1257}
1258
1259fn dispatch_train_regressor<F: Float + Send + Sync + ScalarOperand + 'static>(
1261 x: &Array2<F>,
1262 y: &Array1<F>,
1263 w: &mut Array1<F>,
1264 b: &mut F,
1265 loss_enum: &RegressorLoss<F>,
1266 hyper: &SGDHyper<F>,
1267 initial_t: usize,
1268) -> (F, usize) {
1269 match loss_enum {
1270 RegressorLoss::SquaredError => {
1271 train_regressor_sgd(x, y, w, b, &SquaredError, hyper, initial_t)
1272 }
1273 RegressorLoss::Huber(eps) => {
1274 train_regressor_sgd(x, y, w, b, &Huber { epsilon: *eps }, hyper, initial_t)
1275 }
1276 RegressorLoss::EpsilonInsensitive(eps) => train_regressor_sgd(
1277 x,
1278 y,
1279 w,
1280 b,
1281 &EpsilonInsensitive { epsilon: *eps },
1282 hyper,
1283 initial_t,
1284 ),
1285 }
1286}
1287
1288#[derive(Debug, Clone)]
1293pub struct FittedSGDRegressor<F> {
1294 weights: Array1<F>,
1296 intercept: F,
1298 n_features: usize,
1300 loss: RegressorLoss<F>,
1302 hyper: SGDHyper<F>,
1304 t: usize,
1306}
1307
1308fn validate_reg_params<F: Float>(
1310 x: &Array2<F>,
1311 y: &Array1<F>,
1312 eta0: F,
1313 alpha: F,
1314) -> Result<(), FerroError> {
1315 let n_samples = x.nrows();
1316 if n_samples != y.len() {
1317 return Err(FerroError::ShapeMismatch {
1318 expected: vec![n_samples],
1319 actual: vec![y.len()],
1320 context: "y length must match number of samples in X".into(),
1321 });
1322 }
1323 if n_samples == 0 {
1324 return Err(FerroError::InsufficientSamples {
1325 required: 1,
1326 actual: 0,
1327 context: "SGDRegressor requires at least one sample".into(),
1328 });
1329 }
1330 if eta0 <= F::zero() {
1331 return Err(FerroError::InvalidParameter {
1332 name: "eta0".into(),
1333 reason: "must be positive".into(),
1334 });
1335 }
1336 if alpha < F::zero() {
1337 return Err(FerroError::InvalidParameter {
1338 name: "alpha".into(),
1339 reason: "must be non-negative".into(),
1340 });
1341 }
1342 Ok(())
1343}
1344
1345impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<F>>
1346 for SGDRegressor<F>
1347{
1348 type Fitted = FittedSGDRegressor<F>;
1349 type Error = FerroError;
1350
1351 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedSGDRegressor<F>, FerroError> {
1360 validate_reg_params(x, y, self.eta0, self.alpha)?;
1361
1362 let n_features = x.ncols();
1363 let hyper = reg_hyper(self);
1364 let mut w = Array1::<F>::zeros(n_features);
1365 let mut b = F::zero();
1366
1367 let (_, t) = dispatch_train_regressor(x, y, &mut w, &mut b, &self.loss, &hyper, 0);
1368
1369 Ok(FittedSGDRegressor {
1370 weights: w,
1371 intercept: b,
1372 n_features,
1373 loss: self.loss,
1374 hyper,
1375 t,
1376 })
1377 }
1378}
1379
1380impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
1381 for FittedSGDRegressor<F>
1382{
1383 type Output = Array1<F>;
1384 type Error = FerroError;
1385
1386 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1395 let n_features = x.ncols();
1396 if n_features != self.n_features {
1397 return Err(FerroError::ShapeMismatch {
1398 expected: vec![self.n_features],
1399 actual: vec![n_features],
1400 context: "number of features must match fitted model".into(),
1401 });
1402 }
1403
1404 let preds = x.dot(&self.weights) + self.intercept;
1405 Ok(preds)
1406 }
1407}
1408
1409impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<F>>
1410 for FittedSGDRegressor<F>
1411{
1412 type FitResult = FittedSGDRegressor<F>;
1413 type Error = FerroError;
1414
1415 fn partial_fit(
1422 mut self,
1423 x: &Array2<F>,
1424 y: &Array1<F>,
1425 ) -> Result<FittedSGDRegressor<F>, FerroError> {
1426 let n_samples = x.nrows();
1427 if n_samples != y.len() {
1428 return Err(FerroError::ShapeMismatch {
1429 expected: vec![n_samples],
1430 actual: vec![y.len()],
1431 context: "y length must match number of samples in X".into(),
1432 });
1433 }
1434 if x.ncols() != self.n_features {
1435 return Err(FerroError::ShapeMismatch {
1436 expected: vec![self.n_features],
1437 actual: vec![x.ncols()],
1438 context: "number of features must match fitted model".into(),
1439 });
1440 }
1441
1442 let mut hyper = self.hyper.clone();
1443 hyper.max_iter = 1;
1444
1445 let (_, t) = dispatch_train_regressor(
1446 x,
1447 y,
1448 &mut self.weights,
1449 &mut self.intercept,
1450 &self.loss,
1451 &hyper,
1452 self.t,
1453 );
1454 self.t = t;
1455
1456 Ok(self)
1457 }
1458}
1459
1460impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<F>>
1461 for SGDRegressor<F>
1462{
1463 type FitResult = FittedSGDRegressor<F>;
1464 type Error = FerroError;
1465
1466 fn partial_fit(
1474 self,
1475 x: &Array2<F>,
1476 y: &Array1<F>,
1477 ) -> Result<FittedSGDRegressor<F>, FerroError> {
1478 validate_reg_params(x, y, self.eta0, self.alpha)?;
1479
1480 let n_features = x.ncols();
1481 let mut hyper = reg_hyper(&self);
1482 hyper.max_iter = 1;
1483 let mut w = Array1::<F>::zeros(n_features);
1484 let mut b = F::zero();
1485
1486 let (_, t) = dispatch_train_regressor(x, y, &mut w, &mut b, &self.loss, &hyper, 0);
1487
1488 Ok(FittedSGDRegressor {
1489 weights: w,
1490 intercept: b,
1491 n_features,
1492 loss: self.loss,
1493 hyper: reg_hyper(&self),
1494 t,
1495 })
1496 }
1497}
1498
1499impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
1500 for FittedSGDRegressor<F>
1501{
1502 fn coefficients(&self) -> &Array1<F> {
1503 &self.weights
1504 }
1505
1506 fn intercept(&self) -> F {
1507 self.intercept
1508 }
1509}
1510
1511impl<F> PipelineEstimator<F> for SGDRegressor<F>
1513where
1514 F: Float + ScalarOperand + Send + Sync + 'static,
1515{
1516 fn fit_pipeline(
1517 &self,
1518 x: &Array2<F>,
1519 y: &Array1<F>,
1520 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1521 let fitted = self.fit(x, y)?;
1522 Ok(Box::new(fitted))
1523 }
1524}
1525
1526impl<F> FittedPipelineEstimator<F> for FittedSGDRegressor<F>
1527where
1528 F: Float + ScalarOperand + Send + Sync + 'static,
1529{
1530 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1531 self.predict(x)
1532 }
1533}
1534
1535#[cfg(test)]
1540mod tests {
1541 use super::*;
1542 use ndarray::array;
1543
1544 #[test]
1549 fn test_hinge_loss_correct_side() {
1550 let h = Hinge;
1551 assert!((Loss::<f64>::loss(&h, 1.0, 2.0) - 0.0).abs() < 1e-10);
1553 assert!((Loss::<f64>::gradient(&h, 1.0, 2.0) - 0.0).abs() < 1e-10);
1554 }
1555
1556 #[test]
1557 fn test_hinge_loss_wrong_side() {
1558 let h = Hinge;
1559 assert!((Loss::<f64>::loss(&h, 1.0, -0.5) - 1.5).abs() < 1e-10);
1561 assert!((Loss::<f64>::gradient(&h, 1.0, -0.5) - (-1.0)).abs() < 1e-10);
1562 }
1563
1564 #[test]
1565 fn test_log_loss_zero_pred() {
1566 let l = LogLoss;
1567 let loss = Loss::<f64>::loss(&l, 1.0, 0.0);
1569 assert!((loss - 2.0_f64.ln()).abs() < 1e-10);
1570 }
1571
1572 #[test]
1573 fn test_log_loss_large_correct() {
1574 let l = LogLoss;
1575 let loss = Loss::<f64>::loss(&l, 1.0, 20.0);
1577 assert!(loss < 1e-5);
1578 }
1579
1580 #[test]
1581 fn test_squared_error_loss() {
1582 let s = SquaredError;
1583 assert!((Loss::<f64>::loss(&s, 3.0, 1.0) - 2.0).abs() < 1e-10);
1584 assert!((Loss::<f64>::gradient(&s, 3.0, 1.0) - (-2.0)).abs() < 1e-10);
1585 }
1586
1587 #[test]
1588 fn test_modified_huber_loss() {
1589 let mh = ModifiedHuber;
1590 assert!((Loss::<f64>::loss(&mh, 1.0, 2.0)).abs() < 1e-10);
1592 assert!((Loss::<f64>::loss(&mh, 1.0, 0.5) - 0.25).abs() < 1e-10);
1594 assert!((Loss::<f64>::loss(&mh, 1.0, -2.0) - 8.0).abs() < 1e-10);
1596 }
1597
1598 #[test]
1599 fn test_huber_loss_quadratic_region() {
1600 let h = Huber { epsilon: 1.0_f64 };
1601 assert!((Loss::<f64>::loss(&h, 1.0, 0.5) - 0.125).abs() < 1e-10);
1603 }
1604
1605 #[test]
1606 fn test_huber_loss_linear_region() {
1607 let h = Huber { epsilon: 1.0_f64 };
1608 assert!((Loss::<f64>::loss(&h, 3.0, 0.0) - 2.5).abs() < 1e-10);
1610 }
1611
1612 #[test]
1613 fn test_epsilon_insensitive_inside() {
1614 let ei = EpsilonInsensitive { epsilon: 0.1_f64 };
1615 assert!((Loss::<f64>::loss(&ei, 1.0, 0.95)).abs() < 1e-10);
1617 }
1618
1619 #[test]
1620 fn test_epsilon_insensitive_outside() {
1621 let ei = EpsilonInsensitive { epsilon: 0.1_f64 };
1622 assert!((Loss::<f64>::loss(&ei, 1.0, 0.5) - 0.4).abs() < 1e-10);
1624 }
1625
1626 #[test]
1631 fn test_constant_lr() {
1632 let lr: LearningRateSchedule<f64> = LearningRateSchedule::Constant;
1633 assert!((compute_lr(&lr, 0.1, 0.01, 0.25, 100) - 0.1).abs() < 1e-10);
1634 }
1635
1636 #[test]
1637 fn test_optimal_lr() {
1638 let lr: LearningRateSchedule<f64> = LearningRateSchedule::Optimal;
1639 assert!((compute_lr(&lr, 0.1, 0.01, 0.25, 10) - 10.0).abs() < 1e-10);
1641 }
1642
1643 #[test]
1644 fn test_invscaling_lr() {
1645 let lr: LearningRateSchedule<f64> = LearningRateSchedule::InvScaling;
1646 let result = compute_lr(&lr, 0.1, 0.01, 0.5, 10);
1648 let expected = 0.1 / 10.0_f64.sqrt();
1649 assert!((result - expected).abs() < 1e-10);
1650 }
1651
1652 #[test]
1653 fn test_adaptive_lr_returns_eta0() {
1654 let lr: LearningRateSchedule<f64> = LearningRateSchedule::Adaptive;
1655 assert!((compute_lr(&lr, 0.05, 0.01, 0.25, 100) - 0.05).abs() < 1e-10);
1656 }
1657
1658 #[test]
1663 fn test_sgd_classifier_binary() {
1664 let x = Array2::from_shape_vec(
1666 (8, 2),
1667 vec![
1668 -2.0, -2.0, -1.5, -2.0, -2.0, -1.5, -1.5, -1.5, 2.0, 2.0, 1.5, 2.0, 2.0, 1.5, 1.5,
1669 1.5,
1670 ],
1671 )
1672 .unwrap();
1673 let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1674
1675 let clf = SGDClassifier::<f64>::new()
1676 .with_loss(ClassifierLoss::Log)
1677 .with_random_state(42)
1678 .with_max_iter(1000)
1679 .with_eta0(0.01);
1680 let fitted = clf.fit(&x, &y).unwrap();
1681 let preds = fitted.predict(&x).unwrap();
1682
1683 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1684 assert!(correct >= 6, "expected >= 6 correct, got {correct}");
1685 }
1686
1687 #[test]
1688 fn test_sgd_classifier_log_loss() {
1689 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1690 let y = array![0, 0, 0, 1, 1, 1];
1691
1692 let clf = SGDClassifier::<f64>::new()
1693 .with_loss(ClassifierLoss::Log)
1694 .with_random_state(42)
1695 .with_max_iter(500);
1696 let fitted = clf.fit(&x, &y).unwrap();
1697 let preds = fitted.predict(&x).unwrap();
1698
1699 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1700 assert!(correct >= 4, "expected >= 4 correct, got {correct}");
1701 }
1702
1703 #[test]
1704 fn test_sgd_classifier_multiclass() {
1705 let x = Array2::from_shape_vec(
1706 (9, 2),
1707 vec![
1708 0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 0.0, 5.5, 0.0, 5.0, 0.5, 0.0, 5.0, 0.5, 5.0,
1709 0.0, 5.5,
1710 ],
1711 )
1712 .unwrap();
1713 let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1714
1715 let clf = SGDClassifier::<f64>::new()
1716 .with_random_state(42)
1717 .with_max_iter(1000)
1718 .with_eta0(0.01);
1719 let fitted = clf.fit(&x, &y).unwrap();
1720 let preds = fitted.predict(&x).unwrap();
1721
1722 let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1723 assert!(
1724 correct >= 6,
1725 "expected >= 6 correct for multiclass, got {correct}"
1726 );
1727 }
1728
1729 #[test]
1730 fn test_sgd_classifier_shape_mismatch_fit() {
1731 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1732 let y = array![0, 1]; let clf = SGDClassifier::<f64>::new();
1734 assert!(clf.fit(&x, &y).is_err());
1735 }
1736
1737 #[test]
1738 fn test_sgd_classifier_shape_mismatch_predict() {
1739 let x =
1740 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1741 let y = array![0, 0, 1, 1];
1742 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1743 let fitted = clf.fit(&x, &y).unwrap();
1744
1745 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1746 assert!(fitted.predict(&x_bad).is_err());
1747 }
1748
1749 #[test]
1750 fn test_sgd_classifier_single_class_error() {
1751 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1752 let y = array![0, 0, 0];
1753 let clf = SGDClassifier::<f64>::new();
1754 assert!(clf.fit(&x, &y).is_err());
1755 }
1756
1757 #[test]
1758 fn test_sgd_classifier_invalid_eta0() {
1759 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1760 let y = array![0, 0, 1, 1];
1761 let clf = SGDClassifier::<f64>::new().with_eta0(0.0);
1762 assert!(clf.fit(&x, &y).is_err());
1763 }
1764
1765 #[test]
1766 fn test_sgd_classifier_invalid_alpha() {
1767 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1768 let y = array![0, 0, 1, 1];
1769 let clf = SGDClassifier::<f64>::new().with_alpha(-1.0);
1770 assert!(clf.fit(&x, &y).is_err());
1771 }
1772
1773 #[test]
1774 fn test_sgd_classifier_has_coefficients() {
1775 let x =
1776 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1777 let y = array![0, 0, 1, 1];
1778 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1779 let fitted = clf.fit(&x, &y).unwrap();
1780 assert_eq!(fitted.coefficients().len(), 2);
1781 }
1782
1783 #[test]
1784 fn test_sgd_classifier_partial_fit() {
1785 let x =
1786 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1787 let y = array![0, 0, 1, 1];
1788
1789 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1790 let fitted = clf.partial_fit(&x, &y).unwrap();
1791 let fitted = fitted.partial_fit(&x, &y).unwrap();
1792 let preds = fitted.predict(&x).unwrap();
1793 assert_eq!(preds.len(), 4);
1794 }
1795
1796 #[test]
1797 fn test_sgd_classifier_partial_fit_chain() {
1798 let x1 =
1801 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1802 let y1 = array![0, 0, 1, 1];
1803 let x2 =
1804 Array2::from_shape_vec((4, 2), vec![0.5, 0.5, 1.5, 1.5, 7.5, 7.5, 8.5, 8.5]).unwrap();
1805 let y2 = array![0, 0, 1, 1];
1806
1807 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1808 let preds = clf
1809 .partial_fit(&x1, &y1)
1810 .unwrap()
1811 .partial_fit(&x2, &y2)
1812 .unwrap()
1813 .predict(&x1)
1814 .unwrap();
1815 assert_eq!(preds.len(), 4);
1816 }
1817
1818 #[test]
1819 fn test_sgd_classifier_partial_fit_shape_mismatch() {
1820 let x =
1821 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1822 let y = array![0, 0, 1, 1];
1823 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1824 let fitted = clf.partial_fit(&x, &y).unwrap();
1825
1826 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1827 let y_bad = array![0, 1];
1828 assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
1829 }
1830
1831 #[test]
1832 fn test_sgd_classifier_modified_huber() {
1833 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1834 let y = array![0, 0, 0, 1, 1, 1];
1835
1836 let clf = SGDClassifier::<f64>::new()
1837 .with_loss(ClassifierLoss::ModifiedHuber)
1838 .with_random_state(42)
1839 .with_max_iter(500);
1840 let fitted = clf.fit(&x, &y).unwrap();
1841 let preds = fitted.predict(&x).unwrap();
1842 assert_eq!(preds.len(), 6);
1843 }
1844
1845 #[test]
1846 fn test_sgd_classifier_squared_error_loss() {
1847 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1848 let y = array![0, 0, 0, 1, 1, 1];
1849
1850 let clf = SGDClassifier::<f64>::new()
1851 .with_loss(ClassifierLoss::SquaredError)
1852 .with_random_state(42)
1853 .with_max_iter(500);
1854 let fitted = clf.fit(&x, &y).unwrap();
1855 let preds = fitted.predict(&x).unwrap();
1856 assert_eq!(preds.len(), 6);
1857 }
1858
1859 #[test]
1860 fn test_sgd_classifier_pipeline() {
1861 let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1862 let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1863
1864 let clf = SGDClassifier::<f64>::new().with_random_state(42);
1865 let fitted = clf.fit_pipeline(&x, &y).unwrap();
1866 let preds = fitted.predict_pipeline(&x).unwrap();
1867 assert_eq!(preds.len(), 6);
1868 }
1869
1870 #[test]
1871 fn test_sgd_classifier_constant_lr() {
1872 let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
1873 let y = array![0, 0, 1, 1];
1874
1875 let clf = SGDClassifier::<f64>::new()
1876 .with_learning_rate(LearningRateSchedule::Constant)
1877 .with_random_state(42)
1878 .with_max_iter(200);
1879 let fitted = clf.fit(&x, &y).unwrap();
1880 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1881 }
1882
1883 #[test]
1884 fn test_sgd_classifier_f32() {
1885 let x = Array2::from_shape_vec((4, 1), vec![-2.0f32, -1.0, 1.0, 2.0]).unwrap();
1886 let y = array![0_usize, 0, 1, 1];
1887
1888 let clf = SGDClassifier::<f32>::new()
1889 .with_random_state(42)
1890 .with_max_iter(200);
1891 let fitted = clf.fit(&x, &y).unwrap();
1892 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1893 }
1894
1895 #[test]
1900 fn test_sgd_regressor_basic() {
1901 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1903 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
1904
1905 let model = SGDRegressor::<f64>::new()
1906 .with_random_state(42)
1907 .with_max_iter(2000)
1908 .with_eta0(0.01)
1909 .with_alpha(0.0);
1910 let fitted = model.fit(&x, &y).unwrap();
1911 let preds = fitted.predict(&x).unwrap();
1912
1913 for (p, &actual) in preds.iter().zip(y.iter()) {
1915 assert!(
1916 (*p - actual).abs() < 2.0,
1917 "prediction {p} too far from {actual}"
1918 );
1919 }
1920 }
1921
1922 #[test]
1923 fn test_sgd_regressor_shape_mismatch() {
1924 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1925 let y = array![1.0, 2.0]; let model = SGDRegressor::<f64>::new();
1927 assert!(model.fit(&x, &y).is_err());
1928 }
1929
1930 #[test]
1931 fn test_sgd_regressor_predict_shape_mismatch() {
1932 let x =
1933 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1934 let y = array![1.0, 2.0, 3.0, 4.0];
1935 let model = SGDRegressor::<f64>::new().with_random_state(42);
1936 let fitted = model.fit(&x, &y).unwrap();
1937
1938 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1939 assert!(fitted.predict(&x_bad).is_err());
1940 }
1941
1942 #[test]
1943 fn test_sgd_regressor_invalid_eta0() {
1944 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1945 let y = array![1.0, 2.0, 3.0];
1946 let model = SGDRegressor::<f64>::new().with_eta0(-0.1);
1947 assert!(model.fit(&x, &y).is_err());
1948 }
1949
1950 #[test]
1951 fn test_sgd_regressor_has_coefficients() {
1952 let x =
1953 Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1954 let y = array![1.0, 2.0, 3.0, 4.0];
1955 let model = SGDRegressor::<f64>::new().with_random_state(42);
1956 let fitted = model.fit(&x, &y).unwrap();
1957 assert_eq!(fitted.coefficients().len(), 2);
1958 }
1959
1960 #[test]
1961 fn test_sgd_regressor_partial_fit() {
1962 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1963 let y = array![2.0, 4.0, 6.0, 8.0];
1964
1965 let model = SGDRegressor::<f64>::new().with_random_state(42);
1966 let fitted = model.partial_fit(&x, &y).unwrap();
1967 let fitted = fitted.partial_fit(&x, &y).unwrap();
1968 let preds = fitted.predict(&x).unwrap();
1969 assert_eq!(preds.len(), 4);
1970 }
1971
1972 #[test]
1973 fn test_sgd_regressor_partial_fit_chain() {
1974 let x1 = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1975 let y1 = array![2.0, 4.0, 6.0];
1976 let x2 = Array2::from_shape_vec((3, 1), vec![4.0, 5.0, 6.0]).unwrap();
1977 let y2 = array![8.0, 10.0, 12.0];
1978
1979 let model = SGDRegressor::<f64>::new().with_random_state(42);
1980 let preds = model
1981 .partial_fit(&x1, &y1)
1982 .unwrap()
1983 .partial_fit(&x2, &y2)
1984 .unwrap()
1985 .predict(&x1)
1986 .unwrap();
1987 assert_eq!(preds.len(), 3);
1988 }
1989
1990 #[test]
1991 fn test_sgd_regressor_partial_fit_shape_mismatch() {
1992 let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
1993 let y = array![1.0, 2.0, 3.0];
1994 let model = SGDRegressor::<f64>::new().with_random_state(42);
1995 let fitted = model.partial_fit(&x, &y).unwrap();
1996
1997 let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1998 let y_bad = array![1.0, 2.0];
1999 assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
2000 }
2001
2002 #[test]
2003 fn test_sgd_regressor_huber_loss() {
2004 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2005 let y = array![2.0, 4.0, 6.0, 8.0];
2006
2007 let model = SGDRegressor::<f64>::new()
2008 .with_loss(RegressorLoss::Huber(1.35))
2009 .with_random_state(42)
2010 .with_max_iter(500);
2011 let fitted = model.fit(&x, &y).unwrap();
2012 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2013 }
2014
2015 #[test]
2016 fn test_sgd_regressor_epsilon_insensitive() {
2017 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2018 let y = array![2.0, 4.0, 6.0, 8.0];
2019
2020 let model = SGDRegressor::<f64>::new()
2021 .with_loss(RegressorLoss::EpsilonInsensitive(0.1))
2022 .with_random_state(42)
2023 .with_max_iter(500);
2024 let fitted = model.fit(&x, &y).unwrap();
2025 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2026 }
2027
2028 #[test]
2029 fn test_sgd_regressor_pipeline() {
2030 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2031 let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0]);
2032
2033 let model = SGDRegressor::<f64>::new().with_random_state(42);
2034 let fitted = model.fit_pipeline(&x, &y).unwrap();
2035 let preds = fitted.predict_pipeline(&x).unwrap();
2036 assert_eq!(preds.len(), 4);
2037 }
2038
2039 #[test]
2040 fn test_sgd_regressor_f32() {
2041 let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
2042 let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
2043
2044 let model = SGDRegressor::<f32>::new().with_random_state(42);
2045 let fitted = model.fit(&x, &y).unwrap();
2046 assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2047 }
2048
2049 #[test]
2050 fn test_sgd_regressor_empty_data() {
2051 let x = Array2::<f64>::zeros((0, 2));
2052 let y = Array1::<f64>::zeros(0);
2053 let model = SGDRegressor::<f64>::new();
2054 assert!(model.fit(&x, &y).is_err());
2055 }
2056
2057 #[test]
2058 fn test_sgd_classifier_empty_data() {
2059 let x = Array2::<f64>::zeros((0, 2));
2060 let y = Array1::<usize>::zeros(0);
2061 let clf = SGDClassifier::<f64>::new();
2062 assert!(clf.fit(&x, &y).is_err());
2063 }
2064
2065 #[test]
2066 fn test_sgd_classifier_default() {
2067 let clf = SGDClassifier::<f64>::default();
2068 assert!(clf.eta0 > 0.0);
2069 assert!(clf.alpha >= 0.0);
2070 }
2071
2072 #[test]
2073 fn test_sgd_regressor_default() {
2074 let model = SGDRegressor::<f64>::default();
2075 assert!(model.eta0 > 0.0);
2076 assert!(model.alpha >= 0.0);
2077 }
2078}