use super::helpers::{all_dims, batch_size, prepare_targets};
use crate::error::{Error, Result};
use numr::autograd::{
Var, var_add, var_gather, var_log_softmax, var_mean, var_mul_scalar, var_neg, var_reshape,
};
use numr::dtype::DType;
use numr::ops::{ActivationOps, BinaryOps, IndexingOps, ReduceOps, ScalarOps, UnaryOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn cross_entropy_loss<R, C>(client: &C, logits: &Var<R>, targets: &Tensor<R>) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ IndexingOps<R>,
R::Client: ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ IndexingOps<R>,
{
let ndim = logits.shape().len();
if ndim < 2 {
return Err(Error::InvalidArgument {
arg: "logits",
reason: format!("expected at least 2 dims, got {ndim}"),
});
}
let vocab_size = logits.shape()[ndim - 1];
let n = batch_size(logits.shape());
let log_probs = var_log_softmax(logits, -1, client).map_err(Error::Numr)?;
let log_probs_flat = var_reshape(&log_probs, &[n, vocab_size]).map_err(Error::Numr)?;
let targets_expanded = prepare_targets(targets, n)?;
let selected =
var_gather(&log_probs_flat, 1, &targets_expanded, client).map_err(Error::Numr)?;
let neg_selected = var_neg(&selected, client).map_err(Error::Numr)?;
let loss = var_mean(
&neg_selected,
&all_dims(neg_selected.shape().len()),
false,
client,
)
.map_err(Error::Numr)?;
Ok(loss)
}
pub fn cross_entropy_loss_smooth<R, C>(
client: &C,
logits: &Var<R>,
targets: &Tensor<R>,
smoothing: f64,
) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>
+ ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ IndexingOps<R>,
R::Client: ActivationOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ ReduceOps<R>
+ ScalarOps<R>
+ IndexingOps<R>,
{
if smoothing == 0.0 {
return cross_entropy_loss(client, logits, targets);
}
let ndim = logits.shape().len();
if ndim < 2 {
return Err(Error::InvalidArgument {
arg: "logits",
reason: format!("expected at least 2 dims, got {ndim}"),
});
}
let vocab_size = logits.shape()[ndim - 1];
let n = batch_size(logits.shape());
let log_probs = var_log_softmax(logits, -1, client).map_err(Error::Numr)?;
let log_probs_flat = var_reshape(&log_probs, &[n, vocab_size]).map_err(Error::Numr)?;
let targets_expanded = prepare_targets(targets, n)?;
let selected =
var_gather(&log_probs_flat, 1, &targets_expanded, client).map_err(Error::Numr)?;
let nll = var_neg(
&var_mean(&selected, &all_dims(selected.shape().len()), false, client)
.map_err(Error::Numr)?,
client,
)
.map_err(Error::Numr)?;
let uniform_loss = var_neg(
&var_mean(
&log_probs_flat,
&all_dims(log_probs_flat.shape().len()),
false,
client,
)
.map_err(Error::Numr)?,
client,
)
.map_err(Error::Numr)?;
let nll_scaled = var_mul_scalar(&nll, 1.0 - smoothing, client).map_err(Error::Numr)?;
let uni_scaled = var_mul_scalar(&uniform_loss, smoothing, client).map_err(Error::Numr)?;
let loss = var_add(&nll_scaled, &uni_scaled, client).map_err(Error::Numr)?;
Ok(loss)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
#[test]
fn test_cross_entropy_basic() {
let (client, device) = cpu_setup();
#[rustfmt::skip]
let logits = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[2.0f32, 1.0, 0.1, 0.1, 2.0, 1.0], &[2, 3],
&device,
),
true,
);
let targets = Tensor::<CpuRuntime>::from_slice(&[0i64, 1], &[2], &device);
let loss = cross_entropy_loss(&client, &logits, &targets).unwrap();
assert_eq!(loss.shape(), &[] as &[usize]);
let val: Vec<f32> = loss.tensor().to_vec();
assert!(
val[0] < 1.0,
"loss={} should be < 1.0 for correct predictions",
val[0]
);
}
#[test]
fn test_cross_entropy_wrong_predictions() {
let (client, device) = cpu_setup();
let logits = Var::new(
Tensor::<CpuRuntime>::from_slice(
&[
0.1f32, 0.1, 2.0, 2.0, 0.1, 0.1, ],
&[2, 3],
&device,
),
false,
);
let targets = Tensor::<CpuRuntime>::from_slice(&[0i64, 1], &[2], &device);
let loss = cross_entropy_loss(&client, &logits, &targets).unwrap();
let val: Vec<f32> = loss.tensor().to_vec();
assert!(
val[0] > 1.0,
"loss={} should be > 1.0 for wrong predictions",
val[0]
);
}
#[test]
fn test_label_smoothing_reduces_confidence() {
let (client, device) = cpu_setup();
let logits = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32, 1.0, 0.1, 0.1, 2.0, 1.0], &[2, 3], &device),
false,
);
let targets = Tensor::<CpuRuntime>::from_slice(&[0i64, 1], &[2], &device);
let loss_no_smooth = cross_entropy_loss(&client, &logits, &targets).unwrap();
let loss_smooth = cross_entropy_loss_smooth(&client, &logits, &targets, 0.1).unwrap();
let v0: Vec<f32> = loss_no_smooth.tensor().to_vec();
let vs: Vec<f32> = loss_smooth.tensor().to_vec();
assert!(
vs[0] > v0[0],
"smoothed loss {} should be > unsmoothed {}",
vs[0],
v0[0]
);
}
#[test]
fn test_label_smoothing_zero_is_ce() {
let (client, device) = cpu_setup();
let logits = Var::new(
Tensor::<CpuRuntime>::from_slice(&[2.0f32, 1.0, 0.1, 0.1, 2.0, 1.0], &[2, 3], &device),
false,
);
let targets = Tensor::<CpuRuntime>::from_slice(&[0i64, 1], &[2], &device);
let loss_ce = cross_entropy_loss(&client, &logits, &targets).unwrap();
let loss_smooth = cross_entropy_loss_smooth(&client, &logits, &targets, 0.0).unwrap();
let v0: Vec<f32> = loss_ce.tensor().to_vec();
let vs: Vec<f32> = loss_smooth.tensor().to_vec();
assert!(
(v0[0] - vs[0]).abs() < 1e-6,
"smoothing=0 should match CE: {} vs {}",
v0[0],
vs[0]
);
}
}