use crate::{TrainError, TrainResult};
use scirs2_core::ndarray::{Array, Ix2};
use std::collections::HashMap;
pub trait Regularizer {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64>;
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
}
#[derive(Debug, Clone)]
pub struct L1Regularization {
pub lambda: f64,
}
impl L1Regularization {
pub fn new(lambda: f64) -> Self {
Self { lambda }
}
}
impl Default for L1Regularization {
fn default() -> Self {
Self { lambda: 0.01 }
}
}
impl Regularizer for L1Regularization {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
let mut penalty = 0.0;
for param in parameters.values() {
for &value in param.iter() {
penalty += value.abs();
}
}
Ok(self.lambda * penalty)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut gradients = HashMap::new();
for (name, param) in parameters {
let grad = param.mapv(|w| self.lambda * w.signum());
gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
}
#[derive(Debug, Clone)]
pub struct L2Regularization {
pub lambda: f64,
}
impl L2Regularization {
pub fn new(lambda: f64) -> Self {
Self { lambda }
}
}
impl Default for L2Regularization {
fn default() -> Self {
Self { lambda: 0.01 }
}
}
impl Regularizer for L2Regularization {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
let mut penalty = 0.0;
for param in parameters.values() {
for &value in param.iter() {
penalty += value * value;
}
}
Ok(0.5 * self.lambda * penalty)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut gradients = HashMap::new();
for (name, param) in parameters {
let grad = param.mapv(|w| self.lambda * w);
gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
}
#[derive(Debug, Clone)]
pub struct ElasticNetRegularization {
pub lambda: f64,
pub l1_ratio: f64,
}
impl ElasticNetRegularization {
pub fn new(lambda: f64, l1_ratio: f64) -> TrainResult<Self> {
if !(0.0..=1.0).contains(&l1_ratio) {
return Err(TrainError::InvalidParameter(
"l1_ratio must be between 0.0 and 1.0".to_string(),
));
}
Ok(Self { lambda, l1_ratio })
}
}
impl Default for ElasticNetRegularization {
fn default() -> Self {
Self {
lambda: 0.01,
l1_ratio: 0.5,
}
}
}
impl Regularizer for ElasticNetRegularization {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
let mut l1_penalty = 0.0;
let mut l2_penalty = 0.0;
for param in parameters.values() {
for &value in param.iter() {
l1_penalty += value.abs();
l2_penalty += value * value;
}
}
let penalty =
self.lambda * (self.l1_ratio * l1_penalty + (1.0 - self.l1_ratio) * 0.5 * l2_penalty);
Ok(penalty)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut gradients = HashMap::new();
for (name, param) in parameters {
let grad = param
.mapv(|w| self.lambda * (self.l1_ratio * w.signum() + (1.0 - self.l1_ratio) * w));
gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
}
#[derive(Clone)]
pub struct CompositeRegularization {
regularizers: Vec<Box<dyn RegularizerClone>>,
}
impl std::fmt::Debug for CompositeRegularization {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeRegularization")
.field("num_regularizers", &self.regularizers.len())
.finish()
}
}
trait RegularizerClone: Regularizer {
fn clone_box(&self) -> Box<dyn RegularizerClone>;
}
impl<T: Regularizer + Clone + 'static> RegularizerClone for T {
fn clone_box(&self) -> Box<dyn RegularizerClone> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn RegularizerClone> {
fn clone(&self) -> Self {
self.clone_box()
}
}
impl Regularizer for Box<dyn RegularizerClone> {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
(**self).compute_penalty(parameters)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
(**self).compute_gradient(parameters)
}
}
impl CompositeRegularization {
pub fn new() -> Self {
Self {
regularizers: Vec::new(),
}
}
pub fn add<R: Regularizer + Clone + 'static>(&mut self, regularizer: R) {
self.regularizers.push(Box::new(regularizer));
}
pub fn len(&self) -> usize {
self.regularizers.len()
}
pub fn is_empty(&self) -> bool {
self.regularizers.is_empty()
}
}
impl Default for CompositeRegularization {
fn default() -> Self {
Self::new()
}
}
impl Regularizer for CompositeRegularization {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
let mut total_penalty = 0.0;
for regularizer in &self.regularizers {
total_penalty += regularizer.compute_penalty(parameters)?;
}
Ok(total_penalty)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut total_gradients: HashMap<String, Array<f64, Ix2>> = HashMap::new();
for (name, param) in parameters {
total_gradients.insert(name.clone(), Array::zeros(param.raw_dim()));
}
for regularizer in &self.regularizers {
let grads = regularizer.compute_gradient(parameters)?;
for (name, grad) in grads {
if let Some(total_grad) = total_gradients.get_mut(&name) {
*total_grad = &*total_grad + &grad;
}
}
}
Ok(total_gradients)
}
}
#[derive(Debug, Clone)]
pub struct SpectralNormalization {
pub target_norm: f64,
pub lambda: f64,
pub power_iterations: usize,
}
impl SpectralNormalization {
pub fn new(lambda: f64, target_norm: f64, power_iterations: usize) -> Self {
Self {
lambda,
target_norm,
power_iterations,
}
}
fn estimate_spectral_norm(&self, matrix: &Array<f64, Ix2>) -> f64 {
if matrix.is_empty() {
return 0.0;
}
let (nrows, ncols) = matrix.dim();
if nrows == 0 || ncols == 0 {
return 0.0;
}
let mut v = Array::from_elem((ncols,), 1.0 / (ncols as f64).sqrt());
for _ in 0..self.power_iterations {
let u = matrix.dot(&v);
let u_norm = u.iter().map(|&x| x * x).sum::<f64>().sqrt();
if u_norm < 1e-10 {
break;
}
let u = u / u_norm;
v = matrix.t().dot(&u);
let v_norm = v.iter().map(|&x| x * x).sum::<f64>().sqrt();
if v_norm < 1e-10 {
break;
}
v /= v_norm;
}
let final_u = matrix.dot(&v);
final_u.iter().map(|&x| x * x).sum::<f64>().sqrt()
}
}
impl Default for SpectralNormalization {
fn default() -> Self {
Self {
target_norm: 1.0,
lambda: 0.01,
power_iterations: 1,
}
}
}
impl Regularizer for SpectralNormalization {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
let mut penalty = 0.0;
for param in parameters.values() {
let spectral_norm = self.estimate_spectral_norm(param);
penalty += (spectral_norm - self.target_norm).powi(2);
}
Ok(self.lambda * penalty)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut gradients = HashMap::new();
for (name, param) in parameters {
let spectral_norm = self.estimate_spectral_norm(param);
if spectral_norm < 1e-10 {
gradients.insert(name.clone(), Array::zeros(param.dim()));
continue;
}
let frobenius_norm = param.iter().map(|&x| x * x).sum::<f64>().sqrt();
if frobenius_norm < 1e-10 {
gradients.insert(name.clone(), Array::zeros(param.dim()));
continue;
}
let scale = 2.0 * self.lambda * (spectral_norm - self.target_norm) / frobenius_norm;
let grad = param.mapv(|w| scale * w);
gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
}
#[derive(Debug, Clone)]
pub struct MaxNormRegularization {
pub max_norm: f64,
pub lambda: f64,
pub axis: usize,
}
impl MaxNormRegularization {
pub fn new(max_norm: f64, lambda: f64, axis: usize) -> Self {
Self {
max_norm,
lambda,
axis,
}
}
}
impl Default for MaxNormRegularization {
fn default() -> Self {
Self {
max_norm: 2.0,
lambda: 0.01,
axis: 0,
}
}
}
impl Regularizer for MaxNormRegularization {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
let mut penalty = 0.0;
for param in parameters.values() {
let axis_len = if self.axis == 0 {
param.nrows()
} else {
param.ncols()
};
for i in 0..axis_len {
let row_or_col = if self.axis == 0 {
param.row(i)
} else {
param.column(i)
};
let norm = row_or_col.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm > self.max_norm {
penalty += (norm - self.max_norm).powi(2);
}
}
}
Ok(self.lambda * penalty)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut gradients = HashMap::new();
for (name, param) in parameters {
let mut grad = Array::zeros(param.dim());
let axis_len = if self.axis == 0 {
param.nrows()
} else {
param.ncols()
};
for i in 0..axis_len {
let row_or_col = if self.axis == 0 {
param.row(i)
} else {
param.column(i)
};
let norm = row_or_col.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm > self.max_norm {
let scale = 2.0 * self.lambda * (norm - self.max_norm) / (norm + 1e-10);
for (j, &val) in row_or_col.iter().enumerate() {
if self.axis == 0 {
grad[[i, j]] = scale * val;
} else {
grad[[j, i]] = scale * val;
}
}
}
}
gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
}
#[derive(Debug, Clone)]
pub struct OrthogonalRegularization {
pub lambda: f64,
}
impl OrthogonalRegularization {
pub fn new(lambda: f64) -> Self {
Self { lambda }
}
}
impl Default for OrthogonalRegularization {
fn default() -> Self {
Self { lambda: 0.01 }
}
}
impl Regularizer for OrthogonalRegularization {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
let mut penalty = 0.0;
for param in parameters.values() {
let wt_w = param.t().dot(param);
let (n, _) = wt_w.dim();
for i in 0..n {
for j in 0..n {
let target = if i == j { 1.0 } else { 0.0 };
let diff = wt_w[[i, j]] - target;
penalty += diff * diff;
}
}
}
Ok(self.lambda * penalty)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut gradients = HashMap::new();
for (name, param) in parameters {
let wt_w = param.t().dot(param);
let (n, _) = wt_w.dim();
let mut identity = Array::zeros((n, n));
for i in 0..n {
identity[[i, i]] = 1.0;
}
let diff = &wt_w - &identity;
let grad = param.dot(&diff) * (2.0 * self.lambda);
gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
}
#[derive(Debug, Clone)]
pub struct GroupLassoRegularization {
pub lambda: f64,
pub group_size: usize,
}
impl GroupLassoRegularization {
pub fn new(lambda: f64, group_size: usize) -> Self {
Self { lambda, group_size }
}
}
impl Default for GroupLassoRegularization {
fn default() -> Self {
Self {
lambda: 0.01,
group_size: 10,
}
}
}
impl Regularizer for GroupLassoRegularization {
fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
let mut penalty = 0.0;
for param in parameters.values() {
let flat: Vec<f64> = param.iter().copied().collect();
for group in flat.chunks(self.group_size) {
let group_norm = group.iter().map(|&x| x * x).sum::<f64>().sqrt();
penalty += group_norm;
}
}
Ok(self.lambda * penalty)
}
fn compute_gradient(
&self,
parameters: &HashMap<String, Array<f64, Ix2>>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
let mut gradients = HashMap::new();
for (name, param) in parameters {
let mut grad_flat = Vec::new();
let flat: Vec<f64> = param.iter().copied().collect();
for group in flat.chunks(self.group_size) {
let group_norm = group.iter().map(|&x| x * x).sum::<f64>().sqrt();
if group_norm > 1e-10 {
let scale = self.lambda / group_norm;
grad_flat.extend(group.iter().map(|&x| scale * x));
} else {
grad_flat.extend(vec![0.0; group.len()]);
}
}
let grad = Array::from_shape_vec(param.dim(), grad_flat).map_err(|e| {
TrainError::ModelError(format!("Failed to reshape gradient: {}", e))
})?;
gradients.insert(name.clone(), grad);
}
Ok(gradients)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_l1_regularization() {
let regularizer = L1Regularization::new(0.1);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0, -2.0], [3.0, -4.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 1.0).abs() < 1e-6);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert_eq!(grad_w[[0, 0]], 0.1); assert_eq!(grad_w[[0, 1]], -0.1); assert_eq!(grad_w[[1, 0]], 0.1); assert_eq!(grad_w[[1, 1]], -0.1); }
#[test]
fn test_l2_regularization() {
let regularizer = L2Regularization::new(0.1);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 1.5).abs() < 1e-6);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert!((grad_w[[0, 0]] - 0.1).abs() < 1e-10); assert!((grad_w[[0, 1]] - 0.2).abs() < 1e-10); assert!((grad_w[[1, 0]] - 0.3).abs() < 1e-10); assert!((grad_w[[1, 1]] - 0.4).abs() < 1e-10); }
#[test]
fn test_elastic_net_regularization() {
let regularizer = ElasticNetRegularization::new(0.1, 0.5).expect("unwrap");
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0, 2.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!(penalty > 0.0);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert_eq!(grad_w.shape(), &[1, 2]);
}
#[test]
fn test_elastic_net_invalid_ratio() {
let result = ElasticNetRegularization::new(0.1, 1.5);
assert!(result.is_err());
let result = ElasticNetRegularization::new(0.1, -0.1);
assert!(result.is_err());
}
#[test]
fn test_composite_regularization() {
let mut composite = CompositeRegularization::new();
composite.add(L1Regularization::new(0.1));
composite.add(L2Regularization::new(0.1));
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0, 2.0]]);
let penalty = composite.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 0.55).abs() < 1e-6);
let gradients = composite.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert_eq!(grad_w.shape(), &[1, 2]);
assert!((grad_w[[0, 0]] - 0.2).abs() < 1e-6);
}
#[test]
fn test_composite_empty() {
let composite = CompositeRegularization::new();
assert!(composite.is_empty());
assert_eq!(composite.len(), 0);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0]]);
let penalty = composite.compute_penalty(¶ms).expect("unwrap");
assert_eq!(penalty, 0.0);
}
#[test]
fn test_multiple_parameters() {
let regularizer = L2Regularization::new(0.1);
let mut params = HashMap::new();
params.insert("w1".to_string(), array![[1.0, 2.0]]);
params.insert("w2".to_string(), array![[3.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 0.7).abs() < 1e-6);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
assert_eq!(gradients.len(), 2);
assert!(gradients.contains_key("w1"));
assert!(gradients.contains_key("w2"));
}
#[test]
fn test_zero_lambda() {
let regularizer = L1Regularization::new(0.0);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[100.0, 200.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert_eq!(penalty, 0.0);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert_eq!(grad_w[[0, 0]], 0.0);
assert_eq!(grad_w[[0, 1]], 0.0);
}
#[test]
fn test_spectral_normalization() {
let regularizer = SpectralNormalization::new(0.1, 1.0, 5);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[2.0, 0.0], [0.0, 1.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 0.1).abs() < 0.01);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
assert!(gradients.contains_key("w"));
}
#[test]
fn test_max_norm_regularization() {
let regularizer = MaxNormRegularization::new(1.0, 0.1, 0);
let mut params = HashMap::new();
params.insert(
"w".to_string(),
array![[3.0, 4.0], [0.1, 0.1]], );
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 1.6).abs() < 0.1);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert!(grad_w[[0, 0]].abs() > 0.0);
assert!(grad_w[[1, 0]].abs() < 1e-10);
}
#[test]
fn test_orthogonal_regularization() {
let regularizer = OrthogonalRegularization::new(0.1);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0, 0.0], [0.0, 1.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!(penalty.abs() < 1e-10);
params.insert("w".to_string(), array![[1.0, 1.0], [1.0, 1.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!(penalty > 0.0);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
assert!(gradients.contains_key("w"));
}
#[test]
fn test_group_lasso_regularization() {
let regularizer = GroupLassoRegularization::new(0.1, 2);
let mut params = HashMap::new();
params.insert(
"w".to_string(),
array![[1.0, 2.0], [3.0, 4.0]], );
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 0.7236).abs() < 0.01);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert_eq!(grad_w.dim(), (2, 2));
}
#[test]
fn test_spectral_normalization_zero_matrix() {
let regularizer = SpectralNormalization::new(0.1, 1.0, 5);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[0.0, 0.0], [0.0, 0.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 0.1).abs() < 0.01);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert!(grad_w.iter().all(|&x| x.abs() < 1e-10));
}
#[test]
fn test_max_norm_no_violation() {
let regularizer = MaxNormRegularization::new(10.0, 0.1, 0);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!(penalty.abs() < 1e-10);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert!(grad_w.iter().all(|&x| x.abs() < 1e-10));
}
#[test]
fn test_orthogonal_non_square() {
let regularizer = OrthogonalRegularization::new(0.1);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!(penalty > 0.0);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
assert!(gradients.contains_key("w"));
}
#[test]
fn test_group_lasso_single_group() {
let regularizer = GroupLassoRegularization::new(0.1, 4);
let mut params = HashMap::new();
params.insert("w".to_string(), array![[3.0, 4.0]]);
let penalty = regularizer.compute_penalty(¶ms).expect("unwrap");
assert!((penalty - 0.5).abs() < 0.01);
let gradients = regularizer.compute_gradient(¶ms).expect("unwrap");
let grad_w = gradients.get("w").expect("unwrap");
assert_eq!(grad_w.dim(), (1, 2));
}
}