optirs_core/regularizers/
label_smoothing.rs

1// Label Smoothing regularization
2//
3// Label smoothing is a regularization technique that prevents the model from
4// becoming over-confident by replacing hard one-hot encoded targets with
5// soft targets that include some probability for incorrect classes.
6
7use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12use crate::regularizers::Regularizer;
13
14/// Label Smoothing regularization
15///
16/// Implements label smoothing by replacing one-hot encoded target vectors with
17/// "smoother" target distributions, where some probability mass is assigned to
18/// non-target classes.
19///
20/// # Example
21///
22/// ```
23/// use scirs2_core::ndarray::array;
24/// use optirs_core::regularizers::LabelSmoothing;
25///
26/// let label_smooth = LabelSmoothing::new(0.1, 3).unwrap();
27/// let one_hot_target = array![0.0, 1.0, 0.0];
28///
29/// // Apply label smoothing to one-hot targets
30/// let smoothed_target = label_smooth.smooth_labels(&one_hot_target).unwrap();
31/// // Result will be [0.033..., 0.933..., 0.033...]
32/// ```
33#[derive(Debug, Clone)]
34pub struct LabelSmoothing<A: Float> {
35    /// Smoothing factor (between 0 and 1)
36    alpha: A,
37    /// Number of classes
38    num_classes: usize,
39}
40
41impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> LabelSmoothing<A> {
42    /// Create a new label smoothing regularizer
43    ///
44    /// # Arguments
45    ///
46    /// * `alpha` - Smoothing factor, where 0 gives one-hot encoding and 1 gives uniform distribution
47    /// * `num_classes` - Number of classes in the classification task
48    ///
49    /// # Errors
50    ///
51    /// Returns an error if alpha is not between 0 and 1
52    pub fn new(alpha: A, numclasses: usize) -> Result<Self> {
53        if alpha < A::zero() || alpha > A::one() {
54            return Err(OptimError::InvalidConfig(
55                "Alpha must be between 0 and 1".to_string(),
56            ));
57        }
58
59        Ok(Self {
60            alpha,
61            num_classes: numclasses,
62        })
63    }
64
65    /// Smooth the one-hot encoded target labels
66    ///
67    /// # Arguments
68    ///
69    /// * `labels` - One-hot encoded target labels
70    ///
71    /// # Returns
72    ///
73    /// The smoothed labels
74    ///
75    /// # Example
76    ///
77    /// For a 3-class problem with smoothing factor 0.1:
78    /// [0, 1, 0] -> [0.033..., 0.933..., 0.033...]
79    pub fn smooth_labels(&self, labels: &Array1<A>) -> Result<Array1<A>> {
80        if labels.len() != self.num_classes {
81            return Err(OptimError::InvalidConfig(format!(
82                "Expected {} classes, got {} in label vector",
83                self.num_classes,
84                labels.len()
85            )));
86        }
87
88        let uniform_val = A::one() / A::from_usize(self.num_classes).unwrap();
89        let smooth_coef = self.alpha;
90        let one_minus_alpha = A::one() - smooth_coef;
91
92        // Compute (1 - alpha) * y + alpha * uniform
93        let smoothed = labels.map(|&y| one_minus_alpha * y + smooth_coef * uniform_val);
94
95        Ok(smoothed)
96    }
97
98    /// Apply label smoothing to a batch of one-hot encoded targets
99    ///
100    /// # Arguments
101    ///
102    /// * `labels` - Batch of one-hot encoded target labels
103    ///
104    /// # Returns
105    ///
106    /// The smoothed labels for the batch
107    pub fn smooth_batch<D>(&self, labels: &Array<A, D>) -> Result<Array<A, D>>
108    where
109        D: Dimension,
110    {
111        // Ensure the last dimension is the class dimension
112        if labels.shape().last().unwrap_or(&0) != &self.num_classes {
113            return Err(OptimError::InvalidConfig(
114                "Last dimension must match number of classes".to_string(),
115            ));
116        }
117
118        // Apply smoothing to each label vector
119        let uniform_val = A::one() / A::from_usize(self.num_classes).unwrap();
120        let smooth_coef = self.alpha;
121        let one_minus_alpha = A::one() - smooth_coef;
122
123        // Compute (1 - alpha) * y + alpha * uniform for each element
124        let smoothed = labels.map(|&y| one_minus_alpha * y + smooth_coef * uniform_val);
125
126        Ok(smoothed)
127    }
128
129    /// Compute cross-entropy loss with label smoothing
130    ///
131    /// # Arguments
132    ///
133    /// * `logits` - Raw model outputs (unnormalized)
134    /// * `labels` - One-hot encoded target labels
135    /// * `eps` - Small value for numerical stability
136    ///
137    /// # Returns
138    ///
139    /// The smoothed cross-entropy loss
140    pub fn cross_entropy_loss(&self, logits: &Array1<A>, labels: &Array1<A>, eps: A) -> Result<A> {
141        if logits.len() != self.num_classes || labels.len() != self.num_classes {
142            return Err(OptimError::InvalidConfig(
143                "Logits and labels must match number of classes".to_string(),
144            ));
145        }
146
147        // Compute softmax probabilities
148        let max_logit = logits.fold(A::neg_infinity(), |max, &v| if v > max { v } else { max });
149        let exp_logits = logits.map(|&l| (l - max_logit).exp());
150        let sum_exp = exp_logits.sum();
151        let probs = exp_logits.map(|&e| e / (sum_exp + eps));
152
153        // Smooth the labels
154        let smoothed_labels = self.smooth_labels(labels)?;
155
156        // Compute cross-entropy with smoothed labels
157        let mut loss = A::zero();
158        for (p, y) in probs.iter().zip(smoothed_labels.iter()) {
159            loss = loss - *y * (*p + eps).ln();
160        }
161
162        Ok(loss)
163    }
164}
165
166// Implement Regularizer trait (though it's not the primary interface for label smoothing)
167impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
168    for LabelSmoothing<A>
169{
170    fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
171        // Label smoothing is not applied to model parameters directly
172        // It's applied to the target labels during loss computation
173        Ok(A::zero())
174    }
175
176    fn penalty(&self, params: &Array<A, D>) -> Result<A> {
177        // Label smoothing doesn't add a parameter penalty term
178        Ok(A::zero())
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use approx::assert_relative_eq;
186    use scirs2_core::ndarray::array;
187
188    #[test]
189    fn test_label_smoothing_creation() {
190        let ls = LabelSmoothing::<f64>::new(0.1, 3).unwrap();
191        assert_eq!(ls.alpha, 0.1);
192        assert_eq!(ls.num_classes, 3);
193
194        // Alpha out of range should fail
195        assert!(LabelSmoothing::<f64>::new(-0.1, 3).is_err());
196        assert!(LabelSmoothing::<f64>::new(1.1, 3).is_err());
197    }
198
199    #[test]
200    fn test_smooth_labels() {
201        let ls = LabelSmoothing::new(0.1, 3).unwrap();
202        let one_hot = array![0.0, 1.0, 0.0];
203
204        let smoothed = ls.smooth_labels(&one_hot).unwrap();
205
206        // Expected: [0.033..., 0.933..., 0.033...]
207        let uniform_val = 1.0 / 3.0;
208        let expected_1 = 0.9 * 1.0 + 0.1 * uniform_val;
209        let expected_0 = 0.9 * 0.0 + 0.1 * uniform_val;
210
211        assert_relative_eq!(smoothed[0], expected_0, epsilon = 1e-5);
212        assert_relative_eq!(smoothed[1], expected_1, epsilon = 1e-5);
213        assert_relative_eq!(smoothed[2], expected_0, epsilon = 1e-5);
214
215        // Sum should still be 1
216        assert_relative_eq!(smoothed.sum(), 1.0, epsilon = 1e-5);
217    }
218
219    #[test]
220    fn test_full_smoothing() {
221        let ls = LabelSmoothing::new(1.0, 4).unwrap();
222        let one_hot = array![0.0, 0.0, 1.0, 0.0];
223
224        let smoothed = ls.smooth_labels(&one_hot).unwrap();
225
226        // With alpha=1, should be uniform distribution [0.25, 0.25, 0.25, 0.25]
227        for i in 0..4 {
228            assert_relative_eq!(smoothed[i], 0.25, epsilon = 1e-5);
229        }
230    }
231
232    #[test]
233    fn test_no_smoothing() {
234        let ls = LabelSmoothing::new(0.0, 3).unwrap();
235        let one_hot = array![0.0, 1.0, 0.0];
236
237        let smoothed = ls.smooth_labels(&one_hot).unwrap();
238
239        // With alpha=0, should be identical to input
240        for i in 0..3 {
241            assert_relative_eq!(smoothed[i], one_hot[i], epsilon = 1e-5);
242        }
243    }
244
245    #[test]
246    fn test_smooth_batch() {
247        let ls = LabelSmoothing::new(0.2, 2).unwrap();
248        let batch = array![[1.0, 0.0], [0.0, 1.0]];
249
250        let smoothed = ls.smooth_batch(&batch).unwrap();
251
252        // With alpha=0.2 and 2 classes, uniform_val = 0.5
253        // For label 1.0: (1 - 0.2) * 1.0 + 0.2 * 0.5 = 0.8 + 0.1 = 0.9
254        // For label 0.0: (1 - 0.2) * 0.0 + 0.2 * 0.5 = 0.0 + 0.1 = 0.1
255        assert_relative_eq!(smoothed[[0, 0]], 0.9, epsilon = 1e-5);
256        assert_relative_eq!(smoothed[[0, 1]], 0.1, epsilon = 1e-5);
257        assert_relative_eq!(smoothed[[1, 0]], 0.1, epsilon = 1e-5);
258        assert_relative_eq!(smoothed[[1, 1]], 0.9, epsilon = 1e-5);
259    }
260
261    #[test]
262    fn test_cross_entropy_loss() {
263        let ls = LabelSmoothing::new(0.1, 3).unwrap();
264        let labels = array![0.0, 1.0, 0.0];
265        let logits = array![1.0, 2.0, 0.5];
266
267        let loss = ls.cross_entropy_loss(&logits, &labels, 1e-8).unwrap();
268
269        // Loss should be positive and finite
270        assert!(loss > 0.0 && loss.is_finite());
271    }
272
273    #[test]
274    fn test_regularizer_trait() {
275        let ls = LabelSmoothing::new(0.1, 3).unwrap();
276        let params = array![[1.0, 2.0], [3.0, 4.0]];
277        let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
278        let original_gradients = gradients.clone();
279
280        let penalty = ls.apply(&params, &mut gradients).unwrap();
281
282        // Penalty should be zero
283        assert_eq!(penalty, 0.0);
284
285        // Gradients should be unchanged
286        assert_eq!(gradients, original_gradients);
287    }
288}