Skip to main content

ferrolearn_linear/
sgd.rs

1//! Stochastic Gradient Descent (SGD) linear models.
2//!
3//! This module provides [`SGDClassifier`] and [`SGDRegressor`], two linear
4//! models trained using stochastic gradient descent. Both support online /
5//! streaming learning via the [`PartialFit`] trait and a range of configurable
6//! loss functions and learning-rate schedules.
7//!
8//! # Classifier
9//!
10//! ```
11//! use ferrolearn_linear::sgd::{SGDClassifier, ClassifierLoss};
12//! use ferrolearn_core::{Fit, Predict};
13//! use ndarray::{array, Array2};
14//!
15//! let x = Array2::from_shape_vec((6, 2), vec![
16//!     1.0, 2.0, 2.0, 3.0, 3.0, 1.0,
17//!     8.0, 7.0, 9.0, 8.0, 7.0, 9.0,
18//! ]).unwrap();
19//! let y = array![0, 0, 0, 1, 1, 1];
20//!
21//! let model = SGDClassifier::<f64>::new();
22//! let fitted = model.fit(&x, &y).unwrap();
23//! let preds = fitted.predict(&x).unwrap();
24//! assert_eq!(preds.len(), 6);
25//! ```
26//!
27//! # Regressor
28//!
29//! ```
30//! use ferrolearn_linear::sgd::{SGDRegressor, RegressorLoss};
31//! use ferrolearn_core::{Fit, Predict};
32//! use ndarray::{array, Array2};
33//!
34//! let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
35//! let y = array![2.0, 4.0, 6.0, 8.0];
36//!
37//! let model = SGDRegressor::<f64>::new();
38//! let fitted = model.fit(&x, &y).unwrap();
39//! let preds = fitted.predict(&x).unwrap();
40//! assert_eq!(preds.len(), 4);
41//! ```
42
43use ferrolearn_core::error::FerroError;
44use ferrolearn_core::introspection::HasCoefficients;
45use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
46use ferrolearn_core::traits::{Fit, PartialFit, Predict};
47use ndarray::{Array1, Array2, ScalarOperand};
48use num_traits::Float;
49use rand::SeedableRng;
50use rand::seq::SliceRandom;
51
52// ---------------------------------------------------------------------------
53// Loss functions
54// ---------------------------------------------------------------------------
55
56/// A loss function for SGD optimization.
57///
58/// Provides the loss value and its gradient with respect to the prediction.
59pub trait Loss<F: Float>: Clone + Send + Sync {
60    /// Compute the loss for a single sample.
61    fn loss(&self, y_true: F, y_pred: F) -> F;
62
63    /// Compute the gradient of the loss with respect to `y_pred`.
64    fn gradient(&self, y_true: F, y_pred: F) -> F;
65}
66
67/// Hinge loss for linear SVM-style classification.
68///
69/// `L(y, p) = max(0, 1 - y * p)` where `y in {-1, +1}`.
70#[derive(Debug, Clone, Copy)]
71pub struct Hinge;
72
73impl<F: Float> Loss<F> for Hinge {
74    fn loss(&self, y_true: F, y_pred: F) -> F {
75        let margin = y_true * y_pred;
76        if margin < F::one() {
77            F::one() - margin
78        } else {
79            F::zero()
80        }
81    }
82
83    fn gradient(&self, y_true: F, y_pred: F) -> F {
84        let margin = y_true * y_pred;
85        if margin < F::one() {
86            -y_true
87        } else {
88            F::zero()
89        }
90    }
91}
92
93/// Log loss (logistic regression / cross-entropy).
94///
95/// `L(y, p) = log(1 + exp(-y * p))` where `y in {-1, +1}`.
96#[derive(Debug, Clone, Copy)]
97pub struct LogLoss;
98
99impl<F: Float> Loss<F> for LogLoss {
100    fn loss(&self, y_true: F, y_pred: F) -> F {
101        let z = y_true * y_pred;
102        if z > F::from(18.0).unwrap() {
103            (-z).exp()
104        } else if z < F::from(-18.0).unwrap() {
105            -z
106        } else {
107            (F::one() + (-z).exp()).ln()
108        }
109    }
110
111    fn gradient(&self, y_true: F, y_pred: F) -> F {
112        let z = y_true * y_pred;
113        let exp_nz = if z > F::from(18.0).unwrap() {
114            (-z).exp()
115        } else if z < F::from(-18.0).unwrap() {
116            F::from(1e18).unwrap()
117        } else {
118            (-z).exp()
119        };
120        -y_true * exp_nz / (F::one() + exp_nz)
121    }
122}
123
124/// Squared error loss for regression.
125///
126/// `L(y, p) = 0.5 * (y - p)^2`.
127#[derive(Debug, Clone, Copy)]
128pub struct SquaredError;
129
130impl<F: Float> Loss<F> for SquaredError {
131    fn loss(&self, y_true: F, y_pred: F) -> F {
132        let diff = y_true - y_pred;
133        F::from(0.5).unwrap() * diff * diff
134    }
135
136    fn gradient(&self, y_true: F, y_pred: F) -> F {
137        y_pred - y_true
138    }
139}
140
141/// Modified Huber loss for classification.
142///
143/// Smooth approximation to hinge with quadratic behaviour near the margin:
144///
145/// ```text
146/// L(y, p) = max(0, 1 - y*p)^2   if y*p >= -1
147///         = -4 * y * p            otherwise
148/// ```
149#[derive(Debug, Clone, Copy)]
150pub struct ModifiedHuber;
151
152impl<F: Float> Loss<F> for ModifiedHuber {
153    fn loss(&self, y_true: F, y_pred: F) -> F {
154        let z = y_true * y_pred;
155        if z >= -F::one() {
156            let margin = F::one() - z;
157            if margin > F::zero() {
158                margin * margin
159            } else {
160                F::zero()
161            }
162        } else {
163            -F::from(4.0).unwrap() * z
164        }
165    }
166
167    fn gradient(&self, y_true: F, y_pred: F) -> F {
168        let z = y_true * y_pred;
169        if z >= -F::one() {
170            if z < F::one() {
171                F::from(-2.0).unwrap() * y_true * (F::one() - z)
172            } else {
173                F::zero()
174            }
175        } else {
176            -F::from(4.0).unwrap() * y_true
177        }
178    }
179}
180
181/// Huber loss for robust regression.
182///
183/// `L(y, p) = 0.5 * (y - p)^2` if `|y - p| <= epsilon`, else
184/// `epsilon * (|y - p| - 0.5 * epsilon)`.
185#[derive(Debug, Clone, Copy)]
186pub struct Huber<F> {
187    /// Threshold parameter for switching from quadratic to linear loss.
188    pub epsilon: F,
189}
190
191impl<F: Float + Send + Sync> Loss<F> for Huber<F> {
192    fn loss(&self, y_true: F, y_pred: F) -> F {
193        let diff = y_true - y_pred;
194        let abs_diff = diff.abs();
195        if abs_diff <= self.epsilon {
196            F::from(0.5).unwrap() * diff * diff
197        } else {
198            self.epsilon * (abs_diff - F::from(0.5).unwrap() * self.epsilon)
199        }
200    }
201
202    fn gradient(&self, y_true: F, y_pred: F) -> F {
203        let diff = y_pred - y_true;
204        let abs_diff = diff.abs();
205        if abs_diff <= self.epsilon {
206            diff
207        } else if diff > F::zero() {
208            self.epsilon
209        } else {
210            -self.epsilon
211        }
212    }
213}
214
215/// Epsilon-insensitive loss for support vector regression.
216///
217/// `L(y, p) = max(0, |y - p| - epsilon)`.
218#[derive(Debug, Clone, Copy)]
219pub struct EpsilonInsensitive<F> {
220    /// Insensitivity margin.
221    pub epsilon: F,
222}
223
224impl<F: Float + Send + Sync> Loss<F> for EpsilonInsensitive<F> {
225    fn loss(&self, y_true: F, y_pred: F) -> F {
226        let diff = (y_true - y_pred).abs();
227        if diff > self.epsilon {
228            diff - self.epsilon
229        } else {
230            F::zero()
231        }
232    }
233
234    fn gradient(&self, y_true: F, y_pred: F) -> F {
235        let diff = y_pred - y_true;
236        if diff > self.epsilon {
237            F::one()
238        } else if diff < -self.epsilon {
239            -F::one()
240        } else {
241            F::zero()
242        }
243    }
244}
245
246// ---------------------------------------------------------------------------
247// Learning rate schedules
248// ---------------------------------------------------------------------------
249
250/// Learning rate schedule for SGD.
251#[derive(Debug, Clone, Copy)]
252pub enum LearningRateSchedule<F> {
253    /// Fixed learning rate `eta0` throughout training.
254    Constant,
255    /// Optimal schedule: `eta = 1 / (alpha * t)`.
256    Optimal,
257    /// Inverse scaling: `eta = eta0 / t^power_t`.
258    InvScaling,
259    /// Adaptive: starts at `eta0`, halved when loss fails to decrease for
260    /// 5 consecutive epochs. Stops when `eta < 1e-6`.
261    Adaptive,
262    #[doc(hidden)]
263    _Phantom(std::marker::PhantomData<F>),
264}
265
266/// Compute the learning rate for a given step.
267fn compute_lr<F: Float>(
268    schedule: &LearningRateSchedule<F>,
269    eta0: F,
270    alpha: F,
271    power_t: F,
272    t: usize,
273) -> F {
274    let t_f = F::from(t.max(1)).unwrap();
275    match schedule {
276        LearningRateSchedule::Constant => eta0,
277        LearningRateSchedule::Optimal => F::one() / (alpha * t_f),
278        LearningRateSchedule::InvScaling => eta0 / t_f.powf(power_t),
279        LearningRateSchedule::Adaptive => eta0,
280        LearningRateSchedule::_Phantom(_) => unreachable!(),
281    }
282}
283
284// ---------------------------------------------------------------------------
285// Classifier loss enum
286// ---------------------------------------------------------------------------
287
288/// Available loss functions for [`SGDClassifier`].
289#[derive(Debug, Clone, Copy)]
290pub enum ClassifierLoss {
291    /// Hinge loss (linear SVM).
292    Hinge,
293    /// Log loss (logistic regression).
294    Log,
295    /// Squared error loss.
296    SquaredError,
297    /// Modified Huber loss.
298    ModifiedHuber,
299}
300
301/// Available loss functions for [`SGDRegressor`].
302#[derive(Debug, Clone, Copy)]
303pub enum RegressorLoss<F> {
304    /// Squared error loss (default).
305    SquaredError,
306    /// Huber loss with the given epsilon.
307    Huber(F),
308    /// Epsilon-insensitive loss with the given epsilon.
309    EpsilonInsensitive(F),
310}
311
312// ---------------------------------------------------------------------------
313// SGDClassifier
314// ---------------------------------------------------------------------------
315
316/// Stochastic Gradient Descent classifier.
317///
318/// Supports binary classification via a decision boundary and multiclass
319/// classification via one-vs-all decomposition.
320///
321/// # Type Parameters
322///
323/// - `F`: The floating-point type (`f32` or `f64`).
324///
325/// # Examples
326///
327/// ```
328/// use ferrolearn_linear::sgd::SGDClassifier;
329/// use ferrolearn_core::{Fit, Predict};
330/// use ndarray::{array, Array2};
331///
332/// let x = Array2::from_shape_vec((6, 2), vec![
333///     1.0, 2.0, 2.0, 3.0, 3.0, 1.0,
334///     8.0, 7.0, 9.0, 8.0, 7.0, 9.0,
335/// ]).unwrap();
336/// let y = array![0, 0, 0, 1, 1, 1];
337///
338/// let clf = SGDClassifier::<f64>::new();
339/// let fitted = clf.fit(&x, &y).unwrap();
340/// let preds = fitted.predict(&x).unwrap();
341/// ```
342#[derive(Debug, Clone)]
343pub struct SGDClassifier<F> {
344    /// The loss function to use.
345    pub loss: ClassifierLoss,
346    /// The learning rate schedule.
347    pub learning_rate: LearningRateSchedule<F>,
348    /// Initial learning rate.
349    pub eta0: F,
350    /// L2 regularization strength.
351    pub alpha: F,
352    /// Maximum number of passes over the training data.
353    pub max_iter: usize,
354    /// Convergence tolerance. Training stops when the loss improvement
355    /// is below this threshold.
356    pub tol: F,
357    /// Optional random seed for sample shuffling.
358    pub random_state: Option<u64>,
359    /// Power parameter for inverse scaling schedule.
360    pub power_t: F,
361}
362
363impl<F: Float> SGDClassifier<F> {
364    /// Create a new `SGDClassifier` with default settings.
365    ///
366    /// Defaults: `loss = Hinge`, `learning_rate = InvScaling`,
367    /// `eta0 = 0.01`, `alpha = 0.0001`, `max_iter = 1000`,
368    /// `tol = 1e-3`, `power_t = 0.25`.
369    #[must_use]
370    pub fn new() -> Self {
371        Self {
372            loss: ClassifierLoss::Hinge,
373            learning_rate: LearningRateSchedule::InvScaling,
374            eta0: F::from(0.01).unwrap(),
375            alpha: F::from(0.0001).unwrap(),
376            max_iter: 1000,
377            tol: F::from(1e-3).unwrap(),
378            random_state: None,
379            power_t: F::from(0.25).unwrap(),
380        }
381    }
382
383    /// Set the loss function.
384    #[must_use]
385    pub fn with_loss(mut self, loss: ClassifierLoss) -> Self {
386        self.loss = loss;
387        self
388    }
389
390    /// Set the learning rate schedule.
391    #[must_use]
392    pub fn with_learning_rate(mut self, lr: LearningRateSchedule<F>) -> Self {
393        self.learning_rate = lr;
394        self
395    }
396
397    /// Set the initial learning rate.
398    #[must_use]
399    pub fn with_eta0(mut self, eta0: F) -> Self {
400        self.eta0 = eta0;
401        self
402    }
403
404    /// Set the L2 regularization strength.
405    #[must_use]
406    pub fn with_alpha(mut self, alpha: F) -> Self {
407        self.alpha = alpha;
408        self
409    }
410
411    /// Set the maximum number of epochs.
412    #[must_use]
413    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
414        self.max_iter = max_iter;
415        self
416    }
417
418    /// Set the convergence tolerance.
419    #[must_use]
420    pub fn with_tol(mut self, tol: F) -> Self {
421        self.tol = tol;
422        self
423    }
424
425    /// Set the random seed for reproducibility.
426    #[must_use]
427    pub fn with_random_state(mut self, seed: u64) -> Self {
428        self.random_state = Some(seed);
429        self
430    }
431
432    /// Set the power parameter for inverse scaling.
433    #[must_use]
434    pub fn with_power_t(mut self, power_t: F) -> Self {
435        self.power_t = power_t;
436        self
437    }
438}
439
440impl<F: Float> Default for SGDClassifier<F> {
441    fn default() -> Self {
442        Self::new()
443    }
444}
445
446/// Extract hyperparameter bundle from an `SGDClassifier`.
447fn clf_hyper<F: Float>(clf: &SGDClassifier<F>) -> SGDHyper<F> {
448    SGDHyper {
449        learning_rate: clf.learning_rate,
450        eta0: clf.eta0,
451        alpha: clf.alpha,
452        max_iter: clf.max_iter,
453        tol: clf.tol,
454        random_state: clf.random_state,
455        power_t: clf.power_t,
456    }
457}
458
459/// Internal hyperparameter bundle shared between Fit and PartialFit paths.
460#[derive(Debug, Clone)]
461struct SGDHyper<F> {
462    learning_rate: LearningRateSchedule<F>,
463    eta0: F,
464    alpha: F,
465    max_iter: usize,
466    tol: F,
467    random_state: Option<u64>,
468    power_t: F,
469}
470
471/// Train a single binary classifier via SGD, updating `weights` and
472/// `intercept` in place. `y_binary` must be in `{-1, +1}`.
473///
474/// Returns the cumulative loss and the step counter after training.
475fn train_binary_sgd<F, L>(
476    x: &Array2<F>,
477    y_binary: &Array1<F>,
478    weights: &mut Array1<F>,
479    intercept: &mut F,
480    loss_fn: &L,
481    hyper: &SGDHyper<F>,
482    initial_t: usize,
483) -> (F, usize)
484where
485    F: Float + ScalarOperand + Send + Sync + 'static,
486    L: Loss<F>,
487{
488    let n_samples = x.nrows();
489    let n_features = x.ncols();
490    let mut t = initial_t;
491    let mut prev_loss = F::infinity();
492    let mut current_eta = hyper.eta0;
493    let mut no_improve_count: usize = 0;
494    let mut indices: Vec<usize> = (0..n_samples).collect();
495
496    // Build the RNG for shuffling.
497    let mut rng = match hyper.random_state {
498        Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
499        None => rand::rngs::StdRng::from_os_rng(),
500    };
501
502    let mut total_loss = F::zero();
503
504    for _epoch in 0..hyper.max_iter {
505        indices.shuffle(&mut rng);
506        let mut epoch_loss = F::zero();
507
508        for &i in &indices {
509            t += 1;
510
511            let eta = match hyper.learning_rate {
512                LearningRateSchedule::Adaptive => current_eta,
513                _ => compute_lr(
514                    &hyper.learning_rate,
515                    hyper.eta0,
516                    hyper.alpha,
517                    hyper.power_t,
518                    t,
519                ),
520            };
521
522            // Compute prediction: w^T x_i + b.
523            let mut y_pred = *intercept;
524            let xi = x.row(i);
525            for j in 0..n_features {
526                y_pred = y_pred + weights[j] * xi[j];
527            }
528
529            let grad = loss_fn.gradient(y_binary[i], y_pred);
530            epoch_loss = epoch_loss + loss_fn.loss(y_binary[i], y_pred);
531
532            // Update weights with gradient + L2 regularization.
533            for j in 0..n_features {
534                weights[j] = weights[j] - eta * (grad * xi[j] + hyper.alpha * weights[j]);
535            }
536            *intercept = *intercept - eta * grad;
537        }
538
539        epoch_loss = epoch_loss / F::from(n_samples).unwrap();
540        total_loss = epoch_loss;
541
542        // Convergence check.
543        if (prev_loss - epoch_loss).abs() < hyper.tol {
544            break;
545        }
546
547        // Adaptive learning rate adjustment.
548        if let LearningRateSchedule::Adaptive = hyper.learning_rate {
549            if epoch_loss >= prev_loss {
550                no_improve_count += 1;
551                if no_improve_count >= 5 {
552                    current_eta = current_eta / F::from(2.0).unwrap();
553                    no_improve_count = 0;
554                    if current_eta < F::from(1e-6).unwrap() {
555                        break;
556                    }
557                }
558            } else {
559                no_improve_count = 0;
560            }
561        }
562
563        prev_loss = epoch_loss;
564    }
565
566    (total_loss, t)
567}
568
569/// Fitted SGD classifier.
570///
571/// Holds the learned weight vectors and intercepts. For binary problems
572/// there is a single weight vector; for multiclass problems there is one
573/// per class (one-vs-all).
574///
575/// Implements [`Predict`] and [`PartialFit`] to support both inference and
576/// online learning.
577#[derive(Debug, Clone)]
578pub struct FittedSGDClassifier<F> {
579    /// Weight matrix: one row per binary sub-problem.
580    /// Binary: shape `(1, n_features)`, multiclass: `(n_classes, n_features)`.
581    weight_matrix: Vec<Array1<F>>,
582    /// Intercept vector, one per sub-problem.
583    intercepts: Vec<F>,
584    /// Sorted unique class labels.
585    classes: Vec<usize>,
586    /// Number of features the model was trained on.
587    n_features: usize,
588    /// The loss function used during training.
589    loss: ClassifierLoss,
590    /// Hyperparameters for continued training via `partial_fit`.
591    hyper: SGDHyper<F>,
592    /// Global step counter across all partial_fit calls.
593    t: usize,
594}
595
596impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<usize>>
597    for SGDClassifier<F>
598{
599    type Fitted = FittedSGDClassifier<F>;
600    type Error = FerroError;
601
602    /// Fit the SGD classifier on the given data.
603    ///
604    /// # Errors
605    ///
606    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched
607    /// sample counts.
608    /// Returns [`FerroError::InsufficientSamples`] if fewer than 2 classes
609    /// are present.
610    /// Returns [`FerroError::InvalidParameter`] if `eta0` or `alpha` are
611    /// not positive.
612    fn fit(&self, x: &Array2<F>, y: &Array1<usize>) -> Result<FittedSGDClassifier<F>, FerroError> {
613        validate_clf_params(x, y, self.eta0, self.alpha)?;
614
615        let n_features = x.ncols();
616        let mut classes: Vec<usize> = y.to_vec();
617        classes.sort_unstable();
618        classes.dedup();
619
620        if classes.len() < 2 {
621            return Err(FerroError::InsufficientSamples {
622                required: 2,
623                actual: classes.len(),
624                context: "SGDClassifier requires at least 2 distinct classes".into(),
625            });
626        }
627
628        let hyper = clf_hyper(self);
629        let loss_enum = self.loss;
630
631        let (weight_matrix, intercepts, t) =
632            fit_ova(x, y, &classes, n_features, &loss_enum, &hyper, 0)?;
633
634        Ok(FittedSGDClassifier {
635            weight_matrix,
636            intercepts,
637            classes,
638            n_features,
639            loss: loss_enum,
640            hyper,
641            t,
642        })
643    }
644}
645
646/// Validate classifier input shapes and parameters.
647fn validate_clf_params<F: Float>(
648    x: &Array2<F>,
649    y: &Array1<usize>,
650    eta0: F,
651    alpha: F,
652) -> Result<(), FerroError> {
653    let n_samples = x.nrows();
654    if n_samples != y.len() {
655        return Err(FerroError::ShapeMismatch {
656            expected: vec![n_samples],
657            actual: vec![y.len()],
658            context: "y length must match number of samples in X".into(),
659        });
660    }
661    if n_samples == 0 {
662        return Err(FerroError::InsufficientSamples {
663            required: 1,
664            actual: 0,
665            context: "SGDClassifier requires at least one sample".into(),
666        });
667    }
668    if eta0 <= F::zero() {
669        return Err(FerroError::InvalidParameter {
670            name: "eta0".into(),
671            reason: "must be positive".into(),
672        });
673    }
674    if alpha < F::zero() {
675        return Err(FerroError::InvalidParameter {
676            name: "alpha".into(),
677            reason: "must be non-negative".into(),
678        });
679    }
680    Ok(())
681}
682
683/// Result type for one-vs-all training: (weight_matrix, intercepts, step_counter).
684type OvaResult<F> = (Vec<Array1<F>>, Vec<F>, usize);
685
686/// Train one-vs-all binary classifiers, returning per-class weights, intercepts,
687/// and the cumulative step counter.
688fn fit_ova<F: Float + Send + Sync + ScalarOperand + 'static>(
689    x: &Array2<F>,
690    y: &Array1<usize>,
691    classes: &[usize],
692    n_features: usize,
693    loss_enum: &ClassifierLoss,
694    hyper: &SGDHyper<F>,
695    initial_t: usize,
696) -> Result<OvaResult<F>, FerroError> {
697    let n_classes = classes.len();
698    let mut weight_matrix: Vec<Array1<F>> = Vec::with_capacity(n_classes);
699    let mut intercepts: Vec<F> = Vec::with_capacity(n_classes);
700    let mut global_t = initial_t;
701
702    if n_classes == 2 {
703        // Single binary problem: class[0] -> -1, class[1] -> +1.
704        let y_binary: Array1<F> = y.mapv(|label| {
705            if label == classes[1] {
706                F::one()
707            } else {
708                -F::one()
709            }
710        });
711        let mut w = Array1::<F>::zeros(n_features);
712        let mut b = F::zero();
713        let (_, t) =
714            dispatch_train_binary(x, &y_binary, &mut w, &mut b, loss_enum, hyper, global_t);
715        global_t = t;
716        weight_matrix.push(w);
717        intercepts.push(b);
718    } else {
719        // One-vs-all: one binary problem per class.
720        for &cls in classes {
721            let y_binary: Array1<F> =
722                y.mapv(|label| if label == cls { F::one() } else { -F::one() });
723            let mut w = Array1::<F>::zeros(n_features);
724            let mut b = F::zero();
725            let (_, t) =
726                dispatch_train_binary(x, &y_binary, &mut w, &mut b, loss_enum, hyper, global_t);
727            global_t = t;
728            weight_matrix.push(w);
729            intercepts.push(b);
730        }
731    }
732
733    Ok((weight_matrix, intercepts, global_t))
734}
735
736/// Train one-vs-all using existing weight vectors (for partial_fit).
737#[allow(clippy::too_many_arguments)]
738fn partial_fit_ova<F: Float + Send + Sync + ScalarOperand + 'static>(
739    x: &Array2<F>,
740    y: &Array1<usize>,
741    classes: &[usize],
742    weight_matrix: &mut [Array1<F>],
743    intercepts: &mut [F],
744    loss_enum: &ClassifierLoss,
745    hyper: &SGDHyper<F>,
746    initial_t: usize,
747) -> usize {
748    let n_classes = classes.len();
749    let mut global_t = initial_t;
750
751    if n_classes == 2 {
752        let y_binary: Array1<F> = y.mapv(|label| {
753            if label == classes[1] {
754                F::one()
755            } else {
756                -F::one()
757            }
758        });
759        let (_, t) = dispatch_train_binary(
760            x,
761            &y_binary,
762            &mut weight_matrix[0],
763            &mut intercepts[0],
764            loss_enum,
765            hyper,
766            global_t,
767        );
768        global_t = t;
769    } else {
770        for (idx, &cls) in classes.iter().enumerate() {
771            let y_binary: Array1<F> =
772                y.mapv(|label| if label == cls { F::one() } else { -F::one() });
773            let (_, t) = dispatch_train_binary(
774                x,
775                &y_binary,
776                &mut weight_matrix[idx],
777                &mut intercepts[idx],
778                loss_enum,
779                hyper,
780                global_t,
781            );
782            global_t = t;
783        }
784    }
785
786    global_t
787}
788
789/// Dispatch to the appropriate typed loss training function.
790fn dispatch_train_binary<F: Float + Send + Sync + ScalarOperand + 'static>(
791    x: &Array2<F>,
792    y_binary: &Array1<F>,
793    w: &mut Array1<F>,
794    b: &mut F,
795    loss_enum: &ClassifierLoss,
796    hyper: &SGDHyper<F>,
797    initial_t: usize,
798) -> (F, usize) {
799    match loss_enum {
800        ClassifierLoss::Hinge => train_binary_sgd(x, y_binary, w, b, &Hinge, hyper, initial_t),
801        ClassifierLoss::Log => train_binary_sgd(x, y_binary, w, b, &LogLoss, hyper, initial_t),
802        ClassifierLoss::SquaredError => {
803            train_binary_sgd(x, y_binary, w, b, &SquaredError, hyper, initial_t)
804        }
805        ClassifierLoss::ModifiedHuber => {
806            train_binary_sgd(x, y_binary, w, b, &ModifiedHuber, hyper, initial_t)
807        }
808    }
809}
810
811impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
812    for FittedSGDClassifier<F>
813{
814    type Output = Array1<usize>;
815    type Error = FerroError;
816
817    /// Predict class labels for the given feature matrix.
818    ///
819    /// For binary classification, uses `sign(w^T x + b)`.
820    /// For multiclass, returns the class whose one-vs-all score is highest.
821    ///
822    /// # Errors
823    ///
824    /// Returns [`FerroError::ShapeMismatch`] if the number of features
825    /// does not match the fitted model.
826    fn predict(&self, x: &Array2<F>) -> Result<Array1<usize>, FerroError> {
827        let n_features = x.ncols();
828        if n_features != self.n_features {
829            return Err(FerroError::ShapeMismatch {
830                expected: vec![self.n_features],
831                actual: vec![n_features],
832                context: "number of features must match fitted model".into(),
833            });
834        }
835
836        let n_samples = x.nrows();
837        let mut predictions = Array1::<usize>::zeros(n_samples);
838
839        if self.classes.len() == 2 {
840            // Binary: single weight vector.
841            let scores = x.dot(&self.weight_matrix[0]) + self.intercepts[0];
842            for i in 0..n_samples {
843                predictions[i] = if scores[i] >= F::zero() {
844                    self.classes[1]
845                } else {
846                    self.classes[0]
847                };
848            }
849        } else {
850            // Multiclass: one-vs-all, pick highest score.
851            for i in 0..n_samples {
852                let xi = x.row(i);
853                let mut best_class = 0;
854                let mut best_score = F::neg_infinity();
855                for (c, w) in self.weight_matrix.iter().enumerate() {
856                    let score = xi.dot(w) + self.intercepts[c];
857                    if score > best_score {
858                        best_score = score;
859                        best_class = c;
860                    }
861                }
862                predictions[i] = self.classes[best_class];
863            }
864        }
865
866        Ok(predictions)
867    }
868}
869
870impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<usize>>
871    for FittedSGDClassifier<F>
872{
873    type FitResult = FittedSGDClassifier<F>;
874    type Error = FerroError;
875
876    /// Incrementally train the classifier on a new batch of data.
877    ///
878    /// # Errors
879    ///
880    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched
881    /// sizes or `x` has the wrong number of features.
882    fn partial_fit(
883        mut self,
884        x: &Array2<F>,
885        y: &Array1<usize>,
886    ) -> Result<FittedSGDClassifier<F>, FerroError> {
887        let n_samples = x.nrows();
888        if n_samples != y.len() {
889            return Err(FerroError::ShapeMismatch {
890                expected: vec![n_samples],
891                actual: vec![y.len()],
892                context: "y length must match number of samples in X".into(),
893            });
894        }
895        if x.ncols() != self.n_features {
896            return Err(FerroError::ShapeMismatch {
897                expected: vec![self.n_features],
898                actual: vec![x.ncols()],
899                context: "number of features must match fitted model".into(),
900            });
901        }
902
903        // Use a single-epoch hyper for partial_fit.
904        let mut hyper = self.hyper.clone();
905        hyper.max_iter = 1;
906
907        let t = partial_fit_ova(
908            x,
909            y,
910            &self.classes,
911            &mut self.weight_matrix,
912            &mut self.intercepts,
913            &self.loss,
914            &hyper,
915            self.t,
916        );
917        self.t = t;
918
919        Ok(self)
920    }
921}
922
923impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<usize>>
924    for SGDClassifier<F>
925{
926    type FitResult = FittedSGDClassifier<F>;
927    type Error = FerroError;
928
929    /// Initial call to `partial_fit` on an unfitted classifier.
930    ///
931    /// Equivalent to `fit` but with a single epoch, enabling subsequent
932    /// incremental calls.
933    ///
934    /// # Errors
935    ///
936    /// Same as [`Fit::fit`].
937    fn partial_fit(
938        self,
939        x: &Array2<F>,
940        y: &Array1<usize>,
941    ) -> Result<FittedSGDClassifier<F>, FerroError> {
942        validate_clf_params(x, y, self.eta0, self.alpha)?;
943
944        let n_features = x.ncols();
945        let mut classes: Vec<usize> = y.to_vec();
946        classes.sort_unstable();
947        classes.dedup();
948
949        if classes.len() < 2 {
950            return Err(FerroError::InsufficientSamples {
951                required: 2,
952                actual: classes.len(),
953                context: "SGDClassifier requires at least 2 distinct classes".into(),
954            });
955        }
956
957        let mut hyper = clf_hyper(&self);
958        hyper.max_iter = 1;
959        let loss_enum = self.loss;
960
961        let (weight_matrix, intercepts, t) =
962            fit_ova(x, y, &classes, n_features, &loss_enum, &hyper, 0)?;
963
964        Ok(FittedSGDClassifier {
965            weight_matrix,
966            intercepts,
967            classes,
968            n_features,
969            loss: loss_enum,
970            hyper: clf_hyper(&self),
971            t,
972        })
973    }
974}
975
976impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
977    for FittedSGDClassifier<F>
978{
979    /// Returns the coefficient vector for the first (or only) binary classifier.
980    fn coefficients(&self) -> &Array1<F> {
981        &self.weight_matrix[0]
982    }
983
984    /// Returns the intercept for the first (or only) binary classifier.
985    fn intercept(&self) -> F {
986        self.intercepts[0]
987    }
988}
989
990// Pipeline integration for f64.
991impl PipelineEstimator for SGDClassifier<f64> {
992    fn fit_pipeline(
993        &self,
994        x: &Array2<f64>,
995        y: &Array1<f64>,
996    ) -> Result<Box<dyn FittedPipelineEstimator>, FerroError> {
997        let y_usize: Array1<usize> = y.mapv(|v| v as usize);
998        let fitted = self.fit(x, &y_usize)?;
999        Ok(Box::new(FittedSGDClassifierPipeline(fitted)))
1000    }
1001}
1002
1003/// Wrapper for pipeline integration that converts predictions to f64.
1004struct FittedSGDClassifierPipeline(FittedSGDClassifier<f64>);
1005
1006// Safety: inner type fields are Send + Sync.
1007unsafe impl Send for FittedSGDClassifierPipeline {}
1008unsafe impl Sync for FittedSGDClassifierPipeline {}
1009
1010impl FittedPipelineEstimator for FittedSGDClassifierPipeline {
1011    fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1012        let preds = self.0.predict(x)?;
1013        Ok(preds.mapv(|v| v as f64))
1014    }
1015}
1016
1017// ---------------------------------------------------------------------------
1018// SGDRegressor
1019// ---------------------------------------------------------------------------
1020
1021/// Stochastic Gradient Descent regressor.
1022///
1023/// Supports several loss functions for regression, trained using stochastic
1024/// gradient descent with configurable learning rate schedules.
1025///
1026/// # Type Parameters
1027///
1028/// - `F`: The floating-point type (`f32` or `f64`).
1029///
1030/// # Examples
1031///
1032/// ```
1033/// use ferrolearn_linear::sgd::SGDRegressor;
1034/// use ferrolearn_core::{Fit, Predict};
1035/// use ndarray::{array, Array2};
1036///
1037/// let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1038/// let y = array![2.0, 4.0, 6.0, 8.0];
1039///
1040/// let model = SGDRegressor::<f64>::new();
1041/// let fitted = model.fit(&x, &y).unwrap();
1042/// let preds = fitted.predict(&x).unwrap();
1043/// ```
1044#[derive(Debug, Clone)]
1045pub struct SGDRegressor<F> {
1046    /// The loss function to use.
1047    pub loss: RegressorLoss<F>,
1048    /// The learning rate schedule.
1049    pub learning_rate: LearningRateSchedule<F>,
1050    /// Initial learning rate.
1051    pub eta0: F,
1052    /// L2 regularization strength.
1053    pub alpha: F,
1054    /// Maximum number of passes over the training data.
1055    pub max_iter: usize,
1056    /// Convergence tolerance.
1057    pub tol: F,
1058    /// Optional random seed for sample shuffling.
1059    pub random_state: Option<u64>,
1060    /// Power parameter for inverse scaling schedule.
1061    pub power_t: F,
1062}
1063
1064impl<F: Float> SGDRegressor<F> {
1065    /// Create a new `SGDRegressor` with default settings.
1066    ///
1067    /// Defaults: `loss = SquaredError`, `learning_rate = InvScaling`,
1068    /// `eta0 = 0.01`, `alpha = 0.0001`, `max_iter = 1000`,
1069    /// `tol = 1e-3`, `power_t = 0.25`.
1070    #[must_use]
1071    pub fn new() -> Self {
1072        Self {
1073            loss: RegressorLoss::SquaredError,
1074            learning_rate: LearningRateSchedule::InvScaling,
1075            eta0: F::from(0.01).unwrap(),
1076            alpha: F::from(0.0001).unwrap(),
1077            max_iter: 1000,
1078            tol: F::from(1e-3).unwrap(),
1079            random_state: None,
1080            power_t: F::from(0.25).unwrap(),
1081        }
1082    }
1083
1084    /// Set the loss function.
1085    #[must_use]
1086    pub fn with_loss(mut self, loss: RegressorLoss<F>) -> Self {
1087        self.loss = loss;
1088        self
1089    }
1090
1091    /// Set the learning rate schedule.
1092    #[must_use]
1093    pub fn with_learning_rate(mut self, lr: LearningRateSchedule<F>) -> Self {
1094        self.learning_rate = lr;
1095        self
1096    }
1097
1098    /// Set the initial learning rate.
1099    #[must_use]
1100    pub fn with_eta0(mut self, eta0: F) -> Self {
1101        self.eta0 = eta0;
1102        self
1103    }
1104
1105    /// Set the L2 regularization strength.
1106    #[must_use]
1107    pub fn with_alpha(mut self, alpha: F) -> Self {
1108        self.alpha = alpha;
1109        self
1110    }
1111
1112    /// Set the maximum number of epochs.
1113    #[must_use]
1114    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
1115        self.max_iter = max_iter;
1116        self
1117    }
1118
1119    /// Set the convergence tolerance.
1120    #[must_use]
1121    pub fn with_tol(mut self, tol: F) -> Self {
1122        self.tol = tol;
1123        self
1124    }
1125
1126    /// Set the random seed for reproducibility.
1127    #[must_use]
1128    pub fn with_random_state(mut self, seed: u64) -> Self {
1129        self.random_state = Some(seed);
1130        self
1131    }
1132
1133    /// Set the power parameter for inverse scaling.
1134    #[must_use]
1135    pub fn with_power_t(mut self, power_t: F) -> Self {
1136        self.power_t = power_t;
1137        self
1138    }
1139}
1140
1141impl<F: Float> Default for SGDRegressor<F> {
1142    fn default() -> Self {
1143        Self::new()
1144    }
1145}
1146
1147/// Extract hyperparameter bundle from an `SGDRegressor`.
1148fn reg_hyper<F: Float>(reg: &SGDRegressor<F>) -> SGDHyper<F> {
1149    SGDHyper {
1150        learning_rate: reg.learning_rate,
1151        eta0: reg.eta0,
1152        alpha: reg.alpha,
1153        max_iter: reg.max_iter,
1154        tol: reg.tol,
1155        random_state: reg.random_state,
1156        power_t: reg.power_t,
1157    }
1158}
1159
1160/// Train a single regressor via SGD, updating `weights` and `intercept`
1161/// in place. Returns the final loss and step counter.
1162fn train_regressor_sgd<F, L>(
1163    x: &Array2<F>,
1164    y: &Array1<F>,
1165    weights: &mut Array1<F>,
1166    intercept: &mut F,
1167    loss_fn: &L,
1168    hyper: &SGDHyper<F>,
1169    initial_t: usize,
1170) -> (F, usize)
1171where
1172    F: Float + ScalarOperand + Send + Sync + 'static,
1173    L: Loss<F>,
1174{
1175    let n_samples = x.nrows();
1176    let n_features = x.ncols();
1177    let mut t = initial_t;
1178    let mut prev_loss = F::infinity();
1179    let mut current_eta = hyper.eta0;
1180    let mut no_improve_count: usize = 0;
1181    let mut indices: Vec<usize> = (0..n_samples).collect();
1182
1183    let mut rng = match hyper.random_state {
1184        Some(seed) => rand::rngs::StdRng::seed_from_u64(seed),
1185        None => rand::rngs::StdRng::from_os_rng(),
1186    };
1187
1188    let mut total_loss = F::zero();
1189
1190    for _epoch in 0..hyper.max_iter {
1191        indices.shuffle(&mut rng);
1192        let mut epoch_loss = F::zero();
1193
1194        for &i in &indices {
1195            t += 1;
1196
1197            let eta = match hyper.learning_rate {
1198                LearningRateSchedule::Adaptive => current_eta,
1199                _ => compute_lr(
1200                    &hyper.learning_rate,
1201                    hyper.eta0,
1202                    hyper.alpha,
1203                    hyper.power_t,
1204                    t,
1205                ),
1206            };
1207
1208            let xi = x.row(i);
1209            let mut y_pred = *intercept;
1210            for j in 0..n_features {
1211                y_pred = y_pred + weights[j] * xi[j];
1212            }
1213
1214            let grad = loss_fn.gradient(y[i], y_pred);
1215            epoch_loss = epoch_loss + loss_fn.loss(y[i], y_pred);
1216
1217            for j in 0..n_features {
1218                weights[j] = weights[j] - eta * (grad * xi[j] + hyper.alpha * weights[j]);
1219            }
1220            *intercept = *intercept - eta * grad;
1221        }
1222
1223        epoch_loss = epoch_loss / F::from(n_samples).unwrap();
1224        total_loss = epoch_loss;
1225
1226        if (prev_loss - epoch_loss).abs() < hyper.tol {
1227            break;
1228        }
1229
1230        if let LearningRateSchedule::Adaptive = hyper.learning_rate {
1231            if epoch_loss >= prev_loss {
1232                no_improve_count += 1;
1233                if no_improve_count >= 5 {
1234                    current_eta = current_eta / F::from(2.0).unwrap();
1235                    no_improve_count = 0;
1236                    if current_eta < F::from(1e-6).unwrap() {
1237                        break;
1238                    }
1239                }
1240            } else {
1241                no_improve_count = 0;
1242            }
1243        }
1244
1245        prev_loss = epoch_loss;
1246    }
1247
1248    (total_loss, t)
1249}
1250
1251/// Dispatch regressor training to the appropriate typed loss function.
1252fn dispatch_train_regressor<F: Float + Send + Sync + ScalarOperand + 'static>(
1253    x: &Array2<F>,
1254    y: &Array1<F>,
1255    w: &mut Array1<F>,
1256    b: &mut F,
1257    loss_enum: &RegressorLoss<F>,
1258    hyper: &SGDHyper<F>,
1259    initial_t: usize,
1260) -> (F, usize) {
1261    match loss_enum {
1262        RegressorLoss::SquaredError => {
1263            train_regressor_sgd(x, y, w, b, &SquaredError, hyper, initial_t)
1264        }
1265        RegressorLoss::Huber(eps) => {
1266            train_regressor_sgd(x, y, w, b, &Huber { epsilon: *eps }, hyper, initial_t)
1267        }
1268        RegressorLoss::EpsilonInsensitive(eps) => train_regressor_sgd(
1269            x,
1270            y,
1271            w,
1272            b,
1273            &EpsilonInsensitive { epsilon: *eps },
1274            hyper,
1275            initial_t,
1276        ),
1277    }
1278}
1279
1280/// Fitted SGD regressor.
1281///
1282/// Holds the learned weight vector and intercept. Implements [`Predict`]
1283/// and [`PartialFit`] to support both inference and online learning.
1284#[derive(Debug, Clone)]
1285pub struct FittedSGDRegressor<F> {
1286    /// Learned weight vector (one per feature).
1287    weights: Array1<F>,
1288    /// Learned intercept (bias) term.
1289    intercept: F,
1290    /// Number of features the model was trained on.
1291    n_features: usize,
1292    /// The loss function used during training.
1293    loss: RegressorLoss<F>,
1294    /// Hyperparameters for continued training.
1295    hyper: SGDHyper<F>,
1296    /// Global step counter.
1297    t: usize,
1298}
1299
1300/// Validate regressor input shapes and parameters.
1301fn validate_reg_params<F: Float>(
1302    x: &Array2<F>,
1303    y: &Array1<F>,
1304    eta0: F,
1305    alpha: F,
1306) -> Result<(), FerroError> {
1307    let n_samples = x.nrows();
1308    if n_samples != y.len() {
1309        return Err(FerroError::ShapeMismatch {
1310            expected: vec![n_samples],
1311            actual: vec![y.len()],
1312            context: "y length must match number of samples in X".into(),
1313        });
1314    }
1315    if n_samples == 0 {
1316        return Err(FerroError::InsufficientSamples {
1317            required: 1,
1318            actual: 0,
1319            context: "SGDRegressor requires at least one sample".into(),
1320        });
1321    }
1322    if eta0 <= F::zero() {
1323        return Err(FerroError::InvalidParameter {
1324            name: "eta0".into(),
1325            reason: "must be positive".into(),
1326        });
1327    }
1328    if alpha < F::zero() {
1329        return Err(FerroError::InvalidParameter {
1330            name: "alpha".into(),
1331            reason: "must be non-negative".into(),
1332        });
1333    }
1334    Ok(())
1335}
1336
1337impl<F: Float + Send + Sync + ScalarOperand + 'static> Fit<Array2<F>, Array1<F>>
1338    for SGDRegressor<F>
1339{
1340    type Fitted = FittedSGDRegressor<F>;
1341    type Error = FerroError;
1342
1343    /// Fit the SGD regressor on the given data.
1344    ///
1345    /// # Errors
1346    ///
1347    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched
1348    /// sample counts.
1349    /// Returns [`FerroError::InvalidParameter`] if `eta0` or `alpha` are
1350    /// invalid.
1351    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedSGDRegressor<F>, FerroError> {
1352        validate_reg_params(x, y, self.eta0, self.alpha)?;
1353
1354        let n_features = x.ncols();
1355        let hyper = reg_hyper(self);
1356        let mut w = Array1::<F>::zeros(n_features);
1357        let mut b = F::zero();
1358
1359        let (_, t) = dispatch_train_regressor(x, y, &mut w, &mut b, &self.loss, &hyper, 0);
1360
1361        Ok(FittedSGDRegressor {
1362            weights: w,
1363            intercept: b,
1364            n_features,
1365            loss: self.loss,
1366            hyper,
1367            t,
1368        })
1369    }
1370}
1371
1372impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
1373    for FittedSGDRegressor<F>
1374{
1375    type Output = Array1<F>;
1376    type Error = FerroError;
1377
1378    /// Predict target values for the given feature matrix.
1379    ///
1380    /// Computes `X @ weights + intercept`.
1381    ///
1382    /// # Errors
1383    ///
1384    /// Returns [`FerroError::ShapeMismatch`] if the number of features
1385    /// does not match the fitted model.
1386    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
1387        let n_features = x.ncols();
1388        if n_features != self.n_features {
1389            return Err(FerroError::ShapeMismatch {
1390                expected: vec![self.n_features],
1391                actual: vec![n_features],
1392                context: "number of features must match fitted model".into(),
1393            });
1394        }
1395
1396        let preds = x.dot(&self.weights) + self.intercept;
1397        Ok(preds)
1398    }
1399}
1400
1401impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<F>>
1402    for FittedSGDRegressor<F>
1403{
1404    type FitResult = FittedSGDRegressor<F>;
1405    type Error = FerroError;
1406
1407    /// Incrementally train the regressor on a new batch of data.
1408    ///
1409    /// # Errors
1410    ///
1411    /// Returns [`FerroError::ShapeMismatch`] if `x` and `y` have mismatched
1412    /// sizes or `x` has the wrong number of features.
1413    fn partial_fit(
1414        mut self,
1415        x: &Array2<F>,
1416        y: &Array1<F>,
1417    ) -> Result<FittedSGDRegressor<F>, FerroError> {
1418        let n_samples = x.nrows();
1419        if n_samples != y.len() {
1420            return Err(FerroError::ShapeMismatch {
1421                expected: vec![n_samples],
1422                actual: vec![y.len()],
1423                context: "y length must match number of samples in X".into(),
1424            });
1425        }
1426        if x.ncols() != self.n_features {
1427            return Err(FerroError::ShapeMismatch {
1428                expected: vec![self.n_features],
1429                actual: vec![x.ncols()],
1430                context: "number of features must match fitted model".into(),
1431            });
1432        }
1433
1434        let mut hyper = self.hyper.clone();
1435        hyper.max_iter = 1;
1436
1437        let (_, t) = dispatch_train_regressor(
1438            x,
1439            y,
1440            &mut self.weights,
1441            &mut self.intercept,
1442            &self.loss,
1443            &hyper,
1444            self.t,
1445        );
1446        self.t = t;
1447
1448        Ok(self)
1449    }
1450}
1451
1452impl<F: Float + Send + Sync + ScalarOperand + 'static> PartialFit<Array2<F>, Array1<F>>
1453    for SGDRegressor<F>
1454{
1455    type FitResult = FittedSGDRegressor<F>;
1456    type Error = FerroError;
1457
1458    /// Initial call to `partial_fit` on an unfitted regressor.
1459    ///
1460    /// Equivalent to `fit` but with a single epoch.
1461    ///
1462    /// # Errors
1463    ///
1464    /// Same as [`Fit::fit`].
1465    fn partial_fit(
1466        self,
1467        x: &Array2<F>,
1468        y: &Array1<F>,
1469    ) -> Result<FittedSGDRegressor<F>, FerroError> {
1470        validate_reg_params(x, y, self.eta0, self.alpha)?;
1471
1472        let n_features = x.ncols();
1473        let mut hyper = reg_hyper(&self);
1474        hyper.max_iter = 1;
1475        let mut w = Array1::<F>::zeros(n_features);
1476        let mut b = F::zero();
1477
1478        let (_, t) = dispatch_train_regressor(x, y, &mut w, &mut b, &self.loss, &hyper, 0);
1479
1480        Ok(FittedSGDRegressor {
1481            weights: w,
1482            intercept: b,
1483            n_features,
1484            loss: self.loss,
1485            hyper: reg_hyper(&self),
1486            t,
1487        })
1488    }
1489}
1490
1491impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
1492    for FittedSGDRegressor<F>
1493{
1494    fn coefficients(&self) -> &Array1<F> {
1495        &self.weights
1496    }
1497
1498    fn intercept(&self) -> F {
1499        self.intercept
1500    }
1501}
1502
1503// Pipeline integration for f64.
1504impl PipelineEstimator for SGDRegressor<f64> {
1505    fn fit_pipeline(
1506        &self,
1507        x: &Array2<f64>,
1508        y: &Array1<f64>,
1509    ) -> Result<Box<dyn FittedPipelineEstimator>, FerroError> {
1510        let fitted = self.fit(x, y)?;
1511        Ok(Box::new(fitted))
1512    }
1513}
1514
1515impl FittedPipelineEstimator for FittedSGDRegressor<f64> {
1516    fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
1517        self.predict(x)
1518    }
1519}
1520
1521// ---------------------------------------------------------------------------
1522// Tests
1523// ---------------------------------------------------------------------------
1524
1525#[cfg(test)]
1526mod tests {
1527    use super::*;
1528    use ndarray::array;
1529
1530    // -----------------------------------------------------------------------
1531    // Loss function tests
1532    // -----------------------------------------------------------------------
1533
1534    #[test]
1535    fn test_hinge_loss_correct_side() {
1536        let h = Hinge;
1537        // y=1, pred=2 => margin=2 >= 1 => loss=0
1538        assert!((Loss::<f64>::loss(&h, 1.0, 2.0) - 0.0).abs() < 1e-10);
1539        assert!((Loss::<f64>::gradient(&h, 1.0, 2.0) - 0.0).abs() < 1e-10);
1540    }
1541
1542    #[test]
1543    fn test_hinge_loss_wrong_side() {
1544        let h = Hinge;
1545        // y=1, pred=-0.5 => margin=-0.5 < 1 => loss=1.5
1546        assert!((Loss::<f64>::loss(&h, 1.0, -0.5) - 1.5).abs() < 1e-10);
1547        assert!((Loss::<f64>::gradient(&h, 1.0, -0.5) - (-1.0)).abs() < 1e-10);
1548    }
1549
1550    #[test]
1551    fn test_log_loss_zero_pred() {
1552        let l = LogLoss;
1553        // y=1, pred=0 => loss=log(1+exp(0))=log(2)
1554        let loss = Loss::<f64>::loss(&l, 1.0, 0.0);
1555        assert!((loss - 2.0_f64.ln()).abs() < 1e-10);
1556    }
1557
1558    #[test]
1559    fn test_log_loss_large_correct() {
1560        let l = LogLoss;
1561        // y=1, pred=20 => very small loss
1562        let loss = Loss::<f64>::loss(&l, 1.0, 20.0);
1563        assert!(loss < 1e-5);
1564    }
1565
1566    #[test]
1567    fn test_squared_error_loss() {
1568        let s = SquaredError;
1569        assert!((Loss::<f64>::loss(&s, 3.0, 1.0) - 2.0).abs() < 1e-10);
1570        assert!((Loss::<f64>::gradient(&s, 3.0, 1.0) - (-2.0)).abs() < 1e-10);
1571    }
1572
1573    #[test]
1574    fn test_modified_huber_loss() {
1575        let mh = ModifiedHuber;
1576        // y=1, pred=2 => z=2 >= 1 => loss=0
1577        assert!((Loss::<f64>::loss(&mh, 1.0, 2.0)).abs() < 1e-10);
1578        // y=1, pred=0.5 => z=0.5 => loss=(1-0.5)^2=0.25
1579        assert!((Loss::<f64>::loss(&mh, 1.0, 0.5) - 0.25).abs() < 1e-10);
1580        // y=1, pred=-2 => z=-2 < -1 => loss=-4*(-2)=8
1581        assert!((Loss::<f64>::loss(&mh, 1.0, -2.0) - 8.0).abs() < 1e-10);
1582    }
1583
1584    #[test]
1585    fn test_huber_loss_quadratic_region() {
1586        let h = Huber { epsilon: 1.0_f64 };
1587        // |y - p| = 0.5 <= 1.0 => quadratic
1588        assert!((Loss::<f64>::loss(&h, 1.0, 0.5) - 0.125).abs() < 1e-10);
1589    }
1590
1591    #[test]
1592    fn test_huber_loss_linear_region() {
1593        let h = Huber { epsilon: 1.0_f64 };
1594        // |y - p| = 3 > 1 => linear: 1*(3 - 0.5) = 2.5
1595        assert!((Loss::<f64>::loss(&h, 3.0, 0.0) - 2.5).abs() < 1e-10);
1596    }
1597
1598    #[test]
1599    fn test_epsilon_insensitive_inside() {
1600        let ei = EpsilonInsensitive { epsilon: 0.1_f64 };
1601        // |y - p| = 0.05 <= 0.1 => loss=0
1602        assert!((Loss::<f64>::loss(&ei, 1.0, 0.95)).abs() < 1e-10);
1603    }
1604
1605    #[test]
1606    fn test_epsilon_insensitive_outside() {
1607        let ei = EpsilonInsensitive { epsilon: 0.1_f64 };
1608        // |y - p| = 0.5 > 0.1 => loss=0.4
1609        assert!((Loss::<f64>::loss(&ei, 1.0, 0.5) - 0.4).abs() < 1e-10);
1610    }
1611
1612    // -----------------------------------------------------------------------
1613    // Learning rate tests
1614    // -----------------------------------------------------------------------
1615
1616    #[test]
1617    fn test_constant_lr() {
1618        let lr: LearningRateSchedule<f64> = LearningRateSchedule::Constant;
1619        assert!((compute_lr(&lr, 0.1, 0.01, 0.25, 100) - 0.1).abs() < 1e-10);
1620    }
1621
1622    #[test]
1623    fn test_optimal_lr() {
1624        let lr: LearningRateSchedule<f64> = LearningRateSchedule::Optimal;
1625        // eta = 1 / (alpha * t) = 1 / (0.01 * 10) = 10.0
1626        assert!((compute_lr(&lr, 0.1, 0.01, 0.25, 10) - 10.0).abs() < 1e-10);
1627    }
1628
1629    #[test]
1630    fn test_invscaling_lr() {
1631        let lr: LearningRateSchedule<f64> = LearningRateSchedule::InvScaling;
1632        // eta = 0.1 / 10^0.5 = 0.1 / 3.162... ~= 0.0316...
1633        let result = compute_lr(&lr, 0.1, 0.01, 0.5, 10);
1634        let expected = 0.1 / 10.0_f64.sqrt();
1635        assert!((result - expected).abs() < 1e-10);
1636    }
1637
1638    #[test]
1639    fn test_adaptive_lr_returns_eta0() {
1640        let lr: LearningRateSchedule<f64> = LearningRateSchedule::Adaptive;
1641        assert!((compute_lr(&lr, 0.05, 0.01, 0.25, 100) - 0.05).abs() < 1e-10);
1642    }
1643
1644    // -----------------------------------------------------------------------
1645    // SGDClassifier tests
1646    // -----------------------------------------------------------------------
1647
1648    #[test]
1649    fn test_sgd_classifier_binary() {
1650        // Well-separated clusters centered near origin for SGD stability.
1651        let x = Array2::from_shape_vec(
1652            (8, 2),
1653            vec![
1654                -2.0, -2.0, -1.5, -2.0, -2.0, -1.5, -1.5, -1.5, 2.0, 2.0, 1.5, 2.0, 2.0, 1.5, 1.5,
1655                1.5,
1656            ],
1657        )
1658        .unwrap();
1659        let y = array![0, 0, 0, 0, 1, 1, 1, 1];
1660
1661        let clf = SGDClassifier::<f64>::new()
1662            .with_loss(ClassifierLoss::Log)
1663            .with_random_state(42)
1664            .with_max_iter(1000)
1665            .with_eta0(0.01);
1666        let fitted = clf.fit(&x, &y).unwrap();
1667        let preds = fitted.predict(&x).unwrap();
1668
1669        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1670        assert!(correct >= 6, "expected >= 6 correct, got {correct}");
1671    }
1672
1673    #[test]
1674    fn test_sgd_classifier_log_loss() {
1675        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1676        let y = array![0, 0, 0, 1, 1, 1];
1677
1678        let clf = SGDClassifier::<f64>::new()
1679            .with_loss(ClassifierLoss::Log)
1680            .with_random_state(42)
1681            .with_max_iter(500);
1682        let fitted = clf.fit(&x, &y).unwrap();
1683        let preds = fitted.predict(&x).unwrap();
1684
1685        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1686        assert!(correct >= 4, "expected >= 4 correct, got {correct}");
1687    }
1688
1689    #[test]
1690    fn test_sgd_classifier_multiclass() {
1691        let x = Array2::from_shape_vec(
1692            (9, 2),
1693            vec![
1694                0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 5.0, 0.0, 5.5, 0.0, 5.0, 0.5, 0.0, 5.0, 0.5, 5.0,
1695                0.0, 5.5,
1696            ],
1697        )
1698        .unwrap();
1699        let y = array![0, 0, 0, 1, 1, 1, 2, 2, 2];
1700
1701        let clf = SGDClassifier::<f64>::new()
1702            .with_random_state(42)
1703            .with_max_iter(1000)
1704            .with_eta0(0.01);
1705        let fitted = clf.fit(&x, &y).unwrap();
1706        let preds = fitted.predict(&x).unwrap();
1707
1708        let correct: usize = preds.iter().zip(y.iter()).filter(|(p, a)| p == a).count();
1709        assert!(
1710            correct >= 6,
1711            "expected >= 6 correct for multiclass, got {correct}"
1712        );
1713    }
1714
1715    #[test]
1716    fn test_sgd_classifier_shape_mismatch_fit() {
1717        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1718        let y = array![0, 1]; // Wrong length
1719        let clf = SGDClassifier::<f64>::new();
1720        assert!(clf.fit(&x, &y).is_err());
1721    }
1722
1723    #[test]
1724    fn test_sgd_classifier_shape_mismatch_predict() {
1725        let x =
1726            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1727        let y = array![0, 0, 1, 1];
1728        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1729        let fitted = clf.fit(&x, &y).unwrap();
1730
1731        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1732        assert!(fitted.predict(&x_bad).is_err());
1733    }
1734
1735    #[test]
1736    fn test_sgd_classifier_single_class_error() {
1737        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1738        let y = array![0, 0, 0];
1739        let clf = SGDClassifier::<f64>::new();
1740        assert!(clf.fit(&x, &y).is_err());
1741    }
1742
1743    #[test]
1744    fn test_sgd_classifier_invalid_eta0() {
1745        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1746        let y = array![0, 0, 1, 1];
1747        let clf = SGDClassifier::<f64>::new().with_eta0(0.0);
1748        assert!(clf.fit(&x, &y).is_err());
1749    }
1750
1751    #[test]
1752    fn test_sgd_classifier_invalid_alpha() {
1753        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1754        let y = array![0, 0, 1, 1];
1755        let clf = SGDClassifier::<f64>::new().with_alpha(-1.0);
1756        assert!(clf.fit(&x, &y).is_err());
1757    }
1758
1759    #[test]
1760    fn test_sgd_classifier_has_coefficients() {
1761        let x =
1762            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1763        let y = array![0, 0, 1, 1];
1764        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1765        let fitted = clf.fit(&x, &y).unwrap();
1766        assert_eq!(fitted.coefficients().len(), 2);
1767    }
1768
1769    #[test]
1770    fn test_sgd_classifier_partial_fit() {
1771        let x =
1772            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1773        let y = array![0, 0, 1, 1];
1774
1775        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1776        let fitted = clf.partial_fit(&x, &y).unwrap();
1777        let fitted = fitted.partial_fit(&x, &y).unwrap();
1778        let preds = fitted.predict(&x).unwrap();
1779        assert_eq!(preds.len(), 4);
1780    }
1781
1782    #[test]
1783    fn test_sgd_classifier_partial_fit_chain() {
1784        // Test the chaining pattern:
1785        // model.partial_fit(&b1, &y1)?.partial_fit(&b2, &y2)?.predict(&x)?
1786        let x1 =
1787            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1788        let y1 = array![0, 0, 1, 1];
1789        let x2 =
1790            Array2::from_shape_vec((4, 2), vec![0.5, 0.5, 1.5, 1.5, 7.5, 7.5, 8.5, 8.5]).unwrap();
1791        let y2 = array![0, 0, 1, 1];
1792
1793        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1794        let preds = clf
1795            .partial_fit(&x1, &y1)
1796            .unwrap()
1797            .partial_fit(&x2, &y2)
1798            .unwrap()
1799            .predict(&x1)
1800            .unwrap();
1801        assert_eq!(preds.len(), 4);
1802    }
1803
1804    #[test]
1805    fn test_sgd_classifier_partial_fit_shape_mismatch() {
1806        let x =
1807            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 8.0, 8.0, 9.0, 9.0]).unwrap();
1808        let y = array![0, 0, 1, 1];
1809        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1810        let fitted = clf.partial_fit(&x, &y).unwrap();
1811
1812        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1813        let y_bad = array![0, 1];
1814        assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
1815    }
1816
1817    #[test]
1818    fn test_sgd_classifier_modified_huber() {
1819        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1820        let y = array![0, 0, 0, 1, 1, 1];
1821
1822        let clf = SGDClassifier::<f64>::new()
1823            .with_loss(ClassifierLoss::ModifiedHuber)
1824            .with_random_state(42)
1825            .with_max_iter(500);
1826        let fitted = clf.fit(&x, &y).unwrap();
1827        let preds = fitted.predict(&x).unwrap();
1828        assert_eq!(preds.len(), 6);
1829    }
1830
1831    #[test]
1832    fn test_sgd_classifier_squared_error_loss() {
1833        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1834        let y = array![0, 0, 0, 1, 1, 1];
1835
1836        let clf = SGDClassifier::<f64>::new()
1837            .with_loss(ClassifierLoss::SquaredError)
1838            .with_random_state(42)
1839            .with_max_iter(500);
1840        let fitted = clf.fit(&x, &y).unwrap();
1841        let preds = fitted.predict(&x).unwrap();
1842        assert_eq!(preds.len(), 6);
1843    }
1844
1845    #[test]
1846    fn test_sgd_classifier_pipeline() {
1847        let x = Array2::from_shape_vec((6, 1), vec![-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]).unwrap();
1848        let y = Array1::from_vec(vec![0.0, 0.0, 0.0, 1.0, 1.0, 1.0]);
1849
1850        let clf = SGDClassifier::<f64>::new().with_random_state(42);
1851        let fitted = clf.fit_pipeline(&x, &y).unwrap();
1852        let preds = fitted.predict_pipeline(&x).unwrap();
1853        assert_eq!(preds.len(), 6);
1854    }
1855
1856    #[test]
1857    fn test_sgd_classifier_constant_lr() {
1858        let x = Array2::from_shape_vec((4, 1), vec![-2.0, -1.0, 1.0, 2.0]).unwrap();
1859        let y = array![0, 0, 1, 1];
1860
1861        let clf = SGDClassifier::<f64>::new()
1862            .with_learning_rate(LearningRateSchedule::Constant)
1863            .with_random_state(42)
1864            .with_max_iter(200);
1865        let fitted = clf.fit(&x, &y).unwrap();
1866        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1867    }
1868
1869    #[test]
1870    fn test_sgd_classifier_f32() {
1871        let x = Array2::from_shape_vec((4, 1), vec![-2.0f32, -1.0, 1.0, 2.0]).unwrap();
1872        let y = array![0_usize, 0, 1, 1];
1873
1874        let clf = SGDClassifier::<f32>::new()
1875            .with_random_state(42)
1876            .with_max_iter(200);
1877        let fitted = clf.fit(&x, &y).unwrap();
1878        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1879    }
1880
1881    // -----------------------------------------------------------------------
1882    // SGDRegressor tests
1883    // -----------------------------------------------------------------------
1884
1885    #[test]
1886    fn test_sgd_regressor_basic() {
1887        // y = 2*x + 1
1888        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
1889        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
1890
1891        let model = SGDRegressor::<f64>::new()
1892            .with_random_state(42)
1893            .with_max_iter(2000)
1894            .with_eta0(0.01)
1895            .with_alpha(0.0);
1896        let fitted = model.fit(&x, &y).unwrap();
1897        let preds = fitted.predict(&x).unwrap();
1898
1899        // Check rough accuracy.
1900        for (p, &actual) in preds.iter().zip(y.iter()) {
1901            assert!(
1902                (*p - actual).abs() < 2.0,
1903                "prediction {p} too far from {actual}"
1904            );
1905        }
1906    }
1907
1908    #[test]
1909    fn test_sgd_regressor_shape_mismatch() {
1910        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1911        let y = array![1.0, 2.0]; // Wrong length
1912        let model = SGDRegressor::<f64>::new();
1913        assert!(model.fit(&x, &y).is_err());
1914    }
1915
1916    #[test]
1917    fn test_sgd_regressor_predict_shape_mismatch() {
1918        let x =
1919            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1920        let y = array![1.0, 2.0, 3.0, 4.0];
1921        let model = SGDRegressor::<f64>::new().with_random_state(42);
1922        let fitted = model.fit(&x, &y).unwrap();
1923
1924        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1925        assert!(fitted.predict(&x_bad).is_err());
1926    }
1927
1928    #[test]
1929    fn test_sgd_regressor_invalid_eta0() {
1930        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1931        let y = array![1.0, 2.0, 3.0];
1932        let model = SGDRegressor::<f64>::new().with_eta0(-0.1);
1933        assert!(model.fit(&x, &y).is_err());
1934    }
1935
1936    #[test]
1937    fn test_sgd_regressor_has_coefficients() {
1938        let x =
1939            Array2::from_shape_vec((4, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]).unwrap();
1940        let y = array![1.0, 2.0, 3.0, 4.0];
1941        let model = SGDRegressor::<f64>::new().with_random_state(42);
1942        let fitted = model.fit(&x, &y).unwrap();
1943        assert_eq!(fitted.coefficients().len(), 2);
1944    }
1945
1946    #[test]
1947    fn test_sgd_regressor_partial_fit() {
1948        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1949        let y = array![2.0, 4.0, 6.0, 8.0];
1950
1951        let model = SGDRegressor::<f64>::new().with_random_state(42);
1952        let fitted = model.partial_fit(&x, &y).unwrap();
1953        let fitted = fitted.partial_fit(&x, &y).unwrap();
1954        let preds = fitted.predict(&x).unwrap();
1955        assert_eq!(preds.len(), 4);
1956    }
1957
1958    #[test]
1959    fn test_sgd_regressor_partial_fit_chain() {
1960        let x1 = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
1961        let y1 = array![2.0, 4.0, 6.0];
1962        let x2 = Array2::from_shape_vec((3, 1), vec![4.0, 5.0, 6.0]).unwrap();
1963        let y2 = array![8.0, 10.0, 12.0];
1964
1965        let model = SGDRegressor::<f64>::new().with_random_state(42);
1966        let preds = model
1967            .partial_fit(&x1, &y1)
1968            .unwrap()
1969            .partial_fit(&x2, &y2)
1970            .unwrap()
1971            .predict(&x1)
1972            .unwrap();
1973        assert_eq!(preds.len(), 3);
1974    }
1975
1976    #[test]
1977    fn test_sgd_regressor_partial_fit_shape_mismatch() {
1978        let x = Array2::from_shape_vec((3, 2), vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]).unwrap();
1979        let y = array![1.0, 2.0, 3.0];
1980        let model = SGDRegressor::<f64>::new().with_random_state(42);
1981        let fitted = model.partial_fit(&x, &y).unwrap();
1982
1983        let x_bad = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap();
1984        let y_bad = array![1.0, 2.0];
1985        assert!(fitted.partial_fit(&x_bad, &y_bad).is_err());
1986    }
1987
1988    #[test]
1989    fn test_sgd_regressor_huber_loss() {
1990        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
1991        let y = array![2.0, 4.0, 6.0, 8.0];
1992
1993        let model = SGDRegressor::<f64>::new()
1994            .with_loss(RegressorLoss::Huber(1.35))
1995            .with_random_state(42)
1996            .with_max_iter(500);
1997        let fitted = model.fit(&x, &y).unwrap();
1998        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
1999    }
2000
2001    #[test]
2002    fn test_sgd_regressor_epsilon_insensitive() {
2003        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2004        let y = array![2.0, 4.0, 6.0, 8.0];
2005
2006        let model = SGDRegressor::<f64>::new()
2007            .with_loss(RegressorLoss::EpsilonInsensitive(0.1))
2008            .with_random_state(42)
2009            .with_max_iter(500);
2010        let fitted = model.fit(&x, &y).unwrap();
2011        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2012    }
2013
2014    #[test]
2015    fn test_sgd_regressor_pipeline() {
2016        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
2017        let y = Array1::from_vec(vec![2.0, 4.0, 6.0, 8.0]);
2018
2019        let model = SGDRegressor::<f64>::new().with_random_state(42);
2020        let fitted = model.fit_pipeline(&x, &y).unwrap();
2021        let preds = fitted.predict_pipeline(&x).unwrap();
2022        assert_eq!(preds.len(), 4);
2023    }
2024
2025    #[test]
2026    fn test_sgd_regressor_f32() {
2027        let x = Array2::from_shape_vec((4, 1), vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
2028        let y = Array1::from_vec(vec![2.0f32, 4.0, 6.0, 8.0]);
2029
2030        let model = SGDRegressor::<f32>::new().with_random_state(42);
2031        let fitted = model.fit(&x, &y).unwrap();
2032        assert_eq!(fitted.predict(&x).unwrap().len(), 4);
2033    }
2034
2035    #[test]
2036    fn test_sgd_regressor_empty_data() {
2037        let x = Array2::<f64>::zeros((0, 2));
2038        let y = Array1::<f64>::zeros(0);
2039        let model = SGDRegressor::<f64>::new();
2040        assert!(model.fit(&x, &y).is_err());
2041    }
2042
2043    #[test]
2044    fn test_sgd_classifier_empty_data() {
2045        let x = Array2::<f64>::zeros((0, 2));
2046        let y = Array1::<usize>::zeros(0);
2047        let clf = SGDClassifier::<f64>::new();
2048        assert!(clf.fit(&x, &y).is_err());
2049    }
2050
2051    #[test]
2052    fn test_sgd_classifier_default() {
2053        let clf = SGDClassifier::<f64>::default();
2054        assert!(clf.eta0 > 0.0);
2055        assert!(clf.alpha >= 0.0);
2056    }
2057
2058    #[test]
2059    fn test_sgd_regressor_default() {
2060        let model = SGDRegressor::<f64>::default();
2061        assert!(model.eta0 > 0.0);
2062        assert!(model.alpha >= 0.0);
2063    }
2064}