use crate::Tensor;
use ndarray::Array1;
use super::LossFn;
pub struct BCEWithLogitsLoss;
impl BCEWithLogitsLoss {
pub(crate) fn sigmoid(x: &Array1<f32>) -> Array1<f32> {
contract_pre_sigmoid!();
x.mapv(|v| {
if v >= 0.0 {
let exp_neg = (-v).exp();
1.0 / (1.0 + exp_neg)
} else {
let exp_v = v.exp();
exp_v / (1.0 + exp_v)
}
})
}
fn stable_bce(logit: f32, target: f32) -> f32 {
let relu = logit.max(0.0);
let abs_x = logit.abs();
relu - logit * target + (1.0 + (-abs_x).exp()).ln()
}
}
impl LossFn for BCEWithLogitsLoss {
fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
assert_eq!(
predictions.len(),
targets.len(),
"Predictions and targets must have same length"
);
let total_loss: f32 = predictions
.data()
.iter()
.zip(targets.data().iter())
.map(|(&logit, &target)| Self::stable_bce(logit, target))
.sum::<f32>()
/ predictions.len() as f32;
let mut loss = Tensor::from_vec(vec![total_loss], true);
let sigmoid_vals = Self::sigmoid(predictions.data());
let n = predictions.len() as f32;
let grad = (&sigmoid_vals - targets.data()) / n;
use crate::autograd::BackwardOp;
use std::rc::Rc;
struct BCEBackward {
pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
grad: Array1<f32>,
}
impl BackwardOp for BCEBackward {
fn backward(&self) {
let mut pred_grad = self.pred_grad_cell.borrow_mut();
if let Some(existing) = pred_grad.as_mut() {
*existing = &*existing + &self.grad;
} else {
*pred_grad = Some(self.grad.clone());
}
}
}
if predictions.requires_grad() {
loss.set_backward_op(Rc::new(BCEBackward {
pred_grad_cell: predictions.grad_cell(),
grad,
}));
}
loss
}
fn name(&self) -> &'static str {
"BCEWithLogits"
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_bce_with_logits_loss_basic() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![2.0, -1.0, 0.5], true);
let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0] > 0.0);
assert!(loss.data()[0].is_finite());
}
#[test]
fn test_sigmoid_basic() {
let x = Array1::from(vec![0.0, 100.0, -100.0]);
let s = BCEWithLogitsLoss::sigmoid(&x);
assert_relative_eq!(s[0], 0.5, epsilon = 1e-5);
assert_relative_eq!(s[1], 1.0, epsilon = 1e-5);
assert_relative_eq!(s[2], 0.0, epsilon = 1e-5);
}
#[test]
fn test_sigmoid_symmetry() {
let x = Array1::from(vec![1.0, 2.0, -3.0, 0.5]);
let neg_x = x.mapv(|v| -v);
let s_x = BCEWithLogitsLoss::sigmoid(&x);
let s_neg_x = BCEWithLogitsLoss::sigmoid(&neg_x);
for i in 0..x.len() {
assert_relative_eq!(s_x[i] + s_neg_x[i], 1.0, epsilon = 1e-6);
}
}
#[test]
fn test_bce_perfect_prediction() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![100.0, -100.0, 100.0, -100.0, 100.0], true);
let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0] < 0.01, "Perfect prediction should have near-zero loss");
}
#[test]
fn test_bce_wrong_prediction() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![-100.0, 100.0, -100.0], true);
let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0] > 10.0, "Wrong prediction should have high loss");
}
#[test]
fn test_bce_gradient_direction() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![2.0, -1.0, 0.5], true);
let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
let loss = loss_fn.forward(&logits, &targets);
if let Some(backward_op) = loss.backward_op() {
backward_op.backward();
}
let grad = logits.grad().expect("gradient should be available");
assert!(grad[0] < 0.0, "grad[0] should be negative (target=1, logit=2.0)");
assert!(grad[1] > 0.0, "grad[1] should be positive (target=0, logit=-1.0)");
for g in &grad {
assert!(g.is_finite());
}
}
#[test]
fn test_bce_gradient_at_zero() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![0.0], true);
let targets = Tensor::from_vec(vec![1.0], false);
let loss = loss_fn.forward(&logits, &targets);
if let Some(op) = loss.backward_op() {
op.backward();
}
let grad = logits.grad().expect("gradient should be available");
assert_relative_eq!(grad[0], -0.5, epsilon = 1e-5);
}
#[test]
fn test_bce_all_zeros_target() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![0.0; 5], true);
let targets = Tensor::from_vec(vec![0.0; 5], false);
let loss = loss_fn.forward(&logits, &targets);
assert_relative_eq!(loss.data()[0], 2.0_f32.ln(), epsilon = 1e-5);
}
#[test]
fn test_bce_all_ones_target() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![0.0; 5], true);
let targets = Tensor::from_vec(vec![1.0; 5], false);
let loss = loss_fn.forward(&logits, &targets);
assert_relative_eq!(loss.data()[0], 2.0_f32.ln(), epsilon = 1e-5);
}
#[test]
fn test_bce_numerical_stability_large_positive() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![1000.0, 500.0, 100.0], true);
let targets = Tensor::from_vec(vec![1.0, 1.0, 1.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0].is_finite(), "Must be stable for large positive logits");
assert!(loss.data()[0] < 0.01, "Loss should be near-zero for correct large logits");
}
#[test]
fn test_bce_numerical_stability_large_negative() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![-1000.0, -500.0, -100.0], true);
let targets = Tensor::from_vec(vec![0.0, 0.0, 0.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0].is_finite(), "Must be stable for large negative logits");
assert!(loss.data()[0] < 0.01, "Loss should be near-zero for correct large logits");
}
#[test]
#[should_panic(expected = "must have same length")]
fn test_bce_mismatched_lengths() {
let loss_fn = BCEWithLogitsLoss;
let pred = Tensor::from_vec(vec![1.0, 2.0], true);
let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
loss_fn.forward(&pred, &target);
}
#[test]
fn test_bce_no_grad() {
let loss_fn = BCEWithLogitsLoss;
let pred = Tensor::from_vec(vec![2.0, -1.0], false);
let target = Tensor::from_vec(vec![1.0, 0.0], false);
let loss = loss_fn.forward(&pred, &target);
assert!(loss.data()[0] > 0.0);
}
#[test]
fn test_bce_gradient_accumulation() {
let logits = Tensor::from_vec(vec![1.0, -1.0], true);
let targets = Tensor::from_vec(vec![1.0, 0.0], false);
let loss1 = BCEWithLogitsLoss.forward(&logits, &targets);
if let Some(op) = loss1.backward_op() {
op.backward();
}
let loss2 = BCEWithLogitsLoss.forward(&logits, &targets);
if let Some(op) = loss2.backward_op() {
op.backward();
}
let grad = logits.grad().expect("gradient should be available");
assert!(grad[0].is_finite());
assert!(grad[1].is_finite());
}
#[test]
fn test_bce_name() {
assert_eq!(BCEWithLogitsLoss.name(), "BCEWithLogits");
}
#[test]
fn test_stable_bce_formula() {
let logit = 1.5f32;
let target = 0.7f32;
let stable = BCEWithLogitsLoss::stable_bce(logit, target);
let sigma = 1.0 / (1.0 + (-logit).exp());
let naive = -(target * sigma.ln() + (1.0 - target) * (1.0 - sigma).ln());
assert_relative_eq!(stable, naive, epsilon = 1e-5);
}
#[test]
fn test_multi_label_scenario() {
let loss_fn = BCEWithLogitsLoss;
let logits = Tensor::from_vec(vec![-2.0, 3.0, 4.0, -1.0, -3.0], true);
let targets = Tensor::from_vec(vec![0.0, 1.0, 1.0, 0.0, 0.0], false);
let loss = loss_fn.forward(&logits, &targets);
assert!(loss.data()[0].is_finite());
assert!(loss.data()[0] > 0.0);
}
}