use crate::error::{NeuralError, Result};
use crate::losses::Loss;
use scirs2_core::ndarray::Array;
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct FocalLoss {
gamma: f64,
alpha: Option<f64>,
alpha_per_class: Option<Vec<f64>>,
epsilon: f64,
}
impl FocalLoss {
pub fn new(gamma: f64, alpha: Option<f64>, epsilon: f64) -> Self {
Self {
gamma,
alpha,
alpha_per_class: None,
epsilon,
}
}
pub fn with_class_weights(gamma: f64, alpha_perclass: Vec<f64>, epsilon: f64) -> Self {
Self {
gamma,
alpha: None,
alpha_per_class: Some(alpha_perclass),
epsilon,
}
}
}
impl Default for FocalLoss {
fn default() -> Self {
Self::new(2.0, Some(0.25), 1e-10)
}
}
impl<F: Float + Debug + NumAssign> Loss<F> for FocalLoss {
fn forward(
&self,
predictions: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<F> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::InferenceError(format!(
"Shape mismatch in FocalLoss: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let gamma = F::from(self.gamma).ok_or_else(|| {
NeuralError::InferenceError("Could not convert gamma to float".to_string())
})?;
let epsilon = F::from(self.epsilon).ok_or_else(|| {
NeuralError::InferenceError("Could not convert epsilon to float".to_string())
})?;
let n = F::from(if predictions.ndim() > 1 {
predictions.shape()[0]
} else {
1
})
.ok_or_else(|| {
NeuralError::InferenceError("Could not convert batch size to float".to_string())
})?;
if let Some(ref alpha_per_class) = self.alpha_per_class {
let num_classes = if predictions.ndim() > 1 {
predictions.shape()[1]
} else {
predictions.len()
};
if alpha_per_class.len() != num_classes {
return Err(NeuralError::InferenceError(format!(
"Number of alpha values ({}) does not match number of classes ({})",
alpha_per_class.len(),
num_classes
)));
}
}
let mut loss = F::zero();
if predictions.ndim() > 1 {
for i in 0..predictions.shape()[0] {
let mut sample_loss = F::zero();
for j in 0..predictions.shape()[1] {
let y_pred = predictions[[i, j]].max(epsilon).min(F::one() - epsilon);
let y_true = targets[[i, j]];
if y_true > F::zero() {
let alpha = if let Some(ref alpha_per_class) = self.alpha_per_class {
F::from(alpha_per_class[j]).ok_or_else(|| {
NeuralError::InferenceError(
"Could not convert alpha to float".to_string(),
)
})?
} else if let Some(alpha) = self.alpha {
F::from(alpha).ok_or_else(|| {
NeuralError::InferenceError(
"Could not convert alpha to float".to_string(),
)
})?
} else {
F::one()
};
let p_t = y_pred;
let focal_weight = (F::one() - p_t).powf(gamma);
sample_loss -= alpha * focal_weight * y_true * p_t.ln();
}
}
loss += sample_loss;
}
loss /= n;
} else {
for j in 0..predictions.len() {
let p = predictions[j];
let t = targets[j];
if t > F::zero() {
let alpha = if let Some(ref alpha_per_class) = self.alpha_per_class {
F::from(alpha_per_class[j]).unwrap_or(F::one())
} else if let Some(a) = self.alpha {
F::from(a).unwrap_or(F::one())
} else {
F::one()
};
let p_safe = p.max(epsilon).min(F::one() - epsilon);
let focal_weight = (F::one() - p_safe).powf(gamma);
loss -= alpha * focal_weight * t * p_safe.ln();
}
}
}
Ok(loss)
}
fn backward(
&self,
predictions: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
if predictions.shape() != targets.shape() {
return Err(NeuralError::InferenceError(format!(
"Shape mismatch in FocalLoss gradient: predictions {:?} vs targets {:?}",
predictions.shape(),
targets.shape()
)));
}
let gamma = F::from(self.gamma).ok_or_else(|| {
NeuralError::InferenceError("Could not convert gamma to float".to_string())
})?;
let epsilon = F::from(self.epsilon).ok_or_else(|| {
NeuralError::InferenceError("Could not convert epsilon to float".to_string())
})?;
let n = F::from(if predictions.ndim() > 1 {
predictions.shape()[0]
} else {
1
})
.ok_or_else(|| {
NeuralError::InferenceError("Could not convert batch size to float".to_string())
})?;
let mut gradients = Array::zeros(predictions.raw_dim());
if predictions.ndim() == 1 {
for idx in 0..predictions.len() {
let p = predictions[idx];
let t = targets[idx];
if t > F::zero() {
let p_safe = p.max(epsilon).min(F::one() - epsilon);
let alpha = if let Some(ref alpha_per_class) = self.alpha_per_class {
F::from(alpha_per_class[idx]).unwrap_or(F::one())
} else if let Some(a) = self.alpha {
F::from(a).unwrap_or(F::one())
} else {
F::one()
};
let term1 =
-alpha * gamma * (F::one() - p_safe).powf(gamma - F::one()) * p_safe.ln();
let term2 = -alpha * (F::one() - p_safe).powf(gamma) * (F::one() / p_safe);
gradients[idx] = (term1 + term2) * t / n;
} else {
gradients[idx] = F::zero();
}
}
} else {
let batch_size = predictions.shape()[0];
for i in 0..batch_size {
for j in 0..predictions.shape()[1] {
let p = predictions[[i, j]];
let t = targets[[i, j]];
if t > F::zero() {
let p_safe = p.max(epsilon).min(F::one() - epsilon);
let alpha = if let Some(ref alpha_per_class) = self.alpha_per_class {
F::from(alpha_per_class[j]).unwrap_or(F::one())
} else if let Some(a) = self.alpha {
F::from(a).unwrap_or(F::one())
} else {
F::one()
};
let term1 = -alpha
* gamma
* (F::one() - p_safe).powf(gamma - F::one())
* p_safe.ln();
let term2 = -alpha * (F::one() - p_safe).powf(gamma) * (F::one() / p_safe);
gradients[[i, j]] = (term1 + term2) * t / n;
} else {
gradients[[i, j]] = F::zero();
}
}
}
}
Ok(gradients)
}
}