Skip to main content

burn_nn/loss/
binary_cross_entropy.rs

1use burn_core as burn;
2
3use alloc::vec::Vec;
4use burn::module::{Content, DisplaySettings, ModuleDisplay};
5use burn::tensor::activation::log_sigmoid;
6use burn::tensor::{Int, Tensor, backend::Backend};
7use burn::{config::Config, module::Module};
8
9/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).
10#[derive(Config, Debug)]
11pub struct BinaryCrossEntropyLossConfig {
12    /// Create weighted binary cross-entropy with a weight for each class.
13    ///
14    /// The loss of a specific sample will simply be multiplied by its label weight.
15    pub weights: Option<Vec<f32>>,
16
17    /// Create binary cross-entropy with label smoothing according to [When Does Label Smoothing Help?](https://arxiv.org/abs/1906.02629).
18    ///
19    /// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`.
20    /// Alpha = 0 would be the same as default.
21    pub smoothing: Option<f32>,
22
23    /// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
24    #[config(default = false)]
25    pub logits: bool,
26}
27
28impl BinaryCrossEntropyLossConfig {
29    /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
30    pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {
31        self.assertions();
32        BinaryCrossEntropyLoss {
33            weights: self
34                .weights
35                .as_ref()
36                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
37            smoothing: self.smoothing,
38            logits: self.logits,
39        }
40    }
41
42    fn assertions(&self) {
43        if let Some(alpha) = self.smoothing {
44            assert!(
45                (0.0..=1.).contains(&alpha),
46                "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
47            );
48        };
49        if let Some(weights) = self.weights.as_ref() {
50            assert!(
51                weights.iter().all(|e| e > &0.),
52                "Weights of cross-entropy have to be positive."
53            );
54        }
55    }
56}
57
58/// Calculate the binary cross entropy loss from the input logits and the targets.
59///
60/// Should be created using [BinaryCrossEntropyLossConfig]
61#[derive(Module, Debug)]
62#[module(custom_display)]
63pub struct BinaryCrossEntropyLoss<B: Backend> {
64    /// Weights for cross-entropy.
65    pub weights: Option<Tensor<B, 1>>,
66    /// Label smoothing alpha.
67    pub smoothing: Option<f32>,
68    /// Treat the inputs as logits
69    pub logits: bool,
70}
71
72impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
73    fn custom_settings(&self) -> Option<DisplaySettings> {
74        DisplaySettings::new()
75            .with_new_line_after_attribute(false)
76            .optional()
77    }
78
79    fn custom_content(&self, content: Content) -> Option<Content> {
80        content
81            .add("weights", &self.weights)
82            .add("smoothing", &self.smoothing)
83            .add("logits", &self.logits)
84            .optional()
85    }
86}
87
88impl<B: Backend> BinaryCrossEntropyLoss<B> {
89    /// Compute the criterion on the input tensor.
90    ///
91    /// # Shapes
92    ///
93    /// Binary:
94    /// - logits: `[batch_size]`
95    /// - targets: `[batch_size]`
96    ///
97    /// Multi-label:
98    /// - logits: `[batch_size, num_classes]`
99    /// - targets: `[batch_size, num_classes]`
100    pub fn forward<const D: usize>(
101        &self,
102        logits: Tensor<B, D>,
103        targets: Tensor<B, D, Int>,
104    ) -> Tensor<B, 1> {
105        self.assertions(&logits, &targets);
106
107        let mut targets_float = targets.clone().float();
108        let shape = targets.dims();
109
110        if let Some(alpha) = self.smoothing {
111            let num_classes = if D > 1 { shape[D - 1] } else { 2 };
112            targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
113        }
114
115        let mut loss = if self.logits {
116            // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
117            (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
118        } else {
119            // - (target * log(input) + (1 - target) * log(1 - input))
120            // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
121            (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
122                - targets_float * logits.log().clamp_min(-100.0)
123        };
124
125        if let Some(weights) = &self.weights {
126            let weights = if D > 1 {
127                weights.clone().expand(shape)
128            } else {
129                // Flatten targets and expand resulting weights to make it compatible with
130                // Tensor<B, D> for binary 1-D case
131                weights
132                    .clone()
133                    .gather(0, targets.flatten(0, 0))
134                    .expand(shape)
135            };
136            loss = loss * weights;
137        }
138
139        loss.mean()
140    }
141
142    fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
143        let logits_dims = logits.dims();
144        let targets_dims = targets.dims();
145        assert!(
146            logits_dims == targets_dims,
147            "Shape of targets ({targets_dims:?}) should correspond to outer shape of logits ({logits_dims:?})."
148        );
149
150        if let Some(weights) = &self.weights
151            && D > 1
152        {
153            let targets_classes = targets_dims[D - 1];
154            let weights_classes = weights.dims()[0];
155            assert!(
156                weights_classes == targets_classes,
157                "The number of classes ({weights_classes}) does not match the weights provided ({targets_classes})."
158            );
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::TestBackend;
167    use burn::tensor::{TensorData, activation::sigmoid};
168    use burn::tensor::{Tolerance, ops::FloatElem};
169    type FT = FloatElem<TestBackend>;
170
171    #[test]
172    fn test_binary_cross_entropy_preds_all_correct() {
173        let device = Default::default();
174        let preds = Tensor::<TestBackend, 1>::from_floats([1.0, 0.0, 1.0, 0.0], &device);
175        let targets =
176            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
177
178        let loss_actual = BinaryCrossEntropyLossConfig::new()
179            .init(&device)
180            .forward(preds, targets)
181            .into_data();
182
183        let loss_expected = TensorData::from([0.000]);
184        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
185    }
186
187    #[test]
188    fn test_binary_cross_entropy_preds_all_incorrect() {
189        let device = Default::default();
190        let preds = Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 0.0, 1.0], &device);
191        let targets =
192            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
193
194        let loss_actual = BinaryCrossEntropyLossConfig::new()
195            .init(&device)
196            .forward(preds, targets)
197            .into_data();
198
199        let loss_expected = TensorData::from([100.000]); // clamped value
200        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
201    }
202
203    #[test]
204    fn test_binary_cross_entropy() {
205        // import torch
206        // from torch import nn
207        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
208        // target = torch.tensor([0., 1., 0., 1.])
209        // loss = nn.BCELoss()
210        // sigmoid = nn.Sigmoid()
211        // out = loss(sigmoid(input), target) # tensor(0.7491)
212
213        let device = Default::default();
214        let logits =
215            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
216        let targets =
217            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
218
219        let loss_actual = BinaryCrossEntropyLossConfig::new()
220            .init(&device)
221            .forward(sigmoid(logits), targets)
222            .into_data();
223
224        let loss_expected = TensorData::from([0.7491]);
225        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
226    }
227
228    #[test]
229    fn test_binary_cross_entropy_with_logits() {
230        let device = Default::default();
231        let logits =
232            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
233        let targets =
234            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
235
236        let loss_actual = BinaryCrossEntropyLossConfig::new()
237            .with_logits(true)
238            .init(&device)
239            .forward(logits, targets)
240            .into_data();
241
242        let loss_expected = TensorData::from([0.7491]);
243        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
244    }
245
246    #[test]
247    fn test_binary_cross_entropy_with_weights() {
248        // import torch
249        // from torch import nn
250        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
251        // target = torch.tensor([0, 1, 0, 1])
252        // weights = torch.tensor([3., 7.]).gather(0, target)
253        // loss = nn.BCELoss(weights)
254        // sigmoid = nn.Sigmoid()
255        // out = loss(sigmoid(input), target.float()) # tensor(3.1531)
256
257        let device = Default::default();
258        let logits =
259            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
260        let targets =
261            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
262        let weights = [3., 7.];
263
264        let loss_actual = BinaryCrossEntropyLossConfig::new()
265            .with_weights(Some(weights.to_vec()))
266            .init(&device)
267            .forward(sigmoid(logits), targets)
268            .into_data();
269
270        let loss_expected = TensorData::from([3.1531]);
271        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
272    }
273
274    #[test]
275    fn test_binary_cross_entropy_with_smoothing() {
276        // import torch
277        // from torch import nn
278        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
279        // target = torch.tensor([0., 1., 0., 1.])
280        // target_smooth = target * (1 - 0.1) + (0.1 / 2)
281        // loss = nn.BCELoss()
282        // sigmoid = nn.Sigmoid()
283        // out = loss(sigmoid(input), target_smooth) # tensor(0.7490)
284
285        let device = Default::default();
286        let logits =
287            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
288        let targets =
289            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
290
291        let loss_actual = BinaryCrossEntropyLossConfig::new()
292            .with_smoothing(Some(0.1))
293            .init(&device)
294            .forward(sigmoid(logits), targets)
295            .into_data();
296
297        let loss_expected = TensorData::from([0.7490]);
298        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
299    }
300
301    #[test]
302    fn test_binary_cross_entropy_multilabel() {
303        // import torch
304        // from torch import nn
305        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
306        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
307        // weights = torch.tensor([3., 7., 0.9])
308        // loss = nn.BCEWithLogitsLoss()
309        // out = loss(input, target) # tensor(0.7112)
310
311        let device = Default::default();
312        let logits = Tensor::<TestBackend, 2>::from_floats(
313            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
314            &device,
315        );
316        let targets = Tensor::<TestBackend, 2, Int>::from_data(
317            TensorData::from([[1, 0, 1], [1, 0, 0]]),
318            &device,
319        );
320
321        let loss_actual = BinaryCrossEntropyLossConfig::new()
322            .with_logits(true)
323            .init(&device)
324            .forward(logits, targets)
325            .into_data();
326
327        let loss_expected = TensorData::from([0.7112]);
328        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
329    }
330
331    #[test]
332    fn test_binary_cross_entropy_multilabel_with_weights() {
333        // import torch
334        // from torch import nn
335        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
336        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
337        // loss = nn.BCEWithLogitsLoss()
338        // out = loss(input, target) # tensor(3.1708)
339
340        let device = Default::default();
341        let logits = Tensor::<TestBackend, 2>::from_floats(
342            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
343            &device,
344        );
345        let targets = Tensor::<TestBackend, 2, Int>::from_data(
346            TensorData::from([[1, 0, 1], [1, 0, 0]]),
347            &device,
348        );
349        let weights = [3., 7., 0.9];
350
351        let loss_actual = BinaryCrossEntropyLossConfig::new()
352            .with_logits(true)
353            .with_weights(Some(weights.to_vec()))
354            .init(&device)
355            .forward(logits, targets)
356            .into_data();
357
358        let loss_expected = TensorData::from([3.1708]);
359        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
360    }
361
362    #[test]
363    fn test_binary_cross_entropy_multilabel_with_smoothing() {
364        // import torch
365        // from torch import nn
366        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
367        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
368        // target_smooth = target * (1 - 0.1) + (0.1 / 3)
369        // loss = nn.BCELoss()
370        // sigmoid = nn.Sigmoid()
371        // out = loss(sigmoid(input), target_smooth) # tensor(0.7228)
372
373        let device = Default::default();
374        let logits = Tensor::<TestBackend, 2>::from_floats(
375            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
376            &device,
377        );
378        let targets = Tensor::<TestBackend, 2, Int>::from_data(
379            TensorData::from([[1, 0, 1], [1, 0, 0]]),
380            &device,
381        );
382
383        let loss_actual = BinaryCrossEntropyLossConfig::new()
384            .with_smoothing(Some(0.1))
385            .init(&device)
386            .forward(sigmoid(logits), targets)
387            .into_data();
388
389        let loss_expected = TensorData::from([0.7228]);
390        loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
391    }
392
393    #[test]
394    #[should_panic = "The number of classes"]
395    fn multilabel_weights_should_match_target() {
396        // import torch
397        // from torch import nn
398        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
399        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
400        // loss = nn.BCEWithLogitsLoss()
401        // out = loss(input, target) # tensor(3.1708)
402
403        let device = Default::default();
404        let logits = Tensor::<TestBackend, 2>::from_floats(
405            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
406            &device,
407        );
408        let targets = Tensor::<TestBackend, 2, Int>::from_data(
409            TensorData::from([[1, 0, 1], [1, 0, 0]]),
410            &device,
411        );
412        let weights = [3., 7.];
413
414        let _loss = BinaryCrossEntropyLossConfig::new()
415            .with_logits(true)
416            .with_weights(Some(weights.to_vec()))
417            .init(&device)
418            .forward(logits, targets);
419    }
420
421    #[test]
422    fn display() {
423        let config =
424            BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
425        let loss = config.init::<TestBackend>(&Default::default());
426
427        assert_eq!(
428            alloc::format!("{loss}"),
429            "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
430        );
431    }
432}