use crate::loss::common::ReductionType;
use crate::utils::{
function_context, validate_elementwise_shapes, validate_non_empty, validate_positive,
};
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn mse_loss(input: &Tensor, target: &Tensor, reduction: ReductionType) -> TorshResult<Tensor> {
let context = function_context("mse_loss");
validate_non_empty(input, &context)?;
validate_non_empty(target, &context)?;
validate_elementwise_shapes(input, target)?;
let diff = input.sub(target).map_err(|e| {
TorshError::config_error_with_context(
&format!("Failed to compute input - target: {}", e),
&context,
)
})?;
let squared = diff.pow_scalar(2.0).map_err(|e| {
TorshError::config_error_with_context(
&format!("Failed to square differences: {}", e),
&context,
)
})?;
reduction.apply(squared).map_err(|e| {
TorshError::config_error_with_context(
&format!("Failed to apply reduction: {}", e),
&context,
)
})
}
pub fn l1_loss(input: &Tensor, target: &Tensor, reduction: ReductionType) -> TorshResult<Tensor> {
validate_elementwise_shapes(input, target)?;
let diff = input.sub(target)?;
let abs_diff = diff.abs()?;
reduction.apply(abs_diff)
}
pub fn smooth_l1_loss(
input: &Tensor,
target: &Tensor,
reduction: ReductionType,
beta: f32,
) -> TorshResult<Tensor> {
validate_elementwise_shapes(input, target)?;
validate_positive(beta, "beta", "smooth_l1_loss")?;
let diff = input.sub(target)?;
let abs_diff = diff.abs()?;
let mask = abs_diff.lt_scalar(beta)?;
let l2_component = diff.pow_scalar(2.0)?.div_scalar(2.0 * beta)?;
let l1_component = abs_diff.sub_scalar(0.5 * beta)?;
let smooth_l1 = l2_component.where_tensor(&mask, &l1_component)?;
reduction.apply(smooth_l1)
}
pub fn poisson_nll_loss(
log_input: &Tensor,
target: &Tensor,
log_input_is_log: bool,
full: bool,
_size_average: Option<bool>,
eps: f32,
reduction: ReductionType,
) -> TorshResult<Tensor> {
validate_elementwise_shapes(log_input, target)?;
let input = if log_input_is_log {
log_input.clone()
} else {
log_input.log()?
};
let exp_input = input.exp()?;
let target_log_input = target.mul(&input)?;
let mut loss = exp_input.sub(&target_log_input)?;
if full {
let target_safe = target.add_scalar(eps)?;
let target_log_target = target.mul(&target_safe.log()?)?;
let stirling = target_log_target
.sub(target)?
.add_scalar(0.5 * (2.0 * std::f32::consts::PI).ln())?
.add(&target_safe.log()?.mul_scalar(0.5)?)?;
loss = loss.add(&stirling)?;
}
reduction.apply(loss)
}
pub fn gaussian_nll_loss(
input: &Tensor,
target: &Tensor,
var: &Tensor,
full: bool,
eps: f32,
reduction: ReductionType,
) -> TorshResult<Tensor> {
validate_elementwise_shapes(input, target)?;
validate_elementwise_shapes(input, var)?;
let var_safe = var.add_scalar(eps)?;
let diff = input.sub(target)?;
let diff_squared = diff.pow_scalar(2.0)?;
let normalized_diff = diff_squared.div(&var_safe)?;
let log_var = var_safe.log()?;
let mut loss = normalized_diff.add(&log_var)?.mul_scalar(0.5)?;
if full {
loss = loss.add_scalar(0.5 * (2.0 * std::f32::consts::PI).ln())?;
}
reduction.apply(loss)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_mse_loss_basic() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let target = from_vec(vec![1.5, 2.5, 2.5], &[3], DeviceType::Cpu)?;
let loss = mse_loss(&input, &target, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!((loss_value - 0.25).abs() < 1e-6);
Ok(())
}
#[test]
fn test_l1_loss_basic() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let target = from_vec(vec![1.5, 2.5, 2.5], &[3], DeviceType::Cpu)?;
let loss = l1_loss(&input, &target, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!((loss_value - 0.5).abs() < 1e-6);
Ok(())
}
#[test]
fn test_smooth_l1_loss_basic() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let target = from_vec(vec![1.1, 2.1, 4.0], &[3], DeviceType::Cpu)?;
let loss = smooth_l1_loss(&input, &target, ReductionType::Mean, 1.0)?;
let loss_value = loss.item()?;
assert!((loss_value - 0.17).abs() < 1e-2);
Ok(())
}
#[test]
fn test_mse_loss_zero_when_equal() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let target = input.clone();
let loss = mse_loss(&input, &target, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!(loss_value.abs() < 1e-6);
Ok(())
}
#[test]
fn test_l1_loss_zero_when_equal() -> TorshResult<()> {
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], DeviceType::Cpu)?;
let target = input.clone();
let loss = l1_loss(&input, &target, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!(loss_value.abs() < 1e-6);
Ok(())
}
}