use crate::error::{VisionError, VisionResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Reduction {
Mean,
Sum,
None,
}
#[inline]
fn log_sigmoid(x: f32) -> f32 {
let z = -x;
let sp = z.max(0.0) + (-z.abs()).exp().ln_1p();
-sp
}
#[inline]
fn sigmoid(x: f32) -> f32 {
if x >= 0.0 {
let e = (-x).exp();
1.0 / (1.0 + e)
} else {
let e = x.exp();
e / (1.0 + e)
}
}
pub fn binary_focal_loss_one(logit: f32, target: f32, alpha: f32, gamma: f32) -> VisionResult<f32> {
if !logit.is_finite() {
return Err(VisionError::NonFinite("focal logit"));
}
if !(target == 0.0 || target == 1.0) {
return Err(VisionError::Internal(format!(
"binary focal target must be 0 or 1, got {target}"
)));
}
if !(0.0..=1.0).contains(&alpha) {
return Err(VisionError::Internal(format!(
"focal alpha must be in [0,1], got {alpha}"
)));
}
if gamma < 0.0 {
return Err(VisionError::Internal(format!(
"focal gamma must be >= 0, got {gamma}"
)));
}
let p = sigmoid(logit);
let (p_t, log_pt, alpha_t) = if target == 1.0 {
(p, log_sigmoid(logit), alpha)
} else {
(1.0 - p, log_sigmoid(-logit), 1.0 - alpha)
};
let modulating = (1.0 - p_t).max(0.0).powf(gamma);
Ok(-alpha_t * modulating * log_pt)
}
pub fn binary_focal_loss(
logits: &[f32],
targets: &[f32],
alpha: f32,
gamma: f32,
reduction: Reduction,
) -> VisionResult<Vec<f32>> {
if logits.is_empty() {
return Err(VisionError::EmptyInput("binary_focal_loss logits"));
}
if logits.len() != targets.len() {
return Err(VisionError::ShapeMismatch {
lhs: vec![logits.len()],
rhs: vec![targets.len()],
});
}
let mut per_elem = Vec::with_capacity(logits.len());
for (&logit, &target) in logits.iter().zip(targets.iter()) {
per_elem.push(binary_focal_loss_one(logit, target, alpha, gamma)?);
}
Ok(reduce(per_elem, reduction))
}
pub fn multiclass_focal_loss_one(
logits: &[f32],
target: usize,
alpha: f32,
gamma: f32,
) -> VisionResult<f32> {
if logits.is_empty() {
return Err(VisionError::EmptyInput("multiclass_focal_loss logits"));
}
if target >= logits.len() {
return Err(VisionError::InvalidNumClasses(logits.len()));
}
if gamma < 0.0 {
return Err(VisionError::Internal(format!(
"focal gamma must be >= 0, got {gamma}"
)));
}
for &l in logits {
if !l.is_finite() {
return Err(VisionError::NonFinite("focal logits"));
}
}
let m = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum_exp = 0.0f32;
for &l in logits {
sum_exp += (l - m).exp();
}
let log_sum_exp = m + sum_exp.ln();
let log_pt = logits[target] - log_sum_exp;
let p_t = log_pt.exp().clamp(0.0, 1.0);
let modulating = (1.0 - p_t).max(0.0).powf(gamma);
Ok(-alpha * modulating * log_pt)
}
pub fn multiclass_focal_loss(
logits: &[f32],
targets: &[usize],
num_classes: usize,
alpha: f32,
gamma: f32,
reduction: Reduction,
) -> VisionResult<Vec<f32>> {
if num_classes == 0 {
return Err(VisionError::InvalidNumClasses(0));
}
if targets.is_empty() {
return Err(VisionError::EmptyInput("multiclass_focal_loss targets"));
}
let n = targets.len();
if logits.len() != n * num_classes {
return Err(VisionError::ShapeMismatch {
lhs: vec![logits.len()],
rhs: vec![n, num_classes],
});
}
let mut per_elem = Vec::with_capacity(n);
for (row, &target) in targets.iter().enumerate() {
let start = row * num_classes;
let logit_row = &logits[start..start + num_classes];
per_elem.push(multiclass_focal_loss_one(logit_row, target, alpha, gamma)?);
}
Ok(reduce(per_elem, reduction))
}
fn reduce(per_elem: Vec<f32>, reduction: Reduction) -> Vec<f32> {
match reduction {
Reduction::None => per_elem,
Reduction::Sum => vec![per_elem.iter().sum()],
Reduction::Mean => {
let n = per_elem.len().max(1) as f32;
vec![per_elem.iter().sum::<f32>() / n]
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOL: f32 = 1e-5;
#[test]
fn gamma_zero_recovers_weighted_cross_entropy() {
let logit = 0.7f32;
let p = sigmoid(logit);
let alpha = 0.25;
let fl = binary_focal_loss_one(logit, 1.0, alpha, 0.0).expect("fl");
let bce = -alpha * p.ln();
assert!((fl - bce).abs() < 1e-4, "fl={fl} bce={bce}");
}
#[test]
fn perfect_positive_prediction_near_zero_loss() {
let fl = binary_focal_loss_one(20.0, 1.0, 0.5, 2.0).expect("fl");
assert!(fl < 1e-6, "fl={fl}");
}
#[test]
fn confident_wrong_prediction_has_large_loss() {
let fl = binary_focal_loss_one(10.0, 0.0, 0.5, 2.0).expect("fl");
assert!(fl > 1.0, "fl={fl}");
}
#[test]
fn focal_downweights_easy_examples_relative_to_ce() {
let logit = 3.0f32; let p = sigmoid(logit);
let ce = -p.ln();
let fl = binary_focal_loss_one(logit, 1.0, 1.0, 2.0).expect("fl");
let ratio = fl / ce;
let expected = (1.0 - p).powi(2);
assert!(
(ratio - expected).abs() < 1e-4,
"ratio={ratio} expected={expected}"
);
assert!(ratio < 0.01, "easy example not downweighted: {ratio}");
}
#[test]
fn binary_focal_loss_is_nonnegative() {
let logits = [-2.0, 0.0, 1.5, 4.0];
let targets = [0.0, 1.0, 0.0, 1.0];
let per = binary_focal_loss(&logits, &targets, 0.25, 2.0, Reduction::None).expect("ok");
for v in per {
assert!(v >= -TOL, "negative loss {v}");
}
}
#[test]
fn binary_mean_matches_manual_average() {
let logits = [-1.0, 0.5, 2.0];
let targets = [0.0, 1.0, 1.0];
let mean = binary_focal_loss(&logits, &targets, 0.3, 1.5, Reduction::Mean).expect("mean");
let per = binary_focal_loss(&logits, &targets, 0.3, 1.5, Reduction::None).expect("per");
let manual = per.iter().sum::<f32>() / per.len() as f32;
assert!((mean[0] - manual).abs() < TOL);
}
#[test]
fn binary_shape_mismatch_errors() {
let logits = [0.0, 1.0];
let targets = [1.0];
assert!(binary_focal_loss(&logits, &targets, 0.5, 2.0, Reduction::Sum).is_err());
}
#[test]
fn binary_invalid_target_errors() {
assert!(binary_focal_loss_one(0.0, 0.5, 0.5, 2.0).is_err());
}
#[test]
fn binary_invalid_alpha_gamma_errors() {
assert!(binary_focal_loss_one(0.0, 1.0, 1.5, 2.0).is_err());
assert!(binary_focal_loss_one(0.0, 1.0, 0.5, -1.0).is_err());
}
#[test]
fn multiclass_correct_class_low_loss() {
let logits = [10.0, 0.0, 0.0];
let fl = multiclass_focal_loss_one(&logits, 0, 1.0, 2.0).expect("fl");
assert!(fl < 1e-3, "fl={fl}");
}
#[test]
fn multiclass_gamma_zero_is_cross_entropy() {
let logits = [1.0, 2.0, 0.5];
let target = 1;
let m = 2.0f32;
let denom = ((1.0 - m).exp() + (2.0 - m).exp() + (0.5 - m).exp()).ln() + m;
let ce = -(logits[target] - denom);
let fl = multiclass_focal_loss_one(&logits, target, 1.0, 0.0).expect("fl");
assert!((fl - ce).abs() < 1e-5, "fl={fl} ce={ce}");
}
#[test]
fn multiclass_batched_sum_and_shape() {
let logits = [
1.0, 0.0, 0.0, 0.0, 3.0, 0.0, ];
let targets = [0usize, 1usize];
let s = multiclass_focal_loss(&logits, &targets, 3, 1.0, 2.0, Reduction::Sum).expect("sum");
assert_eq!(s.len(), 1);
assert!(s[0] >= 0.0);
}
#[test]
fn multiclass_target_out_of_range_errors() {
let logits = [1.0, 2.0];
assert!(multiclass_focal_loss_one(&logits, 5, 1.0, 2.0).is_err());
}
#[test]
fn multiclass_shape_mismatch_errors() {
let logits = [1.0, 2.0, 3.0]; let targets = [0usize, 1usize];
assert!(multiclass_focal_loss(&logits, &targets, 2, 1.0, 2.0, Reduction::Mean).is_err());
}
#[test]
fn nonfinite_logit_errors() {
assert!(binary_focal_loss_one(f32::INFINITY, 1.0, 0.5, 2.0).is_err());
let logits = [1.0, f32::NAN];
assert!(multiclass_focal_loss_one(&logits, 0, 1.0, 2.0).is_err());
}
}