use burn::nn::loss::{CrossEntropyLossConfig, MseLoss};
use burn::prelude::*;
use burn::tensor::activation::{log_softmax, softmax};
#[derive(Debug, Default)]
pub struct CrossEntropyLoss;
impl CrossEntropyLoss {
pub fn new() -> Self {
Self
}
pub fn forward<B: Backend>(
&self,
logits: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let loss = CrossEntropyLossConfig::new().init(&logits.device());
loss.forward(logits, targets)
}
}
#[derive(Debug, Default)]
pub struct MSELoss;
impl MSELoss {
pub fn new() -> Self {
Self
}
pub fn forward<B: Backend>(&self, preds: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
let loss = MseLoss::new();
loss.forward(preds, targets, burn::nn::loss::Reduction::Mean)
}
}
#[derive(Debug)]
pub struct HuberLoss {
pub delta: f32,
}
impl HuberLoss {
pub fn new(delta: f32) -> Self {
Self { delta }
}
pub fn forward<B: Backend>(&self, preds: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
let diff = preds - targets;
let abs_diff = diff.clone().abs();
let device = abs_diff.device();
let abs_data: Vec<f32> = abs_diff.clone().into_data().to_vec().unwrap();
let diff_data: Vec<f32> = diff.into_data().to_vec().unwrap();
let delta = self.delta;
let half_delta_sq = 0.5 * delta * delta;
let huber_values: Vec<f32> = abs_data
.iter()
.zip(&diff_data)
.map(|(&abs_val, &diff_val)| {
if abs_val <= delta {
0.5 * diff_val * diff_val
} else {
delta * abs_val - half_delta_sq
}
})
.collect();
let mean: f32 = huber_values.iter().sum::<f32>() / huber_values.len() as f32;
Tensor::<B, 1>::from_floats([mean], &device)
}
}
impl Default for HuberLoss {
fn default() -> Self {
Self::new(1.0)
}
}
#[derive(Debug)]
pub struct FocalLoss {
pub gamma: f32,
pub alpha: Option<Vec<f32>>,
epsilon: f32,
}
impl FocalLoss {
pub fn new(gamma: f32) -> Self {
Self {
gamma,
alpha: None,
epsilon: 1e-8,
}
}
#[must_use]
pub fn with_alpha(mut self, alpha: Vec<f32>) -> Self {
self.alpha = Some(alpha);
self
}
pub fn forward<B: Backend>(
&self,
logits: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let [batch_size, n_classes] = logits.dims();
let device = logits.device();
let probs = softmax(logits.clone(), 1);
let log_probs = log_softmax(logits, 1);
let probs_data: Vec<f32> = probs.into_data().to_vec().unwrap();
let log_probs_data: Vec<f32> = log_probs.into_data().to_vec().unwrap();
let targets_data: Vec<i32> = targets.into_data().to_vec().unwrap();
let mut focal_losses: Vec<f32> = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let target_class = targets_data[i] as usize;
let p_t = probs_data[i * n_classes + target_class].max(self.epsilon);
let log_p_t = log_probs_data[i * n_classes + target_class];
let focal_weight = (1.0 - p_t).powf(self.gamma);
let alpha_weight = self
.alpha
.as_ref()
.map(|a| a.get(target_class).copied().unwrap_or(1.0))
.unwrap_or(1.0);
let loss = -alpha_weight * focal_weight * log_p_t;
focal_losses.push(loss);
}
let mean_loss: f32 = focal_losses.iter().sum::<f32>() / batch_size as f32;
Tensor::<B, 1>::from_floats([mean_loss], &device)
}
}
impl Default for FocalLoss {
fn default() -> Self {
Self::new(2.0)
}
}
#[derive(Debug)]
pub struct LabelSmoothingLoss {
pub smoothing: f32,
}
impl LabelSmoothingLoss {
pub fn new(smoothing: f32) -> Self {
Self { smoothing }
}
pub fn forward<B: Backend>(
&self,
logits: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let [batch_size, n_classes] = logits.dims();
let device = logits.device();
let log_probs = log_softmax(logits, 1);
let log_probs_data: Vec<f32> = log_probs.into_data().to_vec().unwrap();
let targets_data: Vec<i32> = targets.into_data().to_vec().unwrap();
let smooth_positive = 1.0 - self.smoothing;
let smooth_negative = self.smoothing / (n_classes - 1) as f32;
let mut total_loss = 0.0f32;
for i in 0..batch_size {
let target_class = targets_data[i] as usize;
let mut sample_loss = 0.0f32;
for c in 0..n_classes {
let log_p = log_probs_data[i * n_classes + c];
let target_prob = if c == target_class {
smooth_positive
} else {
smooth_negative
};
sample_loss -= target_prob * log_p;
}
total_loss += sample_loss;
}
let mean_loss = total_loss / batch_size as f32;
Tensor::<B, 1>::from_floats([mean_loss], &device)
}
}
impl Default for LabelSmoothingLoss {
fn default() -> Self {
Self::new(0.1)
}
}
#[derive(Debug, Default)]
pub struct LogCoshLoss;
impl LogCoshLoss {
pub fn new() -> Self {
Self
}
pub fn forward<B: Backend>(&self, preds: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
let diff = preds - targets;
let device = diff.device();
let diff_data: Vec<f32> = diff.into_data().to_vec().unwrap();
let log_cosh_values: Vec<f32> = diff_data
.iter()
.map(|&x| {
let abs_x = x.abs();
abs_x + (1.0 + (-2.0 * abs_x).exp()).ln() - std::f32::consts::LN_2
})
.collect();
let mean: f32 = log_cosh_values.iter().sum::<f32>() / log_cosh_values.len() as f32;
Tensor::<B, 1>::from_floats([mean], &device)
}
}
#[derive(Debug)]
pub struct CenterLoss {
num_classes: usize,
feature_dim: usize,
alpha: f32,
centers: Vec<Vec<f32>>,
}
impl CenterLoss {
pub fn new(num_classes: usize, feature_dim: usize) -> Self {
let centers = vec![vec![0.0f32; feature_dim]; num_classes];
Self {
num_classes,
feature_dim,
alpha: 0.5,
centers,
}
}
#[must_use]
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
pub fn forward<B: Backend>(
&mut self,
features: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let device = features.device();
let [batch_size, feat_dim] = features.dims();
assert_eq!(
feat_dim, self.feature_dim,
"Feature dimension mismatch: expected {}, got {}",
self.feature_dim, feat_dim
);
let features_data: Vec<f32> = features.clone().into_data().to_vec().unwrap();
let targets_data: Vec<i64> = targets.into_data().to_vec().unwrap();
let mut total_loss = 0.0f32;
let mut center_updates: Vec<Vec<f32>> = vec![vec![0.0f32; self.feature_dim]; self.num_classes];
let mut class_counts: Vec<usize> = vec![0; self.num_classes];
for i in 0..batch_size {
let class_idx = targets_data[i] as usize;
if class_idx >= self.num_classes {
continue;
}
class_counts[class_idx] += 1;
let mut dist_sq = 0.0f32;
for j in 0..self.feature_dim {
let diff = features_data[i * self.feature_dim + j] - self.centers[class_idx][j];
dist_sq += diff * diff;
center_updates[class_idx][j] += diff;
}
total_loss += 0.5 * dist_sq;
}
for c in 0..self.num_classes {
if class_counts[c] > 0 {
let count = class_counts[c] as f32;
for j in 0..self.feature_dim {
self.centers[c][j] += self.alpha * center_updates[c][j] / count;
}
}
}
let mean_loss = total_loss / batch_size as f32;
Tensor::<B, 1>::from_floats([mean_loss], &device)
}
pub fn get_centers(&self) -> &Vec<Vec<f32>> {
&self.centers
}
pub fn reset_centers(&mut self) {
self.centers = vec![vec![0.0f32; self.feature_dim]; self.num_classes];
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum BaseLossType {
#[default]
MSE,
Huber,
LogCosh,
}
#[derive(Debug)]
pub struct MaskedLossWrapper {
loss_type: BaseLossType,
huber_delta: f32,
}
impl MaskedLossWrapper {
pub fn new() -> Self {
Self {
loss_type: BaseLossType::MSE,
huber_delta: 1.0,
}
}
#[must_use]
pub fn with_loss_type(mut self, loss_type: BaseLossType) -> Self {
self.loss_type = loss_type;
self
}
#[must_use]
pub fn with_huber_delta(mut self, delta: f32) -> Self {
self.huber_delta = delta;
self
}
pub fn forward<B: Backend>(&self, preds: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
let device = preds.device();
let preds_data: Vec<f32> = preds.into_data().to_vec().unwrap();
let targets_data: Vec<f32> = targets.into_data().to_vec().unwrap();
let valid_pairs: Vec<(f32, f32)> = preds_data
.iter()
.zip(&targets_data)
.filter(|(&p, &t)| !p.is_nan() && !t.is_nan())
.map(|(&p, &t)| (p, t))
.collect();
if valid_pairs.is_empty() {
return Tensor::<B, 1>::from_floats([0.0], &device);
}
let n = valid_pairs.len() as f32;
let loss = match self.loss_type {
BaseLossType::MSE => {
let sum: f32 = valid_pairs.iter().map(|(p, t)| (p - t).powi(2)).sum();
sum / n
}
BaseLossType::Huber => {
let delta = self.huber_delta;
let half_delta_sq = 0.5 * delta * delta;
let sum: f32 = valid_pairs
.iter()
.map(|(p, t)| {
let diff = (p - t).abs();
if diff <= delta {
0.5 * diff * diff
} else {
delta * diff - half_delta_sq
}
})
.sum();
sum / n
}
BaseLossType::LogCosh => {
let sum: f32 = valid_pairs
.iter()
.map(|(p, t)| {
let x = p - t;
let abs_x = x.abs();
abs_x + (1.0 + (-2.0 * abs_x).exp()).ln() - std::f32::consts::LN_2
})
.sum();
sum / n
}
};
Tensor::<B, 1>::from_floats([loss], &device)
}
pub fn get_valid_fraction(preds: &[f32], targets: &[f32]) -> f32 {
let total = preds.len().min(targets.len());
if total == 0 {
return 0.0;
}
let valid_count = preds
.iter()
.zip(targets)
.filter(|(&p, &t)| !p.is_nan() && !t.is_nan())
.count();
valid_count as f32 / total as f32
}
}
impl Default for MaskedLossWrapper {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct CenterPlusLoss {
center_loss: CenterLoss,
lambda: f32,
}
impl CenterPlusLoss {
pub fn new(num_classes: usize, feature_dim: usize) -> Self {
Self {
center_loss: CenterLoss::new(num_classes, feature_dim),
lambda: 0.003,
}
}
#[must_use]
pub fn with_lambda(mut self, lambda: f32) -> Self {
self.lambda = lambda;
self
}
#[must_use]
pub fn with_center_alpha(mut self, alpha: f32) -> Self {
self.center_loss = self.center_loss.with_alpha(alpha);
self
}
pub fn forward<B: Backend>(
&mut self,
logits: Tensor<B, 2>,
features: Tensor<B, 2>,
targets: Tensor<B, 1, Int>,
) -> Tensor<B, 1> {
let device = logits.device();
let ce_loss = CrossEntropyLossConfig::new()
.init(&device)
.forward(logits, targets.clone());
let ce_data: Vec<f32> = ce_loss.into_data().to_vec().unwrap();
let center_loss = self.center_loss.forward(features, targets);
let center_data: Vec<f32> = center_loss.into_data().to_vec().unwrap();
let total_loss = ce_data[0] + self.lambda * center_data[0];
Tensor::<B, 1>::from_floats([total_loss], &device)
}
pub fn get_center_loss(&self) -> &CenterLoss {
&self.center_loss
}
pub fn get_centers(&self) -> &Vec<Vec<f32>> {
self.center_loss.get_centers()
}
pub fn reset_centers(&mut self) {
self.center_loss.reset_centers();
}
}
#[derive(Debug)]
pub struct TweedieLoss {
pub power: f32,
epsilon: f32,
}
impl TweedieLoss {
pub fn new(power: f32) -> Self {
Self {
power,
epsilon: 1e-8,
}
}
pub fn forward<B: Backend>(&self, preds: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
let device = preds.device();
let preds_data: Vec<f32> = preds.into_data().to_vec().unwrap();
let targets_data: Vec<f32> = targets.into_data().to_vec().unwrap();
let p = self.power;
let one_minus_p = 1.0 - p;
let two_minus_p = 2.0 - p;
let mut total_loss = 0.0f32;
let n = preds_data.len();
for i in 0..n {
let log_mu = preds_data[i];
let y = targets_data[i].max(0.0);
let mu = log_mu.exp().max(self.epsilon);
let term1 = if y > self.epsilon {
-y * mu.powf(one_minus_p) / one_minus_p
} else {
0.0
};
let term2 = mu.powf(two_minus_p) / two_minus_p;
total_loss += term1 + term2;
}
let mean_loss = total_loss / n as f32;
Tensor::<B, 1>::from_floats([mean_loss], &device)
}
pub fn deviance<B: Backend>(&self, preds: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
let device = preds.device();
let preds_data: Vec<f32> = preds.into_data().to_vec().unwrap();
let targets_data: Vec<f32> = targets.into_data().to_vec().unwrap();
let p = self.power;
let mut total_dev = 0.0f32;
let n = preds_data.len();
for i in 0..n {
let log_mu = preds_data[i];
let y = targets_data[i].max(0.0);
let mu = log_mu.exp().max(self.epsilon);
let dev = if y > self.epsilon {
let term1 = y.powf(2.0 - p) / ((1.0 - p) * (2.0 - p));
let term2 = y * mu.powf(1.0 - p) / (1.0 - p);
let term3 = mu.powf(2.0 - p) / (2.0 - p);
2.0 * (term1 - term2 + term3)
} else {
2.0 * mu.powf(2.0 - p) / (2.0 - p)
};
total_dev += dev;
}
let mean_dev = total_dev / n as f32;
Tensor::<B, 1>::from_floats([mean_dev], &device)
}
}
impl Default for TweedieLoss {
fn default() -> Self {
Self::new(1.5) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cross_entropy_loss_creation() {
let _loss = CrossEntropyLoss::new();
}
#[test]
fn test_huber_loss_creation() {
let loss = HuberLoss::new(0.5);
assert_eq!(loss.delta, 0.5);
let default_loss = HuberLoss::default();
assert_eq!(default_loss.delta, 1.0);
}
#[test]
fn test_focal_loss_creation() {
let loss = FocalLoss::new(2.0);
assert_eq!(loss.gamma, 2.0);
assert!(loss.alpha.is_none());
let weighted_loss = FocalLoss::new(2.0).with_alpha(vec![0.25, 0.75]);
assert!(weighted_loss.alpha.is_some());
assert_eq!(weighted_loss.alpha.unwrap(), vec![0.25, 0.75]);
}
#[test]
fn test_focal_loss_default() {
let loss = FocalLoss::default();
assert_eq!(loss.gamma, 2.0);
}
#[test]
fn test_label_smoothing_loss_creation() {
let loss = LabelSmoothingLoss::new(0.1);
assert_eq!(loss.smoothing, 0.1);
let default_loss = LabelSmoothingLoss::default();
assert_eq!(default_loss.smoothing, 0.1);
}
#[test]
fn test_log_cosh_loss_creation() {
let _loss = LogCoshLoss::new();
let _default_loss = LogCoshLoss::default();
}
}