1use crate as burn;
2use crate::module::{Content, DisplaySettings, ModuleDisplay};
3
4use crate::tensor::activation::log_sigmoid;
5use crate::tensor::{Int, Tensor, backend::Backend};
6use crate::{config::Config, module::Module};
7use alloc::vec::Vec;
8
9#[derive(Config, Debug)]
11pub struct BinaryCrossEntropyLossConfig {
12 pub weights: Option<Vec<f32>>,
16
17 pub smoothing: Option<f32>,
22
23 #[config(default = false)]
25 pub logits: bool,
26}
27
28impl BinaryCrossEntropyLossConfig {
29 pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {
31 self.assertions();
32 BinaryCrossEntropyLoss {
33 weights: self
34 .weights
35 .as_ref()
36 .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
37 smoothing: self.smoothing,
38 logits: self.logits,
39 }
40 }
41
42 fn assertions(&self) {
43 if let Some(alpha) = self.smoothing {
44 assert!(
45 (0.0..=1.).contains(&alpha),
46 "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
47 alpha
48 );
49 };
50 if let Some(weights) = self.weights.as_ref() {
51 assert!(
52 weights.iter().all(|e| e > &0.),
53 "Weights of cross-entropy have to be positive."
54 );
55 }
56 }
57}
58
59#[derive(Module, Debug)]
63#[module(custom_display)]
64pub struct BinaryCrossEntropyLoss<B: Backend> {
65 pub weights: Option<Tensor<B, 1>>,
67 pub smoothing: Option<f32>,
69 pub logits: bool,
71}
72
73impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
74 fn custom_settings(&self) -> Option<DisplaySettings> {
75 DisplaySettings::new()
76 .with_new_line_after_attribute(false)
77 .optional()
78 }
79
80 fn custom_content(&self, content: Content) -> Option<Content> {
81 content
82 .add("weights", &self.weights)
83 .add("smoothing", &self.smoothing)
84 .add("logits", &self.logits)
85 .optional()
86 }
87}
88
89impl<B: Backend> BinaryCrossEntropyLoss<B> {
90 pub fn forward<const D: usize>(
102 &self,
103 logits: Tensor<B, D>,
104 targets: Tensor<B, D, Int>,
105 ) -> Tensor<B, 1> {
106 self.assertions(&logits, &targets);
107
108 let mut targets_float = targets.clone().float();
109 let shape = targets.dims();
110
111 if let Some(alpha) = self.smoothing {
112 let num_classes = if D > 1 { shape[D - 1] } else { 2 };
113 targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
114 }
115
116 let mut loss = if self.logits {
117 (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
119 } else {
120 (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
123 - targets_float * logits.log().clamp_min(-100.0)
124 };
125
126 if let Some(weights) = &self.weights {
127 let weights = if D > 1 {
128 weights.clone().expand(shape)
129 } else {
130 weights
133 .clone()
134 .gather(0, targets.flatten(0, 0))
135 .expand(shape)
136 };
137 loss = loss * weights;
138 }
139
140 loss.mean()
141 }
142
143 fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
144 let logits_dims = logits.dims();
145 let targets_dims = targets.dims();
146 assert!(
147 logits_dims == targets_dims,
148 "Shape of targets ({:?}) should correspond to outer shape of logits ({:?}).",
149 targets_dims,
150 logits_dims
151 );
152
153 if let Some(weights) = &self.weights {
154 if D > 1 {
155 let targets_classes = targets_dims[D - 1];
156 let weights_classes = weights.dims()[0];
157 assert!(
158 weights_classes == targets_classes,
159 "The number of classes ({}) does not match the weights provided ({}).",
160 weights_classes,
161 targets_classes
162 );
163 }
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::TestBackend;
172 use crate::tensor::{TensorData, activation::sigmoid};
173 use burn_tensor::{Tolerance, ops::FloatElem};
174 type FT = FloatElem<TestBackend>;
175
176 #[test]
177 fn test_binary_cross_entropy_preds_all_correct() {
178 let device = Default::default();
179 let preds = Tensor::<TestBackend, 1>::from_floats([1.0, 0.0, 1.0, 0.0], &device);
180 let targets =
181 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
182
183 let loss_actual = BinaryCrossEntropyLossConfig::new()
184 .init(&device)
185 .forward(preds, targets)
186 .into_data();
187
188 let loss_expected = TensorData::from([0.000]);
189 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
190 }
191
192 #[test]
193 fn test_binary_cross_entropy_preds_all_incorrect() {
194 let device = Default::default();
195 let preds = Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 0.0, 1.0], &device);
196 let targets =
197 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
198
199 let loss_actual = BinaryCrossEntropyLossConfig::new()
200 .init(&device)
201 .forward(preds, targets)
202 .into_data();
203
204 let loss_expected = TensorData::from([100.000]); loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
206 }
207
208 #[test]
209 fn test_binary_cross_entropy() {
210 let device = Default::default();
219 let logits =
220 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
221 let targets =
222 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
223
224 let loss_actual = BinaryCrossEntropyLossConfig::new()
225 .init(&device)
226 .forward(sigmoid(logits), targets)
227 .into_data();
228
229 let loss_expected = TensorData::from([0.7491]);
230 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
231 }
232
233 #[test]
234 fn test_binary_cross_entropy_with_logits() {
235 let device = Default::default();
236 let logits =
237 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
238 let targets =
239 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
240
241 let loss_actual = BinaryCrossEntropyLossConfig::new()
242 .with_logits(true)
243 .init(&device)
244 .forward(logits, targets)
245 .into_data();
246
247 let loss_expected = TensorData::from([0.7491]);
248 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
249 }
250
251 #[test]
252 fn test_binary_cross_entropy_with_weights() {
253 let device = Default::default();
263 let logits =
264 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
265 let targets =
266 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
267 let weights = [3., 7.];
268
269 let loss_actual = BinaryCrossEntropyLossConfig::new()
270 .with_weights(Some(weights.to_vec()))
271 .init(&device)
272 .forward(sigmoid(logits), targets)
273 .into_data();
274
275 let loss_expected = TensorData::from([3.1531]);
276 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
277 }
278
279 #[test]
280 fn test_binary_cross_entropy_with_smoothing() {
281 let device = Default::default();
291 let logits =
292 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
293 let targets =
294 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
295
296 let loss_actual = BinaryCrossEntropyLossConfig::new()
297 .with_smoothing(Some(0.1))
298 .init(&device)
299 .forward(sigmoid(logits), targets)
300 .into_data();
301
302 let loss_expected = TensorData::from([0.7490]);
303 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
304 }
305
306 #[test]
307 fn test_binary_cross_entropy_multilabel() {
308 let device = Default::default();
317 let logits = Tensor::<TestBackend, 2>::from_floats(
318 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
319 &device,
320 );
321 let targets = Tensor::<TestBackend, 2, Int>::from_data(
322 TensorData::from([[1, 0, 1], [1, 0, 0]]),
323 &device,
324 );
325
326 let loss_actual = BinaryCrossEntropyLossConfig::new()
327 .with_logits(true)
328 .init(&device)
329 .forward(logits, targets)
330 .into_data();
331
332 let loss_expected = TensorData::from([0.7112]);
333 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
334 }
335
336 #[test]
337 fn test_binary_cross_entropy_multilabel_with_weights() {
338 let device = Default::default();
346 let logits = Tensor::<TestBackend, 2>::from_floats(
347 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
348 &device,
349 );
350 let targets = Tensor::<TestBackend, 2, Int>::from_data(
351 TensorData::from([[1, 0, 1], [1, 0, 0]]),
352 &device,
353 );
354 let weights = [3., 7., 0.9];
355
356 let loss_actual = BinaryCrossEntropyLossConfig::new()
357 .with_logits(true)
358 .with_weights(Some(weights.to_vec()))
359 .init(&device)
360 .forward(logits, targets)
361 .into_data();
362
363 let loss_expected = TensorData::from([3.1708]);
364 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
365 }
366
367 #[test]
368 fn test_binary_cross_entropy_multilabel_with_smoothing() {
369 let device = Default::default();
379 let logits = Tensor::<TestBackend, 2>::from_floats(
380 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
381 &device,
382 );
383 let targets = Tensor::<TestBackend, 2, Int>::from_data(
384 TensorData::from([[1, 0, 1], [1, 0, 0]]),
385 &device,
386 );
387
388 let loss_actual = BinaryCrossEntropyLossConfig::new()
389 .with_smoothing(Some(0.1))
390 .init(&device)
391 .forward(sigmoid(logits), targets)
392 .into_data();
393
394 let loss_expected = TensorData::from([0.7228]);
395 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
396 }
397
398 #[test]
399 #[should_panic = "The number of classes"]
400 fn multilabel_weights_should_match_target() {
401 let device = Default::default();
409 let logits = Tensor::<TestBackend, 2>::from_floats(
410 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
411 &device,
412 );
413 let targets = Tensor::<TestBackend, 2, Int>::from_data(
414 TensorData::from([[1, 0, 1], [1, 0, 0]]),
415 &device,
416 );
417 let weights = [3., 7.];
418
419 let _loss = BinaryCrossEntropyLossConfig::new()
420 .with_logits(true)
421 .with_weights(Some(weights.to_vec()))
422 .init(&device)
423 .forward(logits, targets);
424 }
425
426 #[test]
427 fn display() {
428 let config =
429 BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
430 let loss = config.init::<TestBackend>(&Default::default());
431
432 assert_eq!(
433 alloc::format!("{}", loss),
434 "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
435 );
436 }
437}