Skip to main content

entrenar/train/loss/
bce_with_logits.rs

1//! Binary Cross-Entropy with Logits Loss for multi-label classification
2//!
3//! Combines a sigmoid activation with binary cross-entropy loss.
4//! Each output is treated as an independent binary classification,
5//! allowing multiple labels to be active simultaneously.
6//!
7//! # Formula
8//!
9//! Numerically stable computation:
10//! ```text
11//! L_i = max(x_i, 0) - x_i * t_i + log(1 + exp(-|x_i|))
12//! L = mean(L_i) over all i
13//! ```
14//!
15//! Gradient: `∂L/∂x_i = σ(x_i) - t_i`
16//!
17//! # Multi-label vs single-label
18//!
19//! - **CrossEntropyLoss**: softmax → mutual exclusion (single label)
20//! - **BCEWithLogitsLoss**: sigmoid → independent per-class (multi-label)
21
22use crate::Tensor;
23use ndarray::Array1;
24
25use super::LossFn;
26
27/// Binary Cross-Entropy with Logits Loss.
28///
29/// For multi-label classification where each class is an independent binary decision.
30/// Targets are multi-hot vectors (e.g., `[1.0, 0.0, 1.0, 0.0, 1.0]` for classes 0, 2, 4).
31///
32/// # Example
33///
34/// ```
35/// use entrenar::train::{BCEWithLogitsLoss, LossFn};
36/// use entrenar::Tensor;
37///
38/// let loss_fn = BCEWithLogitsLoss;
39/// let logits = Tensor::from_vec(vec![2.0, -1.0, 0.5], true);
40/// let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false); // multi-hot
41///
42/// let loss = loss_fn.forward(&logits, &targets);
43/// assert!(loss.data()[0] > 0.0);
44/// ```
45pub struct BCEWithLogitsLoss;
46
47impl BCEWithLogitsLoss {
48    /// Compute element-wise sigmoid: σ(x) = 1 / (1 + exp(-x))
49    pub(crate) fn sigmoid(x: &Array1<f32>) -> Array1<f32> {
50        contract_pre_sigmoid!();
51        let result = x.mapv(|v| {
52            // Numerically stable sigmoid
53            if v >= 0.0 {
54                let exp_neg = (-v).exp();
55                1.0 / (1.0 + exp_neg)
56            } else {
57                let exp_v = v.exp();
58                exp_v / (1.0 + exp_v)
59            }
60        });
61        contract_post_silu!(result);
62        result
63    }
64
65    /// Numerically stable BCE: max(x, 0) - x*t + log(1 + exp(-|x|))
66    fn stable_bce(logit: f32, target: f32) -> f32 {
67        let relu = logit.max(0.0);
68        let abs_x = logit.abs();
69        relu - logit * target + (1.0 + (-abs_x).exp()).ln()
70    }
71}
72
73impl LossFn for BCEWithLogitsLoss {
74    fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
75        assert_eq!(
76            predictions.len(),
77            targets.len(),
78            "Predictions and targets must have same length"
79        );
80
81        // Compute per-element BCE loss
82        let total_loss: f32 = predictions
83            .data()
84            .iter()
85            .zip(targets.data().iter())
86            .map(|(&logit, &target)| Self::stable_bce(logit, target))
87            .sum::<f32>()
88            / predictions.len() as f32;
89
90        let mut loss = Tensor::from_vec(vec![total_loss], true);
91
92        // Gradient: ∂L/∂x_i = (σ(x_i) - t_i) / N
93        let sigmoid_vals = Self::sigmoid(predictions.data());
94        let n = predictions.len() as f32;
95        let grad = (&sigmoid_vals - targets.data()) / n;
96
97        use crate::autograd::BackwardOp;
98        use std::rc::Rc;
99
100        struct BCEBackward {
101            pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
102            grad: Array1<f32>,
103        }
104
105        impl BackwardOp for BCEBackward {
106            fn backward(&self) {
107                let mut pred_grad = self.pred_grad_cell.borrow_mut();
108                if let Some(existing) = pred_grad.as_mut() {
109                    *existing = &*existing + &self.grad;
110                } else {
111                    *pred_grad = Some(self.grad.clone());
112                }
113            }
114        }
115
116        if predictions.requires_grad() {
117            loss.set_backward_op(Rc::new(BCEBackward {
118                pred_grad_cell: predictions.grad_cell(),
119                grad,
120            }));
121        }
122
123        loss
124    }
125
126    fn name(&self) -> &'static str {
127        "BCEWithLogits"
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    #![allow(clippy::unwrap_used)]
134    use super::*;
135    use approx::assert_relative_eq;
136
137    #[test]
138    fn test_bce_with_logits_loss_basic() {
139        let loss_fn = BCEWithLogitsLoss;
140        let logits = Tensor::from_vec(vec![2.0, -1.0, 0.5], true);
141        let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
142
143        let loss = loss_fn.forward(&logits, &targets);
144        assert!(loss.data()[0] > 0.0);
145        assert!(loss.data()[0].is_finite());
146    }
147
148    #[test]
149    fn test_sigmoid_basic() {
150        let x = Array1::from(vec![0.0, 100.0, -100.0]);
151        let s = BCEWithLogitsLoss::sigmoid(&x);
152
153        assert_relative_eq!(s[0], 0.5, epsilon = 1e-5);
154        assert_relative_eq!(s[1], 1.0, epsilon = 1e-5);
155        assert_relative_eq!(s[2], 0.0, epsilon = 1e-5);
156    }
157
158    #[test]
159    fn test_sigmoid_symmetry() {
160        // σ(x) + σ(-x) = 1
161        let x = Array1::from(vec![1.0, 2.0, -3.0, 0.5]);
162        let neg_x = x.mapv(|v| -v);
163        let s_x = BCEWithLogitsLoss::sigmoid(&x);
164        let s_neg_x = BCEWithLogitsLoss::sigmoid(&neg_x);
165
166        for i in 0..x.len() {
167            assert_relative_eq!(s_x[i] + s_neg_x[i], 1.0, epsilon = 1e-6);
168        }
169    }
170
171    #[test]
172    fn test_bce_perfect_prediction() {
173        let loss_fn = BCEWithLogitsLoss;
174        // Logits that strongly match targets → low loss
175        let logits = Tensor::from_vec(vec![100.0, -100.0, 100.0, -100.0, 100.0], true);
176        let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0], false);
177
178        let loss = loss_fn.forward(&logits, &targets);
179        assert!(loss.data()[0] < 0.01, "Perfect prediction should have near-zero loss");
180    }
181
182    #[test]
183    fn test_bce_wrong_prediction() {
184        let loss_fn = BCEWithLogitsLoss;
185        // Logits that strongly disagree with targets → high loss
186        let logits = Tensor::from_vec(vec![-100.0, 100.0, -100.0], true);
187        let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
188
189        let loss = loss_fn.forward(&logits, &targets);
190        assert!(loss.data()[0] > 10.0, "Wrong prediction should have high loss");
191    }
192
193    #[test]
194    fn test_bce_gradient_direction() {
195        let loss_fn = BCEWithLogitsLoss;
196        let logits = Tensor::from_vec(vec![2.0, -1.0, 0.5], true);
197        let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
198
199        let loss = loss_fn.forward(&logits, &targets);
200        if let Some(backward_op) = loss.backward_op() {
201            backward_op.backward();
202        }
203
204        let grad = logits.grad().expect("gradient should be available");
205        // For target=1 with positive logit: grad should be negative (push logit higher)
206        assert!(grad[0] < 0.0, "grad[0] should be negative (target=1, logit=2.0)");
207        // For target=0 with negative logit: grad should be positive (push logit lower)
208        assert!(grad[1] > 0.0, "grad[1] should be positive (target=0, logit=-1.0)");
209        // All gradients finite
210        for g in &grad {
211            assert!(g.is_finite());
212        }
213    }
214
215    #[test]
216    fn test_bce_gradient_at_zero() {
217        let loss_fn = BCEWithLogitsLoss;
218        // At logit=0, sigmoid=0.5
219        let logits = Tensor::from_vec(vec![0.0], true);
220        let targets = Tensor::from_vec(vec![1.0], false);
221
222        let loss = loss_fn.forward(&logits, &targets);
223        if let Some(op) = loss.backward_op() {
224            op.backward();
225        }
226
227        let grad = logits.grad().expect("gradient should be available");
228        // ∂L/∂x = (σ(0) - 1) / 1 = (0.5 - 1) / 1 = -0.5
229        assert_relative_eq!(grad[0], -0.5, epsilon = 1e-5);
230    }
231
232    #[test]
233    fn test_bce_all_zeros_target() {
234        let loss_fn = BCEWithLogitsLoss;
235        let logits = Tensor::from_vec(vec![0.0; 5], true);
236        let targets = Tensor::from_vec(vec![0.0; 5], false);
237
238        let loss = loss_fn.forward(&logits, &targets);
239        // log(1 + exp(0)) = log(2) ≈ 0.693 per element
240        assert_relative_eq!(loss.data()[0], 2.0_f32.ln(), epsilon = 1e-5);
241    }
242
243    #[test]
244    fn test_bce_all_ones_target() {
245        let loss_fn = BCEWithLogitsLoss;
246        let logits = Tensor::from_vec(vec![0.0; 5], true);
247        let targets = Tensor::from_vec(vec![1.0; 5], false);
248
249        let loss = loss_fn.forward(&logits, &targets);
250        // Same: log(2) per element (symmetric when logit=0)
251        assert_relative_eq!(loss.data()[0], 2.0_f32.ln(), epsilon = 1e-5);
252    }
253
254    #[test]
255    fn test_bce_numerical_stability_large_positive() {
256        let loss_fn = BCEWithLogitsLoss;
257        let logits = Tensor::from_vec(vec![1000.0, 500.0, 100.0], true);
258        let targets = Tensor::from_vec(vec![1.0, 1.0, 1.0], false);
259
260        let loss = loss_fn.forward(&logits, &targets);
261        assert!(loss.data()[0].is_finite(), "Must be stable for large positive logits");
262        assert!(loss.data()[0] < 0.01, "Loss should be near-zero for correct large logits");
263    }
264
265    #[test]
266    fn test_bce_numerical_stability_large_negative() {
267        let loss_fn = BCEWithLogitsLoss;
268        let logits = Tensor::from_vec(vec![-1000.0, -500.0, -100.0], true);
269        let targets = Tensor::from_vec(vec![0.0, 0.0, 0.0], false);
270
271        let loss = loss_fn.forward(&logits, &targets);
272        assert!(loss.data()[0].is_finite(), "Must be stable for large negative logits");
273        assert!(loss.data()[0] < 0.01, "Loss should be near-zero for correct large logits");
274    }
275
276    #[test]
277    #[should_panic(expected = "must have same length")]
278    fn test_bce_mismatched_lengths() {
279        let loss_fn = BCEWithLogitsLoss;
280        let pred = Tensor::from_vec(vec![1.0, 2.0], true);
281        let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
282        loss_fn.forward(&pred, &target);
283    }
284
285    #[test]
286    fn test_bce_no_grad() {
287        let loss_fn = BCEWithLogitsLoss;
288        let pred = Tensor::from_vec(vec![2.0, -1.0], false);
289        let target = Tensor::from_vec(vec![1.0, 0.0], false);
290        let loss = loss_fn.forward(&pred, &target);
291        assert!(loss.data()[0] > 0.0);
292    }
293
294    #[test]
295    fn test_bce_gradient_accumulation() {
296        let logits = Tensor::from_vec(vec![1.0, -1.0], true);
297        let targets = Tensor::from_vec(vec![1.0, 0.0], false);
298
299        let loss1 = BCEWithLogitsLoss.forward(&logits, &targets);
300        if let Some(op) = loss1.backward_op() {
301            op.backward();
302        }
303
304        let loss2 = BCEWithLogitsLoss.forward(&logits, &targets);
305        if let Some(op) = loss2.backward_op() {
306            op.backward();
307        }
308
309        let grad = logits.grad().expect("gradient should be available");
310        assert!(grad[0].is_finite());
311        assert!(grad[1].is_finite());
312    }
313
314    #[test]
315    fn test_bce_name() {
316        assert_eq!(BCEWithLogitsLoss.name(), "BCEWithLogits");
317    }
318
319    #[test]
320    fn test_stable_bce_formula() {
321        // Verify against naive (potentially unstable) formula
322        // For moderate values, both should agree
323        let logit = 1.5f32;
324        let target = 0.7f32;
325
326        let stable = BCEWithLogitsLoss::stable_bce(logit, target);
327
328        // Naive: -[t * log(σ(x)) + (1-t) * log(1 - σ(x))]
329        let sigma = 1.0 / (1.0 + (-logit).exp());
330        let naive = -(target * sigma.ln() + (1.0 - target) * (1.0 - sigma).ln());
331
332        assert_relative_eq!(stable, naive, epsilon = 1e-5);
333    }
334
335    #[test]
336    fn test_multi_label_scenario() {
337        // Real multi-label: script is both non-deterministic AND needs-quoting
338        let loss_fn = BCEWithLogitsLoss;
339        // 5 classes: safe, needs-quoting, non-det, non-idem, unsafe
340        let logits = Tensor::from_vec(vec![-2.0, 3.0, 4.0, -1.0, -3.0], true);
341        let targets = Tensor::from_vec(vec![0.0, 1.0, 1.0, 0.0, 0.0], false);
342
343        let loss = loss_fn.forward(&logits, &targets);
344        assert!(loss.data()[0].is_finite());
345        assert!(loss.data()[0] > 0.0);
346    }
347}