use scirs2_core::ndarray::{Array, Dimension, ScalarOperand, Zip};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use crate::error::Result;
use crate::regularizers::Regularizer;
#[derive(Debug, Clone, Copy)]
pub struct L1<A: Float + Debug> {
alpha: A,
}
impl<A: Float + Debug + Send + Sync> L1<A> {
pub fn new(alpha: A) -> Self {
Self { alpha }
}
pub fn alpha(&self) -> A {
self.alpha
}
pub fn set_alpha(&mut self, alpha: A) -> &mut Self {
self.alpha = alpha;
self
}
}
impl<A, D> Regularizer<A, D> for L1<A>
where
A: Float + ScalarOperand + Debug,
D: Dimension,
{
fn apply(&self, params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
Zip::from(params).and(gradients).for_each(|¶m, grad| {
let sign = if param > A::zero() {
A::one()
} else if param < A::zero() {
-A::one()
} else {
A::zero()
};
*grad = *grad + self.alpha * sign;
});
self.penalty(params)
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
let sum_abs = params.iter().fold(A::zero(), |acc, &x| acc + x.abs());
Ok(self.alpha * sum_abs)
}
}