burn_core/nn/loss/
binary_cross_entropy.rs

1use crate as burn;
2use crate::module::{Content, DisplaySettings, ModuleDisplay};
3
4use crate::tensor::activation::log_sigmoid;
5use crate::tensor::{backend::Backend, Int, Tensor};
6use crate::{config::Config, module::Module};
7use alloc::vec::Vec;
8
9/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).
10#[derive(Config, Debug)]
11pub struct BinaryCrossEntropyLossConfig {
12    /// Create weighted binary cross-entropy with a weight for each class.
13    ///
14    /// The loss of a specific sample will simply be multiplied by its label weight.
15    pub weights: Option<Vec<f32>>,
16
17    /// Create binary cross-entropy with label smoothing according to [When Does Label Smoothing Help?](https://arxiv.org/abs/1906.02629).
18    ///
19    /// Hard labels {0, 1} will be changed to `y_smoothed = y(1 - a) + a / num_classes`.
20    /// Alpha = 0 would be the same as default.
21    pub smoothing: Option<f32>,
22
23    /// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
24    #[config(default = false)]
25    pub logits: bool,
26}
27
28impl BinaryCrossEntropyLossConfig {
29    /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
30    pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {
31        self.assertions();
32        BinaryCrossEntropyLoss {
33            weights: self
34                .weights
35                .as_ref()
36                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
37            smoothing: self.smoothing,
38            logits: self.logits,
39        }
40    }
41
42    fn assertions(&self) {
43        if let Some(alpha) = self.smoothing {
44            assert!(
45                (0.0..=1.).contains(&alpha),
46                "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
47                alpha
48            );
49        };
50        if let Some(weights) = self.weights.as_ref() {
51            assert!(
52                weights.iter().all(|e| e > &0.),
53                "Weights of cross-entropy have to be positive."
54            );
55        }
56    }
57}
58
59/// Calculate the binary cross entropy loss from the input logits and the targets.
60///
61/// Should be created using [BinaryCrossEntropyLossConfig]
62#[derive(Module, Debug)]
63#[module(custom_display)]
64pub struct BinaryCrossEntropyLoss<B: Backend> {
65    /// Weights for cross-entropy.
66    pub weights: Option<Tensor<B, 1>>,
67    /// Label smoothing alpha.
68    pub smoothing: Option<f32>,
69    /// Treat the inputs as logits
70    pub logits: bool,
71}
72
73impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
74    fn custom_settings(&self) -> Option<DisplaySettings> {
75        DisplaySettings::new()
76            .with_new_line_after_attribute(false)
77            .optional()
78    }
79
80    fn custom_content(&self, content: Content) -> Option<Content> {
81        content
82            .add("weights", &self.weights)
83            .add("smoothing", &self.smoothing)
84            .add("logits", &self.logits)
85            .optional()
86    }
87}
88
89impl<B: Backend> BinaryCrossEntropyLoss<B> {
90    /// Compute the criterion on the input tensor.
91    ///
92    /// # Shapes
93    ///
94    /// Binary:
95    /// - logits: `[batch_size]`
96    /// - targets: `[batch_size]`
97    ///
98    /// Multi-label:
99    /// - logits: `[batch_size, num_classes]`
100    /// - targets: `[batch_size, num_classes]`
101    pub fn forward<const D: usize>(
102        &self,
103        logits: Tensor<B, D>,
104        targets: Tensor<B, D, Int>,
105    ) -> Tensor<B, 1> {
106        self.assertions(&logits, &targets);
107
108        let mut targets_float = targets.clone().float();
109        let shape = targets.dims();
110
111        if let Some(alpha) = self.smoothing {
112            let num_classes = if D > 1 { shape[D - 1] } else { 2 };
113            targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
114        }
115
116        let mut loss = if self.logits {
117            // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
118            (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
119        } else {
120            // - (target * log(input) + (1 - target) * log(1 - input))
121            (targets_float.clone() * logits.clone().log()
122                + (targets_float.neg() + 1.) * (logits.neg() + 1.).log())
123            .neg()
124        };
125
126        if let Some(weights) = &self.weights {
127            let weights = if D > 1 {
128                weights.clone().expand(shape)
129            } else {
130                // Flatten targets and expand resulting weights to make it compatible with
131                // Tensor<B, D> for binary 1-D case
132                weights
133                    .clone()
134                    .gather(0, targets.flatten(0, 0))
135                    .expand(shape)
136            };
137            loss = loss * weights;
138        }
139
140        loss.mean()
141    }
142
143    fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
144        let logits_dims = logits.dims();
145        let targets_dims = targets.dims();
146        assert!(
147            logits_dims == targets_dims,
148            "Shape of targets ({:?}) should correspond to outer shape of logits ({:?}).",
149            targets_dims,
150            logits_dims
151        );
152
153        if let Some(weights) = &self.weights {
154            if D > 1 {
155                let targets_classes = targets_dims[D - 1];
156                let weights_classes = weights.dims()[0];
157                assert!(
158                    weights_classes == targets_classes,
159                    "The number of classes ({}) does not match the weights provided ({}).",
160                    weights_classes,
161                    targets_classes
162                );
163            }
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use crate::tensor::{activation::sigmoid, TensorData};
172    use crate::TestBackend;
173
174    #[test]
175    fn test_binary_cross_entropy() {
176        // import torch
177        // from torch import nn
178        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
179        // target = torch.tensor([0., 1., 0., 1.])
180        // loss = nn.BCELoss()
181        // sigmoid = nn.Sigmoid()
182        // out = loss(sigmoid(input), target) # tensor(0.7491)
183
184        let device = Default::default();
185        let logits =
186            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
187        let targets =
188            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
189
190        let loss_actual = BinaryCrossEntropyLossConfig::new()
191            .init(&device)
192            .forward(sigmoid(logits), targets)
193            .into_data();
194
195        let loss_expected = TensorData::from([0.7491]);
196        loss_actual.assert_approx_eq(&loss_expected, 3);
197    }
198
199    #[test]
200    fn test_binary_cross_entropy_with_logits() {
201        let device = Default::default();
202        let logits =
203            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
204        let targets =
205            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
206
207        let loss_actual = BinaryCrossEntropyLossConfig::new()
208            .with_logits(true)
209            .init(&device)
210            .forward(logits, targets)
211            .into_data();
212
213        let loss_expected = TensorData::from([0.7491]);
214        loss_actual.assert_approx_eq(&loss_expected, 3);
215    }
216
217    #[test]
218    fn test_binary_cross_entropy_with_weights() {
219        // import torch
220        // from torch import nn
221        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
222        // target = torch.tensor([0, 1, 0, 1])
223        // weights = torch.tensor([3., 7.]).gather(0, target)
224        // loss = nn.BCELoss(weights)
225        // sigmoid = nn.Sigmoid()
226        // out = loss(sigmoid(input), target.float()) # tensor(3.1531)
227
228        let device = Default::default();
229        let logits =
230            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
231        let targets =
232            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
233        let weights = [3., 7.];
234
235        let loss_actual = BinaryCrossEntropyLossConfig::new()
236            .with_weights(Some(weights.to_vec()))
237            .init(&device)
238            .forward(sigmoid(logits), targets)
239            .into_data();
240
241        let loss_expected = TensorData::from([3.1531]);
242        loss_actual.assert_approx_eq(&loss_expected, 3);
243    }
244
245    #[test]
246    fn test_binary_cross_entropy_with_smoothing() {
247        // import torch
248        // from torch import nn
249        // input = torch.tensor([0.8271, 0.9626, 0.3796, 0.2355])
250        // target = torch.tensor([0., 1., 0., 1.])
251        // target_smooth = target * (1 - 0.1) + (0.1 / 2)
252        // loss = nn.BCELoss()
253        // sigmoid = nn.Sigmoid()
254        // out = loss(sigmoid(input), target_smooth) # tensor(0.7490)
255
256        let device = Default::default();
257        let logits =
258            Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
259        let targets =
260            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
261
262        let loss_actual = BinaryCrossEntropyLossConfig::new()
263            .with_smoothing(Some(0.1))
264            .init(&device)
265            .forward(sigmoid(logits), targets)
266            .into_data();
267
268        let loss_expected = TensorData::from([0.7490]);
269        loss_actual.assert_approx_eq(&loss_expected, 3);
270    }
271
272    #[test]
273    fn test_binary_cross_entropy_multilabel() {
274        // import torch
275        // from torch import nn
276        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
277        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
278        // weights = torch.tensor([3., 7., 0.9])
279        // loss = nn.BCEWithLogitsLoss()
280        // out = loss(input, target) # tensor(0.7112)
281
282        let device = Default::default();
283        let logits = Tensor::<TestBackend, 2>::from_floats(
284            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
285            &device,
286        );
287        let targets = Tensor::<TestBackend, 2, Int>::from_data(
288            TensorData::from([[1, 0, 1], [1, 0, 0]]),
289            &device,
290        );
291
292        let loss_actual = BinaryCrossEntropyLossConfig::new()
293            .with_logits(true)
294            .init(&device)
295            .forward(logits, targets)
296            .into_data();
297
298        let loss_expected = TensorData::from([0.7112]);
299        loss_actual.assert_approx_eq(&loss_expected, 3);
300    }
301
302    #[test]
303    fn test_binary_cross_entropy_multilabel_with_weights() {
304        // import torch
305        // from torch import nn
306        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
307        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
308        // loss = nn.BCEWithLogitsLoss()
309        // out = loss(input, target) # tensor(3.1708)
310
311        let device = Default::default();
312        let logits = Tensor::<TestBackend, 2>::from_floats(
313            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
314            &device,
315        );
316        let targets = Tensor::<TestBackend, 2, Int>::from_data(
317            TensorData::from([[1, 0, 1], [1, 0, 0]]),
318            &device,
319        );
320        let weights = [3., 7., 0.9];
321
322        let loss_actual = BinaryCrossEntropyLossConfig::new()
323            .with_logits(true)
324            .with_weights(Some(weights.to_vec()))
325            .init(&device)
326            .forward(logits, targets)
327            .into_data();
328
329        let loss_expected = TensorData::from([3.1708]);
330        loss_actual.assert_approx_eq(&loss_expected, 3);
331    }
332
333    #[test]
334    fn test_binary_cross_entropy_multilabel_with_smoothing() {
335        // import torch
336        // from torch import nn
337        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
338        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
339        // target_smooth = target * (1 - 0.1) + (0.1 / 3)
340        // loss = nn.BCELoss()
341        // sigmoid = nn.Sigmoid()
342        // out = loss(sigmoid(input), target_smooth) # tensor(0.7228)
343
344        let device = Default::default();
345        let logits = Tensor::<TestBackend, 2>::from_floats(
346            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
347            &device,
348        );
349        let targets = Tensor::<TestBackend, 2, Int>::from_data(
350            TensorData::from([[1, 0, 1], [1, 0, 0]]),
351            &device,
352        );
353
354        let loss_actual = BinaryCrossEntropyLossConfig::new()
355            .with_smoothing(Some(0.1))
356            .init(&device)
357            .forward(sigmoid(logits), targets)
358            .into_data();
359
360        let loss_expected = TensorData::from([0.7228]);
361        loss_actual.assert_approx_eq(&loss_expected, 3);
362    }
363
364    #[test]
365    #[should_panic = "The number of classes"]
366    fn multilabel_weights_should_match_target() {
367        // import torch
368        // from torch import nn
369        // input = torch.tensor([[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]])
370        // target = torch.tensor([[1., 0., 1.], [1., 0., 0.]])
371        // loss = nn.BCEWithLogitsLoss()
372        // out = loss(input, target) # tensor(3.1708)
373
374        let device = Default::default();
375        let logits = Tensor::<TestBackend, 2>::from_floats(
376            [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
377            &device,
378        );
379        let targets = Tensor::<TestBackend, 2, Int>::from_data(
380            TensorData::from([[1, 0, 1], [1, 0, 0]]),
381            &device,
382        );
383        let weights = [3., 7.];
384
385        let _loss = BinaryCrossEntropyLossConfig::new()
386            .with_logits(true)
387            .with_weights(Some(weights.to_vec()))
388            .init(&device)
389            .forward(logits, targets);
390    }
391
392    #[test]
393    fn display() {
394        let config =
395            BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
396        let loss = config.init::<TestBackend>(&Default::default());
397
398        assert_eq!(
399            alloc::format!("{}", loss),
400            "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
401        );
402    }
403}