use crate::autograd::Variable;
use crate::tensor::{Result, Tensor, TensorError};
pub fn mse_loss(pred: &Variable, target: &Variable) -> Result<Variable> {
let result = pred.data().mse_loss(&target.data(), 1)?; Ok(Variable::wrap(result))
}
pub fn cross_entropy_loss(pred: &Variable, target: &Variable) -> Result<Variable> {
let pred_shape = pred.shape();
if pred_shape.len() != 2 {
return Err(TensorError::new("cross_entropy_loss: pred must be 2D [batch, classes]"));
}
let result = pred.data().cross_entropy_loss(
&target.data(), 1, -100, 0.0, )?;
Ok(Variable::wrap(result))
}
pub fn bce_loss(pred: &Variable, target: &Variable) -> Result<Variable> {
let result = pred.data().bce_loss(&target.data(), 1)?; Ok(Variable::wrap(result))
}
pub fn bce_with_logits_loss(pred: &Variable, target: &Variable) -> Result<Variable> {
let result = pred.data().bce_with_logits_loss(&target.data(), 1)?; Ok(Variable::wrap(result))
}
pub fn l1_loss(pred: &Variable, target: &Variable) -> Result<Variable> {
let result = pred.data().l1_loss(&target.data(), 1)?; Ok(Variable::wrap(result))
}
pub fn smooth_l1_loss(pred: &Variable, target: &Variable, beta: f64) -> Result<Variable> {
if beta <= 0.0 {
return Err(TensorError::new("smooth_l1_loss: beta must be positive"));
}
let result = pred.data().smooth_l1_loss(&target.data(), 1, beta)?; Ok(Variable::wrap(result))
}
pub fn kl_div_loss(input: &Variable, target: &Variable) -> Result<Variable> {
let shape = input.shape();
if shape.is_empty() {
return Err(TensorError::new("kl_div_loss: input must be at least 1D"));
}
let batch = shape[0] as f64;
let sum = input.data().kl_div_loss(&target.data(), 2, false)?; let result = sum.mul_scalar(1.0 / batch)?;
Ok(Variable::wrap(result))
}
pub fn nll_loss(input: &Variable, target: &Variable) -> Result<Variable> {
let result = input.data().nll_loss(&target.data(), 1, -100)?; Ok(Variable::wrap(result))
}
pub fn ctc_loss(
log_probs: &Variable,
targets: &Variable,
input_lengths: &Variable,
target_lengths: &Variable,
blank: i64,
) -> Result<Variable> {
let result = log_probs.data().ctc_loss(
&targets.data(), &input_lengths.data(), &target_lengths.data(),
blank, 1, )?;
Ok(Variable::wrap(result))
}
pub fn focal_loss(
pred: &Variable,
target: &Variable,
alpha: f64,
gamma: f64,
) -> Result<Variable> {
let log_p = pred.log_softmax(-1)?;
let p = log_p.exp()?;
let target_expanded = target.unsqueeze(-1)?;
let idx = target_expanded.data();
let log_pt = log_p.gather(1, &idx)?.squeeze(-1)?;
let pt = p.gather(1, &idx)?.squeeze(-1)?;
let one_minus_pt = pt.neg()?.add_scalar(1.0)?;
let focal_weight = one_minus_pt.pow_scalar(gamma)?.mul_scalar(alpha)?;
let loss = focal_weight.mul(&log_pt)?.neg()?;
loss.mean()
}
pub fn triplet_margin_loss(
anchor: &Variable,
positive: &Variable,
negative: &Variable,
margin: f64,
) -> Result<Variable> {
let d_pos = anchor.sub(positive)?.pow_scalar(2.0)?.sum_dim(1, false)?.sqrt()?;
let d_neg = anchor.sub(negative)?.pow_scalar(2.0)?.sum_dim(1, false)?.sqrt()?;
let diff = d_pos.sub(&d_neg)?.add_scalar(margin)?;
let zero = Variable::wrap(Tensor::zeros(&diff.shape(), crate::tensor::TensorOptions {
dtype: crate::tensor::DType::Float32,
device: diff.data().device(),
})?);
diff.maximum(&zero)?.mean()
}
pub fn cosine_embedding_loss(
x1: &Variable,
x2: &Variable,
label: &Variable,
margin: f64,
) -> Result<Variable> {
let cos = x1.cosine_similarity(x2, 1, 1e-8)?;
let ones = Variable::wrap(Tensor::ones(&cos.shape(), crate::tensor::TensorOptions {
dtype: crate::tensor::DType::Float32,
device: cos.data().device(),
})?);
let loss_pos = ones.sub(&cos)?;
let cos_minus_margin = cos.add_scalar(-margin)?;
let zero = Variable::wrap(Tensor::zeros(&cos.shape(), crate::tensor::TensorOptions {
dtype: crate::tensor::DType::Float32,
device: cos.data().device(),
})?);
let loss_neg = cos_minus_margin.maximum(&zero)?;
let label_pos = label.add_scalar(1.0)?.mul_scalar(0.5)?;
let label_neg = label.neg()?.add_scalar(1.0)?.mul_scalar(0.5)?;
let loss = label_pos.mul(&loss_pos)?.add(&label_neg.mul(&loss_neg)?)?;
loss.mean()
}
pub fn hinge_embedding_loss(
input: &Variable,
label: &Variable,
margin: f64,
) -> Result<Variable> {
let margin_minus_x = input.neg()?.add_scalar(margin)?;
let zero = Variable::wrap(Tensor::zeros(&input.shape(), crate::tensor::TensorOptions {
dtype: crate::tensor::DType::Float32,
device: input.data().device(),
})?);
let loss_neg = margin_minus_x.maximum(&zero)?;
let label_pos = label.add_scalar(1.0)?.mul_scalar(0.5)?;
let label_neg = label.neg()?.add_scalar(1.0)?.mul_scalar(0.5)?;
let loss = label_pos.mul(input)?.add(&label_neg.mul(&loss_neg)?)?;
loss.mean()
}
pub fn margin_ranking_loss(
x1: &Variable,
x2: &Variable,
label: &Variable,
margin: f64,
) -> Result<Variable> {
let diff = x1.sub(x2)?;
let neg_label_diff = label.neg()?.mul(&diff)?;
let shifted = neg_label_diff.add_scalar(margin)?;
let zero = Variable::wrap(Tensor::zeros(&shifted.shape(), crate::tensor::TensorOptions {
dtype: crate::tensor::DType::Float32,
device: shifted.data().device(),
})?);
shifted.maximum(&zero)?.mean()
}
pub fn poisson_nll_loss(input: &Variable, target: &Variable, log_input: bool) -> Result<Variable> {
let loss = if log_input {
input.exp()?.sub(&target.mul(input)?)?
} else {
let log_in = input.add_scalar(1e-8)?.log()?;
input.sub(&target.mul(&log_in)?)?
};
loss.mean()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{Tensor, test_device, test_opts};
#[test]
fn test_nll_loss() {
let log_probs = Tensor::from_f32(
&[-0.5, -1.5, -2.5, -1.0, -0.5, -2.0],
&[2, 3], test_device(),
).unwrap();
let targets = Tensor::from_i64(&[0, 1], &[2], test_device()).unwrap();
let input = Variable::new(log_probs, false);
let target = Variable::new(targets, false);
let loss = nll_loss(&input, &target).unwrap();
assert!(loss.item().unwrap() > 0.0);
}
#[test]
fn test_focal_loss() {
let logits = Tensor::from_f32(
&[2.0, 0.5, -1.0, -0.5, 1.5, 0.3],
&[2, 3], test_device(),
).unwrap();
let targets = Tensor::from_i64(&[0, 1], &[2], test_device()).unwrap();
let pred = Variable::new(logits, false);
let target = Variable::new(targets, false);
let loss = focal_loss(&pred, &target, 0.25, 2.0).unwrap();
assert!(loss.item().unwrap() > 0.0);
}
#[test]
fn test_triplet_margin_loss() {
let a = Variable::new(Tensor::from_f32(&[1.0, 0.0], &[1, 2], test_device()).unwrap(), false);
let p = Variable::new(Tensor::from_f32(&[0.9, 0.1], &[1, 2], test_device()).unwrap(), false);
let n = Variable::new(Tensor::from_f32(&[0.0, 1.0], &[1, 2], test_device()).unwrap(), false);
let loss = triplet_margin_loss(&a, &p, &n, 2.0).unwrap();
assert!(loss.item().unwrap() > 0.0);
}
#[test]
fn test_poisson_nll_loss() {
let input = Variable::new(Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap(), false);
let target = Variable::new(Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap(), false);
let loss = poisson_nll_loss(&input, &target, true).unwrap();
assert!(loss.item().unwrap() > 0.0);
}
#[test]
fn test_mse_loss() {
let pred = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap(), true,
);
let target = Variable::new(
Tensor::from_f32(&[1.5, 2.5, 3.5], &[3], test_device()).unwrap(), false,
);
let loss = mse_loss(&pred, &target).unwrap();
assert!((loss.item().unwrap() - 0.25).abs() < 1e-5);
}
#[test]
fn test_mse_loss_zero() {
let pred = Variable::new(
Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap(), false,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 2.0], &[2], test_device()).unwrap(), false,
);
let loss = mse_loss(&pred, &target).unwrap();
assert!(loss.item().unwrap().abs() < 1e-6);
}
#[test]
fn test_cross_entropy_loss_class_indices() {
let logits = Variable::new(
Tensor::from_f32(&[10.0, 0.0, 0.0, 0.0, 0.0, 10.0], &[2, 3], test_device()).unwrap(),
true,
);
let targets = Variable::new(
Tensor::from_i64(&[0, 2], &[2], test_device()).unwrap(), false,
);
let loss = cross_entropy_loss(&logits, &targets).unwrap();
assert!(loss.item().unwrap() < 0.1);
}
#[test]
fn test_cross_entropy_loss_wrong_predictions() {
let logits = Variable::new(
Tensor::from_f32(&[0.0, 0.0, 10.0, 10.0, 0.0, 0.0], &[2, 3], test_device()).unwrap(),
false,
);
let targets = Variable::new(
Tensor::from_i64(&[0, 2], &[2], test_device()).unwrap(), false,
);
let loss = cross_entropy_loss(&logits, &targets).unwrap();
assert!(loss.item().unwrap() > 5.0);
}
#[test]
fn test_cross_entropy_loss_gradient() {
let logits = Variable::new(
Tensor::from_f32(&[2.0, 1.0, 0.1, 0.5, 1.5, 0.3], &[2, 3], test_device()).unwrap(),
true,
);
let targets = Variable::new(
Tensor::from_i64(&[0, 1], &[2], test_device()).unwrap(), false,
);
let loss = cross_entropy_loss(&logits, &targets).unwrap();
loss.backward().unwrap();
assert!(logits.grad().is_some());
}
#[test]
fn test_bce_loss() {
let pred = Variable::new(
Tensor::from_f32(&[0.9, 0.1, 0.8, 0.2], &[4], test_device()).unwrap(), false,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 0.0, 1.0, 0.0], &[4], test_device()).unwrap(), false,
);
let loss = bce_loss(&pred, &target).unwrap();
assert!(loss.item().unwrap() < 0.3);
}
#[test]
fn test_bce_with_logits_loss() {
let pred = Variable::new(
Tensor::from_f32(&[5.0, -5.0, 5.0, -5.0], &[4], test_device()).unwrap(), true,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 0.0, 1.0, 0.0], &[4], test_device()).unwrap(), false,
);
let loss = bce_with_logits_loss(&pred, &target).unwrap();
assert!(loss.item().unwrap() < 0.1);
}
#[test]
fn test_bce_with_logits_gradient() {
let pred = Variable::new(
Tensor::from_f32(&[0.5, -0.5], &[2], test_device()).unwrap(), true,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 0.0], &[2], test_device()).unwrap(), false,
);
let loss = bce_with_logits_loss(&pred, &target).unwrap();
loss.backward().unwrap();
assert!(pred.grad().is_some());
}
#[test]
fn test_l1_loss() {
let pred = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], test_device()).unwrap(), false,
);
let target = Variable::new(
Tensor::from_f32(&[1.5, 2.5, 3.5], &[3], test_device()).unwrap(), false,
);
let loss = l1_loss(&pred, &target).unwrap();
assert!((loss.item().unwrap() - 0.5).abs() < 1e-5);
}
#[test]
fn test_smooth_l1_loss() {
let pred = Variable::new(
Tensor::from_f32(&[1.0, 5.0], &[2], test_device()).unwrap(), false,
);
let target = Variable::new(
Tensor::from_f32(&[1.5, 2.0], &[2], test_device()).unwrap(), false,
);
let loss = smooth_l1_loss(&pred, &target, 1.0).unwrap();
assert!((loss.item().unwrap() - 1.3125).abs() < 1e-4);
}
#[test]
fn test_smooth_l1_loss_negative_beta() {
let pred = Variable::new(
Tensor::from_f32(&[1.0], &[1], test_device()).unwrap(), false,
);
let target = Variable::new(
Tensor::from_f32(&[2.0], &[1], test_device()).unwrap(), false,
);
assert!(smooth_l1_loss(&pred, &target, -1.0).is_err());
}
#[test]
fn test_kl_div_loss() {
let logits = Tensor::from_f32(&[2.0, 1.0, 0.1, 0.5, 1.5, 0.3], &[2, 3], test_device()).unwrap();
let log_probs = Variable::new(logits.log_softmax(1).unwrap(), false);
let target = Variable::new(
Tensor::from_f32(&[1.0/3.0; 6], &[2, 3], test_device()).unwrap(), false,
);
let loss = kl_div_loss(&log_probs, &target).unwrap();
assert!(loss.item().unwrap() >= -1e-5);
}
#[test]
fn test_ctc_loss() {
let dev = test_device();
let log_probs = Variable::new(
Tensor::randn(&[5, 1, 4], test_opts()).unwrap()
.log_softmax(2).unwrap(),
false,
);
let targets = Variable::new(
Tensor::from_i64(&[1, 2, 3], &[1, 3], dev).unwrap(), false,
);
let input_lengths = Variable::new(
Tensor::from_i64(&[5], &[1], dev).unwrap(), false,
);
let target_lengths = Variable::new(
Tensor::from_i64(&[3], &[1], dev).unwrap(), false,
);
let loss = ctc_loss(&log_probs, &targets, &input_lengths, &target_lengths, 0).unwrap();
assert!(loss.item().unwrap() > 0.0);
}
#[test]
fn test_cosine_embedding_loss_similar() {
let dev = test_device();
let x1 = Variable::new(Tensor::from_f32(&[1.0, 0.0, 0.0], &[1, 3], dev).unwrap(), false);
let x2 = Variable::new(Tensor::from_f32(&[1.0, 0.0, 0.0], &[1, 3], dev).unwrap(), false);
let label = Variable::new(Tensor::from_f32(&[1.0], &[1], dev).unwrap(), false);
let loss = cosine_embedding_loss(&x1, &x2, &label, 0.0).unwrap();
assert!(loss.item().unwrap().abs() < 1e-4);
}
#[test]
fn test_cosine_embedding_loss_dissimilar() {
let dev = test_device();
let x1 = Variable::new(Tensor::from_f32(&[1.0, 0.0, 0.0], &[1, 3], dev).unwrap(), false);
let x2 = Variable::new(Tensor::from_f32(&[-1.0, 0.0, 0.0], &[1, 3], dev).unwrap(), false);
let label = Variable::new(Tensor::from_f32(&[-1.0], &[1], dev).unwrap(), false);
let loss = cosine_embedding_loss(&x1, &x2, &label, 0.0).unwrap();
assert!(loss.item().unwrap().abs() < 1e-4);
}
#[test]
fn test_hinge_embedding_loss() {
let dev = test_device();
let input = Variable::new(Tensor::from_f32(&[0.5, 0.3], &[2], dev).unwrap(), false);
let label = Variable::new(Tensor::from_f32(&[1.0, 1.0], &[2], dev).unwrap(), false);
let loss = hinge_embedding_loss(&input, &label, 1.0).unwrap();
assert!((loss.item().unwrap() - 0.4).abs() < 1e-4);
}
#[test]
fn test_hinge_embedding_loss_negative() {
let dev = test_device();
let input = Variable::new(Tensor::from_f32(&[0.5, 2.0], &[2], dev).unwrap(), false);
let label = Variable::new(Tensor::from_f32(&[-1.0, -1.0], &[2], dev).unwrap(), false);
let loss = hinge_embedding_loss(&input, &label, 1.0).unwrap();
assert!((loss.item().unwrap() - 0.25).abs() < 1e-4);
}
#[test]
fn test_margin_ranking_loss() {
let dev = test_device();
let x1 = Variable::new(Tensor::from_f32(&[1.0, 3.0], &[2], dev).unwrap(), false);
let x2 = Variable::new(Tensor::from_f32(&[2.0, 1.0], &[2], dev).unwrap(), false);
let label = Variable::new(Tensor::from_f32(&[1.0, 1.0], &[2], dev).unwrap(), false);
let loss = margin_ranking_loss(&x1, &x2, &label, 0.0).unwrap();
assert!((loss.item().unwrap() - 0.5).abs() < 1e-4);
}
#[test]
fn test_margin_ranking_loss_with_margin() {
let dev = test_device();
let x1 = Variable::new(Tensor::from_f32(&[3.0], &[1], dev).unwrap(), false);
let x2 = Variable::new(Tensor::from_f32(&[1.0], &[1], dev).unwrap(), false);
let label = Variable::new(Tensor::from_f32(&[1.0], &[1], dev).unwrap(), false);
let loss = margin_ranking_loss(&x1, &x2, &label, 3.0).unwrap();
assert!((loss.item().unwrap() - 1.0).abs() < 1e-4);
}
#[test]
fn test_poisson_nll_loss_no_log() {
let dev = test_device();
let input = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], dev).unwrap(), false,
);
let target = Variable::new(
Tensor::from_f32(&[1.0, 2.0, 3.0], &[3], dev).unwrap(), false,
);
let loss = poisson_nll_loss(&input, &target, false).unwrap();
assert!(loss.item().unwrap().is_finite());
}
#[test]
fn test_focal_loss_reduces_to_ce_at_gamma_zero() {
let dev = test_device();
let logits = Variable::new(
Tensor::from_f32(&[2.0, 0.5, -1.0, -0.5, 1.5, 0.3], &[2, 3], dev).unwrap(),
false,
);
let targets = Variable::new(
Tensor::from_i64(&[0, 1], &[2], dev).unwrap(), false,
);
let fl = focal_loss(&logits, &targets, 1.0, 0.0).unwrap().item().unwrap();
let ce = cross_entropy_loss(&logits, &targets).unwrap().item().unwrap();
assert!((fl - ce).abs() < 1e-4, "focal(gamma=0, alpha=1) = {fl} != ce = {ce}");
}
#[test]
fn test_triplet_margin_loss_zero_when_far() {
let dev = test_device();
let a = Variable::new(Tensor::from_f32(&[1.0, 0.0], &[1, 2], dev).unwrap(), false);
let p = Variable::new(Tensor::from_f32(&[1.0, 0.1], &[1, 2], dev).unwrap(), false);
let n = Variable::new(Tensor::from_f32(&[0.0, 10.0], &[1, 2], dev).unwrap(), false);
let loss = triplet_margin_loss(&a, &p, &n, 0.1).unwrap();
assert!(loss.item().unwrap() < 1e-4);
}
}