burn_core/nn/loss/
binary_cross_entropy.rs

1use crate as burn;
2use crate::module::{Content, DisplaySettings, ModuleDisplay};
3
4use crate::tensor::activation::log_sigmoid;
5use crate::tensor::{Int, Tensor, backend::Backend};
6use crate::{config::Config, module::Module};
7use alloc::vec::Vec;
8
9/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).
10#[derive(Config, Debug)]
11pub struct BinaryCrossEntropyLossConfig {
12    /// Create weighted binary cross-entropy with a weight for each class.
13    ///
14    /// The loss of a specific sample will simply be multiplied by its label weight.
15    pub weights: Option<Vec<f32>>,
16
17    /// Create binary cross-entropy with label smoothing according to [When Does Label Smoothing Help?](https://arxiv.org/abs/1906.02629).
18    ///
19    /// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`.
20    /// Alpha = 0 would be the same as default.
21    pub smoothing: Option<f32>,
22
23    /// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
24    #[config(default = false)]
25    pub logits: bool,
26}
27
28impl BinaryCrossEntropyLossConfig {
29    /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
30    pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {
31        self.assertions();
32        BinaryCrossEntropyLoss {
33            weights: self
34                .weights
35                .as_ref()
36                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
37            smoothing: self.smoothing,
38            logits: self.logits,
39        }
40    }
41
42    fn assertions(&self) {
43        if let Some(alpha) = self.smoothing {
44            assert!(
45                (0.0..=1.).contains(&alpha),
46                "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
47                alpha
48            );
49        };
50        if let Some(weights) = self.weights.as_ref() {
51            assert!(
52                weights.iter().all(|e| e > &0.),
53                "Weights of cross-entropy have to be positive."
54            );
55        }
56    }
57}
58
59/// Calculate the binary cross entropy loss from the input logits and the targets.
60///
61/// Should be created using [BinaryCrossEntropyLossConfig]
62#[derive(Module, Debug)]
63#[module(custom_display)]
64pub struct BinaryCrossEntropyLoss<B: Backend> {
65    /// Weights for cross-entropy.
66    pub weights: Option<Tensor<B, 1>>,
67    /// Label smoothing alpha.
68    pub smoothing: Option<f32>,
69    /// Treat the inputs as logits
70    pub logits: bool,
71}
72
73impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
74    fn custom_settings(&self) -> Option<DisplaySettings> {
75        DisplaySettings::new()
76            .with_new_line_after_attribute(false)
77            .optional()
78    }
79
80    fn custom_content(&self, content: Content) -> Option<Content> {
81        content
82            .add("weights", &self.weights)
83            .add("smoothing", &self.smoothing)
84            .add("logits", &self.logits)
85            .optional()
86    }
87}
88
89impl<B: Backend> BinaryCrossEntropyLoss<B> {
90    /// Compute the criterion on the input tensor.
91    ///
92    /// # Shapes
93    ///
94    /// Binary:
95    /// - logits: `[batch_size]`
96    /// - targets: `[batch_size]`
97    ///
98    /// Multi-label:
99    /// - logits: `[batch_size, num_classes]`
100    /// - targets: `[batch_size, num_classes]`
101    pub fn forward<const D: usize>(
102        &self,
103        logits: Tensor<B, D>,
104        targets: Tensor<B, D, Int>,
105    ) -> Tensor<B, 1> {
106        self.assertions(&logits, &targets);
107
108        let mut targets_float = targets.clone().float();
109        let shape = targets.dims();
110
111        if let Some(alpha) = self.smoothing {
112            let num_classes = if D > 1 { shape[D - 1] } else { 2 };
113            targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
114        }
115
116        let mut loss = if self.logits {
117            // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
118            (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
119        } else {
120            // - (target * log(input) + (1 - target) * log(1 - input))
121            // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
122            (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
123                - targets_float * logits.log().clamp_min(-100.0)
124        };
125
126        if let Some(weights) = &self.weights {
127            let weights = if D > 1 {
128                weights.clone().expand(shape)
129            } else {
130                // Flatten targets and expand resulting weights to make it compatible with
131                // Tensor<B, D> for binary 1-D case
132                weights
133                    .clone()
134                    .gather(0, targets.flatten(0, 0))
135                    .expand(shape)
136            };
137            loss = loss * weights;
138        }
139
140        loss.mean()
141    }
142
143    fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
144        let logits_dims = logits.dims();
145        let targets_dims = targets.dims();
146        assert!(
147            logits_dims == targets_dims,
148            "Shape of targets ({:?}) should correspond to outer shape of logits ({:?}).",
149            targets_dims,
150            logits_dims
151        );
152
153        if let Some(weights) = &self.weights {
154            if D > 1 {
155                let targets_classes = targets_dims[D - 1];
156                let weights_classes = weights.dims()[0];
157                assert!(
158                    weights_classes == targets_classes,
159                    "The number of classes ({}) does not match the weights provided ({}).",
160                    weights_classes,
161                    targets_classes
162                );
163            }
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::TestBackend;
172    use crate::tensor::{TensorData, activation::sigmoid};
173    use burn_tensor::{Tolerance, ops::FloatElem};
174    type FT = FloatElem<TestBackend>;
175
176    #[test]
177    fn test_binary_cross_entropy_preds_all_correct() {
178        let device = Default::default();
179        let preds = Tensor::<TestBackend, 1>::from_floats([1.0, 0.0, 1.0, 0.0], &device);
180        let targets =
181            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
182
183        let loss_actual = BinaryCrossEntropyLossConfig::new()
184            .init(&device)
185            .forward(preds, targets)
186            .into_data();
187
188        let loss_expected = TensorData::from([0.000]);
189        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
190    }
191
192    #[test]
193    fn test_binary_cross_entropy_preds_all_incorrect() {
194        let device = Default::default();
195        let preds = Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 0.0, 1.0], &device);
196        let targets =
197            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
198
199        let loss_actual = BinaryCrossEntropyLossConfig::new()
200            .init(&device)
201            .forward(preds, targets)
202            .into_data();
203
204        let loss_expected = TensorData::from([100.000]); // clamped value
205        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
206    }
207
208    #[test]
209    fn test_binary_cross_entropy() {
210        // import torch
211        // from torch import nn
212        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
213        // target = torch.tensor([0., 1., 0., 1.])
214        // loss = nn.BCELoss()
215        // sigmoid = nn.Sigmoid()
216        // out = loss(sigmoid(input), target) # tensor(0.7491)
217
218        let device = Default::default();
219        let logits =
220            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
221        let targets =
222            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
223
224        let loss_actual = BinaryCrossEntropyLossConfig::new()
225            .init(&device)
226            .forward(sigmoid(logits), targets)
227            .into_data();
228
229        let loss_expected = TensorData::from([0.7491]);
230        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
231    }
232
233    #[test]
234    fn test_binary_cross_entropy_with_logits() {
235        let device = Default::default();
236        let logits =
237            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
238        let targets =
239            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
240
241        let loss_actual = BinaryCrossEntropyLossConfig::new()
242            .with_logits(true)
243            .init(&device)
244            .forward(logits, targets)
245            .into_data();
246
247        let loss_expected = TensorData::from([0.7491]);
248        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
249    }
250
251    #[test]
252    fn test_binary_cross_entropy_with_weights() {
253        // import torch
254        // from torch import nn
255        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
256        // target = torch.tensor([0, 1, 0, 1])
257        // weights = torch.tensor([3., 7.]).gather(0, target)
258        // loss = nn.BCELoss(weights)
259        // sigmoid = nn.Sigmoid()
260        // out = loss(sigmoid(input), target.float()) # tensor(3.1531)
261
262        let device = Default::default();
263        let logits =
264            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
265        let targets =
266            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
267        let weights = [3., 7.];
268
269        let loss_actual = BinaryCrossEntropyLossConfig::new()
270            .with_weights(Some(weights.to_vec()))
271            .init(&device)
272            .forward(sigmoid(logits), targets)
273            .into_data();
274
275        let loss_expected = TensorData::from([3.1531]);
276        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
277    }
278
279    #[test]
280    fn test_binary_cross_entropy_with_smoothing() {
281        // import torch
282        // from torch import nn
283        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
284        // target = torch.tensor([0., 1., 0., 1.])
285        // target_smooth = target * (1 - 0.1) + (0.1 / 2)
286        // loss = nn.BCELoss()
287        // sigmoid = nn.Sigmoid()
288        // out = loss(sigmoid(input), target_smooth) # tensor(0.7490)
289
290        let device = Default::default();
291        let logits =
292            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
293        let targets =
294            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
295
296        let loss_actual = BinaryCrossEntropyLossConfig::new()
297            .with_smoothing(Some(0.1))
298            .init(&device)
299            .forward(sigmoid(logits), targets)
300            .into_data();
301
302        let loss_expected = TensorData::from([0.7490]);
303        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
304    }
305
306    #[test]
307    fn test_binary_cross_entropy_multilabel() {
308        // import torch
309        // from torch import nn
310        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
311        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
312        // weights = torch.tensor([3., 7., 0.9])
313        // loss = nn.BCEWithLogitsLoss()
314        // out = loss(input, target) # tensor(0.7112)
315
316        let device = Default::default();
317        let logits = Tensor::<TestBackend, 2>::from_floats(
318            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
319            &device,
320        );
321        let targets = Tensor::<TestBackend, 2, Int>::from_data(
322            TensorData::from([[1, 0, 1], [1, 0, 0]]),
323            &device,
324        );
325
326        let loss_actual = BinaryCrossEntropyLossConfig::new()
327            .with_logits(true)
328            .init(&device)
329            .forward(logits, targets)
330            .into_data();
331
332        let loss_expected = TensorData::from([0.7112]);
333        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
334    }
335
336    #[test]
337    fn test_binary_cross_entropy_multilabel_with_weights() {
338        // import torch
339        // from torch import nn
340        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
341        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
342        // loss = nn.BCEWithLogitsLoss()
343        // out = loss(input, target) # tensor(3.1708)
344
345        let device = Default::default();
346        let logits = Tensor::<TestBackend, 2>::from_floats(
347            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
348            &device,
349        );
350        let targets = Tensor::<TestBackend, 2, Int>::from_data(
351            TensorData::from([[1, 0, 1], [1, 0, 0]]),
352            &device,
353        );
354        let weights = [3., 7., 0.9];
355
356        let loss_actual = BinaryCrossEntropyLossConfig::new()
357            .with_logits(true)
358            .with_weights(Some(weights.to_vec()))
359            .init(&device)
360            .forward(logits, targets)
361            .into_data();
362
363        let loss_expected = TensorData::from([3.1708]);
364        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
365    }
366
367    #[test]
368    fn test_binary_cross_entropy_multilabel_with_smoothing() {
369        // import torch
370        // from torch import nn
371        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
372        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
373        // target_smooth = target * (1 - 0.1) + (0.1 / 3)
374        // loss = nn.BCELoss()
375        // sigmoid = nn.Sigmoid()
376        // out = loss(sigmoid(input), target_smooth) # tensor(0.7228)
377
378        let device = Default::default();
379        let logits = Tensor::<TestBackend, 2>::from_floats(
380            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
381            &device,
382        );
383        let targets = Tensor::<TestBackend, 2, Int>::from_data(
384            TensorData::from([[1, 0, 1], [1, 0, 0]]),
385            &device,
386        );
387
388        let loss_actual = BinaryCrossEntropyLossConfig::new()
389            .with_smoothing(Some(0.1))
390            .init(&device)
391            .forward(sigmoid(logits), targets)
392            .into_data();
393
394        let loss_expected = TensorData::from([0.7228]);
395        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
396    }
397
398    #[test]
399    #[should_panic = "The number of classes"]
400    fn multilabel_weights_should_match_target() {
401        // import torch
402        // from torch import nn
403        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
404        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
405        // loss = nn.BCEWithLogitsLoss()
406        // out = loss(input, target) # tensor(3.1708)
407
408        let device = Default::default();
409        let logits = Tensor::<TestBackend, 2>::from_floats(
410            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
411            &device,
412        );
413        let targets = Tensor::<TestBackend, 2, Int>::from_data(
414            TensorData::from([[1, 0, 1], [1, 0, 0]]),
415            &device,
416        );
417        let weights = [3., 7.];
418
419        let _loss = BinaryCrossEntropyLossConfig::new()
420            .with_logits(true)
421            .with_weights(Some(weights.to_vec()))
422            .init(&device)
423            .forward(logits, targets);
424    }
425
426    #[test]
427    fn display() {
428        let config =
429            BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
430        let loss = config.init::<TestBackend>(&Default::default());
431
432        assert_eq!(
433            alloc::format!("{}", loss),
434            "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
435        );
436    }
437}