1use burn_core as burn;
2
3use alloc::vec::Vec;
4use burn::module::{Content, DisplaySettings, ModuleDisplay};
5use burn::tensor::activation::log_sigmoid;
6use burn::tensor::{Int, Tensor, backend::Backend};
7use burn::{config::Config, module::Module};
8
9#[derive(Config, Debug)]
11pub struct BinaryCrossEntropyLossConfig {
12 pub weights: Option<Vec<f32>>,
16
17 pub smoothing: Option<f32>,
22
23 #[config(default = false)]
25 pub logits: bool,
26}
27
28impl BinaryCrossEntropyLossConfig {
29 pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {
31 self.assertions();
32 BinaryCrossEntropyLoss {
33 weights: self
34 .weights
35 .as_ref()
36 .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
37 smoothing: self.smoothing,
38 logits: self.logits,
39 }
40 }
41
42 fn assertions(&self) {
43 if let Some(alpha) = self.smoothing {
44 assert!(
45 (0.0..=1.).contains(&alpha),
46 "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
47 );
48 };
49 if let Some(weights) = self.weights.as_ref() {
50 assert!(
51 weights.iter().all(|e| e > &0.),
52 "Weights of cross-entropy have to be positive."
53 );
54 }
55 }
56}
57
58#[derive(Module, Debug)]
62#[module(custom_display)]
63pub struct BinaryCrossEntropyLoss<B: Backend> {
64 pub weights: Option<Tensor<B, 1>>,
66 pub smoothing: Option<f32>,
68 pub logits: bool,
70}
71
72impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
73 fn custom_settings(&self) -> Option<DisplaySettings> {
74 DisplaySettings::new()
75 .with_new_line_after_attribute(false)
76 .optional()
77 }
78
79 fn custom_content(&self, content: Content) -> Option<Content> {
80 content
81 .add("weights", &self.weights)
82 .add("smoothing", &self.smoothing)
83 .add("logits", &self.logits)
84 .optional()
85 }
86}
87
88impl<B: Backend> BinaryCrossEntropyLoss<B> {
89 pub fn forward<const D: usize>(
101 &self,
102 logits: Tensor<B, D>,
103 targets: Tensor<B, D, Int>,
104 ) -> Tensor<B, 1> {
105 self.assertions(&logits, &targets);
106
107 let mut targets_float = targets.clone().float();
108 let shape = targets.dims();
109
110 if let Some(alpha) = self.smoothing {
111 let num_classes = if D > 1 { shape[D - 1] } else { 2 };
112 targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
113 }
114
115 let mut loss = if self.logits {
116 (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
118 } else {
119 (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
122 - targets_float * logits.log().clamp_min(-100.0)
123 };
124
125 if let Some(weights) = &self.weights {
126 let weights = if D > 1 {
127 weights.clone().expand(shape)
128 } else {
129 weights
132 .clone()
133 .gather(0, targets.flatten(0, 0))
134 .expand(shape)
135 };
136 loss = loss * weights;
137 }
138
139 loss.mean()
140 }
141
142 fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
143 let logits_dims = logits.dims();
144 let targets_dims = targets.dims();
145 assert!(
146 logits_dims == targets_dims,
147 "Shape of targets ({targets_dims:?}) should correspond to outer shape of logits ({logits_dims:?})."
148 );
149
150 if let Some(weights) = &self.weights
151 && D > 1
152 {
153 let targets_classes = targets_dims[D - 1];
154 let weights_classes = weights.dims()[0];
155 assert!(
156 weights_classes == targets_classes,
157 "The number of classes ({weights_classes}) does not match the weights provided ({targets_classes})."
158 );
159 }
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::TestBackend;
167 use burn::tensor::{TensorData, activation::sigmoid};
168 use burn::tensor::{Tolerance, ops::FloatElem};
169 type FT = FloatElem<TestBackend>;
170
171 #[test]
172 fn test_binary_cross_entropy_preds_all_correct() {
173 let device = Default::default();
174 let preds = Tensor::<TestBackend, 1>::from_floats([1.0, 0.0, 1.0, 0.0], &device);
175 let targets =
176 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
177
178 let loss_actual = BinaryCrossEntropyLossConfig::new()
179 .init(&device)
180 .forward(preds, targets)
181 .into_data();
182
183 let loss_expected = TensorData::from([0.000]);
184 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
185 }
186
187 #[test]
188 fn test_binary_cross_entropy_preds_all_incorrect() {
189 let device = Default::default();
190 let preds = Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 0.0, 1.0], &device);
191 let targets =
192 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);
193
194 let loss_actual = BinaryCrossEntropyLossConfig::new()
195 .init(&device)
196 .forward(preds, targets)
197 .into_data();
198
199 let loss_expected = TensorData::from([100.000]); loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
201 }
202
203 #[test]
204 fn test_binary_cross_entropy() {
205 let device = Default::default();
214 let logits =
215 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
216 let targets =
217 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
218
219 let loss_actual = BinaryCrossEntropyLossConfig::new()
220 .init(&device)
221 .forward(sigmoid(logits), targets)
222 .into_data();
223
224 let loss_expected = TensorData::from([0.7491]);
225 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
226 }
227
228 #[test]
229 fn test_binary_cross_entropy_with_logits() {
230 let device = Default::default();
231 let logits =
232 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
233 let targets =
234 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
235
236 let loss_actual = BinaryCrossEntropyLossConfig::new()
237 .with_logits(true)
238 .init(&device)
239 .forward(logits, targets)
240 .into_data();
241
242 let loss_expected = TensorData::from([0.7491]);
243 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
244 }
245
246 #[test]
247 fn test_binary_cross_entropy_with_weights() {
248 let device = Default::default();
258 let logits =
259 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
260 let targets =
261 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
262 let weights = [3., 7.];
263
264 let loss_actual = BinaryCrossEntropyLossConfig::new()
265 .with_weights(Some(weights.to_vec()))
266 .init(&device)
267 .forward(sigmoid(logits), targets)
268 .into_data();
269
270 let loss_expected = TensorData::from([3.1531]);
271 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
272 }
273
274 #[test]
275 fn test_binary_cross_entropy_with_smoothing() {
276 let device = Default::default();
286 let logits =
287 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
288 let targets =
289 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
290
291 let loss_actual = BinaryCrossEntropyLossConfig::new()
292 .with_smoothing(Some(0.1))
293 .init(&device)
294 .forward(sigmoid(logits), targets)
295 .into_data();
296
297 let loss_expected = TensorData::from([0.7490]);
298 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
299 }
300
301 #[test]
302 fn test_binary_cross_entropy_multilabel() {
303 let device = Default::default();
312 let logits = Tensor::<TestBackend, 2>::from_floats(
313 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
314 &device,
315 );
316 let targets = Tensor::<TestBackend, 2, Int>::from_data(
317 TensorData::from([[1, 0, 1], [1, 0, 0]]),
318 &device,
319 );
320
321 let loss_actual = BinaryCrossEntropyLossConfig::new()
322 .with_logits(true)
323 .init(&device)
324 .forward(logits, targets)
325 .into_data();
326
327 let loss_expected = TensorData::from([0.7112]);
328 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::relative(1e-4));
329 }
330
331 #[test]
332 fn test_binary_cross_entropy_multilabel_with_weights() {
333 let device = Default::default();
341 let logits = Tensor::<TestBackend, 2>::from_floats(
342 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
343 &device,
344 );
345 let targets = Tensor::<TestBackend, 2, Int>::from_data(
346 TensorData::from([[1, 0, 1], [1, 0, 0]]),
347 &device,
348 );
349 let weights = [3., 7., 0.9];
350
351 let loss_actual = BinaryCrossEntropyLossConfig::new()
352 .with_logits(true)
353 .with_weights(Some(weights.to_vec()))
354 .init(&device)
355 .forward(logits, targets)
356 .into_data();
357
358 let loss_expected = TensorData::from([3.1708]);
359 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
360 }
361
362 #[test]
363 fn test_binary_cross_entropy_multilabel_with_smoothing() {
364 let device = Default::default();
374 let logits = Tensor::<TestBackend, 2>::from_floats(
375 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
376 &device,
377 );
378 let targets = Tensor::<TestBackend, 2, Int>::from_data(
379 TensorData::from([[1, 0, 1], [1, 0, 0]]),
380 &device,
381 );
382
383 let loss_actual = BinaryCrossEntropyLossConfig::new()
384 .with_smoothing(Some(0.1))
385 .init(&device)
386 .forward(sigmoid(logits), targets)
387 .into_data();
388
389 let loss_expected = TensorData::from([0.7228]);
390 loss_actual.assert_approx_eq::<FT>(&loss_expected, Tolerance::default());
391 }
392
393 #[test]
394 #[should_panic = "The number of classes"]
395 fn multilabel_weights_should_match_target() {
396 let device = Default::default();
404 let logits = Tensor::<TestBackend, 2>::from_floats(
405 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
406 &device,
407 );
408 let targets = Tensor::<TestBackend, 2, Int>::from_data(
409 TensorData::from([[1, 0, 1], [1, 0, 0]]),
410 &device,
411 );
412 let weights = [3., 7.];
413
414 let _loss = BinaryCrossEntropyLossConfig::new()
415 .with_logits(true)
416 .with_weights(Some(weights.to_vec()))
417 .init(&device)
418 .forward(logits, targets);
419 }
420
421 #[test]
422 fn display() {
423 let config =
424 BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
425 let loss = config.init::<TestBackend>(&Default::default());
426
427 assert_eq!(
428 alloc::format!("{loss}"),
429 "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
430 );
431 }
432}