1use crate::primitives::Vector;
9use rand::Rng;
10
11#[derive(Debug, Clone)]
14pub struct Mixup {
15 alpha: f32,
16}
17
18impl Mixup {
19 #[must_use]
21 pub fn new(alpha: f32) -> Self {
22 Self { alpha }
23 }
24
25 #[must_use]
27 pub fn sample_lambda(&self) -> f32 {
28 if self.alpha <= 0.0 {
29 return 1.0;
30 }
31 sample_beta(self.alpha, self.alpha)
32 }
33
34 #[must_use]
36 pub fn mix_samples(&self, x1: &Vector<f32>, x2: &Vector<f32>, lambda: f32) -> Vector<f32> {
37 let mixed: Vec<f32> = x1
38 .as_slice()
39 .iter()
40 .zip(x2.as_slice().iter())
41 .map(|(&a, &b)| lambda * a + (1.0 - lambda) * b)
42 .collect();
43 Vector::from_slice(&mixed)
44 }
45
46 #[must_use]
48 pub fn mix_labels(&self, y1: &Vector<f32>, y2: &Vector<f32>, lambda: f32) -> Vector<f32> {
49 self.mix_samples(y1, y2, lambda)
50 }
51
52 #[must_use]
53 pub fn alpha(&self) -> f32 {
54 self.alpha
55 }
56}
57
58#[derive(Debug, Clone)]
61pub struct LabelSmoothing {
62 epsilon: f32,
63}
64
65impl LabelSmoothing {
66 #[must_use]
68 pub fn new(epsilon: f32) -> Self {
69 assert!((0.0..1.0).contains(&epsilon));
70 Self { epsilon }
71 }
72
73 #[must_use]
75 pub fn smooth(&self, label: &Vector<f32>) -> Vector<f32> {
76 let n_classes = label.len();
77 let smoothed: Vec<f32> = label
78 .as_slice()
79 .iter()
80 .map(|&y| (1.0 - self.epsilon) * y + self.epsilon / n_classes as f32)
81 .collect();
82 Vector::from_slice(&smoothed)
83 }
84
85 #[must_use]
87 pub fn smooth_index(&self, class_idx: usize, n_classes: usize) -> Vector<f32> {
88 let mut result = vec![self.epsilon / n_classes as f32; n_classes];
89 result[class_idx] = 1.0 - self.epsilon + self.epsilon / n_classes as f32;
90 Vector::from_slice(&result)
91 }
92
93 #[must_use]
94 pub fn epsilon(&self) -> f32 {
95 self.epsilon
96 }
97}
98
99#[must_use]
101pub fn cross_entropy_with_smoothing(logits: &Vector<f32>, target_idx: usize, epsilon: f32) -> f32 {
102 let n_classes = logits.len();
103 let probs = softmax(logits.as_slice());
104
105 let mut loss = 0.0;
106 for (i, &p) in probs.iter().enumerate() {
107 let target = if i == target_idx {
108 1.0 - epsilon + epsilon / n_classes as f32
109 } else {
110 epsilon / n_classes as f32
111 };
112 loss -= target * p.max(1e-10).ln();
113 }
114 loss
115}
116
117fn sample_beta(alpha: f32, beta: f32) -> f32 {
119 let mut rng = rand::thread_rng();
120 let x = sample_gamma(alpha, &mut rng);
121 let y = sample_gamma(beta, &mut rng);
122 let sum = x + y;
123 if sum <= 0.0 {
127 return 0.5;
128 }
129 (x / sum).clamp(0.0, 1.0)
130}
131
132fn sample_gamma(shape: f32, rng: &mut impl Rng) -> f32 {
133 if shape < 1.0 {
135 return sample_gamma(1.0 + shape, rng) * rng.gen::<f32>().powf(1.0 / shape);
136 }
137 let d = shape - 1.0 / 3.0;
138 let c = 1.0 / (9.0 * d).sqrt();
139 loop {
140 let x: f32 = sample_normal(rng);
141 let v = (1.0 + c * x).powi(3);
142 if v > 0.0 {
143 let u: f32 = rng.gen();
144 if u < 1.0 - 0.0331 * x.powi(4) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
145 return d * v;
146 }
147 }
148 }
149}
150
151fn sample_normal(rng: &mut impl Rng) -> f32 {
152 let u1: f32 = rng.gen::<f32>().max(1e-10);
153 let u2: f32 = rng.gen();
154 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
155}
156
157fn softmax(logits: &[f32]) -> Vec<f32> {
158 let max = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
159 let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
160 let sum: f32 = exp.iter().sum();
161 exp.iter().map(|&x| x / sum).collect()
162}
163
164#[derive(Debug, Clone)]
168pub struct CutMix {
169 alpha: f32,
170}
171
172impl CutMix {
173 #[must_use]
174 pub fn new(alpha: f32) -> Self {
175 Self { alpha }
176 }
177
178 #[must_use]
180 pub fn sample(&self, height: usize, width: usize) -> CutMixParams {
181 if self.alpha <= 0.0 {
183 return CutMixParams {
184 lambda: 1.0,
185 x1: 0,
186 y1: 0,
187 x2: 0,
188 y2: 0,
189 };
190 }
191 let lambda = sample_beta(self.alpha, self.alpha);
192
193 let ratio = (1.0 - lambda).sqrt();
195 let rh = (height as f32 * ratio) as usize;
196 let rw = (width as f32 * ratio) as usize;
197
198 let mut rng = rand::thread_rng();
199 let cx = rng.gen_range(0..width);
200 let cy = rng.gen_range(0..height);
201
202 let x1 = cx.saturating_sub(rw / 2);
203 let y1 = cy.saturating_sub(rh / 2);
204 let x2 = (cx + rw / 2).min(width);
205 let y2 = (cy + rh / 2).min(height);
206
207 let actual_lambda = 1.0 - ((x2 - x1) * (y2 - y1)) as f32 / (height * width) as f32;
209
210 CutMixParams {
211 lambda: actual_lambda,
212 x1,
213 y1,
214 x2,
215 y2,
216 }
217 }
218
219 #[must_use]
220 pub fn alpha(&self) -> f32 {
221 self.alpha
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct CutMixParams {
228 pub lambda: f32,
229 pub x1: usize,
230 pub y1: usize,
231 pub x2: usize,
232 pub y2: usize,
233}
234
235impl CutMixParams {
236 #[must_use]
238 pub fn apply(
239 &self,
240 img1: &[f32],
241 img2: &[f32],
242 channels: usize,
243 height: usize,
244 width: usize,
245 ) -> Vec<f32> {
246 let mut result = img1.to_vec();
247
248 for c in 0..channels {
249 for y in self.y1..self.y2 {
250 for x in self.x1..self.x2 {
251 let idx = c * height * width + y * width + x;
252 if idx < result.len() {
253 result[idx] = img2[idx];
254 }
255 }
256 }
257 }
258 result
259 }
260}
261
262#[derive(Debug, Clone)]
266pub struct StochasticDepth {
267 drop_prob: f32,
268 mode: DropMode,
269}
270
271#[derive(Debug, Clone, Copy, PartialEq)]
272pub enum DropMode {
273 Batch,
275 Row,
277}
278
279impl StochasticDepth {
280 #[must_use]
281 pub fn new(drop_prob: f32, mode: DropMode) -> Self {
282 assert!((0.0..1.0).contains(&drop_prob));
283 Self { drop_prob, mode }
284 }
285
286 #[must_use]
288 pub fn should_keep(&self, training: bool) -> bool {
289 if !training || self.drop_prob == 0.0 {
290 return true;
291 }
292 rand::thread_rng().gen::<f32>() >= self.drop_prob
293 }
294
295 #[must_use]
297 pub fn linear_decay(depth: usize, total_depth: usize, max_drop: f32) -> f32 {
298 1.0 - (depth as f32 / total_depth as f32) * max_drop
299 }
300
301 #[must_use]
302 pub fn drop_prob(&self) -> f32 {
303 self.drop_prob
304 }
305
306 #[must_use]
307 pub fn mode(&self) -> DropMode {
308 self.mode
309 }
310}
311
312#[derive(Debug, Clone)]
322pub struct RDrop {
323 alpha: f32,
324}
325
326impl RDrop {
327 #[must_use]
329 pub fn new(alpha: f32) -> Self {
330 assert!(alpha >= 0.0, "Alpha must be non-negative");
331 Self { alpha }
332 }
333
334 #[must_use]
335 pub fn alpha(&self) -> f32 {
336 self.alpha
337 }
338
339 #[must_use]
341 pub fn kl_divergence(&self, p: &[f32], q: &[f32]) -> f32 {
342 assert_eq!(p.len(), q.len());
343 let eps = 1e-10;
344 p.iter()
345 .zip(q.iter())
346 .map(|(&pi, &qi)| {
347 let pi = pi.max(eps);
348 let qi = qi.max(eps);
349 pi * (pi / qi).ln()
350 })
351 .sum()
352 }
353
354 #[must_use]
356 pub fn symmetric_kl(&self, p: &[f32], q: &[f32]) -> f32 {
357 (self.kl_divergence(p, q) + self.kl_divergence(q, p)) / 2.0
358 }
359
360 #[must_use]
362 pub fn compute_loss(&self, logits1: &[f32], logits2: &[f32]) -> f32 {
363 let p1 = softmax_slice(logits1);
364 let p2 = softmax_slice(logits2);
365 self.alpha * self.symmetric_kl(&p1, &p2)
366 }
367}
368
369fn softmax_slice(logits: &[f32]) -> Vec<f32> {
370 let max = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
371 let exp: Vec<f32> = logits.iter().map(|&x| (x - max).exp()).collect();
372 let sum: f32 = exp.iter().sum();
373 exp.iter().map(|&x| x / sum).collect()
374}
375
376#[derive(Debug, Clone)]
391pub struct SpecAugment {
392 num_freq_masks: usize,
394 freq_mask_param: usize,
396 num_time_masks: usize,
398 time_mask_param: usize,
400 mask_value: f32,
402}
403
404impl Default for SpecAugment {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410impl SpecAugment {
411 #[must_use]
415 pub fn new() -> Self {
416 Self {
417 num_freq_masks: 2,
418 freq_mask_param: 27,
419 num_time_masks: 2,
420 time_mask_param: 100,
421 mask_value: 0.0,
422 }
423 }
424
425 #[must_use]
427 pub fn with_params(
428 num_freq_masks: usize,
429 freq_mask_param: usize,
430 num_time_masks: usize,
431 time_mask_param: usize,
432 ) -> Self {
433 Self {
434 num_freq_masks,
435 freq_mask_param,
436 num_time_masks,
437 time_mask_param,
438 mask_value: 0.0,
439 }
440 }
441
442 #[must_use]
444 pub fn with_mask_value(mut self, value: f32) -> Self {
445 self.mask_value = value;
446 self
447 }
448
449 #[must_use]
461 pub fn apply(&self, spec: &[f32], freq_bins: usize, time_steps: usize) -> Vec<f32> {
462 let mut result = spec.to_vec();
463 let mut rng = rand::thread_rng();
464
465 for _ in 0..self.num_freq_masks {
467 let f = rng.gen_range(0..=self.freq_mask_param.min(freq_bins));
468 let f0 = rng.gen_range(0..freq_bins.saturating_sub(f).max(1));
469
470 for freq in f0..f0 + f {
471 if freq < freq_bins {
472 for t in 0..time_steps {
473 let idx = freq * time_steps + t;
474 if idx < result.len() {
475 result[idx] = self.mask_value;
476 }
477 }
478 }
479 }
480 }
481
482 for _ in 0..self.num_time_masks {
484 let t = rng.gen_range(0..=self.time_mask_param.min(time_steps));
485 let t0 = rng.gen_range(0..time_steps.saturating_sub(t).max(1));
486
487 for time in t0..t0 + t {
488 if time < time_steps {
489 for freq in 0..freq_bins {
490 let idx = freq * time_steps + time;
491 if idx < result.len() {
492 result[idx] = self.mask_value;
493 }
494 }
495 }
496 }
497 }
498
499 result
500 }
501
502 #[must_use]
504 pub fn freq_mask(&self, spec: &[f32], freq_bins: usize, time_steps: usize) -> Vec<f32> {
505 let mut result = spec.to_vec();
506 let mut rng = rand::thread_rng();
507
508 for _ in 0..self.num_freq_masks {
509 let f = rng.gen_range(0..=self.freq_mask_param.min(freq_bins));
510 let f0 = rng.gen_range(0..freq_bins.saturating_sub(f).max(1));
511
512 for freq in f0..f0 + f {
513 if freq < freq_bins {
514 for t in 0..time_steps {
515 let idx = freq * time_steps + t;
516 if idx < result.len() {
517 result[idx] = self.mask_value;
518 }
519 }
520 }
521 }
522 }
523
524 result
525 }
526
527 #[must_use]
529 pub fn time_mask(&self, spec: &[f32], freq_bins: usize, time_steps: usize) -> Vec<f32> {
530 let mut result = spec.to_vec();
531 let mut rng = rand::thread_rng();
532
533 for _ in 0..self.num_time_masks {
534 let t = rng.gen_range(0..=self.time_mask_param.min(time_steps));
535 let t0 = rng.gen_range(0..time_steps.saturating_sub(t).max(1));
536
537 for time in t0..t0 + t {
538 if time < time_steps {
539 for freq in 0..freq_bins {
540 let idx = freq * time_steps + time;
541 if idx < result.len() {
542 result[idx] = self.mask_value;
543 }
544 }
545 }
546 }
547 }
548
549 result
550 }
551
552 #[must_use]
553 pub fn num_freq_masks(&self) -> usize {
554 self.num_freq_masks
555 }
556
557 #[must_use]
558 pub fn num_time_masks(&self) -> usize {
559 self.num_time_masks
560 }
561}
562
563#[derive(Debug, Clone)]
573pub struct RandAugment {
574 n: usize,
576 m: usize,
578 augmentations: Vec<AugmentationType>,
580}
581
582#[derive(Debug, Clone, Copy, PartialEq)]
584pub enum AugmentationType {
585 Identity,
586 Rotate,
587 TranslateX,
588 TranslateY,
589 ShearX,
590 ShearY,
591 Brightness,
592 Contrast,
593 Sharpness,
594 Posterize,
595 Solarize,
596 Equalize,
597}
598
599impl Default for RandAugment {
600 fn default() -> Self {
601 Self::new(2, 9)
602 }
603}
604
605impl RandAugment {
606 #[must_use]
613 pub fn new(n: usize, m: usize) -> Self {
614 Self {
615 n,
616 m: m.min(30),
617 augmentations: vec![
618 AugmentationType::Identity,
619 AugmentationType::Rotate,
620 AugmentationType::TranslateX,
621 AugmentationType::TranslateY,
622 AugmentationType::Brightness,
623 AugmentationType::Contrast,
624 AugmentationType::Sharpness,
625 ],
626 }
627 }
628
629 #[must_use]
631 pub fn with_augmentations(mut self, augs: Vec<AugmentationType>) -> Self {
632 self.augmentations = augs;
633 self
634 }
635
636 #[must_use]
638 pub fn sample_augmentations(&self) -> Vec<AugmentationType> {
639 use rand::seq::SliceRandom;
640 let mut rng = rand::thread_rng();
641 let mut selected = Vec::with_capacity(self.n);
642
643 for _ in 0..self.n {
644 if let Some(&aug) = self.augmentations.choose(&mut rng) {
645 selected.push(aug);
646 }
647 }
648
649 selected
650 }
651
652 #[must_use]
654 pub fn normalized_magnitude(&self) -> f32 {
655 self.m as f32 / 30.0
656 }
657
658 #[must_use]
667 pub fn apply_single(
668 &self,
669 image: &[f32],
670 aug: AugmentationType,
671 h: usize,
672 w: usize,
673 ) -> Vec<f32> {
674 let mag = self.normalized_magnitude();
675 let mut result = image.to_vec();
676
677 match aug {
678 AugmentationType::Brightness => {
679 let factor = 1.0 + (mag - 0.5) * 2.0; for v in &mut result {
681 *v = (*v * factor).clamp(0.0, 1.0);
682 }
683 }
684 AugmentationType::Contrast => {
685 let mean: f32 = result.iter().sum::<f32>() / result.len() as f32;
686 let factor = 1.0 + (mag - 0.5) * 2.0;
687 for v in &mut result {
688 *v = ((*v - mean) * factor + mean).clamp(0.0, 1.0);
689 }
690 }
691 AugmentationType::Rotate => {
692 if mag > 0.5 {
694 result.reverse();
695 }
696 }
697 AugmentationType::TranslateX => {
698 let shift = ((mag - 0.5) * w as f32 * 0.3) as i32;
699 Self::shift_horizontal(&mut result, h, w, shift);
700 }
701 AugmentationType::TranslateY => {
702 let shift = ((mag - 0.5) * h as f32 * 0.3) as i32;
703 Self::shift_vertical(&mut result, h, w, shift);
704 }
705 AugmentationType::Identity
707 | AugmentationType::ShearX
708 | AugmentationType::ShearY
709 | AugmentationType::Sharpness
710 | AugmentationType::Posterize
711 | AugmentationType::Solarize
712 | AugmentationType::Equalize => {}
713 }
714
715 result
716 }
717
718 fn shift_horizontal(data: &mut [f32], h: usize, w: usize, shift: i32) {
719 if shift == 0 {
720 return;
721 }
722 let channels = data.len() / (h * w);
723 for c in 0..channels {
724 for y in 0..h {
725 let row_start = c * h * w + y * w;
726 let row: Vec<f32> = (0..w)
727 .map(|x| {
728 let src_x = (x as i32 - shift).rem_euclid(w as i32) as usize;
729 data[row_start + src_x]
730 })
731 .collect();
732 data[row_start..row_start + w].copy_from_slice(&row);
733 }
734 }
735 }
736
737 fn shift_vertical(data: &mut [f32], h: usize, w: usize, shift: i32) {
738 if shift == 0 {
739 return;
740 }
741 let channels = data.len() / (h * w);
742 for c in 0..channels {
743 for x in 0..w {
744 let col: Vec<f32> = (0..h)
745 .map(|y| {
746 let src_y = (y as i32 - shift).rem_euclid(h as i32) as usize;
747 data[c * h * w + src_y * w + x]
748 })
749 .collect();
750 for (y, &val) in col.iter().enumerate() {
751 data[c * h * w + y * w + x] = val;
752 }
753 }
754 }
755 }
756
757 #[must_use]
758 pub fn n(&self) -> usize {
759 self.n
760 }
761
762 #[must_use]
763 pub fn m(&self) -> usize {
764 self.m
765 }
766}
767
768#[cfg(test)]
769mod tests;