use candle::{Result, Tensor};
pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let b_sz = match target.dims() {
&[b_sz] => b_sz,
dims => candle::bail!("the target tensor should have a single dimension ({dims:?})"),
};
match inp.dims() {
&[inp_b_sz, _] => {
if inp_b_sz != b_sz {
candle::bail!("batch size mismatch between inp ({inp_b_sz}) and target ({b_sz})")
}
}
dims => candle::bail!("the target tensor should have two dimensions ({dims:?})"),
}
inp.gather(&target.unsqueeze(1)?, 1)?
.sum_all()?
.affine(-1f64 / b_sz as f64, 0.)
}
pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
if inp.rank() != 2 {
candle::bail!("cross_entropy expects an input tensor of rank 2")
}
let inp = crate::ops::log_softmax(inp, 1)?;
nll(&inp, target)
}
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::sigmoid(inp)?;
let left_side = target * inp.log()?;
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;
Ok(loss)
}
pub fn huber(inp: &Tensor, target: &Tensor, delta: f64) -> Result<Tensor> {
if inp.dims() != target.dims() {
candle::bail!(
"input and target must have the same shape, got inp: {:?}, target: {:?}",
inp.dims(),
target.dims()
);
}
let diff = (inp - target)?;
let abs_diff = diff.abs()?;
let mask = abs_diff.le(delta)?;
let squared_loss = ((&diff * &diff)? * 0.5)?;
let linear_loss = ((abs_diff * delta)? - 0.5 * delta.powi(2))?;
let loss = mask.where_cond(&squared_loss, &linear_loss)?;
loss.mean_all()
}