1use crate as burn;
2use crate::module::{Content, DisplaySettings, ModuleDisplay};
3
4use crate::tensor::activation::log_sigmoid;
5use crate::tensor::{backend::Backend, Int, Tensor};
6use crate::{config::Config, module::Module};
7use alloc::vec::Vec;
8
9#[derive(Config, Debug)]
11pub struct BinaryCrossEntropyLossConfig {
12 pub weights: Option<Vec<f32>>,
16
17 pub smoothing: Option<f32>,
22
23 #[config(default = false)]
25 pub logits: bool,
26}
27
28impl BinaryCrossEntropyLossConfig {
29 pub fn init<B: Backend>(&self, device: &B::Device) -> BinaryCrossEntropyLoss<B> {
31 self.assertions();
32 BinaryCrossEntropyLoss {
33 weights: self
34 .weights
35 .as_ref()
36 .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
37 smoothing: self.smoothing,
38 logits: self.logits,
39 }
40 }
41
42 fn assertions(&self) {
43 if let Some(alpha) = self.smoothing {
44 assert!(
45 (0.0..=1.).contains(&alpha),
46 "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
47 alpha
48 );
49 };
50 if let Some(weights) = self.weights.as_ref() {
51 assert!(
52 weights.iter().all(|e| e > &0.),
53 "Weights of cross-entropy have to be positive."
54 );
55 }
56 }
57}
58
59#[derive(Module, Debug)]
63#[module(custom_display)]
64pub struct BinaryCrossEntropyLoss<B: Backend> {
65 pub weights: Option<Tensor<B, 1>>,
67 pub smoothing: Option<f32>,
69 pub logits: bool,
71}
72
73impl<B: Backend> ModuleDisplay for BinaryCrossEntropyLoss<B> {
74 fn custom_settings(&self) -> Option<DisplaySettings> {
75 DisplaySettings::new()
76 .with_new_line_after_attribute(false)
77 .optional()
78 }
79
80 fn custom_content(&self, content: Content) -> Option<Content> {
81 content
82 .add("weights", &self.weights)
83 .add("smoothing", &self.smoothing)
84 .add("logits", &self.logits)
85 .optional()
86 }
87}
88
89impl<B: Backend> BinaryCrossEntropyLoss<B> {
90 pub fn forward<const D: usize>(
102 &self,
103 logits: Tensor<B, D>,
104 targets: Tensor<B, D, Int>,
105 ) -> Tensor<B, 1> {
106 self.assertions(&logits, &targets);
107
108 let mut targets_float = targets.clone().float();
109 let shape = targets.dims();
110
111 if let Some(alpha) = self.smoothing {
112 let num_classes = if D > 1 { shape[D - 1] } else { 2 };
113 targets_float = targets_float * (1. - alpha) + alpha / num_classes as f32;
114 }
115
116 let mut loss = if self.logits {
117 (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
119 } else {
120 (targets_float.clone() * logits.clone().log()
122 + (targets_float.neg() + 1.) * (logits.neg() + 1.).log())
123 .neg()
124 };
125
126 if let Some(weights) = &self.weights {
127 let weights = if D > 1 {
128 weights.clone().expand(shape)
129 } else {
130 weights
133 .clone()
134 .gather(0, targets.flatten(0, 0))
135 .expand(shape)
136 };
137 loss = loss * weights;
138 }
139
140 loss.mean()
141 }
142
143 fn assertions<const D: usize>(&self, logits: &Tensor<B, D>, targets: &Tensor<B, D, Int>) {
144 let logits_dims = logits.dims();
145 let targets_dims = targets.dims();
146 assert!(
147 logits_dims == targets_dims,
148 "Shape of targets ({:?}) should correspond to outer shape of logits ({:?}).",
149 targets_dims,
150 logits_dims
151 );
152
153 if let Some(weights) = &self.weights {
154 if D > 1 {
155 let targets_classes = targets_dims[D - 1];
156 let weights_classes = weights.dims()[0];
157 assert!(
158 weights_classes == targets_classes,
159 "The number of classes ({}) does not match the weights provided ({}).",
160 weights_classes,
161 targets_classes
162 );
163 }
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use crate::tensor::{activation::sigmoid, TensorData};
172 use crate::TestBackend;
173
174 #[test]
175 fn test_binary_cross_entropy() {
176 let device = Default::default();
185 let logits =
186 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
187 let targets =
188 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
189
190 let loss_actual = BinaryCrossEntropyLossConfig::new()
191 .init(&device)
192 .forward(sigmoid(logits), targets)
193 .into_data();
194
195 let loss_expected = TensorData::from([0.7491]);
196 loss_actual.assert_approx_eq(&loss_expected, 3);
197 }
198
199 #[test]
200 fn test_binary_cross_entropy_with_logits() {
201 let device = Default::default();
202 let logits =
203 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
204 let targets =
205 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
206
207 let loss_actual = BinaryCrossEntropyLossConfig::new()
208 .with_logits(true)
209 .init(&device)
210 .forward(logits, targets)
211 .into_data();
212
213 let loss_expected = TensorData::from([0.7491]);
214 loss_actual.assert_approx_eq(&loss_expected, 3);
215 }
216
217 #[test]
218 fn test_binary_cross_entropy_with_weights() {
219 let device = Default::default();
229 let logits =
230 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
231 let targets =
232 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
233 let weights = [3., 7.];
234
235 let loss_actual = BinaryCrossEntropyLossConfig::new()
236 .with_weights(Some(weights.to_vec()))
237 .init(&device)
238 .forward(sigmoid(logits), targets)
239 .into_data();
240
241 let loss_expected = TensorData::from([3.1531]);
242 loss_actual.assert_approx_eq(&loss_expected, 3);
243 }
244
245 #[test]
246 fn test_binary_cross_entropy_with_smoothing() {
247 let device = Default::default();
257 let logits =
258 Tensor::<TestBackend, 1>::from_floats([0.8271, 0.9626, 0.3796, 0.2355], &device);
259 let targets =
260 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([0, 1, 0, 1]), &device);
261
262 let loss_actual = BinaryCrossEntropyLossConfig::new()
263 .with_smoothing(Some(0.1))
264 .init(&device)
265 .forward(sigmoid(logits), targets)
266 .into_data();
267
268 let loss_expected = TensorData::from([0.7490]);
269 loss_actual.assert_approx_eq(&loss_expected, 3);
270 }
271
272 #[test]
273 fn test_binary_cross_entropy_multilabel() {
274 let device = Default::default();
283 let logits = Tensor::<TestBackend, 2>::from_floats(
284 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
285 &device,
286 );
287 let targets = Tensor::<TestBackend, 2, Int>::from_data(
288 TensorData::from([[1, 0, 1], [1, 0, 0]]),
289 &device,
290 );
291
292 let loss_actual = BinaryCrossEntropyLossConfig::new()
293 .with_logits(true)
294 .init(&device)
295 .forward(logits, targets)
296 .into_data();
297
298 let loss_expected = TensorData::from([0.7112]);
299 loss_actual.assert_approx_eq(&loss_expected, 3);
300 }
301
302 #[test]
303 fn test_binary_cross_entropy_multilabel_with_weights() {
304 let device = Default::default();
312 let logits = Tensor::<TestBackend, 2>::from_floats(
313 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
314 &device,
315 );
316 let targets = Tensor::<TestBackend, 2, Int>::from_data(
317 TensorData::from([[1, 0, 1], [1, 0, 0]]),
318 &device,
319 );
320 let weights = [3., 7., 0.9];
321
322 let loss_actual = BinaryCrossEntropyLossConfig::new()
323 .with_logits(true)
324 .with_weights(Some(weights.to_vec()))
325 .init(&device)
326 .forward(logits, targets)
327 .into_data();
328
329 let loss_expected = TensorData::from([3.1708]);
330 loss_actual.assert_approx_eq(&loss_expected, 3);
331 }
332
333 #[test]
334 fn test_binary_cross_entropy_multilabel_with_smoothing() {
335 let device = Default::default();
345 let logits = Tensor::<TestBackend, 2>::from_floats(
346 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
347 &device,
348 );
349 let targets = Tensor::<TestBackend, 2, Int>::from_data(
350 TensorData::from([[1, 0, 1], [1, 0, 0]]),
351 &device,
352 );
353
354 let loss_actual = BinaryCrossEntropyLossConfig::new()
355 .with_smoothing(Some(0.1))
356 .init(&device)
357 .forward(sigmoid(logits), targets)
358 .into_data();
359
360 let loss_expected = TensorData::from([0.7228]);
361 loss_actual.assert_approx_eq(&loss_expected, 3);
362 }
363
364 #[test]
365 #[should_panic = "The number of classes"]
366 fn multilabel_weights_should_match_target() {
367 let device = Default::default();
375 let logits = Tensor::<TestBackend, 2>::from_floats(
376 [[0.5150, 0.3097, 0.7556], [0.4974, 0.9879, 0.1564]],
377 &device,
378 );
379 let targets = Tensor::<TestBackend, 2, Int>::from_data(
380 TensorData::from([[1, 0, 1], [1, 0, 0]]),
381 &device,
382 );
383 let weights = [3., 7.];
384
385 let _loss = BinaryCrossEntropyLossConfig::new()
386 .with_logits(true)
387 .with_weights(Some(weights.to_vec()))
388 .init(&device)
389 .forward(logits, targets);
390 }
391
392 #[test]
393 fn display() {
394 let config =
395 BinaryCrossEntropyLossConfig::new().with_weights(Some(alloc::vec![3., 7., 0.9]));
396 let loss = config.init::<TestBackend>(&Default::default());
397
398 assert_eq!(
399 alloc::format!("{}", loss),
400 "BinaryCrossEntropyLoss {weights: Tensor {rank: 1, shape: [3]}, smoothing: None, logits: false}"
401 );
402 }
403}