use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{OptimError, Result};
use crate::regularizers::Regularizer;
#[derive(Debug, Clone)]
pub struct LabelSmoothing<A: Float> {
alpha: A,
num_classes: usize,
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> LabelSmoothing<A> {
pub fn new(alpha: A, numclasses: usize) -> Result<Self> {
if alpha < A::zero() || alpha > A::one() {
return Err(OptimError::InvalidConfig(
"Alpha must be between 0 and 1".to_string(),
));
}
Ok(Self {
alpha,
num_classes: numclasses,
})
}
pub fn smooth_labels(&self, labels: &Array1<A>) -> Result<Array1<A>> {
if labels.len() != self.num_classes {
return Err(OptimError::InvalidConfig(format!(
"Expected {} classes, got {} in label vector",
self.num_classes,
labels.len()
)));
}
let uniform_val = A::one() / A::from_usize(self.num_classes).expect("unwrap failed");
let smooth_coef = self.alpha;
let one_minus_alpha = A::one() - smooth_coef;
let smoothed = labels.map(|&y| one_minus_alpha * y + smooth_coef * uniform_val);
Ok(smoothed)
}
pub fn smooth_batch<D>(&self, labels: &Array<A, D>) -> Result<Array<A, D>>
where
D: Dimension,
{
if labels.shape().last().unwrap_or(&0) != &self.num_classes {
return Err(OptimError::InvalidConfig(
"Last dimension must match number of classes".to_string(),
));
}
let uniform_val = A::one() / A::from_usize(self.num_classes).expect("unwrap failed");
let smooth_coef = self.alpha;
let one_minus_alpha = A::one() - smooth_coef;
let smoothed = labels.map(|&y| one_minus_alpha * y + smooth_coef * uniform_val);
Ok(smoothed)
}
pub fn cross_entropy_loss(&self, logits: &Array1<A>, labels: &Array1<A>, eps: A) -> Result<A> {
if logits.len() != self.num_classes || labels.len() != self.num_classes {
return Err(OptimError::InvalidConfig(
"Logits and labels must match number of classes".to_string(),
));
}
let max_logit = logits.fold(A::neg_infinity(), |max, &v| if v > max { v } else { max });
let exp_logits = logits.map(|&l| (l - max_logit).exp());
let sum_exp = exp_logits.sum();
let probs = exp_logits.map(|&e| e / (sum_exp + eps));
let smoothed_labels = self.smooth_labels(labels)?;
let mut loss = A::zero();
for (p, y) in probs.iter().zip(smoothed_labels.iter()) {
loss = loss - *y * (*p + eps).ln();
}
Ok(loss)
}
}
impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
for LabelSmoothing<A>
{
fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
Ok(A::zero())
}
fn penalty(&self, params: &Array<A, D>) -> Result<A> {
Ok(A::zero())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_label_smoothing_creation() {
let ls = LabelSmoothing::<f64>::new(0.1, 3).expect("unwrap failed");
assert_eq!(ls.alpha, 0.1);
assert_eq!(ls.num_classes, 3);
assert!(LabelSmoothing::<f64>::new(-0.1, 3).is_err());
assert!(LabelSmoothing::<f64>::new(1.1, 3).is_err());
}
#[test]
fn test_smooth_labels() {
let ls = LabelSmoothing::new(0.1, 3).expect("unwrap failed");
let one_hot = array![0.0, 1.0, 0.0];
let smoothed = ls.smooth_labels(&one_hot).expect("unwrap failed");
let uniform_val = 1.0 / 3.0;
let expected_1 = 0.9 * 1.0 + 0.1 * uniform_val;
let expected_0 = 0.9 * 0.0 + 0.1 * uniform_val;
assert_relative_eq!(smoothed[0], expected_0, epsilon = 1e-5);
assert_relative_eq!(smoothed[1], expected_1, epsilon = 1e-5);
assert_relative_eq!(smoothed[2], expected_0, epsilon = 1e-5);
assert_relative_eq!(smoothed.sum(), 1.0, epsilon = 1e-5);
}
#[test]
fn test_full_smoothing() {
let ls = LabelSmoothing::new(1.0, 4).expect("unwrap failed");
let one_hot = array![0.0, 0.0, 1.0, 0.0];
let smoothed = ls.smooth_labels(&one_hot).expect("unwrap failed");
for i in 0..4 {
assert_relative_eq!(smoothed[i], 0.25, epsilon = 1e-5);
}
}
#[test]
fn test_no_smoothing() {
let ls = LabelSmoothing::new(0.0, 3).expect("unwrap failed");
let one_hot = array![0.0, 1.0, 0.0];
let smoothed = ls.smooth_labels(&one_hot).expect("unwrap failed");
for i in 0..3 {
assert_relative_eq!(smoothed[i], one_hot[i], epsilon = 1e-5);
}
}
#[test]
fn test_smooth_batch() {
let ls = LabelSmoothing::new(0.2, 2).expect("unwrap failed");
let batch = array![[1.0, 0.0], [0.0, 1.0]];
let smoothed = ls.smooth_batch(&batch).expect("unwrap failed");
assert_relative_eq!(smoothed[[0, 0]], 0.9, epsilon = 1e-5);
assert_relative_eq!(smoothed[[0, 1]], 0.1, epsilon = 1e-5);
assert_relative_eq!(smoothed[[1, 0]], 0.1, epsilon = 1e-5);
assert_relative_eq!(smoothed[[1, 1]], 0.9, epsilon = 1e-5);
}
#[test]
fn test_cross_entropy_loss() {
let ls = LabelSmoothing::new(0.1, 3).expect("unwrap failed");
let labels = array![0.0, 1.0, 0.0];
let logits = array![1.0, 2.0, 0.5];
let loss = ls
.cross_entropy_loss(&logits, &labels, 1e-8)
.expect("unwrap failed");
assert!(loss > 0.0 && loss.is_finite());
}
#[test]
fn test_regularizer_trait() {
let ls = LabelSmoothing::new(0.1, 3).expect("unwrap failed");
let params = array![[1.0, 2.0], [3.0, 4.0]];
let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
let original_gradients = gradients.clone();
let penalty = ls.apply(¶ms, &mut gradients).expect("unwrap failed");
assert_eq!(penalty, 0.0);
assert_eq!(gradients, original_gradients);
}
}