anofox_ml_linear/
sgd_common.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
7pub enum Penalty {
8 None,
10 L2,
12 L1,
14 ElasticNet,
16}
17
18impl Default for Penalty {
19 fn default() -> Self {
20 Penalty::L2
21 }
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
26pub enum LearningRate {
27 Constant,
29 Optimal,
31 InvScaling,
33}
34
35impl Default for LearningRate {
36 fn default() -> Self {
37 LearningRate::InvScaling
38 }
39}
40
41#[inline]
43pub fn penalty_gradient(w: f64, alpha: f64, penalty: Penalty, l1_ratio: f64) -> f64 {
44 match penalty {
45 Penalty::None => 0.0,
46 Penalty::L2 => alpha * w,
47 Penalty::L1 => alpha * w.signum(),
48 Penalty::ElasticNet => alpha * (l1_ratio * w.signum() + (1.0 - l1_ratio) * w),
49 }
50}
51
52#[inline]
54pub fn compute_lr(schedule: LearningRate, eta0: f64, alpha: f64, t: usize, power_t: f64) -> f64 {
55 match schedule {
56 LearningRate::Constant => eta0,
57 LearningRate::Optimal => {
58 let t0 = 1.0 / (eta0 * alpha);
59 1.0 / (alpha * (t as f64 + t0))
60 }
61 LearningRate::InvScaling => eta0 / (t as f64 + 1.0).powf(power_t),
62 }
63}