use crate::primitives::Vector;
use rand::Rng;
#[derive(Debug, Clone)]
pub struct Mixup {
alpha: f32,
}
impl Mixup {
#[must_use]
pub fn new(alpha: f32) -> Self {
Self { alpha }
}
#[must_use]
pub fn sample_lambda(&self) -> f32 {
if self.alpha <= 0.0 {
return 1.0;
}
sample_beta(self.alpha, self.alpha)
}
#[must_use]
pub fn mix_samples(&self, x1: &Vector<f32>, x2: &Vector<f32>, lambda: f32) -> Vector<f32> {
let mixed: Vec<f32> = x1
.as_slice()
.iter()
.zip(x2.as_slice().iter())
.map(|(&a, &b)| lambda * a + (1.0 - lambda) * b)
.collect();
Vector::from_slice(&mixed)
}
#[must_use]
pub fn mix_labels(&self, y1: &Vector<f32>, y2: &Vector<f32>, lambda: f32) -> Vector<f32> {
self.mix_samples(y1, y2, lambda)
}
#[must_use]
pub fn alpha(&self) -> f32 {
self.alpha
}
}
#[derive(Debug, Clone)]
pub struct LabelSmoothing {
epsilon: f32,
}
impl LabelSmoothing {
#[must_use]
pub fn new(epsilon: f32) -> Self {
assert!((0.0..1.0).contains(&epsilon));
Self { epsilon }
}
#[must_use]
pub fn smooth(&self, label: &Vector<f32>) -> Vector<f32> {
let n_classes = label.len();
let smoothed: Vec<f32> = label
.as_slice()
.iter()
.map(|&y| (1.0 - self.epsilon) * y + self.epsilon / n_classes as f32)
.collect();
Vector::from_slice(&smoothed)
}
#[must_use]
pub fn smooth_index(&self, class_idx: usize, n_classes: usize) -> Vector<f32> {
let mut result = vec![self.epsilon / n_classes as f32; n_classes];
result[class_idx] = 1.0 - self.epsilon + self.epsilon / n_classes as f32;
Vector::from_slice(&result)
}
#[must_use]
pub fn epsilon(&self) -> f32 {
self.epsilon
}
}
#[must_use]
pub fn cross_entropy_with_smoothing(logits: &Vector<f32>, target_idx: usize, epsilon: f32) -> f32 {
let n_classes = logits.len();
let log_probs = crate::nn::functional::log_softmax_1d(logits.as_slice());
let mut loss = 0.0;
for (i, &lp) in log_probs.iter().enumerate() {
let target = if i == target_idx {
1.0 - epsilon + epsilon / n_classes as f32
} else {
epsilon / n_classes as f32
};
loss -= target * lp;
}
loss
}
fn sample_beta(alpha: f32, beta: f32) -> f32 {
let mut rng = rand::rng();
let x = sample_gamma(alpha, &mut rng);
let y = sample_gamma(beta, &mut rng);
let sum = x + y;
if sum <= 0.0 {
return 0.5;
}
(x / sum).clamp(0.0, 1.0)
}
fn sample_gamma(shape: f32, rng: &mut impl Rng) -> f32 {
if shape < 1.0 {
return sample_gamma(1.0 + shape, rng) * rng.random::<f32>().powf(1.0 / shape);
}
let d = shape - 1.0 / 3.0;
let c = 1.0 / (9.0 * d).sqrt();
loop {
let x: f32 = sample_normal(rng);
let v = (1.0 + c * x).powi(3);
if v > 0.0 {
let u: f32 = rng.random();
if u < 1.0 - 0.0331 * x.powi(4) || u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
return d * v;
}
}
}
}
fn sample_normal(rng: &mut impl Rng) -> f32 {
let u1: f32 = rng.random::<f32>().max(1e-10);
let u2: f32 = rng.random();
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
}
#[derive(Debug, Clone)]
pub struct CutMix {
alpha: f32,
}
impl CutMix {
#[must_use]
pub fn new(alpha: f32) -> Self {
Self { alpha }
}
#[must_use]
pub fn sample(&self, height: usize, width: usize) -> CutMixParams {
if self.alpha <= 0.0 {
return CutMixParams {
lambda: 1.0,
x1: 0,
y1: 0,
x2: 0,
y2: 0,
};
}
let lambda = sample_beta(self.alpha, self.alpha);
let ratio = (1.0 - lambda).sqrt();
let rh = (height as f32 * ratio) as usize;
let rw = (width as f32 * ratio) as usize;
let mut rng = rand::rng();
let cx = rng.random_range(0..width);
let cy = rng.random_range(0..height);
let x1 = cx.saturating_sub(rw / 2);
let y1 = cy.saturating_sub(rh / 2);
let x2 = (cx + rw / 2).min(width);
let y2 = (cy + rh / 2).min(height);
let actual_lambda = 1.0 - ((x2 - x1) * (y2 - y1)) as f32 / (height * width) as f32;
CutMixParams {
lambda: actual_lambda,
x1,
y1,
x2,
y2,
}
}
#[must_use]
pub fn alpha(&self) -> f32 {
self.alpha
}
}
#[derive(Debug, Clone)]
pub struct CutMixParams {
pub lambda: f32,
pub x1: usize,
pub y1: usize,
pub x2: usize,
pub y2: usize,
}
impl CutMixParams {
#[must_use]
pub fn apply(
&self,
img1: &[f32],
img2: &[f32],
channels: usize,
height: usize,
width: usize,
) -> Vec<f32> {
let mut result = img1.to_vec();
for c in 0..channels {
for y in self.y1..self.y2 {
for x in self.x1..self.x2 {
let idx = c * height * width + y * width + x;
if idx < result.len() {
result[idx] = img2[idx];
}
}
}
}
result
}
}
#[derive(Debug, Clone)]
pub struct StochasticDepth {
drop_prob: f32,
mode: DropMode,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DropMode {
Batch,
Row,
}
impl StochasticDepth {
#[must_use]
pub fn new(drop_prob: f32, mode: DropMode) -> Self {
assert!((0.0..1.0).contains(&drop_prob));
Self { drop_prob, mode }
}
#[must_use]
pub fn should_keep(&self, training: bool) -> bool {
if !training || self.drop_prob == 0.0 {
return true;
}
rand::rng().random::<f32>() >= self.drop_prob
}
#[must_use]
pub fn linear_decay(depth: usize, total_depth: usize, max_drop: f32) -> f32 {
1.0 - (depth as f32 / total_depth as f32) * max_drop
}
#[must_use]
pub fn drop_prob(&self) -> f32 {
self.drop_prob
}
#[must_use]
pub fn mode(&self) -> DropMode {
self.mode
}
}
#[derive(Debug, Clone)]
pub struct RDrop {
alpha: f32,
}
impl RDrop {
#[must_use]
pub fn new(alpha: f32) -> Self {
assert!(alpha >= 0.0, "Alpha must be non-negative");
Self { alpha }
}
#[must_use]
pub fn alpha(&self) -> f32 {
self.alpha
}
#[must_use]
pub fn kl_divergence(&self, p: &[f32], q: &[f32]) -> f32 {
assert_eq!(p.len(), q.len());
let eps = 1e-10;
p.iter()
.zip(q.iter())
.map(|(&pi, &qi)| {
let pi = pi.max(eps);
let qi = qi.max(eps);
pi * (pi / qi).ln()
})
.sum()
}
#[must_use]
pub fn symmetric_kl(&self, p: &[f32], q: &[f32]) -> f32 {
f32::midpoint(self.kl_divergence(p, q), self.kl_divergence(q, p))
}
#[must_use]
pub fn compute_loss(&self, logits1: &[f32], logits2: &[f32]) -> f32 {
let p1 = softmax_slice(logits1);
let p2 = softmax_slice(logits2);
self.alpha * self.symmetric_kl(&p1, &p2)
}
}
fn softmax_slice(logits: &[f32]) -> Vec<f32> {
crate::nn::functional::softmax_1d(logits)
}
#[derive(Debug, Clone)]
pub struct SpecAugment {
num_freq_masks: usize,
freq_mask_param: usize,
num_time_masks: usize,
time_mask_param: usize,
mask_value: f32,
}
impl Default for SpecAugment {
fn default() -> Self {
Self::new()
}
}
mod specaugment;
pub use specaugment::*;