use syntaxdot_tch_ext::tensor::SumDim;
use tch::{Kind, Reduction, Tensor};
use crate::TransformerError;
trait Reduce {
type Error;
fn reduce(&self, t: &Tensor) -> Result<Tensor, Self::Error>;
}
impl Reduce for Reduction {
type Error = TransformerError;
fn reduce(&self, t: &Tensor) -> Result<Tensor, Self::Error> {
match self {
Reduction::None => Ok(t.shallow_clone()),
Reduction::Mean => Ok(t.f_mean(t.kind())?),
Reduction::Sum => Ok(t.f_sum(t.kind())?),
Reduction::Other(_) => unimplemented!(),
}
}
}
pub struct CrossEntropyLoss {
ignore_index: i64,
label_smoothing: Option<f64>,
reduction: Reduction,
}
impl CrossEntropyLoss {
pub fn new(ignore_index: i64, label_smoothing: Option<f64>, reduction: Reduction) -> Self {
CrossEntropyLoss {
ignore_index,
label_smoothing,
reduction,
}
}
pub fn forward(
&self,
logits: &Tensor,
targets: &Tensor,
target_mask: Option<&Tensor>,
) -> Result<Tensor, TransformerError> {
let (_, n_classes) = logits.size2()?;
let log_probs = logits.f_log_softmax(-1, logits.kind())?;
match self.label_smoothing {
Some(label_smoothing) => {
let token_mask = targets.f_ne(self.ignore_index)?;
let targets_non_negative =
targets.f_where_scalarother(&targets.f_ne(self.ignore_index)?, 0)?;
let smoothed_targets = tch::no_grad(|| match target_mask {
None => {
Tensor::f_full_like(&log_probs, label_smoothing / (n_classes - 1) as f64)?
.f_scatter_value(
1,
&targets_non_negative.f_unsqueeze(1)?,
1. - label_smoothing,
)
}
Some(target_mask) => {
let batch_probs = label_smoothing
/ target_mask
.f_sum_dim(-1, false, Kind::Float)?
.f_sub_scalar(1)?;
Tensor::f_zeros_like(&log_probs)?
.f_add_(&batch_probs.f_unsqueeze(-1)?)?
.f_mul(&target_mask.to_kind(Kind::Float))?
.f_scatter_value(
1,
&targets_non_negative.f_unsqueeze(1)?,
1. - label_smoothing,
)
}
})?;
let losses = (smoothed_targets.f_neg()?.f_mul(&log_probs)?).f_sum_dim(
-1,
false,
log_probs.kind(),
)?;
Ok(self.reduction.reduce(&losses.masked_select(&token_mask))?)
}
None => Ok(log_probs.f_nll_loss::<&Tensor>(
targets,
None,
self.reduction,
self.ignore_index,
)?),
}
}
}
pub enum MSELossNormalization {
Mean,
SquaredL2Norm,
}
pub struct MSELoss {
normalization: MSELossNormalization,
}
impl MSELoss {
pub fn new(normalization: MSELossNormalization) -> Self {
MSELoss { normalization }
}
pub fn forward(&self, prediction: &Tensor, target: &Tensor) -> Result<Tensor, tch::TchError> {
let reduction = match self.normalization {
MSELossNormalization::Mean => Reduction::Mean,
MSELossNormalization::SquaredL2Norm => Reduction::None,
};
let loss = prediction.f_mse_loss(target, reduction);
match self.normalization {
MSELossNormalization::Mean => loss,
MSELossNormalization::SquaredL2Norm => {
let norm = target.f_frobenius_norm(&[1], true)?.f_square()?;
let (batch_size, _) = target.size2()?;
loss?
.f_div(&norm)?
.f_sum(Kind::Float)?
.f_div_scalar(batch_size)
}
}
}
}
#[cfg(test)]
mod tests {
use std::convert::TryInto;
use approx::assert_abs_diff_eq;
use ndarray::{array, ArrayD};
use tch::{Reduction, Tensor};
use crate::loss::CrossEntropyLoss;
use super::MSELoss;
#[test]
fn cross_entropy_loss_without_label_smoothing() {
let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
let targets = Tensor::of_slice(&[2i64]).view([1]);
let cross_entropy_loss = CrossEntropyLoss::new(-1, None, Reduction::None);
let loss: ArrayD<f32> = (&cross_entropy_loss.forward(&logits, &targets, None).unwrap())
.try_into()
.unwrap();
assert_abs_diff_eq!(loss, array![0.432653].into_dyn(), epsilon = 1e-6);
}
#[test]
fn cross_entropy_with_label_smoothing() {
let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
let targets = Tensor::of_slice(&[2i64]).view([1]);
let cross_entropy_loss = CrossEntropyLoss::new(-1, Some(0.1), Reduction::None);
let loss: ArrayD<f32> = (&cross_entropy_loss.forward(&logits, &targets, None).unwrap())
.try_into()
.unwrap();
assert_abs_diff_eq!(loss, array![0.632653].into_dyn(), epsilon = 1e-6);
}
#[test]
fn cross_entropy_with_label_smoothing_and_mask() {
let logits = Tensor::of_slice(&[-1., -1., 1., -1., -1.]).view([1, 5]);
let target_mask = Tensor::of_slice(&[true, false, true, false, true]).view([1, 5]);
let targets = Tensor::of_slice(&[2i64]).view([1]);
let cross_entropy_loss = CrossEntropyLoss::new(-1, Some(0.1), Reduction::None);
let loss: ArrayD<f32> = (&cross_entropy_loss
.forward(&logits, &targets, Some(&target_mask))
.unwrap())
.try_into()
.unwrap();
assert_abs_diff_eq!(loss, array![0.632653].into_dyn(), epsilon = 1e-6);
}
#[test]
fn mse_loss_with_averaging() {
let prediction = Tensor::of_slice(&[-0.5, -0.5, 0.0, 1.0]).view([1, 4]);
let target = Tensor::of_slice(&[-1.0, 0.0, 1.0, 1.0]).view([1, 4]);
let mse_loss = MSELoss::new(super::MSELossNormalization::Mean);
let loss = &mse_loss.forward(&prediction, &target).unwrap();
assert_abs_diff_eq!(f32::from(loss), 0.375f32, epsilon = 1e-6);
}
#[test]
fn mse_loss_with_squared_l2_norm() {
let prediction = Tensor::of_slice(&[-0.5, -0.5, 0.0, 1.0]).view([2, 2]);
let target = Tensor::of_slice(&[-1.0, 0.0, 1.0, 1.0]).view([2, 2]);
let mse_loss = MSELoss::new(super::MSELossNormalization::SquaredL2Norm);
let loss = mse_loss.forward(&prediction, &target).unwrap();
assert_abs_diff_eq!(f32::from(loss), 0.5, epsilon = 1e-6);
}
}