Skip to main content

ferrolearn_linear/
sgd.rs

1//! Stochastic Gradient Descent (SGD) linear models.
2//!
3//! This module provides [`SGDClassifier`] and [`SGDRegressor`], two linear
4//! models trained using stochastic gradient descent. Both support online /
5//! streaming learning via the [`PartialFit`] trait and a range of configurable
6//! loss functions and learning-rate schedules.
7//!
8//! # Classifier
9//!
10//! ```
11//! use ferrolearn_linear::sgd::{SGDClassifier, ClassifierLoss};
12//! use ferrolearn_core::{Fit, Predict};
13//! use ndarray::{array, Array2};
14//!
15//! let x = Array2::from_shape_vec((6, 2), vec![
16//!     1.0, 2.0, 2.0, 3.0, 3.0, 1.0,
17//!     8.0, 7.0, 9.0, 8.0, 7.0, 9.0,
18//! ]).unwrap();
19//! let y = array![0, 0, 0, 1, 1, 1];
20//!
21//! let model = SGDClassifier::<f64>::new();
22//! let fitted = model.fit(&x, &y).unwrap();
23//! let preds = fitted.predict(&x).unwrap();
24//! assert_eq!(preds.len(), 6);
25//! ```
26//!
27//! # Regressor
28//!
29//! ```
30//! use ferrolearn_linear::sgd::{SGDRegressor, RegressorLoss};
31//! use ferrolearn_core::{Fit, Predict};
32//! use ndarray::{array, Array2};
33//!
34//! let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
35//! let y = array![2.0, 4.0, 6.0, 8.0];
36//!
37//! let model = SGDRegressor::<f64>::new();
38//! let fitted = model.fit(&x, &y).unwrap();
39//! let preds = fitted.predict(&x).unwrap();
40//! assert_eq!(preds.len(), 4);
41//! ```
42
43use ferrolearn_core::error::FerroError;
44use ferrolearn_core::introspection::HasCoefficients;
45use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
46use ferrolearn_core::traits::{Fit, PartialFit, Predict};
47use ndarray::{Array1, Array2, ScalarOperand};
48use num_traits::{Float, FromPrimitive, ToPrimitive};
49use rand::SeedableRng;
50use rand::seq::SliceRandom;
51
52// ---------------------------------------------------------------------------
53// Loss functions
54// ---------------------------------------------------------------------------
55
56/// A loss function for SGD optimization.
57///
58/// Provides the loss value and its gradient with respect to the prediction.
59pub trait Loss<F: Float>: Clone + Send + Sync {
60    /// Compute the loss for a single sample.
61    fn loss(&self, y_true: F, y_pred: F) -> F;
62
63    /// Compute the gradient of the loss with respect to `y_pred`.
64    fn gradient(&self, y_true: F, y_pred: F) -> F;
65}
66
67/// Hinge loss for linear SVM-style classification.
68///
69/// `L(y, p) = max(0, 1 - y * p)` where `y in {-1, +1}`.
70#[derive(Debug, Clone, Copy)]
71pub struct Hinge;
72
73impl<F: Float> Loss<F> for Hinge {
74    fn loss(&self, y_true: F, y_pred: F) -> F {
75        let margin = y_true * y_pred;
76        if margin < F::one() {
77            F::one() - margin
78        } else {
79            F::zero()
80        }
81    }
82
83    fn gradient(&self, y_true: F, y_pred: F) -> F {
84        let margin = y_true * y_pred;
85        if margin < F::one() {
86            -y_true
87        } else {
88            F::zero()
89        }
90    }
91}
92
93/// Log loss (logistic regression / cross-entropy).
94///
95/// `L(y, p) = log(1 + exp(-y * p))` where `y in {-1, +1}`.
96#[derive(Debug, Clone, Copy)]
97pub struct LogLoss;
98
99impl<F: Float> Loss<F> for LogLoss {
100    fn loss(&self, y_true: F, y_pred: F) -> F {
101        let z = y_true * y_pred;
102        if z > F::from(18.0).unwrap() {
103            (-z).exp()
104        } else if z < F::from(-18.0).unwrap() {
105            -z
106        } else {
107            (F::one() + (-z).exp()).ln()
108        }
109    }
110
111    fn gradient(&self, y_true: F, y_pred: F) -> F {
112        let z = y_true * y_pred;
113        let exp_nz = if z > F::from(18.0).unwrap() {
114            (-z).exp()
115        } else if z < F::from(-18.0).unwrap() {
116            F::from(1e18).unwrap()
117        } else {
118            (-z).exp()
119        };
120        -y_true * exp_nz / (F::one() + exp_nz)
121    }
122}
123
124/// Squared error loss for regression.
125///
126/// `L(y, p) = 0.5 * (y - p)^2`.
127#[derive(Debug, Clone, Copy)]
128pub struct SquaredError;
129
130impl<F: Float> Loss<F> for SquaredError {
131    fn loss(&self, y_true: F, y_pred: F) -> F {
132        let diff = y_true - y_pred;
133        F::from(0.5).unwrap() * diff * diff
134    }
135
136    fn gradient(&self, y_true: F, y_pred: F) -> F {
137        y_pred - y_true
138    }
139}
140
141/// Modified Huber loss for classification.
142///
143/// Smooth approximation to hinge with quadratic behaviour near the margin:
144///
145/// ```text
146/// L(y, p) = max(0, 1 - y*p)^2   if y*p >= -1
147///         = -4 * y * p            otherwise
148/// ```
149#[derive(Debug, Clone, Copy)]
150pub struct ModifiedHuber;
151
152impl<F: Float> Loss<F> for ModifiedHuber {
153    fn loss(&self, y_true: F, y_pred: F) -> F {
154        let z = y_true * y_pred;
155        if z >= -F::one() {
156            let margin = F::one() - z;
157            if margin > F::zero() {
158                margin * margin
159            } else {
160                F::zero()
161            }
162        } else {
163            -F::from(4.0).unwrap() * z
164        }
165    }
166
167    fn gradient(&self, y_true: F, y_pred: F) -> F {
168        let z = y_true * y_pred;
169        if z >= -F::one() {
170            if z < F::one() {
171                F::from(-2.0).unwrap() * y_true * (F::one() - z)
172            } else {
173                F::zero()
174            }
175        } else {
176            -F::from(4.0).unwrap() * y_true
177        }
178    }
179}
180
181/// Huber loss for robust regression.
182///
183/// `L(y, p) = 0.5 * (y - p)^2` if `|y - p| <= epsilon`, else
184/// `epsilon * (|y - p| - 0.5 * epsilon)`.
185#[derive(Debug, Clone, Copy)]
186pub struct Huber<F> {
187    /// Threshold parameter for switching from quadratic to linear loss.
188    pub epsilon: F,
189}
190
191impl<F: Float + Send + Sync> Loss<F> for Huber<F> {
192    fn loss(&self, y_true: F, y_pred: F) -> F {
193        let diff = y_true - y_pred;
194        let abs_diff = diff.abs();
195        if abs_diff <= self.epsilon {
196            F::from(0.5).unwrap() * diff * diff
197        } else {
198            self.epsilon * (abs_diff - F::from(0.5).unwrap() * self.epsilon)
199        }
200    }
201
202    fn gradient(&self, y_true: F, y_pred: F) -> F {
203        let diff = y_pred - y_true;
204        let abs_diff = diff.abs();
205        if abs_diff <= self.epsilon {
206            diff
207        } else if diff > F::zero() {
208            self.epsilon
209        } else {
210            -self.epsilon
211        }
212    }
213}
214
215/// Epsilon-insensitive loss for support vector regression.
216///
217/// `L(y, p) = max(0, |y - p| - epsilon)`.
218#[derive(Debug, Clone, Copy)]
219pub struct EpsilonInsensitive<F> {
220    /// Insensitivity margin.
221    pub epsilon: F,
222}
223
224impl<F: Float + Send + Sync> Loss<F> for EpsilonInsensitive<F> {
225    fn loss(&self, y_true: F, y_pred: F) -> F {
226        let diff = (y_true - y_pred).abs();
227        if diff > self.epsilon {
228            diff - self.epsilon
229        } else {
230            F::zero()
231        }
232    }
233
234    fn gradient(&self, y_true: F, y_pred: F) -> F {
235        let diff = y_pred - y_true;
236        if diff > self.epsilon {
237            F::one()
238        } else if diff < -self.epsilon {
239            -F::one()
240        } else {
241            F::zero()
242        }
243    }
244}
245
246// ---------------------------------------------------------------------------
247// Learning rate schedules
248// ---------------------------------------------------------------------------
249
250/// Learning rate schedule for SGD.
251#[derive(Debug, Clone, Copy)]
252pub enum LearningRateSchedule<F> {
253    /// Fixed learning rate `eta0` throughout training.
254    Constant,
255    /// Optimal schedule: `eta = 1 / (alpha * t)`.
256    Optimal,
257    /// Inverse scaling: `eta = eta0 / t^power_t`.
258    InvScaling,
259    /// Adaptive: starts at `eta0`, halved when loss fails to decrease for
260    /// 5 consecutive epochs. Stops when `eta < 1e-6`.
261    Adaptive,
262    #[doc(hidden)]
263    _Phantom(std::marker::PhantomData<F>),
264}
265
266/// Compute the learning rate for a given step.
267fn compute_lr<F: Float>(
268    schedule: &LearningRateSchedule<F>,
269    eta0: F,
270    alpha: F,
271    power_t: F,
272    t: usize,
273) -> F {
274    let t_f = F::from(t.max(1)).unwrap();
275    match schedule {
276        LearningRateSchedule::Constant => eta0,
277        LearningRateSchedule::Optimal => F::one() / (alpha * t_f),
278        LearningRateSchedule::InvScaling => eta0 / t_f.powf(power_t),
279        LearningRateSchedule::Adaptive => eta0,
280        LearningRateSchedule::_Phantom(_) => unreachable!(),
281    }
282}
283
284// ---------------------------------------------------------------------------
285// Classifier loss enum
286// ---------------------------------------------------------------------------
287
288/// Available loss functions for [`SGDClassifier`].
289#[derive(Debug, Clone, Copy)]
290pub enum ClassifierLoss {
291    /// Hinge loss (linear SVM).
292    Hinge,
293    /// Log loss (logistic regression).
294    Log,
295    /// Squared error loss.
296    SquaredError,
297    /// Modified Huber loss.
298    ModifiedHuber,
299}
300
301/// Available loss functions for [`SGDRegressor`].
302#[derive(Debug, Clone, Copy)]
303pub enum RegressorLoss<F> {
304    /// Squared error loss (default).
305    SquaredError,
306    /// Huber loss with the given epsilon.
307    Huber(F),
308    /// Epsilon-insensitive loss with the given epsilon.
309    EpsilonInsensitive(F),
310}
311
312// ---------------------------------------------------------------------------
313// SGDClassifier
314// ---------------------------------------------------------------------------
315
316/// Stochastic Gradient Descent classifier.
317///
318/// Supports binary classification via a decision boundary and multiclass
319/// classification via one-vs-all decomposition.
320///
321/// # Type Parameters
322///
323/// - `F`: The floating-point type (`f32` or `f64`).
324///
325/// # Examples
326///
327/// ```
328/// use ferrolearn_linear::sgd::SGDClassifier;
329/// use ferrolearn_core::{Fit, Predict};
330/// use ndarray::{array, Array2};
331///
332/// let x = Array2::from_shape_vec((6, 2), vec![
333///     1.0, 2.0, 2.0, 3.0, 3.0, 1.0,
334///     8.0, 7.0, 9.0, 8.0, 7.0, 9.0,
335/// ]).unwrap();
336/// let y = array![0, 0, 0, 1, 1, 1];
337///
338/// let clf = SGDClassifier::<f64>::new();
339/// let fitted = clf.fit(&x, &y).unwrap();
340/// let preds = fitted.predict(&x).unwrap();
341/// ```
342#[derive(Debug, Clone)]
343pub struct SGDClassifier<F> {
344    /// The loss function to use.
345    pub loss: ClassifierLoss,
346    /// The learning rate schedule.
347    pub learning_rate: LearningRateSchedule<F>,
348    /// Initial learning rate.
349    pub eta0: F,
350    /// L2 regularization strength.
351    pub alpha: F,
352    /// Maximum number of passes over the training data.
353    pub max_iter: usize,
354    /// Convergence tolerance. Training stops when the loss improvement
355    /// is below this threshold.
356    pub tol: F,
357    /// Optional random seed for sample shuffling.
358    pub random_state: Option<u64>,
359    /// Power parameter for inverse scaling schedule.
360    pub power_t: F,
361}
362
363impl<F: Float> SGDClassifier<F> {
364    /// Create a new `SGDClassifier` with default settings.
365    ///
366    /// Defaults: `loss = Hinge`, `learning_rate = InvScaling`,
367    /// `eta0 = 0.01`, `alpha = 0.0001`, `max_iter = 1000`,
368    /// `tol = 1e-3`, `power_t = 0.25`.
369    #[must_use]
370    pub fn new() -> Self {
371        Self {
372            loss: ClassifierLoss::Hinge,
373            learning_rate: LearningRateSchedule::InvScaling,
374            eta0: F::from(0.01).unwrap(),
375            alpha: F::from(0.0001).unwrap(),
376            max_iter: 1000,
377            tol: F::from(1e-3).unwrap(),
378            random_state: None,
379            power_t: F::from(0.25).unwrap(),
380        }
381    }
382
383    /// Set the loss function.
384    #[must_use]
385    pub fn with_loss(mut self, loss: ClassifierLoss) -> Self {
386        self.loss = loss;
387        self
388    }
389
390    /// Set the learning rate schedule.
391    #[must_use]
392    pub fn with_learning_rate(mut self, lr: LearningRateSchedule<F>) -> Self {
393        self.learning_rate = lr;
394        self
395    }
396
397    /// Set the initial learning rate.
398    #[must_use]
399    pub fn with_eta0(mut self, eta0: F) -> Self {
400        self.eta0 = eta0;
401        self
402    }
403
404    /// Set the L2 regularization strength.
405    #[must_use]
406    pub fn with_alpha(mut self, alpha: F) -> Self {
407        self.alpha = alpha;
408        self
409    }
410
411    /// Set the maximum number of epochs.
412    #[must_use]
413    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
414        self.max_iter = max_iter;
415        self
416    }
417
418    /// Set the convergence tolerance.
419    #[must_use]
420    pub fn with_tol(mut self, tol: F) -> Self {
421        self.tol = tol;
422        self
423    }
424
425    /// Set the random seed for reproducibility.
426    #[must_use]
427    pub fn with_random_state(mut self, seed: u64) -> Self {
428        self.random_state = Some(seed);
429        self
430    }
431
432    /// Set the power parameter for inverse scaling.
433    #[must_use]
434    pub fn with_power_t(mut self, power_t: F) -> Self {
435        self.power_t = power_t;
436        self
437    }
438}
439
440impl<F: Float> Default for SGDClassifier<F> {
441    fn default() -> Self {
442        Self::new()
443    }
444}
445
446/// Extract hyperparameter bundle from an `SGDClassifier`.
447fn clf_hyper<F: Float>(clf: &SGDClassifier<F>) -> SGDHyper<F> {
448    SGDHyper {
449        learning_rate: clf.learning_rate,
450        eta0: clf.eta0,
451        alpha: clf.alpha,
452        max_iter: clf.max_iter,
453        tol: clf.tol,
454        random_state: clf.random_state,
455        power_t: clf.power_t,
456    }
457}
458
459/// Internal hyperparameter bundle shared between Fit and PartialFit paths.
460#[derive(Debug, Clone)]
461struct SGDHyper<F> {
462    learning_rate: LearningRateSchedule<F>,
463    eta0: F,
464    alpha: F,
465    max_iter: usize,
466    tol: F,
467    random_state: Option<u64>,
468    power_t: F,
469}
470
471/// Train a single binary classifier via SGD, updating `weights` and
472/// `intercept` in place. `y_binary` must be in `{-1, +1}`.
473///
474/// Returns the cumulative loss and the step counter after training.
475fn train_binary_sgd<F, L>(
476    x: &Array2<F>,
477    y_binary: &Array1<F>,
478    weights: &mut Array1<F>,
479    intercept: &mut F,
480    loss_fn: &L,
481    hyper: &SGDHyper<F>,
482    initial_t: usize,
483) -> (F, usize)
484where
485    F: Float + ScalarOperand + Send + Sync + 'static,
486    L: Loss<F>,
487{
488    let n_samples = x.nrows();
489    let n_features = x.ncols();
490    let mut t = initial_t;
491    let mut prev_loss = F::infinity();
492    let mut current_eta = hyper.eta0;
493    let mut no_improve_count: usize = 0;
494    let mut indices: Vec<usize> = (0..n_samples).collect();
495
496    // Build the RNG for shuffling.
497    let mut rng = match hyper.random_state {
498        Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
499        None => rand::rngs::StdRng::from_os_rng(),
500    };
501
502    let mut total_loss = F::zero();
503
504    for _epoch in 0..hyper.max_iter {
505        indices.shuffle(&mut rng);
506        let mut epoch_loss = F::zero();
507
508        for &i in &indices {
509            t += 1;
510
511            let eta = match hyper.learning_rate {
512                LearningRateSchedule::Adaptive => current_eta,
513                _ => compute_lr(
514                    &hyper.learning_rate,
515                    hyper.eta0,
516                    hyper.alpha,
517                    hyper.power_t,
518                    t,
519                ),
520            };
521
522            // Compute prediction: w^T x_i + b.
523            let mut y_pred = *intercept;
524            let xi = x.row(i);
525            for j in 0..n_features {
526                y_pred = y_pred + weights[j] * xi[j];
527            }
528
529            let grad = loss_fn.gradient(y_binary[i], y_pred);
530            epoch_loss = epoch_loss + loss_fn.loss(y_binary[i], y_pred);
531
532            // Update weights with gradient + L2 regularization.
533            for j in 0..n_features {
534                weights[j] = weights[j] - eta * (grad * xi[j] + hyper.alpha * weights[j]);
535            }
536            *intercept = *intercept - eta * grad;
537        }
538
539        epoch_loss = epoch_loss / F::from(n_samples).unwrap();
540        total_loss = epoch_loss;
541
542        // Convergence check.
543        if (prev_loss - epoch_loss).abs() < hyper.tol {
544            break;
545        }
546
547        // Adaptive learning rate adjustment.
548        if let LearningRateSchedule::Adaptive = hyper.learning_rate {
549            if epoch_loss >= prev_loss {
550                no_improve_count += 1;
551                if no_improve_count >= 5 {
552                    current_eta = current_eta / F::from(2.0).unwrap();
553                    no_improve_count = 0;
554                    if current_eta < F::from(1e-6).unwrap() {
555                        break;
556                    }
557                }
558            } else {
559                no_improve_count = 0;
560            }
561        }
562
563        prev_loss = epoch_loss;
564    }
565
566    (total_loss, t)
567}
568
569/// Fitted SGD classifier.
570///
571/// Holds the learned weight vectors and intercepts. For binary problems
572/// there is a single weight vector; for multiclass problems there is one
573/// per class (one-vs-all).
574///
575/// Implements [`Predict`] and [`PartialFit`] to support both inference and
576/// online learning.
577#[derive(Debug, Clone)]
578pub struct FittedSGDClassifier<F> {
579    /// Weight matrix: one row per binary sub-problem.
580    /// Binary: shape `(1, n_features)`, multiclass: `(n_classes, n_features)`.
581    weight_matrix: Vec<Array1<F>>,
582    /// Intercept vector, one per sub-problem.
583    intercepts: Vec<F>,
584    /// Sorted unique class labels.
585    classes: Vec<usize>,
586    /// Number of features the model was trained on.
587    n_features: usize,
588    /// The loss function used during training.
589    loss: ClassifierLoss,
590    /// Hyperparameters for continued training via `partial_fit`.
591    hyper: SGDHyper<F>,
592    /// Global step counter across all partial_fit calls.
593    t: usize,
594}
595
596impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
597    for SGDClassifier<F>
598{
599    type Fitted = FittedSGDClassifier<F>;
600    type Error = FerroError;
601
602    /// Fit the SGD classifier on the given data.
603    ///
604    /// # Errors
605    ///
606    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched
607    /// sample counts.
608    /// Returns [`FerroError::InsufficientSamples`] if fewer than 2 classes
609    /// are present.
610    /// Returns [`FerroError::InvalidParameter`] if `eta0` or `alpha` are
611    /// not positive.
612    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSGDClassifier<F>, FerroError> {
613        validate_clf_params(x, y, self.eta0, self.alpha)?;
614
615        let n_features = x.ncols();
616        let mut classes: Vec<usize> = y.to_vec();
617        classes.sort_unstable();
618        classes.dedup();
619
620        if classes.len() < 2 {
621            return Err(FerroError::InsufficientSamples {
622                required: 2,
623                actual: classes.len(),
624                context: "SGDClassifier requires at least 2 distinct classes".into(),
625            });
626        }
627
628        let hyper = clf_hyper(self);
629        let loss_enum = self.loss;
630
631        let (weight_matrix, intercepts, t) =
632            fit_ova(x, y, &classes, n_features, &loss_enum, &hyper, 0)?;
633
634        Ok(FittedSGDClassifier {
635            weight_matrix,
636            intercepts,
637            classes,
638            n_features,
639            loss: loss_enum,
640            hyper,
641            t,
642        })
643    }
644}
645
646/// Validate classifier input shapes and parameters.
647fn validate_clf_params<F: Float>(
648    x: &Array2<F>,
649    y: &Array1<usize>,
650    eta0: F,
651    alpha: F,
652) -> Result<(), FerroError> {
653    let n_samples = x.nrows();
654    if n_samples != y.len() {
655        return Err(FerroError::ShapeMismatch {
656            expected: vec![n_samples],
657            actual: vec![y.len()],
658            context: "y length must match number of samples in X".into(),
659        });
660    }
661    if n_samples == 0 {
662        return Err(FerroError::InsufficientSamples {
663            required: 1,
664            actual: 0,
665            context: "SGDClassifier requires at least one sample".into(),
666        });
667    }
668    if eta0 <= F::zero() {
669        return Err(FerroError::InvalidParameter {
670            name: "eta0".into(),
671            reason: "must be positive".into(),
672        });
673    }
674    if alpha < F::zero() {
675        return Err(FerroError::InvalidParameter {
676            name: "alpha".into(),
677            reason: "must be non-negative".into(),
678        });
679    }
680    Ok(())
681}
682
683/// Result type for one-vs-all training: (weight_matrix, intercepts, step_counter).
684type OvaResult<F> = (Vec<Array1<F>>, Vec<F>, usize);
685
686/// Train one-vs-all binary classifiers, returning per-class weights, intercepts,
687/// and the cumulative step counter.
688fn fit_ova<F: Float + Send + Sync + ScalarOperand + 'static>(
689    x: &Array2<F>,
690    y: &Array1<usize>,
691    classes: &[usize],
692    n_features: usize,
693    loss_enum: &ClassifierLoss,
694    hyper: &SGDHyper<F>,
695    initial_t: usize,
696) -> Result<OvaResult<F>, FerroError> {
697    let n_classes = classes.len();
698    let mut weight_matrix: Vec<Array1<F>> = Vec::with_capacity(n_classes);
699    let mut intercepts: Vec<F> = Vec::with_capacity(n_classes);
700    let mut global_t = initial_t;
701
702    if n_classes == 2 {
703        // Single binary problem: class[0] -> -1, class[1] -> +1.
704        let y_binary: Array1<F> = y.mapv(|label| {
705            if label == classes[1] {
706                F::one()
707            } else {
708                -F::one()
709            }
710        });
711        let mut w = Array1::<F>::zeros(n_features);
712        let mut b = F::zero();
713        let (_, t) =
714            dispatch_train_binary(x, &y_binary, &mut w, &mut b, loss_enum, hyper, global_t);
715        global_t = t;
716        weight_matrix.push(w);
717        intercepts.push(b);
718    } else {
719        // One-vs-all: one binary problem per class.
720        for &cls in classes {
721            let y_binary: Array1<F> =
722                y.mapv(|label| if label == cls { F::one() } else { -F::one() });
723            let mut w = Array1::<F>::zeros(n_features);
724            let mut b = F::zero();
725            let (_, t) =
726                dispatch_train_binary(x, &y_binary, &mut w, &mut b, loss_enum, hyper, global_t);
727            global_t = t;
728            weight_matrix.push(w);
729            intercepts.push(b);
730        }
731    }
732
733    Ok((weight_matrix, intercepts, global_t))
734}
735
736/// Train one-vs-all using existing weight vectors (for partial_fit).
737#[allow(clippy::too_many_arguments)]
738fn partial_fit_ova<F: Float + Send + Sync + ScalarOperand + 'static>(
739    x: &Array2<F>,
740    y: &Array1<usize>,
741    classes: &[usize],
742    weight_matrix: &mut [Array1<F>],
743    intercepts: &mut [F],
744    loss_enum: &ClassifierLoss,
745    hyper: &SGDHyper<F>,
746    initial_t: usize,
747) -> usize {
748    let n_classes = classes.len();
749    let mut global_t = initial_t;
750
751    if n_classes == 2 {
752        let y_binary: Array1<F> = y.mapv(|label| {
753            if label == classes[1] {
754                F::one()
755            } else {
756                -F::one()
757            }
758        });
759        let (_, t) = dispatch_train_binary(
760            x,
761            &y_binary,
762            &mut weight_matrix[0],
763            &mut intercepts[0],
764            loss_enum,
765            hyper,
766            global_t,
767        );
768        global_t = t;
769    } else {
770        for (idx, &cls) in classes.iter().enumerate() {
771            let y_binary: Array1<F> =
772                y.mapv(|label| if label == cls { F::one() } else { -F::one() });
773            let (_, t) = dispatch_train_binary(
774                x,
775                &y_binary,
776                &mut weight_matrix[idx],
777                &mut intercepts[idx],
778                loss_enum,
779                hyper,
780                global_t,
781            );
782            global_t = t;
783        }
784    }
785
786    global_t
787}
788
789/// Dispatch to the appropriate typed loss training function.
790fn dispatch_train_binary<F: Float + Send + Sync + ScalarOperand + 'static>(
791    x: &Array2<F>,
792    y_binary: &Array1<F>,
793    w: &mut Array1<F>,
794    b: &mut F,
795    loss_enum: &ClassifierLoss,
796    hyper: &SGDHyper<F>,
797    initial_t: usize,
798) -> (F, usize) {
799    match loss_enum {
800        ClassifierLoss::Hinge => train_binary_sgd(x, y_binary, w, b, &Hinge, hyper, initial_t),
801        ClassifierLoss::Log => train_binary_sgd(x, y_binary, w, b, &LogLoss, hyper, initial_t),
802        ClassifierLoss::SquaredError => {
803            train_binary_sgd(x, y_binary, w, b, &SquaredError, hyper, initial_t)
804        }
805        ClassifierLoss::ModifiedHuber => {
806            train_binary_sgd(x, y_binary, w, b, &ModifiedHuber, hyper, initial_t)
807        }
808    }
809}
810
811impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
812    for FittedSGDClassifier<F>
813{
814    type Output = Array1<usize>;
815    type Error = FerroError;
816
817    /// Predict class labels for the given feature matrix.
818    ///
819    /// For binary classification, uses `sign(w^T x + b)`.
820    /// For multiclass, returns the class whose one-vs-all score is highest.
821    ///
822    /// # Errors
823    ///
824    /// Returns [`FerroError::ShapeMismatch`] if the number of features
825    /// does not match the fitted model.
826    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
827        let n_features = x.ncols();
828        if n_features != self.n_features {
829            return Err(FerroError::ShapeMismatch {
830                expected: vec![self.n_features],
831                actual: vec![n_features],
832                context: "number of features must match fitted model".into(),
833            });
834        }
835
836        let n_samples = x.nrows();
837        let mut predictions = Array1::<usize>::zeros(n_samples);
838
839        if self.classes.len() == 2 {
840            // Binary: single weight vector.
841            let scores = x.dot(&self.weight_matrix[0]) + self.intercepts[0];
842            for i in 0..n_samples {
843                predictions[i] = if scores[i] >= F::zero() {
844                    self.classes[1]
845                } else {
846                    self.classes[0]
847                };
848            }
849        } else {
850            // Multiclass: one-vs-all, pick highest score.
851            for i in 0..n_samples {
852                let xi = x.row(i);
853                let mut best_class = 0;
854                let mut best_score = F::neg_infinity();
855                for (c, w) in self.weight_matrix.iter().enumerate() {
856                    let score = xi.dot(w) + self.intercepts[c];
857                    if score > best_score {
858                        best_score = score;
859                        best_class = c;
860                    }
861                }
862                predictions[i] = self.classes[best_class];
863            }
864        }
865
866        Ok(predictions)
867    }
868}
869
870impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<usize>>
871    for FittedSGDClassifier<F>
872{
873    type FitResult = FittedSGDClassifier<F>;
874    type Error = FerroError;
875
876    /// Incrementally train the classifier on a new batch of data.
877    ///
878    /// # Errors
879    ///
880    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched
881    /// sizes or `x` has the wrong number of features.
882    fn partial_fit(
883        mut self,
884        x: &Array2<F>,
885        y: &Array1<usize>,
886    ) -> Result<FittedSGDClassifier<F>, FerroError> {
887        let n_samples = x.nrows();
888        if n_samples != y.len() {
889            return Err(FerroError::ShapeMismatch {
890                expected: vec![n_samples],
891                actual: vec![y.len()],
892                context: "y length must match number of samples in X".into(),
893            });
894        }
895        if x.ncols() != self.n_features {
896            return Err(FerroError::ShapeMismatch {
897                expected: vec![self.n_features],
898                actual: vec![x.ncols()],
899                context: "number of features must match fitted model".into(),
900            });
901        }
902
903        // Use a single-epoch hyper for partial_fit.
904        let mut hyper = self.hyper.clone();
905        hyper.max_iter = 1;
906
907        let t = partial_fit_ova(
908            x,
909            y,
910            &self.classes,
911            &mut self.weight_matrix,
912            &mut self.intercepts,
913            &self.loss,
914            &hyper,
915            self.t,
916        );
917        self.t = t;
918
919        Ok(self)
920    }
921}
922
923impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<usize>>
924    for SGDClassifier<F>
925{
926    type FitResult = FittedSGDClassifier<F>;
927    type Error = FerroError;
928
929    /// Initial call to `partial_fit` on an unfitted classifier.
930    ///
931    /// Equivalent to `fit` but with a single epoch, enabling subsequent
932    /// incremental calls.
933    ///
934    /// # Errors
935    ///
936    /// Same as [`Fit::fit`].
937    fn partial_fit(
938        self,
939        x: &Array2<F>,
940        y: &Array1<usize>,
941    ) -> Result<FittedSGDClassifier<F>, FerroError> {
942        validate_clf_params(x, y, self.eta0, self.alpha)?;
943
944        let n_features = x.ncols();
945        let mut classes: Vec<usize> = y.to_vec();
946        classes.sort_unstable();
947        classes.dedup();
948
949        if classes.len() < 2 {
950            return Err(FerroError::InsufficientSamples {
951                required: 2,
952                actual: classes.len(),
953                context: "SGDClassifier requires at least 2 distinct classes".into(),
954            });
955        }
956
957        let mut hyper = clf_hyper(&self);
958        hyper.max_iter = 1;
959        let loss_enum = self.loss;
960
961        let (weight_matrix, intercepts, t) =
962            fit_ova(x, y, &classes, n_features, &loss_enum, &hyper, 0)?;
963
964        Ok(FittedSGDClassifier {
965            weight_matrix,
966            intercepts,
967            classes,
968            n_features,
969            loss: loss_enum,
970            hyper: clf_hyper(&self),
971            t,
972        })
973    }
974}
975
976impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
977    for FittedSGDClassifier<F>
978{
979    /// Returns the coefficient vector for the first (or only) binary classifier.
980    fn coefficients(&self) -> &Array1<F> {
981        &self.weight_matrix[0]
982    }
983
984    /// Returns the intercept for the first (or only) binary classifier.
985    fn intercept(&self) -> F {
986        self.intercepts[0]
987    }
988}
989
990// Pipeline integration.
991impl<F> PipelineEstimator<F> for SGDClassifier<F>
992where
993    F: Float + ToPrimitive + FromPrimitive + ScalarOperand + Send + Sync + 'static,
994{
995    fn fit_pipeline(
996        &self,
997        x: &Array2<F>,
998        y: &Array1<F>,
999    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1000        let y_usize: Array1<usize> = y.mapv(|v| v.to_usize().unwrap_or(0));
1001        let fitted = self.fit(x, &y_usize)?;
1002        Ok(Box::new(FittedSGDClassifierPipeline(fitted)))
1003    }
1004}
1005
1006/// Wrapper for pipeline integration that converts predictions to float.
1007struct FittedSGDClassifierPipeline<F>(FittedSGDClassifier<F>)
1008where
1009    F: Float + Send + Sync + 'static;
1010
1011// Safety: inner type fields are Send + Sync.
1012unsafe impl<F> Send for FittedSGDClassifierPipeline<F> where F: Float + Send + Sync + 'static {}
1013unsafe impl<F> Sync for FittedSGDClassifierPipeline<F> where F: Float + Send + Sync + 'static {}
1014
1015impl<F> FittedPipelineEstimator<F> for FittedSGDClassifierPipeline<F>
1016where
1017    F: Float + ToPrimitive + FromPrimitive + ScalarOperand + Send + Sync + 'static,
1018{
1019    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1020        let preds = self.0.predict(x)?;
1021        Ok(preds.mapv(|v| F::from_usize(v).unwrap_or(F::nan())))
1022    }
1023}
1024
1025// ---------------------------------------------------------------------------
1026// SGDRegressor
1027// ---------------------------------------------------------------------------
1028
1029/// Stochastic Gradient Descent regressor.
1030///
1031/// Supports several loss functions for regression, trained using stochastic
1032/// gradient descent with configurable learning rate schedules.
1033///
1034/// # Type Parameters
1035///
1036/// - `F`: The floating-point type (`f32` or `f64`).
1037///
1038/// # Examples
1039///
1040/// ```
1041/// use ferrolearn_linear::sgd::SGDRegressor;
1042/// use ferrolearn_core::{Fit, Predict};
1043/// use ndarray::{array, Array2};
1044///
1045/// let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1046/// let y = array![2.0, 4.0, 6.0, 8.0];
1047///
1048/// let model = SGDRegressor::<f64>::new();
1049/// let fitted = model.fit(&x, &y).unwrap();
1050/// let preds = fitted.predict(&x).unwrap();
1051/// ```
1052#[derive(Debug, Clone)]
1053pub struct SGDRegressor<F> {
1054    /// The loss function to use.
1055    pub loss: RegressorLoss<F>,
1056    /// The learning rate schedule.
1057    pub learning_rate: LearningRateSchedule<F>,
1058    /// Initial learning rate.
1059    pub eta0: F,
1060    /// L2 regularization strength.
1061    pub alpha: F,
1062    /// Maximum number of passes over the training data.
1063    pub max_iter: usize,
1064    /// Convergence tolerance.
1065    pub tol: F,
1066    /// Optional random seed for sample shuffling.
1067    pub random_state: Option<u64>,
1068    /// Power parameter for inverse scaling schedule.
1069    pub power_t: F,
1070}
1071
1072impl<F: Float> SGDRegressor<F> {
1073    /// Create a new `SGDRegressor` with default settings.
1074    ///
1075    /// Defaults: `loss = SquaredError`, `learning_rate = InvScaling`,
1076    /// `eta0 = 0.01`, `alpha = 0.0001`, `max_iter = 1000`,
1077    /// `tol = 1e-3`, `power_t = 0.25`.
1078    #[must_use]
1079    pub fn new() -> Self {
1080        Self {
1081            loss: RegressorLoss::SquaredError,
1082            learning_rate: LearningRateSchedule::InvScaling,
1083            eta0: F::from(0.01).unwrap(),
1084            alpha: F::from(0.0001).unwrap(),
1085            max_iter: 1000,
1086            tol: F::from(1e-3).unwrap(),
1087            random_state: None,
1088            power_t: F::from(0.25).unwrap(),
1089        }
1090    }
1091
1092    /// Set the loss function.
1093    #[must_use]
1094    pub fn with_loss(mut self, loss: RegressorLoss<F>) -> Self {
1095        self.loss = loss;
1096        self
1097    }
1098
1099    /// Set the learning rate schedule.
1100    #[must_use]
1101    pub fn with_learning_rate(mut self, lr: LearningRateSchedule<F>) -> Self {
1102        self.learning_rate = lr;
1103        self
1104    }
1105
1106    /// Set the initial learning rate.
1107    #[must_use]
1108    pub fn with_eta0(mut self, eta0: F) -> Self {
1109        self.eta0 = eta0;
1110        self
1111    }
1112
1113    /// Set the L2 regularization strength.
1114    #[must_use]
1115    pub fn with_alpha(mut self, alpha: F) -> Self {
1116        self.alpha = alpha;
1117        self
1118    }
1119
1120    /// Set the maximum number of epochs.
1121    #[must_use]
1122    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1123        self.max_iter = max_iter;
1124        self
1125    }
1126
1127    /// Set the convergence tolerance.
1128    #[must_use]
1129    pub fn with_tol(mut self, tol: F) -> Self {
1130        self.tol = tol;
1131        self
1132    }
1133
1134    /// Set the random seed for reproducibility.
1135    #[must_use]
1136    pub fn with_random_state(mut self, seed: u64) -> Self {
1137        self.random_state = Some(seed);
1138        self
1139    }
1140
1141    /// Set the power parameter for inverse scaling.
1142    #[must_use]
1143    pub fn with_power_t(mut self, power_t: F) -> Self {
1144        self.power_t = power_t;
1145        self
1146    }
1147}
1148
1149impl<F: Float> Default for SGDRegressor<F> {
1150    fn default() -> Self {
1151        Self::new()
1152    }
1153}
1154
1155/// Extract hyperparameter bundle from an `SGDRegressor`.
1156fn reg_hyper<F: Float>(reg: &SGDRegressor<F>) -> SGDHyper<F> {
1157    SGDHyper {
1158        learning_rate: reg.learning_rate,
1159        eta0: reg.eta0,
1160        alpha: reg.alpha,
1161        max_iter: reg.max_iter,
1162        tol: reg.tol,
1163        random_state: reg.random_state,
1164        power_t: reg.power_t,
1165    }
1166}
1167
1168/// Train a single regressor via SGD, updating `weights` and `intercept`
1169/// in place. Returns the final loss and step counter.
1170fn train_regressor_sgd<F, L>(
1171    x: &Array2<F>,
1172    y: &Array1<F>,
1173    weights: &mut Array1<F>,
1174    intercept: &mut F,
1175    loss_fn: &L,
1176    hyper: &SGDHyper<F>,
1177    initial_t: usize,
1178) -> (F, usize)
1179where
1180    F: Float + ScalarOperand + Send + Sync + 'static,
1181    L: Loss<F>,
1182{
1183    let n_samples = x.nrows();
1184    let n_features = x.ncols();
1185    let mut t = initial_t;
1186    let mut prev_loss = F::infinity();
1187    let mut current_eta = hyper.eta0;
1188    let mut no_improve_count: usize = 0;
1189    let mut indices: Vec<usize> = (0..n_samples).collect();
1190
1191    let mut rng = match hyper.random_state {
1192        Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
1193        None => rand::rngs::StdRng::from_os_rng(),
1194    };
1195
1196    let mut total_loss = F::zero();
1197
1198    for _epoch in 0..hyper.max_iter {
1199        indices.shuffle(&mut rng);
1200        let mut epoch_loss = F::zero();
1201
1202        for &i in &indices {
1203            t += 1;
1204
1205            let eta = match hyper.learning_rate {
1206                LearningRateSchedule::Adaptive => current_eta,
1207                _ => compute_lr(
1208                    &hyper.learning_rate,
1209                    hyper.eta0,
1210                    hyper.alpha,
1211                    hyper.power_t,
1212                    t,
1213                ),
1214            };
1215
1216            let xi = x.row(i);
1217            let mut y_pred = *intercept;
1218            for j in 0..n_features {
1219                y_pred = y_pred + weights[j] * xi[j];
1220            }
1221
1222            let grad = loss_fn.gradient(y[i], y_pred);
1223            epoch_loss = epoch_loss + loss_fn.loss(y[i], y_pred);
1224
1225            for j in 0..n_features {
1226                weights[j] = weights[j] - eta * (grad * xi[j] + hyper.alpha * weights[j]);
1227            }
1228            *intercept = *intercept - eta * grad;
1229        }
1230
1231        epoch_loss = epoch_loss / F::from(n_samples).unwrap();
1232        total_loss = epoch_loss;
1233
1234        if (prev_loss - epoch_loss).abs() < hyper.tol {
1235            break;
1236        }
1237
1238        if let LearningRateSchedule::Adaptive = hyper.learning_rate {
1239            if epoch_loss >= prev_loss {
1240                no_improve_count += 1;
1241                if no_improve_count >= 5 {
1242                    current_eta = current_eta / F::from(2.0).unwrap();
1243                    no_improve_count = 0;
1244                    if current_eta < F::from(1e-6).unwrap() {
1245                        break;
1246                    }
1247                }
1248            } else {
1249                no_improve_count = 0;
1250            }
1251        }
1252
1253        prev_loss = epoch_loss;
1254    }
1255
1256    (total_loss, t)
1257}
1258
1259/// Dispatch regressor training to the appropriate typed loss function.
1260fn dispatch_train_regressor<F: Float + Send + Sync + ScalarOperand + 'static>(
1261    x: &Array2<F>,
1262    y: &Array1<F>,
1263    w: &mut Array1<F>,
1264    b: &mut F,
1265    loss_enum: &RegressorLoss<F>,
1266    hyper: &SGDHyper<F>,
1267    initial_t: usize,
1268) -> (F, usize) {
1269    match loss_enum {
1270        RegressorLoss::SquaredError => {
1271            train_regressor_sgd(x, y, w, b, &SquaredError, hyper, initial_t)
1272        }
1273        RegressorLoss::Huber(eps) => {
1274            train_regressor_sgd(x, y, w, b, &Huber { epsilon: *eps }, hyper, initial_t)
1275        }
1276        RegressorLoss::EpsilonInsensitive(eps) => train_regressor_sgd(
1277            x,
1278            y,
1279            w,
1280            b,
1281            &EpsilonInsensitive { epsilon: *eps },
1282            hyper,
1283            initial_t,
1284        ),
1285    }
1286}
1287
1288/// Fitted SGD regressor.
1289///
1290/// Holds the learned weight vector and intercept. Implements [`Predict`]
1291/// and [`PartialFit`] to support both inference and online learning.
1292#[derive(Debug, Clone)]
1293pub struct FittedSGDRegressor<F> {
1294    /// Learned weight vector (one per feature).
1295    weights: Array1<F>,
1296    /// Learned intercept (bias) term.
1297    intercept: F,
1298    /// Number of features the model was trained on.
1299    n_features: usize,
1300    /// The loss function used during training.
1301    loss: RegressorLoss<F>,
1302    /// Hyperparameters for continued training.
1303    hyper: SGDHyper<F>,
1304    /// Global step counter.
1305    t: usize,
1306}
1307
1308/// Validate regressor input shapes and parameters.
1309fn validate_reg_params<F: Float>(
1310    x: &Array2<F>,
1311    y: &Array1<F>,
1312    eta0: F,
1313    alpha: F,
1314) -> Result<(), FerroError> {
1315    let n_samples = x.nrows();
1316    if n_samples != y.len() {
1317        return Err(FerroError::ShapeMismatch {
1318            expected: vec![n_samples],
1319            actual: vec![y.len()],
1320            context: "y length must match number of samples in X".into(),
1321        });
1322    }
1323    if n_samples == 0 {
1324        return Err(FerroError::InsufficientSamples {
1325            required: 1,
1326            actual: 0,
1327            context: "SGDRegressor requires at least one sample".into(),
1328        });
1329    }
1330    if eta0 <= F::zero() {
1331        return Err(FerroError::InvalidParameter {
1332            name: "eta0".into(),
1333            reason: "must be positive".into(),
1334        });
1335    }
1336    if alpha < F::zero() {
1337        return Err(FerroError::InvalidParameter {
1338            name: "alpha".into(),
1339            reason: "must be non-negative".into(),
1340        });
1341    }
1342    Ok(())
1343}
1344
1345impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<F>>
1346    for SGDRegressor<F>
1347{
1348    type Fitted = FittedSGDRegressor<F>;
1349    type Error = FerroError;
1350
1351    /// Fit the SGD regressor on the given data.
1352    ///
1353    /// # Errors
1354    ///
1355    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched
1356    /// sample counts.
1357    /// Returns [`FerroError::InvalidParameter`] if `eta0` or `alpha` are
1358    /// invalid.
1359    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedSGDRegressor<F>, FerroError> {
1360        validate_reg_params(x, y, self.eta0, self.alpha)?;
1361
1362        let n_features = x.ncols();
1363        let hyper = reg_hyper(self);
1364        let mut w = Array1::<F>::zeros(n_features);
1365        let mut b = F::zero();
1366
1367        let (_, t) = dispatch_train_regressor(x, y, &mut w, &mut b, &self.loss, &hyper, 0);
1368
1369        Ok(FittedSGDRegressor {
1370            weights: w,
1371            intercept: b,
1372            n_features,
1373            loss: self.loss,
1374            hyper,
1375            t,
1376        })
1377    }
1378}
1379
1380impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
1381    for FittedSGDRegressor<F>
1382{
1383    type Output = Array1<F>;
1384    type Error = FerroError;
1385
1386    /// Predict target values for the given feature matrix.
1387    ///
1388    /// Computes `X @ weights + intercept`.
1389    ///
1390    /// # Errors
1391    ///
1392    /// Returns [`FerroError::ShapeMismatch`] if the number of features
1393    /// does not match the fitted model.
1394    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1395        let n_features = x.ncols();
1396        if n_features != self.n_features {
1397            return Err(FerroError::ShapeMismatch {
1398                expected: vec![self.n_features],
1399                actual: vec![n_features],
1400                context: "number of features must match fitted model".into(),
1401            });
1402        }
1403
1404        let preds = x.dot(&self.weights) + self.intercept;
1405        Ok(preds)
1406    }
1407}
1408
1409impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<F>>
1410    for FittedSGDRegressor<F>
1411{
1412    type FitResult = FittedSGDRegressor<F>;
1413    type Error = FerroError;
1414
1415    /// Incrementally train the regressor on a new batch of data.
1416    ///
1417    /// # Errors
1418    ///
1419    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched
1420    /// sizes or `x` has the wrong number of features.
1421    fn partial_fit(
1422        mut self,
1423        x: &Array2<F>,
1424        y: &Array1<F>,
1425    ) -> Result<FittedSGDRegressor<F>, FerroError> {
1426        let n_samples = x.nrows();
1427        if n_samples != y.len() {
1428            return Err(FerroError::ShapeMismatch {
1429                expected: vec![n_samples],
1430                actual: vec![y.len()],
1431                context: "y length must match number of samples in X".into(),
1432            });
1433        }
1434        if x.ncols() != self.n_features {
1435            return Err(FerroError::ShapeMismatch {
1436                expected: vec![self.n_features],
1437                actual: vec![x.ncols()],
1438                context: "number of features must match fitted model".into(),
1439            });
1440        }
1441
1442        let mut hyper = self.hyper.clone();
1443        hyper.max_iter = 1;
1444
1445        let (_, t) = dispatch_train_regressor(
1446            x,
1447            y,
1448            &mut self.weights,
1449            &mut self.intercept,
1450            &self.loss,
1451            &hyper,
1452            self.t,
1453        );
1454        self.t = t;
1455
1456        Ok(self)
1457    }
1458}
1459
1460impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<F>>
1461    for SGDRegressor<F>
1462{
1463    type FitResult = FittedSGDRegressor<F>;
1464    type Error = FerroError;
1465
1466    /// Initial call to `partial_fit` on an unfitted regressor.
1467    ///
1468    /// Equivalent to `fit` but with a single epoch.
1469    ///
1470    /// # Errors
1471    ///
1472    /// Same as [`Fit::fit`].
1473    fn partial_fit(
1474        self,
1475        x: &Array2<F>,
1476        y: &Array1<F>,
1477    ) -> Result<FittedSGDRegressor<F>, FerroError> {
1478        validate_reg_params(x, y, self.eta0, self.alpha)?;
1479
1480        let n_features = x.ncols();
1481        let mut hyper = reg_hyper(&self);
1482        hyper.max_iter = 1;
1483        let mut w = Array1::<F>::zeros(n_features);
1484        let mut b = F::zero();
1485
1486        let (_, t) = dispatch_train_regressor(x, y, &mut w, &mut b, &self.loss, &hyper, 0);
1487
1488        Ok(FittedSGDRegressor {
1489            weights: w,
1490            intercept: b,
1491            n_features,
1492            loss: self.loss,
1493            hyper: reg_hyper(&self),
1494            t,
1495        })
1496    }
1497}
1498
1499impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
1500    for FittedSGDRegressor<F>
1501{
1502    fn coefficients(&self) -> &Array1<F> {
1503        &self.weights
1504    }
1505
1506    fn intercept(&self) -> F {
1507        self.intercept
1508    }
1509}
1510
1511// Pipeline integration.
1512impl<F> PipelineEstimator<F> for SGDRegressor<F>
1513where
1514    F: Float + ScalarOperand + Send + Sync + 'static,
1515{
1516    fn fit_pipeline(
1517        &self,
1518        x: &Array2<F>,
1519        y: &Array1<F>,
1520    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
1521        let fitted = self.fit(x, y)?;
1522        Ok(Box::new(fitted))
1523    }
1524}
1525
1526impl<F> FittedPipelineEstimator<F> for FittedSGDRegressor<F>
1527where
1528    F: Float + ScalarOperand + Send + Sync + 'static,
1529{
1530    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1531        self.predict(x)
1532    }
1533}
1534
1535// ---------------------------------------------------------------------------
1536// Tests
1537// ---------------------------------------------------------------------------
1538
1539#[cfg(test)]
1540mod tests {
1541    use super::*;
1542    use ndarray::array;
1543
1544    // -----------------------------------------------------------------------
1545    // Loss function tests
1546    // -----------------------------------------------------------------------
1547
1548    #[test]
1549    fn test_hinge_loss_correct_side() {
1550        let h = Hinge;
1551        // y=1, pred=2 => margin=2 >= 1 => loss=0
1552        assert!((Loss::<f64>::loss(&h, 1.0, 2.0) - 0.0).abs() < 1e-10);
1553        assert!((Loss::<f64>::gradient(&h, 1.0, 2.0) - 0.0).abs() < 1e-10);
1554    }
1555
1556    #[test]
1557    fn test_hinge_loss_wrong_side() {
1558        let h = Hinge;
1559        // y=1, pred=-0.5 => margin=-0.5 < 1 => loss=1.5
1560        assert!((Loss::<f64>::loss(&h, 1.0, -0.5) - 1.5).abs() < 1e-10);
1561        assert!((Loss::<f64>::gradient(&h, 1.0, -0.5) - (-1.0)).abs() < 1e-10);
1562    }
1563
1564    #[test]
1565    fn test_log_loss_zero_pred() {
1566        let l = LogLoss;
1567        // y=1, pred=0 => loss=log(1+exp(0))=log(2)
1568        let loss = Loss::<f64>::loss(&l, 1.0, 0.0);
1569        assert!((loss - 2.0_f64.ln()).abs() < 1e-10);
1570    }
1571
1572    #[test]
1573    fn test_log_loss_large_correct() {
1574        let l = LogLoss;
1575        // y=1, pred=20 => very small loss
1576        let loss = Loss::<f64>::loss(&l, 1.0, 20.0);
1577        assert!(loss < 1e-5);
1578    }
1579
1580    #[test]
1581    fn test_squared_error_loss() {
1582        let s = SquaredError;
1583        assert!((Loss::<f64>::loss(&s, 3.0, 1.0) - 2.0).abs() < 1e-10);
1584        assert!((Loss::<f64>::gradient(&s, 3.0, 1.0) - (-2.0)).abs() < 1e-10);
1585    }
1586
1587    #[test]
1588    fn test_modified_huber_loss() {
1589        let mh = ModifiedHuber;
1590        // y=1, pred=2 => z=2 >= 1 => loss=0
1591        assert!((Loss::<f64>::loss(&mh, 1.0, 2.0)).abs() < 1e-10);
1592        // y=1, pred=0.5 => z=0.5 => loss=(1-0.5)^2=0.25
1593        assert!((Loss::<f64>::loss(&mh, 1.0, 0.5) - 0.25).abs() < 1e-10);
1594        // y=1, pred=-2 => z=-2 < -1 => loss=-4*(-2)=8
1595        assert!((Loss::<f64>::loss(&mh, 1.0, -2.0) - 8.0).abs() < 1e-10);
1596    }
1597
1598    #[test]
1599    fn test_huber_loss_quadratic_region() {
1600        let h = Huber { epsilon: 1.0_f64 };
1601        // |y - p| = 0.5 <= 1.0 => quadratic
1602        assert!((Loss::<f64>::loss(&h, 1.0, 0.5) - 0.125).abs() < 1e-10);
1603    }
1604
1605    #[test]
1606    fn test_huber_loss_linear_region() {
1607        let h = Huber { epsilon: 1.0_f64 };
1608        // |y - p| = 3 > 1 => linear: 1*(3 - 0.5) = 2.5
1609        assert!((Loss::<f64>::loss(&h, 3.0, 0.0) - 2.5).abs() < 1e-10);
1610    }
1611
1612    #[test]
1613    fn test_epsilon_insensitive_inside() {
1614        let ei = EpsilonInsensitive { epsilon: 0.1_f64 };
1615        // |y - p| = 0.05 <= 0.1 => loss=0
1616        assert!((Loss::<f64>::loss(&ei, 1.0, 0.95)).abs() < 1e-10);
1617    }
1618
1619    #[test]
1620    fn test_epsilon_insensitive_outside() {
1621        let ei = EpsilonInsensitive { epsilon: 0.1_f64 };
1622        // |y - p| = 0.5 > 0.1 => loss=0.4
1623        assert!((Loss::<f64>::loss(&ei, 1.0, 0.5) - 0.4).abs() < 1e-10);
1624    }
1625
1626    // -----------------------------------------------------------------------
1627    // Learning rate tests
1628    // -----------------------------------------------------------------------
1629
1630    #[test]
1631    fn test_constant_lr() {
1632        let lr: LearningRateSchedule<f64> = LearningRateSchedule::Constant;
1633        assert!((compute_lr(&lr, 0.1, 0.01, 0.25, 100) - 0.1).abs() < 1e-10);
1634    }
1635
1636    #[test]
1637    fn test_optimal_lr() {
1638        let lr: LearningRateSchedule<f64> = LearningRateSchedule::Optimal;
1639        // eta = 1 / (alpha * t) = 1 / (0.01 * 10) = 10.0
1640        assert!((compute_lr(&lr, 0.1, 0.01, 0.25, 10) - 10.0).abs() < 1e-10);
1641    }
1642
1643    #[test]
1644    fn test_invscaling_lr() {
1645        let lr: LearningRateSchedule<f64> = LearningRateSchedule::InvScaling;
1646        // eta = 0.1 / 10^0.5 = 0.1 / 3.162... ~= 0.0316...
1647        let result = compute_lr(&lr, 0.1, 0.01, 0.5, 10);
1648        let expected = 0.1 / 10.0_f64.sqrt();
1649        assert!((result - expected).abs() < 1e-10);
1650    }
1651
1652    #[test]
1653    fn test_adaptive_lr_returns_eta0() {
1654        let lr: LearningRateSchedule<f64> = LearningRateSchedule::Adaptive;
1655        assert!((compute_lr(&lr, 0.05, 0.01, 0.25, 100) - 0.05).abs() < 1e-10);
1656    }
1657
1658    // -----------------------------------------------------------------------
1659    // SGDClassifier tests
1660    // -----------------------------------------------------------------------
1661
1662    #[test]
1663    fn test_sgd_classifier_binary() {
1664        // Well-separated clusters centered near origin for SGD stability.
1665        let x = Array2::from_shape_vec(
1666            (8, 2),
1667            vec![
1668                -2.0, -2.0, -1.5, -2.0, -2.0, -1.5, -1.5, -1.5, 2.0, 2.0, 1.5, 2.0, 2.0, 1.5, 1.5,
1669                1.5,
1670            ],
1671        )
1672        .unwrap();
1673        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1674
1675        let clf = SGDClassifier::<f64>::new()
1676            .with_loss(ClassifierLoss::Log)
1677            .with_random_state(42)
1678            .with_max_iter(1000)
1679            .with_eta0(0.01);
1680        let fitted = clf.fit(&x, &y).unwrap();
1681        let preds = fitted.predict(&x).unwrap();
1682
1683        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1684        assert!(correct >= 6, "expected >= 6 correct, got {correct}");
1685    }
1686
1687    #[test]
1688    fn test_sgd_classifier_log_loss() {
1689        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1690        let y = array![0, 0, 0, 1, 1, 1];
1691
1692        let clf = SGDClassifier::<f64>::new()
1693            .with_loss(ClassifierLoss::Log)
1694            .with_random_state(42)
1695            .with_max_iter(500);
1696        let fitted = clf.fit(&x, &y).unwrap();
1697        let preds = fitted.predict(&x).unwrap();
1698
1699        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1700        assert!(correct >= 4, "expected >= 4 correct, got {correct}");
1701    }
1702
1703    #[test]
1704    fn test_sgd_classifier_multiclass() {
1705        let x = Array2::from_shape_vec(
1706            (9, 2),
1707            vec![
1708                0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 0.0, 5.5, 0.0, 5.0, 0.5, 0.0, 5.0, 0.5, 5.0,
1709                0.0, 5.5,
1710            ],
1711        )
1712        .unwrap();
1713        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1714
1715        let clf = SGDClassifier::<f64>::new()
1716            .with_random_state(42)
1717            .with_max_iter(1000)
1718            .with_eta0(0.01);
1719        let fitted = clf.fit(&x, &y).unwrap();
1720        let preds = fitted.predict(&x).unwrap();
1721
1722        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1723        assert!(
1724            correct >= 6,
1725            "expected >= 6 correct for multiclass, got {correct}"
1726        );
1727    }
1728
1729    #[test]
1730    fn test_sgd_classifier_shape_mismatch_fit() {
1731        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1732        let y = array![0, 1]; // Wrong length
1733        let clf = SGDClassifier::<f64>::new();
1734        assert!(clf.fit(&x, &y).is_err());
1735    }
1736
1737    #[test]
1738    fn test_sgd_classifier_shape_mismatch_predict() {
1739        let x =
1740            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1741        let y = array![0, 0, 1, 1];
1742        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1743        let fitted = clf.fit(&x, &y).unwrap();
1744
1745        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1746        assert!(fitted.predict(&x_bad).is_err());
1747    }
1748
1749    #[test]
1750    fn test_sgd_classifier_single_class_error() {
1751        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1752        let y = array![0, 0, 0];
1753        let clf = SGDClassifier::<f64>::new();
1754        assert!(clf.fit(&x, &y).is_err());
1755    }
1756
1757    #[test]
1758    fn test_sgd_classifier_invalid_eta0() {
1759        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1760        let y = array![0, 0, 1, 1];
1761        let clf = SGDClassifier::<f64>::new().with_eta0(0.0);
1762        assert!(clf.fit(&x, &y).is_err());
1763    }
1764
1765    #[test]
1766    fn test_sgd_classifier_invalid_alpha() {
1767        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1768        let y = array![0, 0, 1, 1];
1769        let clf = SGDClassifier::<f64>::new().with_alpha(-1.0);
1770        assert!(clf.fit(&x, &y).is_err());
1771    }
1772
1773    #[test]
1774    fn test_sgd_classifier_has_coefficients() {
1775        let x =
1776            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1777        let y = array![0, 0, 1, 1];
1778        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1779        let fitted = clf.fit(&x, &y).unwrap();
1780        assert_eq!(fitted.coefficients().len(), 2);
1781    }
1782
1783    #[test]
1784    fn test_sgd_classifier_partial_fit() {
1785        let x =
1786            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1787        let y = array![0, 0, 1, 1];
1788
1789        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1790        let fitted = clf.partial_fit(&x, &y).unwrap();
1791        let fitted = fitted.partial_fit(&x, &y).unwrap();
1792        let preds = fitted.predict(&x).unwrap();
1793        assert_eq!(preds.len(), 4);
1794    }
1795
1796    #[test]
1797    fn test_sgd_classifier_partial_fit_chain() {
1798        // Test the chaining pattern:
1799        // model.partial_fit(&b1, &y1)?.partial_fit(&b2, &y2)?.predict(&x)?
1800        let x1 =
1801            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1802        let y1 = array![0, 0, 1, 1];
1803        let x2 =
1804            Array2::from_shape_vec((4, 2), vec![0.5, 0.5, 1.5, 1.5, 7.5, 7.5, 8.5, 8.5]).unwrap();
1805        let y2 = array![0, 0, 1, 1];
1806
1807        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1808        let preds = clf
1809            .partial_fit(&x1, &y1)
1810            .unwrap()
1811            .partial_fit(&x2, &y2)
1812            .unwrap()
1813            .predict(&x1)
1814            .unwrap();
1815        assert_eq!(preds.len(), 4);
1816    }
1817
1818    #[test]
1819    fn test_sgd_classifier_partial_fit_shape_mismatch() {
1820        let x =
1821            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1822        let y = array![0, 0, 1, 1];
1823        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1824        let fitted = clf.partial_fit(&x, &y).unwrap();
1825
1826        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1827        let y_bad = array![0, 1];
1828        assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
1829    }
1830
1831    #[test]
1832    fn test_sgd_classifier_modified_huber() {
1833        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1834        let y = array![0, 0, 0, 1, 1, 1];
1835
1836        let clf = SGDClassifier::<f64>::new()
1837            .with_loss(ClassifierLoss::ModifiedHuber)
1838            .with_random_state(42)
1839            .with_max_iter(500);
1840        let fitted = clf.fit(&x, &y).unwrap();
1841        let preds = fitted.predict(&x).unwrap();
1842        assert_eq!(preds.len(), 6);
1843    }
1844
1845    #[test]
1846    fn test_sgd_classifier_squared_error_loss() {
1847        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1848        let y = array![0, 0, 0, 1, 1, 1];
1849
1850        let clf = SGDClassifier::<f64>::new()
1851            .with_loss(ClassifierLoss::SquaredError)
1852            .with_random_state(42)
1853            .with_max_iter(500);
1854        let fitted = clf.fit(&x, &y).unwrap();
1855        let preds = fitted.predict(&x).unwrap();
1856        assert_eq!(preds.len(), 6);
1857    }
1858
1859    #[test]
1860    fn test_sgd_classifier_pipeline() {
1861        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1862        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1863
1864        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1865        let fitted = clf.fit_pipeline(&x, &y).unwrap();
1866        let preds = fitted.predict_pipeline(&x).unwrap();
1867        assert_eq!(preds.len(), 6);
1868    }
1869
1870    #[test]
1871    fn test_sgd_classifier_constant_lr() {
1872        let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
1873        let y = array![0, 0, 1, 1];
1874
1875        let clf = SGDClassifier::<f64>::new()
1876            .with_learning_rate(LearningRateSchedule::Constant)
1877            .with_random_state(42)
1878            .with_max_iter(200);
1879        let fitted = clf.fit(&x, &y).unwrap();
1880        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1881    }
1882
1883    #[test]
1884    fn test_sgd_classifier_f32() {
1885        let x = Array2::from_shape_vec((4, 1), vec![-2.0f32, -1.0, 1.0, 2.0]).unwrap();
1886        let y = array![0_usize, 0, 1, 1];
1887
1888        let clf = SGDClassifier::<f32>::new()
1889            .with_random_state(42)
1890            .with_max_iter(200);
1891        let fitted = clf.fit(&x, &y).unwrap();
1892        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1893    }
1894
1895    // -----------------------------------------------------------------------
1896    // SGDRegressor tests
1897    // -----------------------------------------------------------------------
1898
1899    #[test]
1900    fn test_sgd_regressor_basic() {
1901        // y = 2*x + 1
1902        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1903        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
1904
1905        let model = SGDRegressor::<f64>::new()
1906            .with_random_state(42)
1907            .with_max_iter(2000)
1908            .with_eta0(0.01)
1909            .with_alpha(0.0);
1910        let fitted = model.fit(&x, &y).unwrap();
1911        let preds = fitted.predict(&x).unwrap();
1912
1913        // Check rough accuracy.
1914        for (p, &actual) in preds.iter().zip(y.iter()) {
1915            assert!(
1916                (*p - actual).abs() < 2.0,
1917                "prediction {p} too far from {actual}"
1918            );
1919        }
1920    }
1921
1922    #[test]
1923    fn test_sgd_regressor_shape_mismatch() {
1924        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1925        let y = array![1.0, 2.0]; // Wrong length
1926        let model = SGDRegressor::<f64>::new();
1927        assert!(model.fit(&x, &y).is_err());
1928    }
1929
1930    #[test]
1931    fn test_sgd_regressor_predict_shape_mismatch() {
1932        let x =
1933            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1934        let y = array![1.0, 2.0, 3.0, 4.0];
1935        let model = SGDRegressor::<f64>::new().with_random_state(42);
1936        let fitted = model.fit(&x, &y).unwrap();
1937
1938        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1939        assert!(fitted.predict(&x_bad).is_err());
1940    }
1941
1942    #[test]
1943    fn test_sgd_regressor_invalid_eta0() {
1944        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1945        let y = array![1.0, 2.0, 3.0];
1946        let model = SGDRegressor::<f64>::new().with_eta0(-0.1);
1947        assert!(model.fit(&x, &y).is_err());
1948    }
1949
1950    #[test]
1951    fn test_sgd_regressor_has_coefficients() {
1952        let x =
1953            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1954        let y = array![1.0, 2.0, 3.0, 4.0];
1955        let model = SGDRegressor::<f64>::new().with_random_state(42);
1956        let fitted = model.fit(&x, &y).unwrap();
1957        assert_eq!(fitted.coefficients().len(), 2);
1958    }
1959
1960    #[test]
1961    fn test_sgd_regressor_partial_fit() {
1962        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1963        let y = array![2.0, 4.0, 6.0, 8.0];
1964
1965        let model = SGDRegressor::<f64>::new().with_random_state(42);
1966        let fitted = model.partial_fit(&x, &y).unwrap();
1967        let fitted = fitted.partial_fit(&x, &y).unwrap();
1968        let preds = fitted.predict(&x).unwrap();
1969        assert_eq!(preds.len(), 4);
1970    }
1971
1972    #[test]
1973    fn test_sgd_regressor_partial_fit_chain() {
1974        let x1 = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1975        let y1 = array![2.0, 4.0, 6.0];
1976        let x2 = Array2::from_shape_vec((3, 1), vec![4.0, 5.0, 6.0]).unwrap();
1977        let y2 = array![8.0, 10.0, 12.0];
1978
1979        let model = SGDRegressor::<f64>::new().with_random_state(42);
1980        let preds = model
1981            .partial_fit(&x1, &y1)
1982            .unwrap()
1983            .partial_fit(&x2, &y2)
1984            .unwrap()
1985            .predict(&x1)
1986            .unwrap();
1987        assert_eq!(preds.len(), 3);
1988    }
1989
1990    #[test]
1991    fn test_sgd_regressor_partial_fit_shape_mismatch() {
1992        let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
1993        let y = array![1.0, 2.0, 3.0];
1994        let model = SGDRegressor::<f64>::new().with_random_state(42);
1995        let fitted = model.partial_fit(&x, &y).unwrap();
1996
1997        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1998        let y_bad = array![1.0, 2.0];
1999        assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
2000    }
2001
2002    #[test]
2003    fn test_sgd_regressor_huber_loss() {
2004        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2005        let y = array![2.0, 4.0, 6.0, 8.0];
2006
2007        let model = SGDRegressor::<f64>::new()
2008            .with_loss(RegressorLoss::Huber(1.35))
2009            .with_random_state(42)
2010            .with_max_iter(500);
2011        let fitted = model.fit(&x, &y).unwrap();
2012        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2013    }
2014
2015    #[test]
2016    fn test_sgd_regressor_epsilon_insensitive() {
2017        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2018        let y = array![2.0, 4.0, 6.0, 8.0];
2019
2020        let model = SGDRegressor::<f64>::new()
2021            .with_loss(RegressorLoss::EpsilonInsensitive(0.1))
2022            .with_random_state(42)
2023            .with_max_iter(500);
2024        let fitted = model.fit(&x, &y).unwrap();
2025        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2026    }
2027
2028    #[test]
2029    fn test_sgd_regressor_pipeline() {
2030        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2031        let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0]);
2032
2033        let model = SGDRegressor::<f64>::new().with_random_state(42);
2034        let fitted = model.fit_pipeline(&x, &y).unwrap();
2035        let preds = fitted.predict_pipeline(&x).unwrap();
2036        assert_eq!(preds.len(), 4);
2037    }
2038
2039    #[test]
2040    fn test_sgd_regressor_f32() {
2041        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
2042        let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
2043
2044        let model = SGDRegressor::<f32>::new().with_random_state(42);
2045        let fitted = model.fit(&x, &y).unwrap();
2046        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2047    }
2048
2049    #[test]
2050    fn test_sgd_regressor_empty_data() {
2051        let x = Array2::<f64>::zeros((0, 2));
2052        let y = Array1::<f64>::zeros(0);
2053        let model = SGDRegressor::<f64>::new();
2054        assert!(model.fit(&x, &y).is_err());
2055    }
2056
2057    #[test]
2058    fn test_sgd_classifier_empty_data() {
2059        let x = Array2::<f64>::zeros((0, 2));
2060        let y = Array1::<usize>::zeros(0);
2061        let clf = SGDClassifier::<f64>::new();
2062        assert!(clf.fit(&x, &y).is_err());
2063    }
2064
2065    #[test]
2066    fn test_sgd_classifier_default() {
2067        let clf = SGDClassifier::<f64>::default();
2068        assert!(clf.eta0 > 0.0);
2069        assert!(clf.alpha >= 0.0);
2070    }
2071
2072    #[test]
2073    fn test_sgd_regressor_default() {
2074        let model = SGDRegressor::<f64>::default();
2075        assert!(model.eta0 > 0.0);
2076        assert!(model.alpha >= 0.0);
2077    }
2078}