#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::error::{Result, TuneError};
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RegularizationConfig {
pub dropout: f32,
pub label_smoothing: f32,
pub gradient_clip: Option<f32>,
pub mixup_alpha: Option<f32>,
}
impl Default for RegularizationConfig {
fn default() -> Self {
Self {
dropout: 0.1,
label_smoothing: 0.1,
gradient_clip: Some(1.0),
mixup_alpha: None,
}
}
}
impl RegularizationConfig {
pub fn none() -> Self {
Self {
dropout: 0.0,
label_smoothing: 0.0,
gradient_clip: None,
mixup_alpha: None,
}
}
pub fn clip_grad_norm(gradients: &mut [f32], max_norm: f32) -> f32 {
let norm_sq: f32 = gradients.iter().map(|g| g * g).sum();
let norm = norm_sq.sqrt();
if norm > max_norm && norm > 0.0 {
let scale = max_norm / norm;
for g in gradients.iter_mut() {
*g *= scale;
}
}
norm
}
pub fn apply_gradient_clip(&self, gradients: &mut [f32]) -> Option<f32> {
self.gradient_clip
.map(|max_norm| Self::clip_grad_norm(gradients, max_norm))
}
pub fn light() -> Self {
Self {
dropout: 0.05,
label_smoothing: 0.05,
gradient_clip: Some(1.0),
mixup_alpha: None,
}
}
pub fn strong() -> Self {
Self {
dropout: 0.3,
label_smoothing: 0.2,
gradient_clip: Some(0.5),
mixup_alpha: Some(0.2),
}
}
pub fn validate(&self) -> Result<()> {
if !(0.0..=1.0).contains(&self.dropout) {
return Err(TuneError::InvalidConfig(format!(
"dropout must be between 0 and 1, got {}",
self.dropout
)));
}
if !(0.0..=1.0).contains(&self.label_smoothing) {
return Err(TuneError::InvalidConfig(format!(
"label_smoothing must be between 0 and 1, got {}",
self.label_smoothing
)));
}
if let Some(clip) = self.gradient_clip {
if clip <= 0.0 {
return Err(TuneError::InvalidConfig(
"gradient_clip must be > 0".to_string(),
));
}
}
if let Some(alpha) = self.mixup_alpha {
if alpha <= 0.0 {
return Err(TuneError::InvalidConfig(
"mixup_alpha must be > 0".to_string(),
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_regularization_config() {
let none = RegularizationConfig::none();
assert_eq!(none.dropout, 0.0);
assert!(none.validate().is_ok());
let strong = RegularizationConfig::strong();
assert!(strong.dropout > 0.0);
assert!(strong.validate().is_ok());
}
#[test]
fn test_clip_grad_norm_exceeds_max() {
let mut grads = vec![3.0, 4.0];
let original_norm = RegularizationConfig::clip_grad_norm(&mut grads, 1.0);
assert!((original_norm - 5.0).abs() < 1e-6);
let new_norm: f32 = grads.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((new_norm - 1.0).abs() < 1e-6);
assert!((grads[0] / grads[1] - 0.75).abs() < 1e-6);
}
#[test]
fn test_clip_grad_norm_under_max() {
let mut grads = vec![1.0, 1.0, 1.0, 1.0]; let original_norm = RegularizationConfig::clip_grad_norm(&mut grads, 5.0);
assert!((original_norm - 2.0).abs() < 1e-6);
assert_eq!(grads, vec![1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn test_clip_grad_norm_empty() {
let mut grads: Vec<f32> = vec![];
let norm = RegularizationConfig::clip_grad_norm(&mut grads, 1.0);
assert_eq!(norm, 0.0);
}
#[test]
fn test_clip_grad_norm_zero_grads() {
let mut grads = vec![0.0, 0.0, 0.0];
let norm = RegularizationConfig::clip_grad_norm(&mut grads, 1.0);
assert_eq!(norm, 0.0);
assert_eq!(grads, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_apply_gradient_clip_enabled() {
let reg = RegularizationConfig {
gradient_clip: Some(1.0),
..RegularizationConfig::none()
};
let mut grads = vec![3.0, 4.0]; let result = reg.apply_gradient_clip(&mut grads);
assert!(result.is_some());
assert!((result.unwrap() - 5.0).abs() < 1e-6);
let new_norm: f32 = grads.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((new_norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_apply_gradient_clip_disabled() {
let reg = RegularizationConfig::none();
let mut grads = vec![3.0, 4.0];
let result = reg.apply_gradient_clip(&mut grads);
assert!(result.is_none());
assert_eq!(grads, vec![3.0, 4.0]);
}
}