pub mod classification;
pub mod common;
pub mod information;
pub mod regression;
pub mod similarity;
pub mod specialized;
pub use common::ReductionType;
pub use regression::{gaussian_nll_loss, l1_loss, mse_loss, poisson_nll_loss, smooth_l1_loss};
pub use classification::{
binary_cross_entropy, binary_cross_entropy_with_logits, cross_entropy,
cross_entropy_with_label_smoothing, focal_loss, multi_margin_loss, nll_loss,
};
pub use similarity::{
contrastive_loss, cosine_embedding_loss, hinge_embedding_loss, margin_ranking_loss,
triplet_margin_loss, triplet_margin_with_distance_loss,
};
pub use information::{
cross_entropy_continuous, entropy_loss, js_divergence, kl_div, mutual_information_loss,
};
pub use specialized::{
ctc_loss, gradient_penalty_loss, seq2seq_loss_with_attention, temporal_consistency_loss,
wasserstein_loss,
};
#[cfg(test)]
mod integration_tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_loss_functions_integration() -> torsh_core::Result<()> {
let device = DeviceType::Cpu;
let input = from_vec(vec![1.0, 2.0, 3.0], &[3], device)?;
let target = from_vec(vec![1.5, 2.5, 2.5], &[3], device)?;
let _mse = mse_loss(&input, &target, ReductionType::Mean)?;
let _l1 = l1_loss(&input, &target, ReductionType::Mean)?;
let _smooth_l1 = smooth_l1_loss(&input, &target, ReductionType::Mean, 1.0)?;
let logits = from_vec(vec![1.0, 2.0, 0.5], &[1, 3], device)?;
let class_target = from_vec(vec![0.0], &[1], device)?;
let _ce = cross_entropy(&logits, &class_target, None, "mean", None, 0.0)?;
let _focal = focal_loss(&logits, &class_target, 0.25, 2.0, ReductionType::Mean)?;
let emb1 = from_vec(vec![1.0, 2.0], &[1, 2], device)?;
let emb2 = from_vec(vec![1.1, 2.1], &[1, 2], device)?;
let sim_target = from_vec(vec![1.0], &[1], device)?;
let _cosine = cosine_embedding_loss(&emb1, &emb2, &sim_target, 0.0, ReductionType::Mean)?;
let p = from_vec(vec![0.5, 0.3, 0.2], &[3], device)?;
let q = from_vec(vec![0.4, 0.4, 0.2], &[3], device)?;
let _entropy = entropy_loss(&p, ReductionType::Sum)?;
let _js = js_divergence(&p, &q, ReductionType::Sum)?;
Ok(())
}
#[test]
fn test_loss_functions_numerical_stability() -> torsh_core::Result<()> {
let device = DeviceType::Cpu;
let small_input = from_vec(vec![1e-6, 1e-7, 1e-8], &[3], device)?;
let small_target = from_vec(vec![1e-6, 1e-7, 1e-8], &[3], device)?;
let mse = mse_loss(&small_input, &small_target, ReductionType::Mean)?;
let mse_val = mse.item()?;
assert!(mse_val.is_finite() && !mse_val.is_nan());
let large_input = from_vec(vec![1e3, 1e4, 1e5], &[3], device)?;
let large_target = from_vec(vec![1e3, 1e4, 1e5], &[3], device)?;
let l1 = l1_loss(&large_input, &large_target, ReductionType::Mean)?;
let l1_val = l1.item()?;
assert!(l1_val.is_finite() && !l1_val.is_nan());
Ok(())
}
#[test]
fn test_reduction_consistency_across_losses() -> torsh_core::Result<()> {
let device = DeviceType::Cpu;
let input = from_vec(vec![1.0, 2.0, 3.0, 4.0], &[4], device)?;
let target = from_vec(vec![1.1, 2.1, 3.1, 4.1], &[4], device)?;
let mse_none = mse_loss(&input, &target, ReductionType::None)?;
let mse_sum = mse_loss(&input, &target, ReductionType::Sum)?;
let mse_mean = mse_loss(&input, &target, ReductionType::Mean)?;
assert_eq!(mse_none.shape().dims(), &[4]);
assert_eq!(mse_sum.shape().dims(), &[] as &[usize]);
assert_eq!(mse_mean.shape().dims(), &[] as &[usize]);
let manual_sum = mse_none.sum()?;
let manual_mean = manual_sum.div_scalar(4.0)?;
let sum_val = mse_sum.item()?;
let mean_val = mse_mean.item()?;
let manual_sum_val = manual_sum.item()?;
let manual_mean_val = manual_mean.item()?;
assert!((sum_val - manual_sum_val).abs() < 1e-6);
assert!((mean_val - manual_mean_val).abs() < 1e-6);
Ok(())
}
}