candle_nn/loss.rs
1//! Loss Calculations
2//!
3use candle::{Result, Tensor};
4
5/// The negative log likelihood loss.
6///
7/// Arguments
8///
9/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
10/// of categories. This is expected to contain log probabilities.
11/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
12///
13/// The resulting tensor is a scalar containing the average value over the batch.
14pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
15 let b_sz = match target.dims() {
16 &[b_sz] => b_sz,
17 dims => candle::bail!("the target tensor should have a single dimension ({dims:?})"),
18 };
19 match inp.dims() {
20 &[inp_b_sz, _] => {
21 if inp_b_sz != b_sz {
22 candle::bail!("batch size mismatch between inp ({inp_b_sz}) and target ({b_sz})")
23 }
24 }
25 dims => candle::bail!("the target tensor should have two dimensions ({dims:?})"),
26 }
27 inp.gather(&target.unsqueeze(1)?, 1)?
28 .sum_all()?
29 .affine(-1f64 / b_sz as f64, 0.)
30}
31
32/// The cross-entropy loss.
33///
34/// Arguments
35///
36/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
37/// of categories. This is expected to raw logits.
38/// * [target]: The ground truth labels as a tensor of u32 of dimension `N`.
39///
40/// The resulting tensor is a scalar containing the average value over the batch.
41pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
42 if inp.rank() != 2 {
43 candle::bail!("cross_entropy expects an input tensor of rank 2")
44 }
45 let inp = crate::ops::log_softmax(inp, 1)?;
46 nll(&inp, target)
47}
48
49/// The mean squared error loss.
50pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
51 (inp - target)?.sqr()?.mean_all()
52}
53
54/// The binary cross-entropy with logit loss.
55///
56/// Arguments
57///
58/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
59/// of categories. This is expected to raw logits.
60/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
61/// of categories.
62///
63/// The resulting tensor is a scalar containing the average value over the batch.
64pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
65 let inp = crate::ops::sigmoid(inp)?;
66
67 let left_side = target * inp.log()?;
68 let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
69
70 let loss = left_side? + right_side?;
71 let loss = loss?.neg()?.mean_all()?;
72
73 Ok(loss)
74}
75
76/// HuberLoss
77///
78/// A robust loss function that combines `MAE` and `MSE` losses:
79///
80/// - When the absolute element-wise error is less than `delta`, it uses a squared term (MSE loss).
81/// - When the absolute element-wise error is greater than or equal to `delta`, it uses a linear term (MAE loss scaled by `delta`).
82/// # Formula
83///
84/// HuberLoss =
85/// ```tex
86/// 0.5(x_n - y_n)^2, & |x_n - y_n| < delta
87/// delta(|x_n - y_n| - 0.5delta), & |x_n - y_n| >= delta
88/// ```
89pub fn huber(inp: &Tensor, target: &Tensor, delta: f64) -> Result<Tensor> {
90 if inp.dims() != target.dims() {
91 candle::bail!(
92 "input and target must have the same shape, got inp: {:?}, target: {:?}",
93 inp.dims(),
94 target.dims()
95 );
96 }
97 let diff = (inp - target)?;
98 let abs_diff = diff.abs()?;
99 let mask = abs_diff.le(delta)?;
100 let squared_loss = ((&diff * &diff)? * 0.5)?;
101 let linear_loss = ((abs_diff * delta)? - 0.5 * delta.powi(2))?;
102 let loss = mask.where_cond(&squared_loss, &linear_loss)?;
103 loss.mean_all()
104}