use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum Penalty {
None,
L2,
L1,
ElasticNet,
}
impl Default for Penalty {
fn default() -> Self {
Penalty::L2
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum LearningRate {
Constant,
Optimal,
InvScaling,
}
impl Default for LearningRate {
fn default() -> Self {
LearningRate::InvScaling
}
}
#[inline]
pub fn penalty_gradient(w: f64, alpha: f64, penalty: Penalty, l1_ratio: f64) -> f64 {
match penalty {
Penalty::None => 0.0,
Penalty::L2 => alpha * w,
Penalty::L1 => alpha * w.signum(),
Penalty::ElasticNet => alpha * (l1_ratio * w.signum() + (1.0 - l1_ratio) * w),
}
}
#[inline]
pub fn compute_lr(schedule: LearningRate, eta0: f64, alpha: f64, t: usize, power_t: f64) -> f64 {
match schedule {
LearningRate::Constant => eta0,
LearningRate::Optimal => {
let t0 = 1.0 / (eta0 * alpha);
1.0 / (alpha * (t as f64 + t0))
}
LearningRate::InvScaling => eta0 / (t as f64 + 1.0).powf(power_t),
}
}