optirs_core/regularizers/
label_smoothing.rs1use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12use crate::regularizers::Regularizer;
13
14#[derive(Debug, Clone)]
34pub struct LabelSmoothing<A: Float> {
35 alpha: A,
37 num_classes: usize,
39}
40
41impl<A: Float + Debug + ScalarOperand + FromPrimitive + Send + Sync> LabelSmoothing<A> {
42 pub fn new(alpha: A, numclasses: usize) -> Result<Self> {
53 if alpha < A::zero() || alpha > A::one() {
54 return Err(OptimError::InvalidConfig(
55 "Alpha must be between 0 and 1".to_string(),
56 ));
57 }
58
59 Ok(Self {
60 alpha,
61 num_classes: numclasses,
62 })
63 }
64
65 pub fn smooth_labels(&self, labels: &Array1<A>) -> Result<Array1<A>> {
80 if labels.len() != self.num_classes {
81 return Err(OptimError::InvalidConfig(format!(
82 "Expected {} classes, got {} in label vector",
83 self.num_classes,
84 labels.len()
85 )));
86 }
87
88 let uniform_val = A::one() / A::from_usize(self.num_classes).unwrap();
89 let smooth_coef = self.alpha;
90 let one_minus_alpha = A::one() - smooth_coef;
91
92 let smoothed = labels.map(|&y| one_minus_alpha * y + smooth_coef * uniform_val);
94
95 Ok(smoothed)
96 }
97
98 pub fn smooth_batch<D>(&self, labels: &Array<A, D>) -> Result<Array<A, D>>
108 where
109 D: Dimension,
110 {
111 if labels.shape().last().unwrap_or(&0) != &self.num_classes {
113 return Err(OptimError::InvalidConfig(
114 "Last dimension must match number of classes".to_string(),
115 ));
116 }
117
118 let uniform_val = A::one() / A::from_usize(self.num_classes).unwrap();
120 let smooth_coef = self.alpha;
121 let one_minus_alpha = A::one() - smooth_coef;
122
123 let smoothed = labels.map(|&y| one_minus_alpha * y + smooth_coef * uniform_val);
125
126 Ok(smoothed)
127 }
128
129 pub fn cross_entropy_loss(&self, logits: &Array1<A>, labels: &Array1<A>, eps: A) -> Result<A> {
141 if logits.len() != self.num_classes || labels.len() != self.num_classes {
142 return Err(OptimError::InvalidConfig(
143 "Logits and labels must match number of classes".to_string(),
144 ));
145 }
146
147 let max_logit = logits.fold(A::neg_infinity(), |max, &v| if v > max { v } else { max });
149 let exp_logits = logits.map(|&l| (l - max_logit).exp());
150 let sum_exp = exp_logits.sum();
151 let probs = exp_logits.map(|&e| e / (sum_exp + eps));
152
153 let smoothed_labels = self.smooth_labels(labels)?;
155
156 let mut loss = A::zero();
158 for (p, y) in probs.iter().zip(smoothed_labels.iter()) {
159 loss = loss - *y * (*p + eps).ln();
160 }
161
162 Ok(loss)
163 }
164}
165
166impl<A: Float + Debug + ScalarOperand + FromPrimitive, D: Dimension + Send + Sync> Regularizer<A, D>
168 for LabelSmoothing<A>
169{
170 fn apply(&self, _params: &Array<A, D>, gradients: &mut Array<A, D>) -> Result<A> {
171 Ok(A::zero())
174 }
175
176 fn penalty(&self, params: &Array<A, D>) -> Result<A> {
177 Ok(A::zero())
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use approx::assert_relative_eq;
186 use scirs2_core::ndarray::array;
187
188 #[test]
189 fn test_label_smoothing_creation() {
190 let ls = LabelSmoothing::<f64>::new(0.1, 3).unwrap();
191 assert_eq!(ls.alpha, 0.1);
192 assert_eq!(ls.num_classes, 3);
193
194 assert!(LabelSmoothing::<f64>::new(-0.1, 3).is_err());
196 assert!(LabelSmoothing::<f64>::new(1.1, 3).is_err());
197 }
198
199 #[test]
200 fn test_smooth_labels() {
201 let ls = LabelSmoothing::new(0.1, 3).unwrap();
202 let one_hot = array![0.0, 1.0, 0.0];
203
204 let smoothed = ls.smooth_labels(&one_hot).unwrap();
205
206 let uniform_val = 1.0 / 3.0;
208 let expected_1 = 0.9 * 1.0 + 0.1 * uniform_val;
209 let expected_0 = 0.9 * 0.0 + 0.1 * uniform_val;
210
211 assert_relative_eq!(smoothed[0], expected_0, epsilon = 1e-5);
212 assert_relative_eq!(smoothed[1], expected_1, epsilon = 1e-5);
213 assert_relative_eq!(smoothed[2], expected_0, epsilon = 1e-5);
214
215 assert_relative_eq!(smoothed.sum(), 1.0, epsilon = 1e-5);
217 }
218
219 #[test]
220 fn test_full_smoothing() {
221 let ls = LabelSmoothing::new(1.0, 4).unwrap();
222 let one_hot = array![0.0, 0.0, 1.0, 0.0];
223
224 let smoothed = ls.smooth_labels(&one_hot).unwrap();
225
226 for i in 0..4 {
228 assert_relative_eq!(smoothed[i], 0.25, epsilon = 1e-5);
229 }
230 }
231
232 #[test]
233 fn test_no_smoothing() {
234 let ls = LabelSmoothing::new(0.0, 3).unwrap();
235 let one_hot = array![0.0, 1.0, 0.0];
236
237 let smoothed = ls.smooth_labels(&one_hot).unwrap();
238
239 for i in 0..3 {
241 assert_relative_eq!(smoothed[i], one_hot[i], epsilon = 1e-5);
242 }
243 }
244
245 #[test]
246 fn test_smooth_batch() {
247 let ls = LabelSmoothing::new(0.2, 2).unwrap();
248 let batch = array![[1.0, 0.0], [0.0, 1.0]];
249
250 let smoothed = ls.smooth_batch(&batch).unwrap();
251
252 assert_relative_eq!(smoothed[[0, 0]], 0.9, epsilon = 1e-5);
256 assert_relative_eq!(smoothed[[0, 1]], 0.1, epsilon = 1e-5);
257 assert_relative_eq!(smoothed[[1, 0]], 0.1, epsilon = 1e-5);
258 assert_relative_eq!(smoothed[[1, 1]], 0.9, epsilon = 1e-5);
259 }
260
261 #[test]
262 fn test_cross_entropy_loss() {
263 let ls = LabelSmoothing::new(0.1, 3).unwrap();
264 let labels = array![0.0, 1.0, 0.0];
265 let logits = array![1.0, 2.0, 0.5];
266
267 let loss = ls.cross_entropy_loss(&logits, &labels, 1e-8).unwrap();
268
269 assert!(loss > 0.0 && loss.is_finite());
271 }
272
273 #[test]
274 fn test_regularizer_trait() {
275 let ls = LabelSmoothing::new(0.1, 3).unwrap();
276 let params = array![[1.0, 2.0], [3.0, 4.0]];
277 let mut gradients = array![[0.1, 0.2], [0.3, 0.4]];
278 let original_gradients = gradients.clone();
279
280 let penalty = ls.apply(¶ms, &mut gradients).unwrap();
281
282 assert_eq!(penalty, 0.0);
284
285 assert_eq!(gradients, original_gradients);
287 }
288}