Skip to main content

aprender/regularization/
mod.rs

1//! Regularization techniques for neural network training.
2//!
3//! # Techniques
4//! - Mixup: Interpolate between training samples
5//! - Label Smoothing: Soft targets instead of hard labels
6//! - `CutMix`: Cut and paste patches between images
7
8use crate::primitives::Vector;
9use rand::Rng;
10
11/// Mixup data augmentation (Zhang et al., 2018).
12/// Creates virtual training examples: x' = `λx_i` + (1-λ)x_j
13#[derive(Debug, Clone)]
14pub struct Mixup {
15    alpha: f32,
16}
17
18impl Mixup {
19    /// Create new Mixup with alpha parameter for Beta distribution.
20    #[must_use]
21    pub fn new(alpha: f32) -> Self {
22        Self { alpha }
23    }
24
25    /// Sample mixing coefficient from Beta(alpha, alpha).
26    #[must_use]
27    pub fn sample_lambda(&self) -> f32 {
28        if self.alpha <= 0.0 {
29            return 1.0;
30        }
31        sample_beta(self.alpha, self.alpha)
32    }
33
34    /// Mix two samples: x' = λ*x1 + (1-λ)*x2
35    #[must_use]
36    pub fn mix_samples(&self, x1: &Vector<f32>, x2: &Vector<f32>, lambda: f32) -> Vector<f32> {
37        let mixed: Vec<f32> = x1
38            .as_slice()
39            .iter()
40            .zip(x2.as_slice().iter())
41            .map(|(&a, &b)| lambda * a + (1.0 - lambda) * b)
42            .collect();
43        Vector::from_slice(&mixed)
44    }
45
46    /// Mix labels: y' = λ*y1 + (1-λ)*y2
47    #[must_use]
48    pub fn mix_labels(&self, y1: &Vector<f32>, y2: &Vector<f32>, lambda: f32) -> Vector<f32> {
49        self.mix_samples(y1, y2, lambda)
50    }
51
52    #[must_use]
53    pub fn alpha(&self) -> f32 {
54        self.alpha
55    }
56}
57
58/// Label smoothing for soft targets.
59/// Converts hard labels to: (1-ε)y + ε/K
60#[derive(Debug, Clone)]
61pub struct LabelSmoothing {
62    epsilon: f32,
63}
64
65impl LabelSmoothing {
66    /// Create label smoothing with smoothing factor ε.
67    #[must_use]
68    pub fn new(epsilon: f32) -> Self {
69        assert!((0.0..1.0).contains(&epsilon));
70        Self { epsilon }
71    }
72
73    /// Smooth a one-hot label vector.
74    #[must_use]
75    pub fn smooth(&self, label: &Vector<f32>) -> Vector<f32> {
76        let n_classes = label.len();
77        let smoothed: Vec<f32> = label
78            .as_slice()
79            .iter()
80            .map(|&y| (1.0 - self.epsilon) * y + self.epsilon / n_classes as f32)
81            .collect();
82        Vector::from_slice(&smoothed)
83    }
84
85    /// Create smoothed one-hot from class index.
86    #[must_use]
87    pub fn smooth_index(&self, class_idx: usize, n_classes: usize) -> Vector<f32> {
88        let mut result = vec![self.epsilon / n_classes as f32; n_classes];
89        result[class_idx] = 1.0 - self.epsilon + self.epsilon / n_classes as f32;
90        Vector::from_slice(&result)
91    }
92
93    #[must_use]
94    pub fn epsilon(&self) -> f32 {
95        self.epsilon
96    }
97}
98
99/// Cross-entropy loss with label smoothing.
100#[must_use]
101pub fn cross_entropy_with_smoothing(logits: &Vector<f32>, target_idx: usize, epsilon: f32) -> f32 {
102    let n_classes = logits.len();
103    let probs = softmax(logits.as_slice());
104
105    let mut loss = 0.0;
106    for (i, &p) in probs.iter().enumerate() {
107        let target = if i == target_idx {
108            1.0 - epsilon + epsilon / n_classes as f32
109        } else {
110            epsilon / n_classes as f32
111        };
112        loss -= target * p.max(1e-10).ln();
113    }
114    loss
115}
116
117/// Sample from Beta distribution using Gamma samples.
118fn sample_beta(alpha: f32, beta: f32) -> f32 {
119    let mut rng = rand::thread_rng();
120    let x = sample_gamma(alpha, &mut rng);
121    let y = sample_gamma(beta, &mut rng);
122    let sum = x + y;
123    // With extreme shape parameters (e.g. 0.01), f32 gamma samples can
124    // underflow to 0.0, producing 0/0 = NaN. Return 0.5 in that case
125    // (unbiased midpoint, correct for the symmetric alpha==beta case).
126    if sum <= 0.0 {
127        return 0.5;
128    }
129    (x / sum).clamp(0.0, 1.0)
130}
131
132fn sample_gamma(shape: f32, rng: &mut impl Rng) -> f32 {
133    // Marsaglia and Tsang's method
134    if shape < 1.0 {
135        return sample_gamma(1.0 + shape, rng) * rng.gen::<f32>().powf(1.0 / shape);
136    }
137    let d = shape - 1.0 / 3.0;
138    let c = 1.0 / (9.0 * d).sqrt();
139    loop {
140        let x: f32 = sample_normal(rng);
141        let v = (1.0 + c * x).powi(3);
142        if v > 0.0 {
143            let u: f32 = rng.gen();
144            if u < 1.0 - 0.0331 * x.powi(4) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
145                return d * v;
146            }
147        }
148    }
149}
150
151fn sample_normal(rng: &mut impl Rng) -> f32 {
152    let u1: f32 = rng.gen::<f32>().max(1e-10);
153    let u2: f32 = rng.gen();
154    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
155}
156
157fn softmax(logits: &[f32]) -> Vec<f32> {
158    let max = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
159    let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
160    let sum: f32 = exp.iter().sum();
161    exp.iter().map(|&x| x / sum).collect()
162}
163
164/// `CutMix` data augmentation (Yun et al., 2019).
165///
166/// Cuts a rectangular region from one image and pastes onto another.
167#[derive(Debug, Clone)]
168pub struct CutMix {
169    alpha: f32,
170}
171
172impl CutMix {
173    #[must_use]
174    pub fn new(alpha: f32) -> Self {
175        Self { alpha }
176    }
177
178    /// Sample lambda and bounding box for cutmix.
179    #[must_use]
180    pub fn sample(&self, height: usize, width: usize) -> CutMixParams {
181        // Alpha <= 0 means no mixing: lambda = 1.0, empty box
182        if self.alpha <= 0.0 {
183            return CutMixParams {
184                lambda: 1.0,
185                x1: 0,
186                y1: 0,
187                x2: 0,
188                y2: 0,
189            };
190        }
191        let lambda = sample_beta(self.alpha, self.alpha);
192
193        // Sample bounding box
194        let ratio = (1.0 - lambda).sqrt();
195        let rh = (height as f32 * ratio) as usize;
196        let rw = (width as f32 * ratio) as usize;
197
198        let mut rng = rand::thread_rng();
199        let cx = rng.gen_range(0..width);
200        let cy = rng.gen_range(0..height);
201
202        let x1 = cx.saturating_sub(rw / 2);
203        let y1 = cy.saturating_sub(rh / 2);
204        let x2 = (cx + rw / 2).min(width);
205        let y2 = (cy + rh / 2).min(height);
206
207        // Actual lambda based on box area
208        let actual_lambda = 1.0 - ((x2 - x1) * (y2 - y1)) as f32 / (height * width) as f32;
209
210        CutMixParams {
211            lambda: actual_lambda,
212            x1,
213            y1,
214            x2,
215            y2,
216        }
217    }
218
219    #[must_use]
220    pub fn alpha(&self) -> f32 {
221        self.alpha
222    }
223}
224
225/// Parameters for a `CutMix` operation.
226#[derive(Debug, Clone)]
227pub struct CutMixParams {
228    pub lambda: f32,
229    pub x1: usize,
230    pub y1: usize,
231    pub x2: usize,
232    pub y2: usize,
233}
234
235impl CutMixParams {
236    /// Apply cutmix to two flat image vectors `[C*H*W]`.
237    #[must_use]
238    pub fn apply(
239        &self,
240        img1: &[f32],
241        img2: &[f32],
242        channels: usize,
243        height: usize,
244        width: usize,
245    ) -> Vec<f32> {
246        let mut result = img1.to_vec();
247
248        for c in 0..channels {
249            for y in self.y1..self.y2 {
250                for x in self.x1..self.x2 {
251                    let idx = c * height * width + y * width + x;
252                    if idx < result.len() {
253                        result[idx] = img2[idx];
254                    }
255                }
256            }
257        }
258        result
259    }
260}
261
262/// Stochastic Depth (Huang et al., 2016).
263///
264/// Randomly drops entire residual blocks during training.
265#[derive(Debug, Clone)]
266pub struct StochasticDepth {
267    drop_prob: f32,
268    mode: DropMode,
269}
270
271#[derive(Debug, Clone, Copy, PartialEq)]
272pub enum DropMode {
273    /// All samples in batch are dropped together
274    Batch,
275    /// Each sample dropped independently
276    Row,
277}
278
279impl StochasticDepth {
280    #[must_use]
281    pub fn new(drop_prob: f32, mode: DropMode) -> Self {
282        assert!((0.0..1.0).contains(&drop_prob));
283        Self { drop_prob, mode }
284    }
285
286    /// Apply stochastic depth: returns true if should keep (not drop).
287    #[must_use]
288    pub fn should_keep(&self, training: bool) -> bool {
289        if !training || self.drop_prob == 0.0 {
290            return true;
291        }
292        rand::thread_rng().gen::<f32>() >= self.drop_prob
293    }
294
295    /// Compute survival probability for linear decay schedule.
296    #[must_use]
297    pub fn linear_decay(depth: usize, total_depth: usize, max_drop: f32) -> f32 {
298        1.0 - (depth as f32 / total_depth as f32) * max_drop
299    }
300
301    #[must_use]
302    pub fn drop_prob(&self) -> f32 {
303        self.drop_prob
304    }
305
306    #[must_use]
307    pub fn mode(&self) -> DropMode {
308        self.mode
309    }
310}
311
312/// R-Drop regularization (Liang et al., 2021).
313///
314/// Forces consistency between two forward passes with different dropout masks.
315/// Adds bidirectional KL divergence loss between the two outputs.
316///
317/// Loss = CE(p1, y) + CE(p2, y) + α * (KL(p1||p2) + KL(p2||p1)) / 2
318///
319/// # Reference
320/// Liang, X., et al. (2021). R-Drop: Regularized Dropout for Neural Networks.
321#[derive(Debug, Clone)]
322pub struct RDrop {
323    alpha: f32,
324}
325
326impl RDrop {
327    /// Create R-Drop with regularization weight alpha.
328    #[must_use]
329    pub fn new(alpha: f32) -> Self {
330        assert!(alpha >= 0.0, "Alpha must be non-negative");
331        Self { alpha }
332    }
333
334    #[must_use]
335    pub fn alpha(&self) -> f32 {
336        self.alpha
337    }
338
339    /// Compute KL divergence: KL(p || q) = Σ p * log(p / q)
340    #[must_use]
341    pub fn kl_divergence(&self, p: &[f32], q: &[f32]) -> f32 {
342        assert_eq!(p.len(), q.len());
343        let eps = 1e-10;
344        p.iter()
345            .zip(q.iter())
346            .map(|(&pi, &qi)| {
347                let pi = pi.max(eps);
348                let qi = qi.max(eps);
349                pi * (pi / qi).ln()
350            })
351            .sum()
352    }
353
354    /// Compute bidirectional KL divergence (symmetric).
355    #[must_use]
356    pub fn symmetric_kl(&self, p: &[f32], q: &[f32]) -> f32 {
357        (self.kl_divergence(p, q) + self.kl_divergence(q, p)) / 2.0
358    }
359
360    /// Compute R-Drop regularization loss between two forward passes.
361    #[must_use]
362    pub fn compute_loss(&self, logits1: &[f32], logits2: &[f32]) -> f32 {
363        let p1 = softmax_slice(logits1);
364        let p2 = softmax_slice(logits2);
365        self.alpha * self.symmetric_kl(&p1, &p2)
366    }
367}
368
369fn softmax_slice(logits: &[f32]) -> Vec<f32> {
370    let max = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
371    let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
372    let sum: f32 = exp.iter().sum();
373    exp.iter().map(|&x| x / sum).collect()
374}
375
376/// `SpecAugment`: Data augmentation for speech recognition (Park et al., 2019).
377///
378/// Applies time warping, frequency masking, and time masking to spectrograms.
379///
380/// # Methods
381///
382/// - **Time Warping**: Warps spectrogram along time axis
383/// - **Frequency Masking**: Masks F consecutive frequency channels
384/// - **Time Masking**: Masks T consecutive time steps
385///
386/// # Reference
387///
388/// - Park, D., et al. (2019). `SpecAugment`: A Simple Data Augmentation Method
389///   for Automatic Speech Recognition. Interspeech.
390#[derive(Debug, Clone)]
391pub struct SpecAugment {
392    /// Number of frequency masks to apply
393    num_freq_masks: usize,
394    /// Maximum size of frequency mask
395    freq_mask_param: usize,
396    /// Number of time masks to apply
397    num_time_masks: usize,
398    /// Maximum size of time mask
399    time_mask_param: usize,
400    /// Mask value (usually 0 or mean)
401    mask_value: f32,
402}
403
404impl Default for SpecAugment {
405    fn default() -> Self {
406        Self::new()
407    }
408}
409
410impl SpecAugment {
411    /// Create `SpecAugment` with default parameters.
412    ///
413    /// Default: 2 frequency masks (F=27), 2 time masks (T=100)
414    #[must_use]
415    pub fn new() -> Self {
416        Self {
417            num_freq_masks: 2,
418            freq_mask_param: 27,
419            num_time_masks: 2,
420            time_mask_param: 100,
421            mask_value: 0.0,
422        }
423    }
424
425    /// Create with custom parameters.
426    #[must_use]
427    pub fn with_params(
428        num_freq_masks: usize,
429        freq_mask_param: usize,
430        num_time_masks: usize,
431        time_mask_param: usize,
432    ) -> Self {
433        Self {
434            num_freq_masks,
435            freq_mask_param,
436            num_time_masks,
437            time_mask_param,
438            mask_value: 0.0,
439        }
440    }
441
442    /// Set the mask value.
443    #[must_use]
444    pub fn with_mask_value(mut self, value: f32) -> Self {
445        self.mask_value = value;
446        self
447    }
448
449    /// Apply `SpecAugment` to a spectrogram.
450    ///
451    /// # Arguments
452    ///
453    /// * `spec` - Spectrogram as flat vector [`freq_bins` * `time_steps`]
454    /// * `freq_bins` - Number of frequency bins
455    /// * `time_steps` - Number of time steps
456    ///
457    /// # Returns
458    ///
459    /// Augmented spectrogram.
460    #[must_use]
461    pub fn apply(&self, spec: &[f32], freq_bins: usize, time_steps: usize) -> Vec<f32> {
462        let mut result = spec.to_vec();
463        let mut rng = rand::thread_rng();
464
465        // Apply frequency masks
466        for _ in 0..self.num_freq_masks {
467            let f = rng.gen_range(0..=self.freq_mask_param.min(freq_bins));
468            let f0 = rng.gen_range(0..freq_bins.saturating_sub(f).max(1));
469
470            for freq in f0..f0 + f {
471                if freq < freq_bins {
472                    for t in 0..time_steps {
473                        let idx = freq * time_steps + t;
474                        if idx < result.len() {
475                            result[idx] = self.mask_value;
476                        }
477                    }
478                }
479            }
480        }
481
482        // Apply time masks
483        for _ in 0..self.num_time_masks {
484            let t = rng.gen_range(0..=self.time_mask_param.min(time_steps));
485            let t0 = rng.gen_range(0..time_steps.saturating_sub(t).max(1));
486
487            for time in t0..t0 + t {
488                if time < time_steps {
489                    for freq in 0..freq_bins {
490                        let idx = freq * time_steps + time;
491                        if idx < result.len() {
492                            result[idx] = self.mask_value;
493                        }
494                    }
495                }
496            }
497        }
498
499        result
500    }
501
502    /// Apply only frequency masking.
503    #[must_use]
504    pub fn freq_mask(&self, spec: &[f32], freq_bins: usize, time_steps: usize) -> Vec<f32> {
505        let mut result = spec.to_vec();
506        let mut rng = rand::thread_rng();
507
508        for _ in 0..self.num_freq_masks {
509            let f = rng.gen_range(0..=self.freq_mask_param.min(freq_bins));
510            let f0 = rng.gen_range(0..freq_bins.saturating_sub(f).max(1));
511
512            for freq in f0..f0 + f {
513                if freq < freq_bins {
514                    for t in 0..time_steps {
515                        let idx = freq * time_steps + t;
516                        if idx < result.len() {
517                            result[idx] = self.mask_value;
518                        }
519                    }
520                }
521            }
522        }
523
524        result
525    }
526
527    /// Apply only time masking.
528    #[must_use]
529    pub fn time_mask(&self, spec: &[f32], freq_bins: usize, time_steps: usize) -> Vec<f32> {
530        let mut result = spec.to_vec();
531        let mut rng = rand::thread_rng();
532
533        for _ in 0..self.num_time_masks {
534            let t = rng.gen_range(0..=self.time_mask_param.min(time_steps));
535            let t0 = rng.gen_range(0..time_steps.saturating_sub(t).max(1));
536
537            for time in t0..t0 + t {
538                if time < time_steps {
539                    for freq in 0..freq_bins {
540                        let idx = freq * time_steps + time;
541                        if idx < result.len() {
542                            result[idx] = self.mask_value;
543                        }
544                    }
545                }
546            }
547        }
548
549        result
550    }
551
552    #[must_use]
553    pub fn num_freq_masks(&self) -> usize {
554        self.num_freq_masks
555    }
556
557    #[must_use]
558    pub fn num_time_masks(&self) -> usize {
559        self.num_time_masks
560    }
561}
562
563/// `RandAugment`: Automated data augmentation policy (Cubuk et al., 2020).
564///
565/// Applies N random augmentations from a set, each with magnitude M.
566/// Simpler than `AutoAugment` with fewer hyperparameters.
567///
568/// # Reference
569///
570/// - Cubuk, E., et al. (2020). Randaugment: Practical automated data
571///   augmentation with a reduced search space. CVPR.
572#[derive(Debug, Clone)]
573pub struct RandAugment {
574    /// Number of augmentations to apply
575    n: usize,
576    /// Magnitude of augmentations (0-30 scale)
577    m: usize,
578    /// Available augmentation types
579    augmentations: Vec<AugmentationType>,
580}
581
582/// Types of image augmentations.
583#[derive(Debug, Clone, Copy, PartialEq)]
584pub enum AugmentationType {
585    Identity,
586    Rotate,
587    TranslateX,
588    TranslateY,
589    ShearX,
590    ShearY,
591    Brightness,
592    Contrast,
593    Sharpness,
594    Posterize,
595    Solarize,
596    Equalize,
597}
598
599impl Default for RandAugment {
600    fn default() -> Self {
601        Self::new(2, 9)
602    }
603}
604
605impl RandAugment {
606    /// Create `RandAugment` with N operations and magnitude M.
607    ///
608    /// # Arguments
609    ///
610    /// * `n` - Number of augmentations to apply
611    /// * `m` - Magnitude (0-30)
612    #[must_use]
613    pub fn new(n: usize, m: usize) -> Self {
614        Self {
615            n,
616            m: m.min(30),
617            augmentations: vec![
618                AugmentationType::Identity,
619                AugmentationType::Rotate,
620                AugmentationType::TranslateX,
621                AugmentationType::TranslateY,
622                AugmentationType::Brightness,
623                AugmentationType::Contrast,
624                AugmentationType::Sharpness,
625            ],
626        }
627    }
628
629    /// Set custom augmentation types.
630    #[must_use]
631    pub fn with_augmentations(mut self, augs: Vec<AugmentationType>) -> Self {
632        self.augmentations = augs;
633        self
634    }
635
636    /// Get N random augmentation types.
637    #[must_use]
638    pub fn sample_augmentations(&self) -> Vec<AugmentationType> {
639        use rand::seq::SliceRandom;
640        let mut rng = rand::thread_rng();
641        let mut selected = Vec::with_capacity(self.n);
642
643        for _ in 0..self.n {
644            if let Some(&aug) = self.augmentations.choose(&mut rng) {
645                selected.push(aug);
646            }
647        }
648
649        selected
650    }
651
652    /// Get magnitude as normalized value [0, 1].
653    #[must_use]
654    pub fn normalized_magnitude(&self) -> f32 {
655        self.m as f32 / 30.0
656    }
657
658    /// Apply a single augmentation to image data (simplified).
659    ///
660    /// # Arguments
661    ///
662    /// * `image` - Flat image vector `[C*H*W]`
663    /// * `aug` - Augmentation type
664    /// * `h` - Image height
665    /// * `w` - Image width
666    #[must_use]
667    pub fn apply_single(
668        &self,
669        image: &[f32],
670        aug: AugmentationType,
671        h: usize,
672        w: usize,
673    ) -> Vec<f32> {
674        let mag = self.normalized_magnitude();
675        let mut result = image.to_vec();
676
677        match aug {
678            AugmentationType::Brightness => {
679                let factor = 1.0 + (mag - 0.5) * 2.0; // [0, 2]
680                for v in &mut result {
681                    *v = (*v * factor).clamp(0.0, 1.0);
682                }
683            }
684            AugmentationType::Contrast => {
685                let mean: f32 = result.iter().sum::<f32>() / result.len() as f32;
686                let factor = 1.0 + (mag - 0.5) * 2.0;
687                for v in &mut result {
688                    *v = ((*v - mean) * factor + mean).clamp(0.0, 1.0);
689                }
690            }
691            AugmentationType::Rotate => {
692                // Simplified: just flip for magnitude > 0.5
693                if mag > 0.5 {
694                    result.reverse();
695                }
696            }
697            AugmentationType::TranslateX => {
698                let shift = ((mag - 0.5) * w as f32 * 0.3) as i32;
699                Self::shift_horizontal(&mut result, h, w, shift);
700            }
701            AugmentationType::TranslateY => {
702                let shift = ((mag - 0.5) * h as f32 * 0.3) as i32;
703                Self::shift_vertical(&mut result, h, w, shift);
704            }
705            // Identity and others: no operation
706            AugmentationType::Identity
707            | AugmentationType::ShearX
708            | AugmentationType::ShearY
709            | AugmentationType::Sharpness
710            | AugmentationType::Posterize
711            | AugmentationType::Solarize
712            | AugmentationType::Equalize => {}
713        }
714
715        result
716    }
717
718    fn shift_horizontal(data: &mut [f32], h: usize, w: usize, shift: i32) {
719        if shift == 0 {
720            return;
721        }
722        let channels = data.len() / (h * w);
723        for c in 0..channels {
724            for y in 0..h {
725                let row_start = c * h * w + y * w;
726                let row: Vec<f32> = (0..w)
727                    .map(|x| {
728                        let src_x = (x as i32 - shift).rem_euclid(w as i32) as usize;
729                        data[row_start + src_x]
730                    })
731                    .collect();
732                data[row_start..row_start + w].copy_from_slice(&row);
733            }
734        }
735    }
736
737    fn shift_vertical(data: &mut [f32], h: usize, w: usize, shift: i32) {
738        if shift == 0 {
739            return;
740        }
741        let channels = data.len() / (h * w);
742        for c in 0..channels {
743            for x in 0..w {
744                let col: Vec<f32> = (0..h)
745                    .map(|y| {
746                        let src_y = (y as i32 - shift).rem_euclid(h as i32) as usize;
747                        data[c * h * w + src_y * w + x]
748                    })
749                    .collect();
750                for (y, &val) in col.iter().enumerate() {
751                    data[c * h * w + y * w + x] = val;
752                }
753            }
754        }
755    }
756
757    #[must_use]
758    pub fn n(&self) -> usize {
759        self.n
760    }
761
762    #[must_use]
763    pub fn m(&self) -> usize {
764        self.m
765    }
766}
767
768#[cfg(test)]
769mod tests;