use crate::autograd::grad_fn::CrossEntropyBackward;
use crate::autograd::{is_grad_enabled, with_graph, Tensor};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Reduction {
None,
#[default]
Mean,
Sum,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MSELoss {
reduction: Reduction,
}
impl MSELoss {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_reduction(reduction: Reduction) -> Self {
Self { reduction }
}
#[must_use]
pub fn forward(&self, pred: &Tensor, target: &Tensor) -> Tensor {
assert_eq!(
pred.shape(),
target.shape(),
"Prediction and target shapes must match"
);
let diff = pred.sub(target);
let squared = diff.pow(2.0);
match self.reduction {
Reduction::None => squared,
Reduction::Mean => squared.mean(),
Reduction::Sum => squared.sum(),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct L1Loss {
reduction: Reduction,
}
impl L1Loss {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_reduction(reduction: Reduction) -> Self {
Self { reduction }
}
#[must_use]
pub fn forward(&self, pred: &Tensor, target: &Tensor) -> Tensor {
assert_eq!(pred.shape(), target.shape());
let diff = pred.sub(target);
let abs_diff = abs(&diff);
match self.reduction {
Reduction::None => abs_diff,
Reduction::Mean => abs_diff.mean(),
Reduction::Sum => abs_diff.sum(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SmoothL1Loss {
beta: f32,
reduction: Reduction,
}
impl SmoothL1Loss {
#[must_use]
pub fn new() -> Self {
Self {
beta: 1.0,
reduction: Reduction::Mean,
}
}
#[must_use]
pub fn with_beta(beta: f32) -> Self {
Self {
beta,
reduction: Reduction::Mean,
}
}
#[must_use]
pub fn forward(&self, pred: &Tensor, target: &Tensor) -> Tensor {
assert_eq!(pred.shape(), target.shape());
let diff = pred.sub(target);
let loss_data: Vec<f32> = diff
.data()
.iter()
.map(|&x| {
let abs_x = x.abs();
if abs_x < self.beta {
0.5 * x * x / self.beta
} else {
abs_x - 0.5 * self.beta
}
})
.collect();
let loss = Tensor::new(&loss_data, pred.shape());
match self.reduction {
Reduction::None => loss,
Reduction::Mean => loss.mean(),
Reduction::Sum => loss.sum(),
}
}
}
impl Default for SmoothL1Loss {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct CrossEntropyLoss {
reduction: Reduction,
label_smoothing: f32,
}
impl CrossEntropyLoss {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_reduction(reduction: Reduction) -> Self {
Self {
reduction,
label_smoothing: 0.0,
}
}
#[must_use]
pub fn with_label_smoothing(label_smoothing: f32) -> Self {
assert!(
(0.0..1.0).contains(&label_smoothing),
"Label smoothing must be in [0, 1)"
);
Self {
reduction: Reduction::Mean,
label_smoothing,
}
}
#[provable_contracts_macros::contract("cross-entropy-kernel-v1", equation = "cross_entropy")]
#[must_use]
pub fn forward(&self, logits: &Tensor, targets: &Tensor) -> Tensor {
assert_eq!(logits.ndim(), 2, "Logits must be 2D [batch, classes]");
assert_eq!(targets.ndim(), 1, "Targets must be 1D [batch]");
assert_eq!(
logits.shape()[0],
targets.shape()[0],
"Batch sizes must match"
);
let batch_size = logits.shape()[0];
let num_classes = logits.shape()[1];
let softmax_output = softmax_2d(logits);
let log_probs = log_softmax(logits);
let target_indices: Vec<usize> = targets
.data()
.iter()
.map(|&t| {
let idx = t as usize;
assert!(
idx < num_classes,
"Target class {idx} out of bounds for {num_classes} classes"
);
idx
})
.collect();
let mut losses = Vec::with_capacity(batch_size);
for (b, &target_class) in target_indices.iter().enumerate() {
if self.label_smoothing > 0.0 {
let smooth_target = (1.0 - self.label_smoothing) / num_classes as f32;
let mut loss = 0.0;
for c in 0..num_classes {
let target_prob = if c == target_class {
1.0 - self.label_smoothing + smooth_target
} else {
smooth_target
};
loss -= target_prob * log_probs.data()[b * num_classes + c];
}
losses.push(loss);
} else {
losses.push(-log_probs.data()[b * num_classes + target_class]);
}
}
let per_sample_loss = Tensor::new(&losses, &[batch_size]);
let mut loss = match self.reduction {
Reduction::None => per_sample_loss,
Reduction::Mean => {
let mean_val = losses.iter().sum::<f32>() / batch_size as f32;
Tensor::from_slice(&[mean_val])
}
Reduction::Sum => {
let sum_val = losses.iter().sum::<f32>();
Tensor::from_slice(&[sum_val])
}
};
if is_grad_enabled() && logits.requires_grad_enabled() && self.label_smoothing == 0.0 {
loss.requires_grad_(true);
let grad_fn = Arc::new(CrossEntropyBackward {
softmax_output: softmax_output.clone(),
targets: target_indices,
});
loss.set_grad_fn(grad_fn.clone());
with_graph(|graph| {
graph.register_tensor(logits.clone());
graph.record(loss.id(), grad_fn, vec![logits.id()]);
});
}
loss
}
}
#[derive(Debug, Clone, Default)]
pub struct BCEWithLogitsLoss {
reduction: Reduction,
pos_weight: Option<f32>,
}
impl BCEWithLogitsLoss {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_reduction(reduction: Reduction) -> Self {
Self {
reduction,
pos_weight: None,
}
}
#[must_use]
pub fn with_pos_weight(pos_weight: f32) -> Self {
Self {
reduction: Reduction::Mean,
pos_weight: Some(pos_weight),
}
}
#[must_use]
pub fn forward(&self, logits: &Tensor, targets: &Tensor) -> Tensor {
assert_eq!(logits.shape(), targets.shape());
let loss_data: Vec<f32> = logits
.data()
.iter()
.zip(targets.data().iter())
.map(|(&x, &y)| {
let max_val = x.max(0.0);
let base_loss = max_val - x * y + (1.0 + (-x.abs()).exp()).ln();
match self.pos_weight {
Some(w) => {
let weight = y * (w - 1.0) + 1.0;
base_loss * weight
}
None => base_loss,
}
})
.collect();
let loss = Tensor::new(&loss_data, logits.shape());
match self.reduction {
Reduction::None => loss,
Reduction::Mean => loss.mean(),
Reduction::Sum => loss.sum(),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NLLLoss {
reduction: Reduction,
}
#[path = "ucbd.rs"]
mod ucbd;
#[allow(clippy::wildcard_imports)]
use ucbd::*;
#[path = "loss_tests.rs"]
mod loss_tests;
#[path = "loss_tests_ce_contract.rs"]
mod loss_tests_ce_contract;
#[path = "loss_tests_lf_contract.rs"]
mod loss_tests_lf_contract;