1use crate::Tensor;
23use ndarray::Array1;
24
25use super::LossFn;
26
27pub struct BCEWithLogitsLoss;
46
47impl BCEWithLogitsLoss {
48 pub(crate) fn sigmoid(x: &Array1<f32>) -> Array1<f32> {
50 contract_pre_sigmoid!();
51 let result = x.mapv(|v| {
52 if v >= 0.0 {
54 let exp_neg = (-v).exp();
55 1.0 / (1.0 + exp_neg)
56 } else {
57 let exp_v = v.exp();
58 exp_v / (1.0 + exp_v)
59 }
60 });
61 contract_post_silu!(result);
62 result
63 }
64
65 fn stable_bce(logit: f32, target: f32) -> f32 {
67 let relu = logit.max(0.0);
68 let abs_x = logit.abs();
69 relu - logit * target + (1.0 + (-abs_x).exp()).ln()
70 }
71}
72
73impl LossFn for BCEWithLogitsLoss {
74 fn forward(&self, predictions: &Tensor, targets: &Tensor) -> Tensor {
75 assert_eq!(
76 predictions.len(),
77 targets.len(),
78 "Predictions and targets must have same length"
79 );
80
81 let total_loss: f32 = predictions
83 .data()
84 .iter()
85 .zip(targets.data().iter())
86 .map(|(&logit, &target)| Self::stable_bce(logit, target))
87 .sum::<f32>()
88 / predictions.len() as f32;
89
90 let mut loss = Tensor::from_vec(vec![total_loss], true);
91
92 let sigmoid_vals = Self::sigmoid(predictions.data());
94 let n = predictions.len() as f32;
95 let grad = (&sigmoid_vals - targets.data()) / n;
96
97 use crate::autograd::BackwardOp;
98 use std::rc::Rc;
99
100 struct BCEBackward {
101 pred_grad_cell: Rc<std::cell::RefCell<Option<Array1<f32>>>>,
102 grad: Array1<f32>,
103 }
104
105 impl BackwardOp for BCEBackward {
106 fn backward(&self) {
107 let mut pred_grad = self.pred_grad_cell.borrow_mut();
108 if let Some(existing) = pred_grad.as_mut() {
109 *existing = &*existing + &self.grad;
110 } else {
111 *pred_grad = Some(self.grad.clone());
112 }
113 }
114 }
115
116 if predictions.requires_grad() {
117 loss.set_backward_op(Rc::new(BCEBackward {
118 pred_grad_cell: predictions.grad_cell(),
119 grad,
120 }));
121 }
122
123 loss
124 }
125
126 fn name(&self) -> &'static str {
127 "BCEWithLogits"
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 #![allow(clippy::unwrap_used)]
134 use super::*;
135 use approx::assert_relative_eq;
136
137 #[test]
138 fn test_bce_with_logits_loss_basic() {
139 let loss_fn = BCEWithLogitsLoss;
140 let logits = Tensor::from_vec(vec![2.0, -1.0, 0.5], true);
141 let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
142
143 let loss = loss_fn.forward(&logits, &targets);
144 assert!(loss.data()[0] > 0.0);
145 assert!(loss.data()[0].is_finite());
146 }
147
148 #[test]
149 fn test_sigmoid_basic() {
150 let x = Array1::from(vec![0.0, 100.0, -100.0]);
151 let s = BCEWithLogitsLoss::sigmoid(&x);
152
153 assert_relative_eq!(s[0], 0.5, epsilon = 1e-5);
154 assert_relative_eq!(s[1], 1.0, epsilon = 1e-5);
155 assert_relative_eq!(s[2], 0.0, epsilon = 1e-5);
156 }
157
158 #[test]
159 fn test_sigmoid_symmetry() {
160 let x = Array1::from(vec![1.0, 2.0, -3.0, 0.5]);
162 let neg_x = x.mapv(|v| -v);
163 let s_x = BCEWithLogitsLoss::sigmoid(&x);
164 let s_neg_x = BCEWithLogitsLoss::sigmoid(&neg_x);
165
166 for i in 0..x.len() {
167 assert_relative_eq!(s_x[i] + s_neg_x[i], 1.0, epsilon = 1e-6);
168 }
169 }
170
171 #[test]
172 fn test_bce_perfect_prediction() {
173 let loss_fn = BCEWithLogitsLoss;
174 let logits = Tensor::from_vec(vec![100.0, -100.0, 100.0, -100.0, 100.0], true);
176 let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0], false);
177
178 let loss = loss_fn.forward(&logits, &targets);
179 assert!(loss.data()[0] < 0.01, "Perfect prediction should have near-zero loss");
180 }
181
182 #[test]
183 fn test_bce_wrong_prediction() {
184 let loss_fn = BCEWithLogitsLoss;
185 let logits = Tensor::from_vec(vec![-100.0, 100.0, -100.0], true);
187 let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
188
189 let loss = loss_fn.forward(&logits, &targets);
190 assert!(loss.data()[0] > 10.0, "Wrong prediction should have high loss");
191 }
192
193 #[test]
194 fn test_bce_gradient_direction() {
195 let loss_fn = BCEWithLogitsLoss;
196 let logits = Tensor::from_vec(vec![2.0, -1.0, 0.5], true);
197 let targets = Tensor::from_vec(vec![1.0, 0.0, 1.0], false);
198
199 let loss = loss_fn.forward(&logits, &targets);
200 if let Some(backward_op) = loss.backward_op() {
201 backward_op.backward();
202 }
203
204 let grad = logits.grad().expect("gradient should be available");
205 assert!(grad[0] < 0.0, "grad[0] should be negative (target=1, logit=2.0)");
207 assert!(grad[1] > 0.0, "grad[1] should be positive (target=0, logit=-1.0)");
209 for g in &grad {
211 assert!(g.is_finite());
212 }
213 }
214
215 #[test]
216 fn test_bce_gradient_at_zero() {
217 let loss_fn = BCEWithLogitsLoss;
218 let logits = Tensor::from_vec(vec![0.0], true);
220 let targets = Tensor::from_vec(vec![1.0], false);
221
222 let loss = loss_fn.forward(&logits, &targets);
223 if let Some(op) = loss.backward_op() {
224 op.backward();
225 }
226
227 let grad = logits.grad().expect("gradient should be available");
228 assert_relative_eq!(grad[0], -0.5, epsilon = 1e-5);
230 }
231
232 #[test]
233 fn test_bce_all_zeros_target() {
234 let loss_fn = BCEWithLogitsLoss;
235 let logits = Tensor::from_vec(vec![0.0; 5], true);
236 let targets = Tensor::from_vec(vec![0.0; 5], false);
237
238 let loss = loss_fn.forward(&logits, &targets);
239 assert_relative_eq!(loss.data()[0], 2.0_f32.ln(), epsilon = 1e-5);
241 }
242
243 #[test]
244 fn test_bce_all_ones_target() {
245 let loss_fn = BCEWithLogitsLoss;
246 let logits = Tensor::from_vec(vec![0.0; 5], true);
247 let targets = Tensor::from_vec(vec![1.0; 5], false);
248
249 let loss = loss_fn.forward(&logits, &targets);
250 assert_relative_eq!(loss.data()[0], 2.0_f32.ln(), epsilon = 1e-5);
252 }
253
254 #[test]
255 fn test_bce_numerical_stability_large_positive() {
256 let loss_fn = BCEWithLogitsLoss;
257 let logits = Tensor::from_vec(vec![1000.0, 500.0, 100.0], true);
258 let targets = Tensor::from_vec(vec![1.0, 1.0, 1.0], false);
259
260 let loss = loss_fn.forward(&logits, &targets);
261 assert!(loss.data()[0].is_finite(), "Must be stable for large positive logits");
262 assert!(loss.data()[0] < 0.01, "Loss should be near-zero for correct large logits");
263 }
264
265 #[test]
266 fn test_bce_numerical_stability_large_negative() {
267 let loss_fn = BCEWithLogitsLoss;
268 let logits = Tensor::from_vec(vec![-1000.0, -500.0, -100.0], true);
269 let targets = Tensor::from_vec(vec![0.0, 0.0, 0.0], false);
270
271 let loss = loss_fn.forward(&logits, &targets);
272 assert!(loss.data()[0].is_finite(), "Must be stable for large negative logits");
273 assert!(loss.data()[0] < 0.01, "Loss should be near-zero for correct large logits");
274 }
275
276 #[test]
277 #[should_panic(expected = "must have same length")]
278 fn test_bce_mismatched_lengths() {
279 let loss_fn = BCEWithLogitsLoss;
280 let pred = Tensor::from_vec(vec![1.0, 2.0], true);
281 let target = Tensor::from_vec(vec![1.0, 2.0, 3.0], false);
282 loss_fn.forward(&pred, &target);
283 }
284
285 #[test]
286 fn test_bce_no_grad() {
287 let loss_fn = BCEWithLogitsLoss;
288 let pred = Tensor::from_vec(vec![2.0, -1.0], false);
289 let target = Tensor::from_vec(vec![1.0, 0.0], false);
290 let loss = loss_fn.forward(&pred, &target);
291 assert!(loss.data()[0] > 0.0);
292 }
293
294 #[test]
295 fn test_bce_gradient_accumulation() {
296 let logits = Tensor::from_vec(vec![1.0, -1.0], true);
297 let targets = Tensor::from_vec(vec![1.0, 0.0], false);
298
299 let loss1 = BCEWithLogitsLoss.forward(&logits, &targets);
300 if let Some(op) = loss1.backward_op() {
301 op.backward();
302 }
303
304 let loss2 = BCEWithLogitsLoss.forward(&logits, &targets);
305 if let Some(op) = loss2.backward_op() {
306 op.backward();
307 }
308
309 let grad = logits.grad().expect("gradient should be available");
310 assert!(grad[0].is_finite());
311 assert!(grad[1].is_finite());
312 }
313
314 #[test]
315 fn test_bce_name() {
316 assert_eq!(BCEWithLogitsLoss.name(), "BCEWithLogits");
317 }
318
319 #[test]
320 fn test_stable_bce_formula() {
321 let logit = 1.5f32;
324 let target = 0.7f32;
325
326 let stable = BCEWithLogitsLoss::stable_bce(logit, target);
327
328 let sigma = 1.0 / (1.0 + (-logit).exp());
330 let naive = -(target * sigma.ln() + (1.0 - target) * (1.0 - sigma).ln());
331
332 assert_relative_eq!(stable, naive, epsilon = 1e-5);
333 }
334
335 #[test]
336 fn test_multi_label_scenario() {
337 let loss_fn = BCEWithLogitsLoss;
339 let logits = Tensor::from_vec(vec![-2.0, 3.0, 4.0, -1.0, -3.0], true);
341 let targets = Tensor::from_vec(vec![0.0, 1.0, 1.0, 0.0, 0.0], false);
342
343 let loss = loss_fn.forward(&logits, &targets);
344 assert!(loss.data()[0].is_finite());
345 assert!(loss.data()[0] > 0.0);
346 }
347}