use crate::{common::matrix::DMat, error::NetworkError};
use serde::{Deserialize, Serialize};
use typetag;
use super::{Regularization, RegularizationClone};
#[derive(Serialize, Deserialize, Clone)]
pub(crate) struct L2Regularization {
lambda: f32,
}
#[typetag::serde]
impl Regularization for L2Regularization {
fn apply(&self, params: &mut [&mut DMat], grads: &mut [&mut DMat]) {
for (param, grad) in params.iter().zip(grads.iter_mut()) {
grad.apply_with_indices(|i, j, v| {
let p = param.at(i, j);
*v += self.lambda * p * p;
});
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl RegularizationClone for L2Regularization {
fn clone_box(&self) -> Box<dyn Regularization> {
Box::new(self.clone())
}
}
pub struct L2 {
lambda: f32,
}
impl L2 {
fn new() -> Self {
Self { lambda: 0.01 }
}
pub fn lambda(mut self, lambda: f32) -> Self {
self.lambda = lambda;
self
}
fn validate(&self) -> Result<(), NetworkError> {
if self.lambda < 0.0 {
return Err(NetworkError::ConfigError(format!(
"Lambda for L2 regularization must be positive, but was {}",
self.lambda
)));
}
Ok(())
}
pub fn build(self) -> Result<Box<dyn Regularization>, NetworkError> {
self.validate()?;
Ok(Box::new(L2Regularization { lambda: self.lambda }))
}
}
impl Default for L2 {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::matrix::DMat;
use crate::util::equal_approx;
#[test]
fn test_l2_regularization() {
let mut params = [DMat::new(2, 2, &[1.0, -2.0, 3.0, -4.0])];
let mut grads = [DMat::new(2, 2, &[0.1, 0.1, 0.1, 0.1])];
let l2 = L2::new().lambda(0.01).build().unwrap();
let mut params_refs: Vec<&mut DMat> = params.iter_mut().collect();
let mut grads_refs: Vec<&mut DMat> = grads.iter_mut().collect();
l2.apply(&mut params_refs, &mut grads_refs);
let expected_grads = DMat::new(2, 2, &[0.11, 0.14, 0.19, 0.26]);
equal_approx(&grads[0], &expected_grads, 1e-6);
}
#[test]
fn test_l2_regularization_invalid_lambda() {
let l2 = L2::new().lambda(-0.01);
let result = l2.build();
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(
e.to_string(),
"Configuration error: Lambda for L2 regularization must be positive, but was -0.01"
);
}
}
}