use crate::loss::common::ReductionType;
use crate::utils::{function_context, validate_positive};
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn ctc_loss(
log_probs: &Tensor,
targets: &Tensor,
input_lengths: &Tensor,
target_lengths: &Tensor,
_blank: i64,
reduction: ReductionType,
zero_infinity: bool,
) -> TorshResult<Tensor> {
let context = function_context("ctc_loss");
if log_probs.ndim() != 3 {
return Err(TorshError::config_error_with_context(
"log_probs must be 3D tensor with shape (T, N, C)",
&context,
));
}
if targets.ndim() != 2 {
return Err(TorshError::config_error_with_context(
"targets must be 2D tensor with shape (N, S)",
&context,
));
}
let shape = log_probs.shape();
let dims = shape.dims();
let seq_len = dims[0];
let batch_size = dims[1];
let num_classes = dims[2];
let mut total_loss = 0.0;
for batch_idx in 0..batch_size {
let input_len = input_lengths.get(&[batch_idx])? as usize;
let target_len = target_lengths.get(&[batch_idx])? as usize;
if input_len == 0 || target_len == 0 {
continue;
}
let mut batch_loss = 0.0;
for t in 0..input_len.min(seq_len).min(target_len) {
let target_class = targets.get(&[batch_idx, t])? as usize;
if target_class < num_classes {
let log_prob = log_probs.get(&[t, batch_idx, target_class])?;
batch_loss -= log_prob;
}
}
if zero_infinity && batch_loss.is_infinite() {
batch_loss = 0.0;
}
total_loss += batch_loss;
}
let loss_tensor = Tensor::from_vec(vec![total_loss], &[1])?;
match reduction {
ReductionType::None => Ok(loss_tensor),
ReductionType::Mean => Ok(Tensor::from_vec(
vec![total_loss / batch_size as f32],
&[1],
)?),
ReductionType::Sum => Ok(loss_tensor),
}
}
pub fn seq2seq_loss_with_attention(
predictions: &Tensor,
targets: &Tensor,
attention_weights: Option<&Tensor>,
attention_reg: f32,
reduction: ReductionType,
) -> TorshResult<Tensor> {
let ce_loss = compute_sequence_cross_entropy(predictions, targets)?;
let total_loss = if let Some(attn_weights) = attention_weights {
let attn_reg_loss = attention_regularization_loss(attn_weights)?;
ce_loss.add(&attn_reg_loss.mul_scalar(attention_reg)?)?
} else {
ce_loss
};
reduction.apply(total_loss)
}
pub fn temporal_consistency_loss(
predictions: &Tensor,
smoothness_weight: f32,
reduction: ReductionType,
) -> TorshResult<Tensor> {
if predictions.ndim() < 3 {
return Err(TorshError::InvalidArgument(
"predictions must be at least 3D (N, T, ...)".to_string(),
));
}
let shape = predictions.shape();
let dims = shape.dims();
let seq_len = dims[1];
if seq_len < 2 {
return Err(TorshError::InvalidArgument(
"Sequence length must be at least 2".to_string(),
));
}
let pred_t = predictions.slice(1, 0, seq_len - 1)?; let pred_t_plus_1 = predictions.slice(1, 1, seq_len)?;
let pred_t_tensor = pred_t.to_tensor()?;
let pred_t_plus_1_tensor = pred_t_plus_1.to_tensor()?;
let diff = pred_t_tensor.sub(&pred_t_plus_1_tensor)?;
let smoothness_loss = diff.pow_scalar(2.0)?.mean(None, false)?;
let total_loss = smoothness_loss.mul_scalar(smoothness_weight)?;
reduction.apply(total_loss)
}
pub fn wasserstein_loss(
real_scores: &Tensor,
fake_scores: &Tensor,
reduction: ReductionType,
) -> TorshResult<Tensor> {
let real_mean = real_scores.mean(None, false)?;
let fake_mean = fake_scores.mean(None, false)?;
let wasserstein_distance = real_mean.sub(&fake_mean)?;
let loss = wasserstein_distance.neg()?;
reduction.apply(loss)
}
pub fn gradient_penalty_loss(
gradients: &Tensor,
penalty_weight: f32,
reduction: ReductionType,
) -> TorshResult<Tensor> {
validate_positive(penalty_weight, "penalty_weight", "gradient_penalty_loss")?;
let grad_norm = gradients.norm()?;
let penalty = grad_norm.sub_scalar(1.0)?.pow_scalar(2.0)?;
let weighted_penalty = penalty.mul_scalar(penalty_weight)?;
reduction.apply(weighted_penalty)
}
fn compute_sequence_cross_entropy(predictions: &Tensor, targets: &Tensor) -> TorshResult<Tensor> {
let dim = (predictions.shape().ndim() - 1) as i32;
let log_probs = predictions.log_softmax(dim)?;
let shape = targets.shape();
let dims = shape.dims();
let batch_size = dims[0];
let seq_len = dims[1];
let mut total_loss = 0.0;
for i in 0..batch_size {
for j in 0..seq_len {
let target_class = targets.get(&[i, j])? as usize;
let log_prob = log_probs.get(&[i, j, target_class])?;
total_loss -= log_prob;
}
}
let loss_value = total_loss / (batch_size * seq_len) as f32;
Tensor::from_vec(vec![loss_value], &[1])
}
fn attention_regularization_loss(attention_weights: &Tensor) -> TorshResult<Tensor> {
attention_weights.pow_scalar(2.0)?.mean(None, false)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::device::DeviceType;
use torsh_tensor::creation::{from_vec, zeros};
#[test]
fn test_ctc_loss_basic() -> TorshResult<()> {
let log_probs = zeros(&[10, 2, 5])?; let targets = from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], DeviceType::Cpu)?;
let input_lengths = from_vec(vec![10.0, 10.0], &[2], DeviceType::Cpu)?;
let target_lengths = from_vec(vec![2.0, 2.0], &[2], DeviceType::Cpu)?;
let loss = ctc_loss(
&log_probs,
&targets,
&input_lengths,
&target_lengths,
0,
ReductionType::Mean,
false,
)?;
let loss_value = loss.item()?;
assert!(loss_value >= 0.0);
Ok(())
}
#[test]
fn test_temporal_consistency_loss_basic() -> TorshResult<()> {
let predictions = from_vec(
vec![1.0, 2.0, 3.0, 1.1, 2.1, 3.1, 1.2, 2.2, 3.2], &[1, 3, 3],
DeviceType::Cpu,
)?;
let loss = temporal_consistency_loss(&predictions, 1.0, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!(loss_value >= 0.0 && loss_value < 1.0);
Ok(())
}
#[test]
fn test_wasserstein_loss_basic() -> TorshResult<()> {
let real_scores = from_vec(vec![1.0, 2.0, 1.5], &[3], DeviceType::Cpu)?;
let fake_scores = from_vec(vec![0.5, 0.8, 0.6], &[3], DeviceType::Cpu)?;
let loss = wasserstein_loss(&real_scores, &fake_scores, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!(loss_value.is_finite());
Ok(())
}
#[test]
fn test_gradient_penalty_loss_basic() -> TorshResult<()> {
let gradients = from_vec(vec![1.5, 0.8, 1.2], &[3], DeviceType::Cpu)?;
let loss = gradient_penalty_loss(&gradients, 10.0, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!(loss_value >= 0.0);
Ok(())
}
#[test]
fn test_seq2seq_loss_basic() -> TorshResult<()> {
let predictions = zeros(&[2, 5, 10])?; let targets = from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 0.0, 1.0, 2.0, 3.0, 4.0],
&[2, 5],
DeviceType::Cpu,
)?;
let loss =
seq2seq_loss_with_attention(&predictions, &targets, None, 0.0, ReductionType::Mean)?;
let loss_value = loss.item()?;
assert!(loss_value > 0.0);
Ok(())
}
}