pub mod augmentation;
pub mod nas;
pub mod normalization;
pub use augmentation::{cutmix, differentiable_augment, mixup};
pub use nas::{
darts_operation, decode_architecture, encode_architecture, mutate_architecture,
predict_architecture_performance,
};
pub use normalization::{spectral_norm, weight_standardization};
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn label_smoothing(targets: &Tensor, smoothing: f32) -> TorshResult<Tensor> {
let shape_binding = targets.shape();
let shape = shape_binding.dims();
if shape.len() != 2 {
return Err(TorshError::invalid_argument_with_context(
"Targets must be 2D [batch_size, num_classes]",
"label_smoothing",
));
}
let num_classes = shape[1] as f32;
let uniform_prob = smoothing / num_classes;
let smoothed = targets
.mul_scalar(1.0 - smoothing)?
.add_scalar(uniform_prob)?;
Ok(smoothed)
}
pub fn temperature_scale(logits: &Tensor, temperature: f32) -> TorshResult<Tensor> {
if temperature <= 0.0 {
return Err(TorshError::invalid_argument_with_context(
"Temperature must be positive",
"temperature_scale",
));
}
logits.div_scalar(temperature)
}
pub fn knowledge_distillation_loss(
student_logits: &Tensor,
teacher_logits: &Tensor,
temperature: f32,
alpha: f32,
hard_targets: Option<&Tensor>,
) -> TorshResult<Tensor> {
let student_soft = temperature_scale(student_logits, temperature)?.softmax(-1)?;
let teacher_soft = temperature_scale(teacher_logits, temperature)?.softmax(-1)?;
let soft_loss = teacher_soft
.mul(&student_soft.log()?)?
.sum()?
.mul_scalar(-1.0)?;
let weighted_soft_loss = soft_loss.mul_scalar((1.0 - alpha) * temperature * temperature)?;
let total_loss = if let Some(targets) = hard_targets {
let student_probs = student_logits.softmax(-1)?;
let hard_loss = targets
.mul(&student_probs.log()?)?
.sum()?
.mul_scalar(-1.0)?;
let weighted_hard_loss = hard_loss.mul_scalar(alpha)?;
weighted_soft_loss.add(&weighted_hard_loss)?
} else {
weighted_soft_loss
};
Ok(total_loss)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::{ones, randn};
#[test]
fn test_advanced_nn_integration() -> TorshResult<()> {
let weight = randn(&[16, 32])?;
let normalized = spectral_norm(&weight, 3, 1e-12)?;
assert_eq!(weight.shape().dims(), normalized.shape().dims());
let x1 = randn(&[4, 3, 8, 8])?;
let x2 = randn(&[4, 3, 8, 8])?;
let y1 = randn(&[4, 10])?;
let y2 = randn(&[4, 10])?;
let (mixed_x, _mixed_y) = mixup(&x1, &x2, &y1, &y2, 0.3)?;
assert_eq!(x1.shape().dims(), mixed_x.shape().dims());
let operations = vec![0, 1, 2];
let connections = ones(&[3, 3])?;
let encoding = encode_architecture(&operations, &connections, 4)?;
assert!(!encoding.data()?.is_empty());
let targets = ones(&[4, 10])?;
let smoothed = label_smoothing(&targets, 0.1)?;
assert_eq!(targets.shape().dims(), smoothed.shape().dims());
Ok(())
}
#[test]
fn test_label_smoothing() -> TorshResult<()> {
let targets = ones(&[2, 5])?;
let smoothed = label_smoothing(&targets, 0.1)?;
assert_eq!(targets.shape().dims(), smoothed.shape().dims());
let smoothed_data = smoothed.data()?;
for &val in smoothed_data.iter() {
assert!(val < 1.0);
assert!(val > 0.0);
}
Ok(())
}
#[test]
fn test_temperature_scaling() -> TorshResult<()> {
let logits = randn(&[3, 4])?;
let scaled = temperature_scale(&logits, 2.0)?;
assert_eq!(logits.shape().dims(), scaled.shape().dims());
let original_data = logits.data()?;
let scaled_data = scaled.data()?;
for (orig, scaled_val) in original_data.iter().zip(scaled_data.iter()) {
assert!((scaled_val * 2.0 - orig).abs() < 1e-6);
}
Ok(())
}
#[test]
fn test_knowledge_distillation_loss() -> TorshResult<()> {
let student_logits = randn(&[2, 3])?;
let teacher_logits = randn(&[2, 3])?;
let hard_targets = ones(&[2, 3])?;
let loss = knowledge_distillation_loss(
&student_logits,
&teacher_logits,
3.0,
0.5,
Some(&hard_targets),
)?;
assert_eq!(loss.shape().dims(), &[] as &[usize]);
Ok(())
}
}