use crate::loss::common::ReductionType;
use crate::utils::{function_context, safe_log_prob, validate_elementwise_shapes, validate_range};
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn cross_entropy(
input: &Tensor,
target: &Tensor,
weight: Option<&Tensor>,
reduction: &str,
ignore_index: Option<i64>,
label_smoothing: f64,
) -> TorshResult<Tensor> {
if label_smoothing > 0.0 {
return cross_entropy_with_label_smoothing(
input,
target,
label_smoothing,
weight,
reduction,
ignore_index,
);
}
let dim = (input.shape().ndim() - 1) as i32;
let log_probs = input.log_softmax(dim)?;
nll_loss(&log_probs, target, weight, reduction, ignore_index)
}
pub fn nll_loss(
input: &Tensor,
target: &Tensor,
weight: Option<&Tensor>,
reduction: &str,
ignore_index: Option<i64>,
) -> TorshResult<Tensor> {
if input.ndim() != 2 || target.ndim() != 1 {
return Err(TorshError::UnsupportedOperation {
op: "nll_loss with >2D input".to_string(),
dtype: "tensor".to_string(),
});
}
if weight.is_some() || ignore_index.is_some() {
return Err(TorshError::UnsupportedOperation {
op: "nll_loss with weight or ignore_index".to_string(),
dtype: "tensor".to_string(),
});
}
let batch_size = input.shape().dims()[0];
let mut losses = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let target_class = target.get(&[i])? as usize;
let loss_val = -input.get(&[i, target_class])?;
losses.push(loss_val);
}
let loss_tensor = Tensor::from_vec(losses, &[batch_size])?;
match reduction {
"none" => Ok(loss_tensor),
"mean" => loss_tensor.mean(None, false),
"sum" => loss_tensor.sum(),
_ => Err(TorshError::InvalidArgument(format!(
"Unknown reduction: {}",
reduction
))),
}
}
pub fn binary_cross_entropy(
input: &Tensor,
target: &Tensor,
weight: Option<&Tensor>,
reduction: ReductionType,
) -> TorshResult<Tensor> {
validate_elementwise_shapes(input, target)?;
let log_input = safe_log_prob(input, None)?;
let one_minus_input = input.neg()?.add_scalar(1.0)?;
let log_one_minus_input = safe_log_prob(&one_minus_input, None)?;
let positive_loss = target.mul(&log_input)?;
let one_minus_target = target.neg()?.add_scalar(1.0)?;
let negative_loss = one_minus_target.mul(&log_one_minus_input)?;
let mut loss = positive_loss.add(&negative_loss)?.neg()?;
if let Some(w) = weight {
validate_elementwise_shapes(&loss, w)?;
loss = loss.mul(w)?;
}
reduction.apply(loss)
}
pub fn binary_cross_entropy_with_logits(
input: &Tensor,
target: &Tensor,
weight: Option<&Tensor>,
reduction: ReductionType,
pos_weight: Option<&Tensor>,
) -> TorshResult<Tensor> {
validate_elementwise_shapes(input, target)?;
let zero = Tensor::zeros_like(input)?;
let max_input = input.maximum(&zero)?;
let input_target = input.mul(target)?;
let abs_input = input.abs()?;
let log_term = abs_input.neg()?.exp()?.add_scalar(1.0)?.log()?;
let mut loss = max_input.sub(&input_target)?.add(&log_term)?;
if let Some(pos_w) = pos_weight {
let pos_weight_term = target.mul(pos_w)?.add_scalar(1.0)?.sub(target)?;
loss = loss.mul(&pos_weight_term)?;
}
if let Some(w) = weight {
validate_elementwise_shapes(&loss, w)?;
loss = loss.mul(w)?;
}
reduction.apply(loss)
}
pub fn multi_margin_loss(
input: &Tensor,
target: &Tensor,
p: i64,
margin: f32,
weight: Option<&Tensor>,
reduction: ReductionType,
) -> TorshResult<Tensor> {
let context = function_context("multi_margin_loss");
if input.ndim() != 2 || target.ndim() != 1 {
return Err(TorshError::config_error_with_context(
"multi_margin_loss expects 2D input and 1D target",
&context,
));
}
if p != 1 && p != 2 {
return Err(TorshError::config_error_with_context(
"multi_margin_loss only supports p=1 or p=2",
&context,
));
}
let batch_size = input.shape().dims()[0];
let num_classes = input.shape().dims()[1];
let mut losses = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let target_class = target.get(&[i])? as usize;
let target_score = input.get(&[i, target_class])?;
let mut sample_loss = 0.0;
for j in 0..num_classes {
if j != target_class {
let score_j = input.get(&[i, j])?;
let margin_violation = margin - target_score + score_j;
if margin_violation > 0.0 {
sample_loss += if p == 1 {
margin_violation
} else {
margin_violation.powi(2)
};
}
}
}
if let Some(w) = weight {
let class_weight = w.get(&[target_class])?;
sample_loss *= class_weight;
}
losses.push(sample_loss / (num_classes - 1) as f32);
}
let loss_tensor = Tensor::from_vec(losses, &[batch_size])?;
reduction.apply(loss_tensor)
}
pub fn focal_loss(
input: &Tensor,
target: &Tensor,
alpha: f32,
gamma: f32,
reduction: ReductionType,
) -> TorshResult<Tensor> {
validate_range(alpha, 0.0, 1.0, "alpha", "focal_loss")?;
validate_range(gamma, 0.0, 5.0, "gamma", "focal_loss")?;
let dim = (input.shape().ndim() - 1) as i32;
let probs = input.softmax(dim)?;
let log_probs = probs.log()?;
let batch_size = target.shape().dims()[0];
let mut focal_losses = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let target_class = target.get(&[i])? as usize;
let p_t = probs.get(&[i, target_class])?;
let log_p_t = log_probs.get(&[i, target_class])?;
let focal_weight = alpha * (1.0 - p_t).powf(gamma);
let focal_loss = -focal_weight * log_p_t;
focal_losses.push(focal_loss);
}
let loss_tensor = Tensor::from_vec(focal_losses, &[batch_size])?;
reduction.apply(loss_tensor)
}
pub fn cross_entropy_with_label_smoothing(
input: &Tensor,
target: &Tensor,
label_smoothing: f64,
weight: Option<&Tensor>,
reduction: &str,
_ignore_index: Option<i64>,
) -> TorshResult<Tensor> {
if label_smoothing < 0.0 || label_smoothing >= 1.0 {
return Err(TorshError::InvalidArgument(
"label_smoothing must be in [0.0, 1.0)".to_string(),
));
}
let num_classes = input.shape().dims()[input.shape().ndim() - 1];
let smoothing_value = label_smoothing as f32 / num_classes as f32;
let confidence = 1.0 - label_smoothing as f32;
let dim = (input.shape().ndim() - 1) as i32;
let log_probs = input.log_softmax(dim)?;
let batch_size = target.shape().dims()[0];
let mut smooth_targets = vec![smoothing_value; batch_size * num_classes];
for i in 0..batch_size {
let target_class = target.get(&[i])? as usize;
smooth_targets[i * num_classes + target_class] = confidence + smoothing_value;
}
let smooth_target_tensor = Tensor::from_vec(smooth_targets, &[batch_size, num_classes])?;
let loss = log_probs
.mul(&smooth_target_tensor)?
.neg()?
.sum_dim(&[1], false)?;
let loss = if let Some(w) = weight {
let mut weighted_losses = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let target_class = target.get(&[i])? as usize;
let class_weight = w.get(&[target_class])?;
let sample_loss = loss.get(&[i])?;
weighted_losses.push(sample_loss * class_weight);
}
Tensor::from_vec(weighted_losses, &[batch_size])?
} else {
loss.squeeze(1)?
};
match reduction {
"none" => Ok(loss),
"mean" => loss.mean(None, false),
"sum" => loss.sum(),
_ => Err(TorshError::InvalidArgument(format!(
"Unknown reduction: {}",
reduction
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_binary_cross_entropy_basic() -> TorshResult<()> {
let input = from_vec(vec![0.8, 0.2, 0.9], &[3], DeviceType::Cpu)?;
let target = from_vec(vec![1.0, 0.0, 1.0], &[3], DeviceType::Cpu)?;
let loss = binary_cross_entropy(&input, &target, None, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!(loss_value > 0.0);
Ok(())
}
#[test]
fn test_focal_loss_basic() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 0.5, 3.0, 1.5, 0.8], &[2, 3], DeviceType::Cpu)?;
let target = from_vec(vec![1.0, 2.0], &[2], DeviceType::Cpu)?;
let loss = focal_loss(&input, &target, 0.25, 2.0, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!(loss_value > 0.0);
Ok(())
}
#[test]
fn test_cross_entropy_simple() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 0.5], &[1, 3], DeviceType::Cpu)?;
let target = from_vec(vec![1.0], &[1], DeviceType::Cpu)?;
let loss = cross_entropy(&input, &target, None, "mean", None, 0.0)?;
let loss_value = loss.item()?;
assert!(loss_value > 0.0);
Ok(())
}
}