use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::regularizers::{Regularizer, L1, L2};
#[derive(Debug, Clone, Copy)]
pub struct ElasticNet<A: Float + Debug> {
alpha: A,
l1_ratio: A,
l1: L1<A>,
l2: L2<A>,
}
impl<A: Float + Debug + Send + Sync> ElasticNet<A> {
pub fn new(alpha: A, l1ratio: A) -> Self {
let l1_ratio = l1ratio.max(A::zero()).min(A::one());
let l1_alpha = alpha * l1_ratio;
let l2_alpha = alpha * (A::one() - l1_ratio);
Self {
alpha,
l1_ratio,
l1: L1::new(l1_alpha),
l2: L2::new(l2_alpha),
}
}
pub fn alpha(&self) -> A {
self.alpha
}
pub fn l1_ratio(&self) -> A {
self.l1_ratio
}
pub fn set_params(&mut self, alpha: A, l1ratio: A) -> &mut Self {
let l1_ratio = l1ratio.max(A::zero()).min(A::one());
self.alpha = alpha;
self.l1_ratio = l1_ratio;
let l1_alpha = alpha * l1_ratio;
let l2_alpha = alpha * (A::one() - l1_ratio);
self.l1.set_alpha(l1_alpha);
self.l2.set_alpha(l2_alpha);
self
}
}
impl<A, D> Regularizer<A, D> for ElasticNet<A>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
let mut l2_grads = gradients.clone();
let l1_penalty = self.l1.apply(params, gradients)?;
let l2_penalty = self.l2.apply(params, &mut l2_grads)?;
Zip::from(gradients)
.and(&l2_grads)
.for_each(|grad, &l2_grad| {
*grad = self.l1_ratio * *grad + (A::one() - self.l1_ratio) * l2_grad;
});
Ok(l1_penalty + l2_penalty)
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
let l1_penalty = self.l1.penalty(params)?;
let l2_penalty = self.l2.penalty(params)?;
Ok(l1_penalty + l2_penalty)
}
}