Skip to main content

anofox_ml_linear/
sgd_common.rs

1//! Common types and utilities for SGD-based linear models.
2
3use serde::{Deserialize, Serialize};
4
5/// Regularization penalty type.
6#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
7pub enum Penalty {
8    /// No regularization.
9    None,
10    /// L2 (Ridge) regularization.
11    L2,
12    /// L1 (Lasso) regularization.
13    L1,
14    /// Elastic Net: `l1_ratio * L1 + (1 - l1_ratio) * L2`.
15    ElasticNet,
16}
17
18impl Default for Penalty {
19    fn default() -> Self {
20        Penalty::L2
21    }
22}
23
24/// Learning rate schedule.
25#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
26pub enum LearningRate {
27    /// Fixed learning rate: `eta = eta0`.
28    Constant,
29    /// Optimal: `eta = 1 / (alpha * (t + t0))`.
30    Optimal,
31    /// Inverse scaling: `eta = eta0 / t^power_t`.
32    InvScaling,
33}
34
35impl Default for LearningRate {
36    fn default() -> Self {
37        LearningRate::InvScaling
38    }
39}
40
41/// Apply L1/L2/ElasticNet penalty to a weight, returning the gradient contribution.
42#[inline]
43pub fn penalty_gradient(w: f64, alpha: f64, penalty: Penalty, l1_ratio: f64) -> f64 {
44    match penalty {
45        Penalty::None => 0.0,
46        Penalty::L2 => alpha * w,
47        Penalty::L1 => alpha * w.signum(),
48        Penalty::ElasticNet => alpha * (l1_ratio * w.signum() + (1.0 - l1_ratio) * w),
49    }
50}
51
52/// Compute the learning rate at iteration t.
53#[inline]
54pub fn compute_lr(schedule: LearningRate, eta0: f64, alpha: f64, t: usize, power_t: f64) -> f64 {
55    match schedule {
56        LearningRate::Constant => eta0,
57        LearningRate::Optimal => {
58            let t0 = 1.0 / (eta0 * alpha);
59            1.0 / (alpha * (t as f64 + t0))
60        }
61        LearningRate::InvScaling => eta0 / (t as f64 + 1.0).powf(power_t),
62    }
63}