use std::any::Any;
use axonml_autograd::no_grad::is_grad_enabled;
use axonml_autograd::{GradFn, GradientFunction, Variable};
use axonml_tensor::Tensor;
use crate::module::Module;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum Reduction {
None,
#[default]
Mean,
Sum,
}
#[derive(Debug, Clone, Copy)]
pub struct MSELoss {
reduction: Reduction,
}
impl MSELoss {
pub fn new() -> Self {
Self {
reduction: Reduction::Mean,
}
}
pub fn with_reduction(reduction: Reduction) -> Self {
Self { reduction }
}
pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
let diff = input.sub_var(target);
let squared = diff.pow(2.0);
match self.reduction {
Reduction::None => squared,
Reduction::Mean => squared.mean(),
Reduction::Sum => squared.sum(),
}
}
}
impl Default for MSELoss {
fn default() -> Self {
Self::new()
}
}
impl Module for MSELoss {
fn forward(&self, input: &Variable) -> Variable {
input.clone()
}
fn name(&self) -> &'static str {
"MSELoss"
}
}
#[derive(Debug, Clone, Copy)]
pub struct L1Loss {
reduction: Reduction,
}
impl L1Loss {
pub fn new() -> Self {
Self {
reduction: Reduction::Mean,
}
}
pub fn with_reduction(reduction: Reduction) -> Self {
Self { reduction }
}
pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
let input_data = input.data();
let target_data = target.data();
let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
let abs_tensor = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
let requires_grad = (input.requires_grad() || target.requires_grad()) && is_grad_enabled();
let loss_var = if requires_grad {
let grad_fn = GradFn::new(L1LossBackward {
next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
diff_tensor,
});
Variable::from_operation(abs_tensor, grad_fn, true)
} else {
Variable::new(abs_tensor, false)
};
match self.reduction {
Reduction::None => loss_var,
Reduction::Mean => loss_var.mean(),
Reduction::Sum => loss_var.sum(),
}
}
}
impl Default for L1Loss {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct L1LossBackward {
next_fns: Vec<Option<GradFn>>,
diff_tensor: Tensor<f32>,
}
impl GradientFunction for L1LossBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let eps_tensor = Tensor::full(self.diff_tensor.shape(), 1e-12);
let eps_on_device = if self.diff_tensor.device().is_gpu() {
eps_tensor.to_device(self.diff_tensor.device()).unwrap()
} else {
eps_tensor
};
let diff_sq = self
.diff_tensor
.mul(&self.diff_tensor)
.expect("tensor mul failed");
let diff_sq_eps = diff_sq.add(&eps_on_device).expect("tensor add failed");
let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
let gi = sign_diff.mul(grad_output).unwrap();
let gt = gi.neg();
vec![Some(gi), Some(gt)]
}
fn name(&self) -> &'static str {
"L1LossBackward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug)]
struct CrossEntropyBackward {
next_fns: Vec<Option<GradFn>>,
softmax_probs: Tensor<f32>,
targets: Tensor<f32>,
batch_size: usize,
num_classes: usize,
}
impl GradientFunction for CrossEntropyBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let softmax_vec = self.softmax_probs.to_vec();
let target_vec = self.targets.to_vec();
let grad_vec = grad_output.to_vec();
let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
let is_scalar_grad = grad_vec.len() == 1;
for b in 0..self.batch_size {
let grad_scale = if is_scalar_grad {
grad_vec[0]
} else if b < grad_vec.len() {
grad_vec[b]
} else {
1.0 / self.batch_size as f32
};
let offset = b * self.num_classes;
let tc = target_vec[b] as usize;
for c in 0..self.num_classes {
let mut g = softmax_vec[offset + c];
if c == tc {
g -= 1.0;
}
grad_input[offset + c] = g * grad_scale;
}
}
let mut grad_tensor = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
.expect("tensor creation failed");
if self.softmax_probs.device().is_gpu() {
grad_tensor = grad_tensor.to_device(self.softmax_probs.device()).unwrap();
}
vec![Some(grad_tensor)]
}
fn name(&self) -> &'static str {
"CrossEntropyBackward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug, Clone, Copy)]
pub struct CrossEntropyLoss {
reduction: Reduction,
}
impl CrossEntropyLoss {
pub fn new() -> Self {
Self {
reduction: Reduction::Mean,
}
}
pub fn with_reduction(reduction: Reduction) -> Self {
Self { reduction }
}
pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
let input_data = input.data();
let target_data = target.data();
let shape = input_data.shape().to_vec();
let batch_size = shape[0];
let num_classes = shape[1];
#[cfg(feature = "cuda")]
if input_data.device().is_gpu() {
let targets_gpu = if target_data.device().is_gpu() {
target_data.clone()
} else {
target_data.to_device(input_data.device()).unwrap()
};
let (loss_tensor, softmax_tensor) = input_data.cross_entropy_fwd_cuda(&targets_gpu);
let loss_var = if input.requires_grad() {
let grad_fn = GradFn::new(CrossEntropyBackward {
next_fns: vec![input.grad_fn().cloned()],
softmax_probs: softmax_tensor,
targets: targets_gpu,
batch_size,
num_classes,
});
Variable::from_operation(loss_tensor, grad_fn, true)
} else {
Variable::new(loss_tensor, false)
};
return match self.reduction {
Reduction::None => loss_var,
Reduction::Mean => loss_var.mean(),
Reduction::Sum => loss_var.sum(),
};
}
let input_vec = input_data.to_vec();
let target_vec = target_data.to_vec();
let mut losses = vec![0.0f32; batch_size];
let mut softmax_probs_vec = vec![0.0f32; batch_size * num_classes];
let mut target_classes = vec![0usize; batch_size];
for b in 0..batch_size {
let offset = b * num_classes;
let max_val = (0..num_classes)
.map(|c| input_vec[offset + c])
.fold(f32::NEG_INFINITY, f32::max);
let mut sum_exp = 0.0f32;
for c in 0..num_classes {
let exp_val = (input_vec[offset + c] - max_val).exp();
softmax_probs_vec[offset + c] = exp_val;
sum_exp += exp_val;
}
for c in 0..num_classes {
softmax_probs_vec[offset + c] /= sum_exp;
}
let log_sum_exp = max_val + sum_exp.ln();
let tc = target_vec[b] as usize;
target_classes[b] = tc;
losses[b] = log_sum_exp - input_vec[offset + tc];
}
let loss_tensor = Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
let softmax_tensor = Tensor::from_vec(softmax_probs_vec, &[batch_size, num_classes])
.expect("tensor creation failed");
let targets_f32: Vec<f32> = target_classes.iter().map(|&tc| tc as f32).collect();
let targets_tensor =
Tensor::from_vec(targets_f32, &[batch_size]).expect("tensor creation failed");
let loss_var = if input.requires_grad() {
let grad_fn = GradFn::new(CrossEntropyBackward {
next_fns: vec![input.grad_fn().cloned()],
softmax_probs: softmax_tensor,
targets: targets_tensor,
batch_size,
num_classes,
});
Variable::from_operation(loss_tensor, grad_fn, true)
} else {
Variable::new(loss_tensor, false)
};
match self.reduction {
Reduction::None => loss_var,
Reduction::Mean => loss_var.mean(),
Reduction::Sum => loss_var.sum(),
}
}
}
impl Default for CrossEntropyLoss {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy)]
pub struct NLLLoss {
reduction: Reduction,
}
impl NLLLoss {
pub fn new() -> Self {
Self {
reduction: Reduction::Mean,
}
}
pub fn with_reduction(reduction: Reduction) -> Self {
Self { reduction }
}
pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
let input_data = input.data();
let target_data = target.data();
let shape = input_data.shape().to_vec();
let batch_size = shape[0];
let num_classes = shape[1];
let target_vec = target_data.to_vec();
let input_vec = input_data.to_vec();
let mut losses = vec![0.0f32; batch_size];
for b in 0..batch_size {
let tc = target_vec[b] as usize;
losses[b] = -input_vec[b * num_classes + tc];
}
let mut loss_tensor =
Tensor::from_vec(losses, &[batch_size]).expect("tensor creation failed");
if input_data.device().is_gpu() {
loss_tensor = loss_tensor.to_device(input_data.device()).unwrap();
}
let requires_grad = input.requires_grad() && is_grad_enabled();
let loss_var = if requires_grad {
let grad_fn = GradFn::new(NLLLossBackward {
next_fns: vec![input.grad_fn().cloned()],
target_tensor: target_data.clone(),
batch_size,
num_classes,
});
Variable::from_operation(loss_tensor, grad_fn, true)
} else {
Variable::new(loss_tensor, false)
};
match self.reduction {
Reduction::None => loss_var,
Reduction::Mean => loss_var.mean(),
Reduction::Sum => loss_var.sum(),
}
}
}
impl Default for NLLLoss {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct NLLLossBackward {
next_fns: Vec<Option<GradFn>>,
target_tensor: Tensor<f32>,
batch_size: usize,
num_classes: usize,
}
impl GradientFunction for NLLLossBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let grad_out_vec = grad_output.to_vec();
let target_vec = self.target_tensor.to_vec();
let mut grad_input = vec![0.0f32; self.batch_size * self.num_classes];
for b in 0..self.batch_size {
let g = if grad_out_vec.len() == 1 {
grad_out_vec[0]
} else {
grad_out_vec[b]
};
let tc = target_vec[b] as usize;
grad_input[b * self.num_classes + tc] = -g;
}
let mut gi = Tensor::from_vec(grad_input, &[self.batch_size, self.num_classes])
.expect("tensor creation failed");
if grad_output.device().is_gpu() {
gi = gi.to_device(grad_output.device()).unwrap();
}
vec![Some(gi)]
}
fn name(&self) -> &'static str {
"NLLLossBackward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug, Clone, Copy)]
pub struct BCELoss {
reduction: Reduction,
}
impl BCELoss {
pub fn new() -> Self {
Self {
reduction: Reduction::Mean,
}
}
pub fn with_reduction(reduction: Reduction) -> Self {
Self { reduction }
}
pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
let input_data = input.data();
let target_data = target.data();
let eps = 1e-7f32;
let p_clamped = axonml_tensor::ops::clamp(&input_data, eps, 1.0 - eps);
let ln_p = p_clamped.ln();
let one_minus_p = p_clamped.neg().add_scalar(1.0);
let ln_one_minus_p = one_minus_p.ln();
let one_minus_t = target_data.neg().add_scalar(1.0);
let term1 = target_data.mul(&ln_p).expect("tensor mul failed");
let term2 = one_minus_t.mul(&ln_one_minus_p).expect("tensor mul failed");
let loss_tensor = term1.add(&term2).expect("tensor add failed").neg();
let requires_grad = input.requires_grad() && is_grad_enabled();
let loss_var = if requires_grad {
let grad_fn = GradFn::new(BCELossBackward {
next_fns: vec![input.grad_fn().cloned()],
input_tensor: input_data,
target_tensor: target_data,
});
Variable::from_operation(loss_tensor, grad_fn, true)
} else {
Variable::new(loss_tensor, false)
};
match self.reduction {
Reduction::None => loss_var,
Reduction::Mean => loss_var.mean(),
Reduction::Sum => loss_var.sum(),
}
}
}
impl Default for BCELoss {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct BCELossBackward {
next_fns: Vec<Option<GradFn>>,
input_tensor: Tensor<f32>,
target_tensor: Tensor<f32>,
}
impl GradientFunction for BCELossBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let eps = 1e-7f32;
let p_clamped = axonml_tensor::ops::clamp(&self.input_tensor, eps, 1.0 - eps);
let p_minus_y = p_clamped
.sub(&self.target_tensor)
.expect("tensor sub failed");
let one_minus_p = p_clamped.neg().add_scalar(1.0);
let denom = p_clamped.mul(&one_minus_p).expect("tensor mul failed");
let ratio = p_minus_y.div(&denom).unwrap();
let grad_tensor = grad_output.mul(&ratio).expect("tensor mul failed");
vec![Some(grad_tensor)]
}
fn name(&self) -> &'static str {
"BCELossBackward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug)]
struct BCEWithLogitsBackward {
next_fns: Vec<Option<GradFn>>,
input_tensor: Tensor<f32>,
target_tensor: Tensor<f32>,
}
impl GradientFunction for BCEWithLogitsBackward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let sig = self.input_tensor.sigmoid();
let sig_minus_t = sig.sub(&self.target_tensor).expect("tensor sub failed");
let grad_tensor = grad_output.mul(&sig_minus_t).expect("tensor mul failed");
vec![Some(grad_tensor)]
}
fn name(&self) -> &'static str {
"BCEWithLogitsBackward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug, Clone, Copy)]
pub struct BCEWithLogitsLoss {
reduction: Reduction,
}
impl BCEWithLogitsLoss {
pub fn new() -> Self {
Self {
reduction: Reduction::Mean,
}
}
pub fn with_reduction(reduction: Reduction) -> Self {
Self { reduction }
}
pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
let input_data = input.data();
let target_data = target.data();
let relu_x = axonml_tensor::ops::clamp_min(&input_data, 0.0);
let x_times_t = input_data.mul(&target_data).expect("tensor mul failed");
let neg_x = input_data.neg();
let relu_neg_x = axonml_tensor::ops::clamp_min(&neg_x, 0.0);
let abs_x = relu_x.add(&relu_neg_x).expect("tensor add failed");
let exp_neg_abs = abs_x.neg().exp();
let log_term = exp_neg_abs.add_scalar(1.0).ln();
let loss_tensor = relu_x
.sub(&x_times_t)
.expect("tensor sub failed")
.add(&log_term)
.expect("tensor add failed");
let loss_var = if input.requires_grad() {
let grad_fn = GradFn::new(BCEWithLogitsBackward {
next_fns: vec![input.grad_fn().cloned()],
input_tensor: input_data,
target_tensor: target_data,
});
Variable::from_operation(loss_tensor, grad_fn, true)
} else {
Variable::new(loss_tensor, false)
};
match self.reduction {
Reduction::None => loss_var,
Reduction::Mean => loss_var.mean(),
Reduction::Sum => loss_var.sum(),
}
}
}
impl Default for BCEWithLogitsLoss {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct SmoothL1Backward {
next_fns: Vec<Option<GradFn>>,
diff_tensor: Tensor<f32>,
beta: f32,
shape: Vec<usize>,
}
impl GradientFunction for SmoothL1Backward {
fn apply(&self, grad_output: &Tensor<f32>) -> Vec<Option<Tensor<f32>>> {
let eps = 1e-12f32;
let diff_sq = self
.diff_tensor
.mul(&self.diff_tensor)
.expect("tensor mul failed");
let diff_sq_eps = diff_sq.add_scalar(eps);
let abs_diff = diff_sq_eps.ln().mul_scalar(0.5).exp();
let sign_diff = self.diff_tensor.div(&abs_diff).unwrap();
let grad_l2 = self.diff_tensor.mul_scalar(1.0 / self.beta); let grad_l1 = sign_diff;
let abs_vec = abs_diff.to_vec();
let beta = self.beta;
let mask_vec: Vec<f32> = abs_vec
.iter()
.map(|&a| if a < beta { 1.0 } else { 0.0 })
.collect();
let mut mask = Tensor::from_vec(mask_vec, &self.shape).expect("tensor creation failed");
if self.diff_tensor.device().is_gpu() {
mask = mask.to_device(self.diff_tensor.device()).unwrap();
}
let inv_mask = mask.neg().add_scalar(1.0);
let blended = mask
.mul(&grad_l2)
.unwrap()
.add(&inv_mask.mul(&grad_l1).expect("tensor add failed"))
.unwrap();
let gi = blended.mul(grad_output).unwrap();
let gt = gi.neg();
vec![Some(gi), Some(gt)]
}
fn name(&self) -> &'static str {
"SmoothL1Backward"
}
fn next_functions(&self) -> &[Option<GradFn>] {
&self.next_fns
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[derive(Debug, Clone, Copy)]
pub struct SmoothL1Loss {
reduction: Reduction,
beta: f32,
}
impl SmoothL1Loss {
pub fn new() -> Self {
Self {
reduction: Reduction::Mean,
beta: 1.0,
}
}
pub fn with_beta(beta: f32) -> Self {
Self {
reduction: Reduction::Mean,
beta,
}
}
pub fn compute(&self, input: &Variable, target: &Variable) -> Variable {
let input_data = input.data();
let target_data = target.data();
let diff_tensor = input_data.sub(&target_data).expect("tensor sub failed");
let shape = diff_tensor.shape().to_vec();
let relu_diff = axonml_tensor::ops::clamp_min(&diff_tensor, 0.0);
let relu_neg_diff = axonml_tensor::ops::clamp_min(&diff_tensor.neg(), 0.0);
let abs_diff = relu_diff.add(&relu_neg_diff).expect("tensor add failed");
let diff_sq = diff_tensor.mul(&diff_tensor).expect("tensor mul failed");
let l2_loss = diff_sq.mul_scalar(0.5 / self.beta);
let l1_loss = abs_diff.add_scalar(-0.5 * self.beta);
let abs_vec = abs_diff.to_vec();
let beta = self.beta;
let mask_vec: Vec<f32> = abs_vec
.iter()
.map(|&a| if a < beta { 1.0 } else { 0.0 })
.collect();
let mut mask = Tensor::from_vec(mask_vec, &shape).expect("tensor creation failed");
if diff_tensor.device().is_gpu() {
mask = mask.to_device(diff_tensor.device()).unwrap();
}
let inv_mask = mask.neg().add_scalar(1.0);
let loss_tensor = mask
.mul(&l2_loss)
.unwrap()
.add(&inv_mask.mul(&l1_loss).expect("tensor add failed"))
.unwrap();
let loss_var = if input.requires_grad() || target.requires_grad() {
let grad_fn = GradFn::new(SmoothL1Backward {
next_fns: vec![input.grad_fn().cloned(), target.grad_fn().cloned()],
diff_tensor,
beta: self.beta,
shape,
});
Variable::from_operation(loss_tensor, grad_fn, true)
} else {
Variable::new(loss_tensor, false)
};
match self.reduction {
Reduction::None => loss_var,
Reduction::Mean => loss_var.mean(),
Reduction::Sum => loss_var.sum(),
}
}
}
impl Default for SmoothL1Loss {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mse_loss() {
let loss_fn = MSELoss::new();
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
false,
);
let target = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
false,
);
let loss = loss_fn.compute(&input, &target);
assert!((loss.data().to_vec()[0] - 0.0).abs() < 1e-6);
}
#[test]
fn test_mse_loss_nonzero() {
let loss_fn = MSELoss::new();
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("tensor creation failed"),
false,
);
let target = Variable::new(
Tensor::from_vec(vec![2.0, 3.0, 4.0], &[3]).expect("tensor creation failed"),
false,
);
let loss = loss_fn.compute(&input, &target);
assert!((loss.data().to_vec()[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_cross_entropy_loss() {
let loss_fn = CrossEntropyLoss::new();
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], &[2, 3])
.expect("tensor creation failed"),
false,
);
let target = Variable::new(
Tensor::from_vec(vec![2.0, 0.0], &[2]).expect("tensor creation failed"),
false,
);
let loss = loss_fn.compute(&input, &target);
assert!(loss.data().to_vec()[0] > 0.0);
}
#[test]
fn test_bce_loss() {
let loss_fn = BCELoss::new();
let input = Variable::new(
Tensor::from_vec(vec![0.5, 0.5], &[2]).expect("tensor creation failed"),
false,
);
let target = Variable::new(
Tensor::from_vec(vec![1.0, 0.0], &[2]).expect("tensor creation failed"),
false,
);
let loss = loss_fn.compute(&input, &target);
assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
}
#[test]
fn test_cross_entropy_gradient_flow() {
use axonml_autograd::backward;
let input = Variable::new(
Tensor::from_vec(vec![2.0, 1.0, 0.1, 0.5, 2.5, 0.3], &[2, 3])
.expect("tensor creation failed"),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![0.0, 1.0], &[2]).expect("tensor creation failed"),
false,
);
let loss_fn = CrossEntropyLoss::new();
let loss = loss_fn.compute(&input, &target);
let loss_val = loss.data().to_vec()[0];
assert!(loss_val > 0.0, "Loss should be positive, got {}", loss_val);
let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
backward(&loss, &ones);
let grad = input
.grad()
.expect("Input should have gradient after backward");
let grad_vec = grad.to_vec();
let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
assert!(
grad_norm > 1e-10,
"Gradient should be non-zero, got norm {}",
grad_norm
);
assert_eq!(grad.shape(), &[2, 3]);
assert!(
grad_vec[0] < 0.0,
"Gradient for correct class should be negative"
);
assert!(
grad_vec[4] < 0.0,
"Gradient for correct class should be negative"
);
assert!(
grad_vec[1] > 0.0,
"Gradient for wrong class should be positive"
);
assert!(
grad_vec[2] > 0.0,
"Gradient for wrong class should be positive"
);
}
#[test]
fn test_cross_entropy_perfect_prediction() {
let loss_fn = CrossEntropyLoss::new();
let input = Variable::new(
Tensor::from_vec(vec![10.0, -10.0, -10.0], &[1, 3]).expect("tensor creation failed"),
false,
);
let target = Variable::new(
Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
false,
);
let loss = loss_fn.compute(&input, &target);
assert!(
loss.data().to_vec()[0] < 0.001,
"Perfect prediction should have near-zero loss"
);
}
#[test]
fn test_cross_entropy_uniform_prediction() {
let loss_fn = CrossEntropyLoss::new();
let num_classes = 16;
let input = Variable::new(
Tensor::from_vec(vec![0.0; num_classes], &[1, num_classes])
.expect("tensor creation failed"),
false,
);
let target = Variable::new(
Tensor::from_vec(vec![0.0], &[1]).expect("tensor creation failed"),
false,
);
let loss = loss_fn.compute(&input, &target);
let expected = (num_classes as f32).ln(); let actual = loss.data().to_vec()[0];
assert!(
(actual - expected).abs() < 0.01,
"Uniform logits should give ln(C)={}, got {}",
expected,
actual,
);
}
#[test]
fn test_bce_with_logits_gradient_flow() {
use axonml_autograd::backward;
let input = Variable::new(
Tensor::from_vec(vec![0.5, -0.5, 1.0, -1.0], &[4]).expect("tensor creation failed"),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).expect("tensor creation failed"),
false,
);
let loss_fn = BCEWithLogitsLoss::new();
let loss = loss_fn.compute(&input, &target);
assert!(loss.data().to_vec()[0] > 0.0);
let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
backward(&loss, &ones);
let grad = input.grad().expect("Input should have gradient");
let grad_vec = grad.to_vec();
assert_eq!(grad_vec.len(), 4);
assert!(grad_vec[0] < 0.0);
assert!(grad_vec[1] > 0.0);
}
#[test]
fn test_smooth_l1_gradient_flow() {
use axonml_autograd::backward;
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 5.0], &[3]).expect("tensor creation failed"),
true,
);
let target = Variable::new(
Tensor::from_vec(vec![1.5, 1.5, 1.5], &[3]).expect("tensor creation failed"),
false,
);
let loss_fn = SmoothL1Loss::new();
let loss = loss_fn.compute(&input, &target);
assert!(loss.data().to_vec()[0] > 0.0);
let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
backward(&loss, &ones);
let grad = input.grad().expect("Input should have gradient");
let grad_vec = grad.to_vec();
assert_eq!(grad_vec.len(), 3);
let grad_norm: f32 = grad_vec.iter().map(|g| g * g).sum();
assert!(grad_norm > 1e-10);
}
#[test]
fn test_mse_loss_gradient_correctness() {
use axonml_autograd::backward;
let input = Variable::new(Tensor::from_vec(vec![3.0, 1.0], &[2]).unwrap(), true);
let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
let loss = MSELoss::new().compute(&input, &target);
assert!(
(loss.data().to_vec()[0] - 2.0).abs() < 1e-5,
"MSE should be 2.0"
);
let ones = Tensor::from_vec(vec![1.0], &loss.shape()).unwrap();
backward(&loss, &ones);
let grad = input.grad().expect("Should have gradient");
let gv = grad.to_vec();
assert!(
(gv[0] - 2.0).abs() < 0.1,
"Grad[0] should be ~2.0, got {}",
gv[0]
);
assert!(gv[1].abs() < 0.1, "Grad[1] should be ~0.0, got {}", gv[1]);
}
#[test]
fn test_mse_loss_reduction_sum() {
let input = Variable::new(Tensor::from_vec(vec![2.0, 4.0], &[2]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![1.0, 1.0], &[2]).unwrap(), false);
let loss = MSELoss::with_reduction(Reduction::Sum).compute(&input, &target);
assert!((loss.data().to_vec()[0] - 10.0).abs() < 1e-5);
}
#[test]
fn test_l1_loss_basic() {
let input = Variable::new(Tensor::from_vec(vec![1.0, 5.0, 3.0], &[3]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0, 4.0], &[3]).unwrap(), false);
let loss = L1Loss::new().compute(&input, &target);
assert!((loss.data().to_vec()[0] - 4.0 / 3.0).abs() < 1e-4);
}
#[test]
fn test_l1_loss_zero() {
let input = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![1.0, 2.0], &[2]).unwrap(), false);
let loss = L1Loss::new().compute(&input, &target);
assert!(
loss.data().to_vec()[0].abs() < 1e-6,
"Identical inputs should give 0 loss"
);
}
#[test]
fn test_bce_loss_perfect_prediction() {
let loss_fn = BCELoss::new();
let input = Variable::new(Tensor::from_vec(vec![0.999, 0.001], &[2]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
let loss = loss_fn.compute(&input, &target);
assert!(
loss.data().to_vec()[0] < 0.01,
"Perfect prediction should have near-zero loss"
);
}
#[test]
fn test_bce_loss_worst_prediction() {
let loss_fn = BCELoss::new();
let input = Variable::new(Tensor::from_vec(vec![0.001, 0.999], &[2]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
let loss = loss_fn.compute(&input, &target);
assert!(
loss.data().to_vec()[0] > 3.0,
"Worst prediction should have high loss"
);
}
#[test]
fn test_bce_with_logits_numerical_stability() {
let loss_fn = BCEWithLogitsLoss::new();
let input = Variable::new(
Tensor::from_vec(vec![100.0, -100.0, 50.0, -50.0], &[4]).unwrap(),
false,
);
let target = Variable::new(
Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0], &[4]).unwrap(),
false,
);
let loss = loss_fn.compute(&input, &target);
let val = loss.data().to_vec()[0];
assert!(
val.is_finite(),
"Loss should be finite for large logits, got {}",
val
);
assert!(val >= 0.0, "BCE loss should be non-negative");
}
#[test]
fn test_bce_with_logits_zero_logits() {
let loss_fn = BCEWithLogitsLoss::new();
let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0], &[2]).unwrap(), false);
let loss = loss_fn.compute(&input, &target);
assert!((loss.data().to_vec()[0] - 0.693).abs() < 0.01);
}
#[test]
fn test_bce_with_logits_reduction_none() {
let loss_fn = BCEWithLogitsLoss::with_reduction(Reduction::None);
let input = Variable::new(Tensor::from_vec(vec![0.0, 0.0, 0.0], &[3]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![1.0, 0.0, 1.0], &[3]).unwrap(), false);
let loss = loss_fn.compute(&input, &target);
assert_eq!(loss.shape().len(), 1);
assert_eq!(loss.shape()[0], 3);
}
#[test]
fn test_smooth_l1_small_error() {
let loss_fn = SmoothL1Loss::new();
let input = Variable::new(Tensor::from_vec(vec![1.0], &[1]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![1.3], &[1]).unwrap(), false);
let loss = loss_fn.compute(&input, &target);
assert!((loss.data().to_vec()[0] - 0.045).abs() < 0.01);
}
#[test]
fn test_smooth_l1_large_error() {
let loss_fn = SmoothL1Loss::new();
let input = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![5.0], &[1]).unwrap(), false);
let loss = loss_fn.compute(&input, &target);
assert!((loss.data().to_vec()[0] - 4.5).abs() < 0.1);
}
#[test]
fn test_cross_entropy_batch_independence() {
let loss_fn = CrossEntropyLoss::new();
let input1 = Variable::new(
Tensor::from_vec(vec![2.0, 1.0, 0.1], &[1, 3]).unwrap(),
false,
);
let target1 = Variable::new(Tensor::from_vec(vec![0.0], &[1]).unwrap(), false);
let loss1 = loss_fn.compute(&input1, &target1).data().to_vec()[0];
let input2 = Variable::new(
Tensor::from_vec(vec![2.0, 1.0, 0.1, 2.0, 1.0, 0.1], &[2, 3]).unwrap(),
false,
);
let target2 = Variable::new(Tensor::from_vec(vec![0.0, 0.0], &[2]).unwrap(), false);
let loss2 = loss_fn.compute(&input2, &target2).data().to_vec()[0];
assert!(
(loss1 - loss2).abs() < 1e-5,
"Duplicated batch should give same loss: {} vs {}",
loss1,
loss2
);
}
#[test]
fn test_cross_entropy_high_class_count() {
let n_classes = 100;
let mut logits = vec![0.0f32; n_classes];
logits[42] = 5.0;
let loss_fn = CrossEntropyLoss::new();
let input = Variable::new(Tensor::from_vec(logits, &[1, n_classes]).unwrap(), false);
let target = Variable::new(Tensor::from_vec(vec![42.0], &[1]).unwrap(), false);
let loss = loss_fn.compute(&input, &target);
let val = loss.data().to_vec()[0];
assert!(val.is_finite(), "Should handle 100 classes");
assert!(val < 1.0, "Correct class should have low loss, got {}", val);
}
}