use burn::nn::loss::{CrossEntropyLossConfig, MseLoss};
use burn::prelude::*;
use burn::tensor::activation::{log_softmax, softmax};
#[derive(Debug, Default)]
pub struct CrossEntropyLoss;
impl CrossEntropyLoss {
pub fn new() -> Self {
Self
}
pub fn forward<B: Backend>(
&self,
logits: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let loss = CrossEntropyLossConfig::new().init(&logits.device());
loss.forward(logits, targets)
}
}
#[derive(Debug, Default)]
pub struct MSELoss;
impl MSELoss {
pub fn new() -> Self {
Self
}
pub fn forward<B: Backend>(&self, preds: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
let loss = MseLoss::new();
loss.forward(preds, targets, burn::nn::loss::Reduction::Mean)
}
}
#[derive(Debug)]
pub struct HuberLoss {
pub delta: f32,
}
impl HuberLoss {
pub fn new(delta: f32) -> Self {
Self { delta }
}
pub fn forward<B: Backend>(&self, preds: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
let diff = preds - targets;
let abs_diff = diff.clone().abs();
let device = abs_diff.device();
let abs_data: Vec<f32> = abs_diff.clone().into_data().to_vec().unwrap();
let diff_data: Vec<f32> = diff.into_data().to_vec().unwrap();
let delta = self.delta;
let half_delta_sq = 0.5 * delta * delta;
let huber_values: Vec<f32> = abs_data
.iter()
.zip(&diff_data)
.map(|(&abs_val, &diff_val)| {
if abs_val <= delta {
0.5 * diff_val * diff_val
} else {
delta * abs_val - half_delta_sq
}
})
.collect();
let mean: f32 = huber_values.iter().sum::<f32>() / huber_values.len() as f32;
Tensor::<B, 1>::from_floats([mean], &device)
}
}
impl Default for HuberLoss {
fn default() -> Self {
Self::new(1.0)
}
}
#[derive(Debug)]
pub struct FocalLoss {
pub gamma: f32,
pub alpha: Option<Vec<f32>>,
epsilon: f32,
}
impl FocalLoss {
pub fn new(gamma: f32) -> Self {
Self {
gamma,
alpha: None,
epsilon: 1e-8,
}
}
#[must_use]
pub fn with_alpha(mut self, alpha: Vec<f32>) -> Self {
self.alpha = Some(alpha);
self
}
pub fn forward<B: Backend>(
&self,
logits: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let [batch_size, n_classes] = logits.dims();
let device = logits.device();
let probs = softmax(logits.clone(), 1);
let log_probs = log_softmax(logits, 1);
let probs_data: Vec<f32> = probs.into_data().to_vec().unwrap();
let log_probs_data: Vec<f32> = log_probs.into_data().to_vec().unwrap();
let targets_data: Vec<i32> = targets.into_data().to_vec().unwrap();
let mut focal_losses: Vec<f32> = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let target_class = targets_data[i] as usize;
let p_t = probs_data[i * n_classes + target_class].max(self.epsilon);
let log_p_t = log_probs_data[i * n_classes + target_class];
let focal_weight = (1.0 - p_t).powf(self.gamma);
let alpha_weight = self
.alpha
.as_ref()
.map(|a| a.get(target_class).copied().unwrap_or(1.0))
.unwrap_or(1.0);
let loss = -alpha_weight * focal_weight * log_p_t;
focal_losses.push(loss);
}
let mean_loss: f32 = focal_losses.iter().sum::<f32>() / batch_size as f32;
Tensor::<B, 1>::from_floats([mean_loss], &device)
}
}
impl Default for FocalLoss {
fn default() -> Self {
Self::new(2.0)
}
}
#[derive(Debug)]
pub struct LabelSmoothingLoss {
pub smoothing: f32,
}
impl LabelSmoothingLoss {
pub fn new(smoothing: f32) -> Self {
Self { smoothing }
}
pub fn forward<B: Backend>(
&self,
logits: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let [batch_size, n_classes] = logits.dims();
let device = logits.device();
let log_probs = log_softmax(logits, 1);
let log_probs_data: Vec<f32> = log_probs.into_data().to_vec().unwrap();
let targets_data: Vec<i32> = targets.into_data().to_vec().unwrap();
let smooth_positive = 1.0 - self.smoothing;
let smooth_negative = self.smoothing / (n_classes - 1) as f32;
let mut total_loss = 0.0f32;
for i in 0..batch_size {
let target_class = targets_data[i] as usize;
let mut sample_loss = 0.0f32;
for c in 0..n_classes {
let log_p = log_probs_data[i * n_classes + c];
let target_prob = if c == target_class {
smooth_positive
} else {
smooth_negative
};
sample_loss -= target_prob * log_p;
}
total_loss += sample_loss;
}
let mean_loss = total_loss / batch_size as f32;
Tensor::<B, 1>::from_floats([mean_loss], &device)
}
}
impl Default for LabelSmoothingLoss {
fn default() -> Self {
Self::new(0.1)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cross_entropy_loss_creation() {
let _loss = CrossEntropyLoss::new();
}
#[test]
fn test_huber_loss_creation() {
let loss = HuberLoss::new(0.5);
assert_eq!(loss.delta, 0.5);
let default_loss = HuberLoss::default();
assert_eq!(default_loss.delta, 1.0);
}
#[test]
fn test_focal_loss_creation() {
let loss = FocalLoss::new(2.0);
assert_eq!(loss.gamma, 2.0);
assert!(loss.alpha.is_none());
let weighted_loss = FocalLoss::new(2.0).with_alpha(vec![0.25, 0.75]);
assert!(weighted_loss.alpha.is_some());
assert_eq!(weighted_loss.alpha.unwrap(), vec![0.25, 0.75]);
}
#[test]
fn test_focal_loss_default() {
let loss = FocalLoss::default();
assert_eq!(loss.gamma, 2.0);
}
#[test]
fn test_label_smoothing_loss_creation() {
let loss = LabelSmoothingLoss::new(0.1);
assert_eq!(loss.smoothing, 0.1);
let default_loss = LabelSmoothingLoss::default();
assert_eq!(default_loss.smoothing, 0.1);
}
}