use crate::loss::common::ReductionType;
use crate::utils::{function_context, safe_for_log, safe_log, validate_elementwise_shapes};
use torsh_core::Result as TorshResult;
use torsh_tensor::Tensor;
pub fn kl_div(
input: &Tensor,
target: &Tensor,
reduction: ReductionType,
log_target: bool,
) -> TorshResult<Tensor> {
let _context = function_context("kl_div");
validate_elementwise_shapes(input, target)?;
let kl = if log_target {
target.mul(&target.sub(input)?)?
} else {
let log_target = safe_log(target, None, None)?;
let log_ratio = log_target.sub(input)?;
target.mul(&log_ratio)?
};
reduction.apply(kl)
}
pub fn js_divergence(
input: &Tensor,
target: &Tensor,
reduction: ReductionType,
) -> TorshResult<Tensor> {
validate_elementwise_shapes(input, target)?;
let mixture = input.add(target)?.mul_scalar(0.5)?;
let input_safe = safe_for_log(input, None, None)?;
let target_safe = safe_for_log(target, None, None)?;
let mixture_safe = safe_for_log(&mixture, None, None)?;
let log_input = input_safe.log()?;
let log_target = target_safe.log()?;
let log_mixture = mixture_safe.log()?;
let kl_input_mixture = input.mul(&log_input.sub(&log_mixture)?)?;
let kl_target_mixture = target.mul(&log_target.sub(&log_mixture)?)?;
let js = kl_input_mixture.add(&kl_target_mixture)?.mul_scalar(0.5)?;
reduction.apply(js)
}
pub fn cross_entropy_continuous(
input: &Tensor,
target: &Tensor,
reduction: ReductionType,
) -> TorshResult<Tensor> {
validate_elementwise_shapes(input, target)?;
let log_input = safe_log(input, None, None)?;
let cross_entropy = target.mul(&log_input)?.neg()?;
reduction.apply(cross_entropy)
}
pub fn mutual_information_loss(
joint_samples: &Tensor,
marginal_samples: &Tensor,
reduction: ReductionType,
) -> TorshResult<Tensor> {
validate_elementwise_shapes(joint_samples, marginal_samples)?;
let joint_mean = joint_samples.mean(None, false)?;
let exp_marginal = marginal_samples.exp()?;
let marginal_mean_exp = exp_marginal.mean(None, false)?;
let log_marginal_mean_exp = marginal_mean_exp.log()?;
let mi_estimate = joint_mean.sub(&log_marginal_mean_exp)?;
let loss = mi_estimate.neg()?;
reduction.apply(loss)
}
pub fn entropy_loss(input: &Tensor, reduction: ReductionType) -> TorshResult<Tensor> {
let log_input = safe_log(input, None, None)?;
let entropy = input.mul(&log_input)?.neg()?;
reduction.apply(entropy)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::from_vec;
#[test]
fn test_kl_div_identical_distributions() -> TorshResult<()> {
let probs = from_vec(vec![0.5, 0.3, 0.2], &[3], DeviceType::Cpu)?;
let log_probs = probs.log()?;
let kl = kl_div(&log_probs, &probs, ReductionType::Sum, false)?;
let kl_value = kl.item()?;
assert!(kl_value.abs() < 1e-6);
Ok(())
}
#[test]
fn test_kl_div_different_distributions() -> TorshResult<()> {
let log_input = from_vec(vec![-1.0, -2.0, -1.5], &[3], DeviceType::Cpu)?; let target = from_vec(vec![0.8, 0.1, 0.1], &[3], DeviceType::Cpu)?;
let kl = kl_div(&log_input, &target, ReductionType::Sum, false)?;
let kl_value = kl.item()?;
assert!(kl_value > 0.0);
Ok(())
}
#[test]
fn test_js_divergence_identical_distributions() -> TorshResult<()> {
let p = from_vec(vec![0.5, 0.3, 0.2], &[3], DeviceType::Cpu)?;
let q = p.clone();
let js = js_divergence(&p, &q, ReductionType::Sum)?;
let js_value = js.item()?;
assert!(js_value.abs() < 1e-6);
Ok(())
}
#[test]
fn test_js_divergence_properties() -> TorshResult<()> {
let p = from_vec(vec![0.7, 0.2, 0.1], &[3], DeviceType::Cpu)?;
let q = from_vec(vec![0.1, 0.2, 0.7], &[3], DeviceType::Cpu)?;
let js = js_divergence(&p, &q, ReductionType::Sum)?;
let js_value = js.item()?;
assert!(js_value >= 0.0);
assert!(js_value <= 0.7); Ok(())
}
#[test]
fn test_entropy_loss_uniform_distribution() -> TorshResult<()> {
let uniform = from_vec(vec![0.25, 0.25, 0.25, 0.25], &[4], DeviceType::Cpu)?;
let entropy = entropy_loss(&uniform, ReductionType::Sum)?;
let entropy_value = entropy.item()?;
let expected_entropy = 4.0f32.ln();
assert!((entropy_value - expected_entropy).abs() < 1e-3);
Ok(())
}
#[test]
fn test_cross_entropy_continuous_basic() -> TorshResult<()> {
let p = from_vec(vec![0.5, 0.3, 0.2], &[3], DeviceType::Cpu)?; let q = from_vec(vec![0.4, 0.4, 0.2], &[3], DeviceType::Cpu)?;
let cross_entropy = cross_entropy_continuous(&q, &p, ReductionType::Sum)?;
let ce_value = cross_entropy.item()?;
assert!(ce_value > 0.0);
Ok(())
}
}