1use burn_core as burn;
2use burn_core::tensor::IndexingUpdateOp;
3
4use alloc::string::ToString;
5use alloc::vec;
6use alloc::vec::Vec;
7use burn::module::{Content, DisplaySettings, ModuleDisplay};
8use burn::tensor::activation::log_softmax;
9use burn::tensor::{Bool, Int, Tensor, backend::Backend};
10use burn::{config::Config, module::Module};
11
12#[cfg(not(feature = "std"))]
13#[allow(unused_imports)]
14use num_traits::Float;
15
16#[derive(Config, Debug)]
18pub struct CrossEntropyLossConfig {
19 pub pad_tokens: Option<Vec<usize>>,
23
24 pub weights: Option<Vec<f32>>,
32
33 pub smoothing: Option<f32>,
38
39 #[config(default = true)]
42 pub logits: bool,
43}
44
45impl CrossEntropyLossConfig {
46 pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
48 self.assertions();
49 CrossEntropyLoss {
50 pad_tokens: self.pad_tokens.clone(),
51 weights: self
52 .weights
53 .as_ref()
54 .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
55 smoothing: self.smoothing,
56 logits: self.logits,
57 }
58 }
59
60 fn assertions(&self) {
61 if let Some(alpha) = self.smoothing {
62 assert!(
63 (0.0..=1.).contains(&alpha),
64 "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
65 );
66 };
67 if let Some(weights) = self.weights.as_ref() {
68 assert!(
69 weights.iter().all(|e| e > &0.),
70 "Weights of cross-entropy have to be positive."
71 );
72 }
73 }
74}
75
76#[derive(Module, Debug)]
80#[module(custom_display)]
81pub struct CrossEntropyLoss<B: Backend> {
82 pub pad_tokens: Option<Vec<usize>>,
84 pub weights: Option<Tensor<B, 1>>,
86 pub smoothing: Option<f32>,
88 pub logits: bool,
90}
91
92impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
93 fn custom_settings(&self) -> Option<DisplaySettings> {
94 DisplaySettings::new()
95 .with_new_line_after_attribute(false)
96 .optional()
97 }
98
99 fn custom_content(&self, content: Content) -> Option<Content> {
100 let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
101 alloc::format!("Vec<0..{}>", pad_tokens.len())
102 } else {
103 "None".to_string()
104 };
105
106 content
107 .add("pad_tokens", &pad_tokens)
108 .add("weights", &self.weights)
109 .add("smoothing", &self.smoothing)
110 .add("logits", &self.logits)
111 .optional()
112 }
113}
114
115impl<B: Backend> CrossEntropyLoss<B> {
116 pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
118 CrossEntropyLossConfig::new()
119 .with_pad_tokens(pad_index.map(|e| vec![e]))
120 .init(device)
121 }
122
123 pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
130 Self::assertions(logits.clone(), targets.clone());
131 match self.smoothing {
132 Some(alpha) => self.forward_smoothed(logits, targets, alpha),
133 _ => self.forward_default(logits, targets),
134 }
135 }
136
137 fn forward_smoothed(
138 &self,
139 logits: Tensor<B, 2>,
140 targets: Tensor<B, 1, Int>,
141 alpha: f32,
142 ) -> Tensor<B, 1> {
143 let mask = self.padding_mask(&targets);
144 let tensor = if self.logits {
145 log_softmax(logits, 1)
146 } else {
147 logits.log()
148 };
149 let [batch_size, nr_classes] = tensor.dims();
150 let tensor = tensor
151 * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
152
153 match &self.weights {
154 Some(weights) => {
155 let tensor = tensor
156 * weights
157 .clone()
158 .reshape([1, nr_classes])
159 .repeat_dim(0, batch_size);
160 let weights = weights.clone().gather(0, targets);
161 let tensor = Self::apply_mask_2d(tensor, mask);
162 tensor.sum().neg() / weights.sum()
163 }
164 None => {
165 let tensor = Self::apply_mask_2d(tensor, mask);
166 tensor.sum_dim(1).mean().neg()
167 }
168 }
169 }
170
171 fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
172 let [batch_size] = targets.dims();
173
174 let mask = self.padding_mask(&targets);
175 let target_indices = targets.clone().reshape([batch_size, 1]);
176 let tensor = if self.logits {
177 log_softmax(logits, 1).gather(1, target_indices)
178 } else {
179 let finfo = logits.dtype().finfo().unwrap();
181 let eps = finfo.min_positive.sqrt();
182 logits.clamp_min(eps).gather(1, target_indices).log()
183 };
184
185 match &self.weights {
186 Some(weights) => {
187 let weights = weights.clone().gather(0, targets);
188 let tensor = tensor.reshape([batch_size]) * weights.clone();
189 let tensor = Self::apply_mask_1d(tensor, mask);
190 tensor.sum().neg() / weights.sum()
191 }
192 None => {
193 let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
194 tensor.mean().neg()
195 }
196 }
197 }
198
199 fn compute_smoothed_targets(
200 shape: [usize; 2],
201 targets: Tensor<B, 1, Int>,
202 alpha: f32,
203 ) -> Tensor<B, 2> {
204 let [batch_size, nr_classes] = shape;
205 let device = &targets.device();
206 let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
207 1,
208 targets.reshape([batch_size, 1]),
209 Tensor::ones([batch_size, 1], device),
210 IndexingUpdateOp::Add,
211 );
212 targets_matrix * (1. - alpha) + alpha / nr_classes as f32
213 }
214
215 fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
216 let mut mask = None;
217 if let Some(pad_tokens) = &self.pad_tokens {
218 let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
219 for x in pad_tokens {
220 res = res + targets.clone().equal_elem(*x as i64).int();
221 }
222 mask = Some(res.greater_elem(0));
223 }
224
225 mask
226 }
227
228 fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
229 if let Some(mask) = mask {
230 tensor = tensor.mask_fill(mask, 0);
231 }
232
233 tensor
234 }
235
236 fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
237 if let Some(mask) = mask {
238 let [batch_size, nr_classes] = tensor.dims();
239 tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
240 }
241
242 tensor
243 }
244
245 fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
246 let [logits_height, _] = logits.dims();
247 let [targets_height] = targets.dims();
248 assert!(
249 logits_height == targets_height,
250 "Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})."
251 );
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use crate::TestBackend;
259 use burn::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
260 use burn::tensor::{Tolerance, ops::FloatElem};
261 type FT = FloatElem<TestBackend>;
262
263 macro_rules! setup {
264 () => {{
265 let [batch_size, num_targets] = [4, 5];
266 let device = Default::default();
267 let logits = Tensor::<TestBackend, 2>::random(
268 [batch_size, num_targets],
269 Distribution::Normal(0., 1.0),
270 &device,
271 );
272 let targets =
273 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
274 let targets_logits = Tensor::<TestBackend, 2>::from_data(
275 TensorData::from([
276 [0.0, 0.0, 1.0, 0.0, 0.0],
277 [1.0, 0.0, 0.0, 0.0, 0.0],
278 [0.0, 0.0, 0.0, 0.0, 1.0],
279 [0.0, 1.0, 0.0, 0.0, 0.0],
280 ]),
281 &device,
282 );
283 (logits, targets, targets_logits)
284 }};
285 }
286
287 macro_rules! setup_padded {
288 () => {{
289 let [batch_size, num_targets, pad_index] = [4, 5, 1];
290 let device = Default::default();
291 let logits = Tensor::<TestBackend, 2>::random(
292 [batch_size, num_targets],
293 Distribution::Normal(0., 1.0),
294 &device,
295 );
296 let targets = Tensor::<TestBackend, 1, Int>::from_data(
297 TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
298 &device,
299 );
300 let targets_logits = Tensor::<TestBackend, 2>::from_data(
301 TensorData::from([
302 [0.0, 0.0, 0.0, 0.0, 0.0],
303 [1.0, 0.0, 0.0, 0.0, 0.0],
304 [0.0, 0.0, 0.0, 0.0, 1.0],
305 [0.0, 0.0, 0.0, 0.0, 0.0],
306 ]),
307 &device,
308 );
309 (logits, targets, targets_logits)
310 }};
311 }
312
313 #[test]
314 fn test_cross_entropy_loss_with_weights() {
315 let (logits, targets, targets_logits) = setup!();
316 let weights = vec![1.0, 2., 3., 4., 5.];
317 let device = Default::default();
318 let loss_1 = CrossEntropyLossConfig::new()
319 .with_weights(Some(weights.clone()))
320 .init(&device)
321 .forward(logits.clone(), targets);
322 let tensor = log_softmax(logits, 1);
323 let loss_2 = tensor
324 * targets_logits
325 * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
326 .unsqueeze()
327 .repeat_dim(0, 4);
328 let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
329 loss_1
330 .into_data()
331 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
332 }
333
334 #[test]
335 fn test_label_smoothing_with_weights_and_alpha_zero() {
336 let (logits, targets, _) = setup!();
337 let device = Default::default();
338 let weights = vec![1.0, 2., 3., 4., 5.];
339 let loss_1 = CrossEntropyLossConfig::new()
340 .with_weights(Some(weights.clone()))
341 .init(&device)
342 .forward(logits.clone(), targets.clone());
343 let loss_2 = CrossEntropyLossConfig::new()
344 .with_weights(Some(weights.clone()))
345 .with_smoothing(Some(0.))
346 .init(&device)
347 .forward(logits.clone(), targets);
348 loss_1
349 .into_data()
350 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
351 }
352
353 #[test]
354 fn test_cross_entropy_loss() {
355 let (logits, targets, targets_logits) = setup!();
356 let device = Default::default();
357 let loss_1 = CrossEntropyLossConfig::new()
358 .init(&device)
359 .forward(logits.clone(), targets);
360 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
361
362 loss_1
363 .into_data()
364 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
365 }
366
367 #[test]
368 fn test_label_smoothing_alpha_equal_zero() {
369 let (logits, targets, _) = setup!();
370 let device = Default::default();
371 let loss_1 = CrossEntropyLossConfig::new()
372 .init(&device)
373 .forward(logits.clone(), targets.clone());
374 let loss_2 = CrossEntropyLossConfig::new()
375 .with_smoothing(Some(0.))
376 .init(&device)
377 .forward(logits, targets);
378
379 loss_1
380 .into_data()
381 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
382 }
383
384 #[test]
385 fn test_cross_entropy_loss_with_pad_token() {
386 let (logits, targets, targets_logits) = setup_padded!();
387 let pad_index = 1;
388
389 let loss_1 = CrossEntropyLossConfig::new()
390 .with_pad_tokens(Some(vec![pad_index, 2]))
391 .init(&logits.device())
392 .forward(logits.clone(), targets);
393 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
394
395 loss_1
396 .into_data()
397 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
398 }
399
400 #[test]
401 fn test_label_smoothing_with_zero_alpha_and_pad_token() {
402 let (logits, targets, _) = setup_padded!();
403 let pad_index = 1;
404
405 let loss_1 = CrossEntropyLossConfig::new()
406 .with_pad_tokens(Some(vec![pad_index, 2]))
407 .init(&logits.device())
408 .forward(logits.clone(), targets.clone());
409 let loss_2 = CrossEntropyLossConfig::new()
410 .with_pad_tokens(Some(vec![pad_index, 2]))
411 .with_smoothing(Some(0.))
412 .init(&logits.device())
413 .forward(logits.clone(), targets);
414
415 loss_1
416 .into_data()
417 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
418 }
419
420 #[test]
421 fn test_label_smoothing_target_conversion() {
422 let (logits, targets, _) = setup!();
423 let smoothed_targets =
424 CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
425 let targets_logits = Tensor::<TestBackend, 2>::from_data(
426 TensorData::from([
427 [0.01, 0.01, 0.96, 0.01, 0.01],
428 [0.96, 0.01, 0.01, 0.01, 0.01],
429 [0.01, 0.01, 0.01, 0.01, 0.96],
430 [0.01, 0.96, 0.01, 0.01, 0.01],
431 ]),
432 &Default::default(),
433 );
434 smoothed_targets
435 .into_data()
436 .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
437 }
438
439 #[test]
440 fn test_label_smoothing() {
441 let (logits, targets, _) = setup!();
442 let device = Default::default();
443 let loss_1 = CrossEntropyLossConfig::new()
444 .with_smoothing(Some(0.05))
445 .init(&device)
446 .forward(logits.clone(), targets);
447 let targets_logits = Tensor::<TestBackend, 2>::from_data(
448 TensorData::from([
449 [0.01, 0.01, 0.96, 0.01, 0.01],
450 [0.96, 0.01, 0.01, 0.01, 0.01],
451 [0.01, 0.01, 0.01, 0.01, 0.96],
452 [0.01, 0.96, 0.01, 0.01, 0.01],
453 ]),
454 &device,
455 );
456
457 let x = log_softmax(logits, 1);
458 let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
459
460 loss_1
461 .into_data()
462 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
463 }
464
465 #[test]
466 fn test_logits_flag_affects_output() {
467 let device = Default::default();
468
469 let probs = Tensor::<TestBackend, 2>::from_data(
470 TensorData::from([
471 [0.1, 0.2, 0.7, 0.0, 0.0],
472 [0.7, 0.1, 0.1, 0.1, 0.0],
473 [0.2, 0.2, 0.2, 0.2, 0.2],
474 [0.0, 0.3, 0.3, 0.2, 0.2],
475 ]),
476 &device,
477 );
478
479 let targets =
480 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
481
482 let loss_logits = CrossEntropyLossConfig::new()
483 .init(&device)
484 .forward(probs.clone(), targets.clone());
485
486 let loss_probs = CrossEntropyLossConfig::new()
487 .with_logits(false)
488 .init(&device)
489 .forward(probs, targets);
490
491 let loss_logits = loss_logits.into_data();
493 let loss_probs = loss_probs.into_data();
494
495 loss_logits.assert_approx_eq::<f32>(&TensorData::from([1.354197]), Tolerance::default());
496 loss_probs.assert_approx_eq::<f32>(&TensorData::from([0.88169014]), Tolerance::default());
497
498 assert_ne!(
499 loss_logits.as_slice::<f32>().unwrap(),
500 loss_probs.as_slice::<f32>().unwrap(),
501 "logits flag should change computation (log_softmax vs log)"
502 );
503 }
504
505 #[test]
506 fn display() {
507 let config = CrossEntropyLossConfig::new()
508 .with_weights(Some(alloc::vec![3., 7., 0.9]))
509 .with_smoothing(Some(0.5));
510 let loss = config.init::<TestBackend>(&Default::default());
511
512 assert_eq!(
513 alloc::format!("{loss}"),
514 "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
515 );
516 }
517
518 }