1use burn_core as burn;
2use burn_core::tensor::IndexingUpdateOp;
3
4use alloc::string::ToString;
5use alloc::vec;
6use alloc::vec::Vec;
7use burn::module::{Content, DisplaySettings, ModuleDisplay};
8use burn::tensor::activation::log_softmax;
9use burn::tensor::{Bool, Int, Tensor, backend::Backend};
10use burn::{config::Config, module::Module};
11
12#[derive(Config, Debug)]
14pub struct CrossEntropyLossConfig {
15 pub pad_tokens: Option<Vec<usize>>,
19
20 pub weights: Option<Vec<f32>>,
28
29 pub smoothing: Option<f32>,
34
35 #[config(default = true)]
38 pub logits: bool,
39}
40
41impl CrossEntropyLossConfig {
42 pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
44 self.assertions();
45 CrossEntropyLoss {
46 pad_tokens: self.pad_tokens.clone(),
47 weights: self
48 .weights
49 .as_ref()
50 .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
51 smoothing: self.smoothing,
52 logits: self.logits,
53 }
54 }
55
56 fn assertions(&self) {
57 if let Some(alpha) = self.smoothing {
58 assert!(
59 (0.0..=1.).contains(&alpha),
60 "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
61 );
62 };
63 if let Some(weights) = self.weights.as_ref() {
64 assert!(
65 weights.iter().all(|e| e > &0.),
66 "Weights of cross-entropy have to be positive."
67 );
68 }
69 }
70}
71
72#[derive(Module, Debug)]
76#[module(custom_display)]
77pub struct CrossEntropyLoss<B: Backend> {
78 pub pad_tokens: Option<Vec<usize>>,
80 pub weights: Option<Tensor<B, 1>>,
82 pub smoothing: Option<f32>,
84 pub logits: bool,
86}
87
88impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
89 fn custom_settings(&self) -> Option<DisplaySettings> {
90 DisplaySettings::new()
91 .with_new_line_after_attribute(false)
92 .optional()
93 }
94
95 fn custom_content(&self, content: Content) -> Option<Content> {
96 let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
97 alloc::format!("Vec<0..{}>", pad_tokens.len())
98 } else {
99 "None".to_string()
100 };
101
102 content
103 .add("pad_tokens", &pad_tokens)
104 .add("weights", &self.weights)
105 .add("smoothing", &self.smoothing)
106 .add("logits", &self.logits)
107 .optional()
108 }
109}
110
111impl<B: Backend> CrossEntropyLoss<B> {
112 pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
114 CrossEntropyLossConfig::new()
115 .with_pad_tokens(pad_index.map(|e| vec![e]))
116 .init(device)
117 }
118
119 pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
126 Self::assertions(logits.clone(), targets.clone());
127 match self.smoothing {
128 Some(alpha) => self.forward_smoothed(logits, targets, alpha),
129 _ => self.forward_default(logits, targets),
130 }
131 }
132
133 fn forward_smoothed(
134 &self,
135 logits: Tensor<B, 2>,
136 targets: Tensor<B, 1, Int>,
137 alpha: f32,
138 ) -> Tensor<B, 1> {
139 let mask = self.padding_mask(&targets);
140 let tensor = if self.logits {
141 log_softmax(logits, 1)
142 } else {
143 logits.log()
144 };
145 let [batch_size, nr_classes] = tensor.dims();
146 let tensor = tensor
147 * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
148
149 match &self.weights {
150 Some(weights) => {
151 let tensor = tensor
152 * weights
153 .clone()
154 .reshape([1, nr_classes])
155 .repeat_dim(0, batch_size);
156 let weights = weights.clone().gather(0, targets);
157 let tensor = Self::apply_mask_2d(tensor, mask);
158 tensor.sum().neg() / weights.sum()
159 }
160 None => {
161 let tensor = Self::apply_mask_2d(tensor, mask);
162 tensor.sum_dim(1).mean().neg()
163 }
164 }
165 }
166
167 fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
168 let [batch_size] = targets.dims();
169
170 let mask = self.padding_mask(&targets);
171 let tensor = log_softmax(logits, 1);
172 let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
173
174 match &self.weights {
175 Some(weights) => {
176 let weights = weights.clone().gather(0, targets);
177 let tensor = tensor.reshape([batch_size]) * weights.clone();
178 let tensor = Self::apply_mask_1d(tensor, mask);
179 tensor.sum().neg() / weights.sum()
180 }
181 None => {
182 let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
183 tensor.mean().neg()
184 }
185 }
186 }
187
188 fn compute_smoothed_targets(
189 shape: [usize; 2],
190 targets: Tensor<B, 1, Int>,
191 alpha: f32,
192 ) -> Tensor<B, 2> {
193 let [batch_size, nr_classes] = shape;
194 let device = &targets.device();
195 let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
196 1,
197 targets.reshape([batch_size, 1]),
198 Tensor::ones([batch_size, 1], device),
199 IndexingUpdateOp::Add,
200 );
201 targets_matrix * (1. - alpha) + alpha / nr_classes as f32
202 }
203
204 fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
205 let mut mask = None;
206 if let Some(pad_tokens) = &self.pad_tokens {
207 let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
208 for x in pad_tokens {
209 res = res + targets.clone().equal_elem(*x as i64).int();
210 }
211 mask = Some(res.greater_elem(0));
212 }
213
214 mask
215 }
216
217 fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
218 if let Some(mask) = mask {
219 tensor = tensor.mask_fill(mask, 0);
220 }
221
222 tensor
223 }
224
225 fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
226 if let Some(mask) = mask {
227 let [batch_size, nr_classes] = tensor.dims();
228 tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
229 }
230
231 tensor
232 }
233
234 fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
235 let [logits_height, _] = logits.dims();
236 let [targets_height] = targets.dims();
237 assert!(
238 logits_height == targets_height,
239 "Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})."
240 );
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::TestBackend;
248 use burn::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
249 use burn::tensor::{Tolerance, ops::FloatElem};
250 type FT = FloatElem<TestBackend>;
251
252 macro_rules! setup {
253 () => {{
254 let [batch_size, num_targets] = [4, 5];
255 let device = Default::default();
256 let logits = Tensor::<TestBackend, 2>::random(
257 [batch_size, num_targets],
258 Distribution::Normal(0., 1.0),
259 &device,
260 );
261 let targets =
262 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
263 let targets_logits = Tensor::<TestBackend, 2>::from_data(
264 TensorData::from([
265 [0.0, 0.0, 1.0, 0.0, 0.0],
266 [1.0, 0.0, 0.0, 0.0, 0.0],
267 [0.0, 0.0, 0.0, 0.0, 1.0],
268 [0.0, 1.0, 0.0, 0.0, 0.0],
269 ]),
270 &device,
271 );
272 (logits, targets, targets_logits)
273 }};
274 }
275
276 macro_rules! setup_padded {
277 () => {{
278 let [batch_size, num_targets, pad_index] = [4, 5, 1];
279 let device = Default::default();
280 let logits = Tensor::<TestBackend, 2>::random(
281 [batch_size, num_targets],
282 Distribution::Normal(0., 1.0),
283 &device,
284 );
285 let targets = Tensor::<TestBackend, 1, Int>::from_data(
286 TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
287 &device,
288 );
289 let targets_logits = Tensor::<TestBackend, 2>::from_data(
290 TensorData::from([
291 [0.0, 0.0, 0.0, 0.0, 0.0],
292 [1.0, 0.0, 0.0, 0.0, 0.0],
293 [0.0, 0.0, 0.0, 0.0, 1.0],
294 [0.0, 0.0, 0.0, 0.0, 0.0],
295 ]),
296 &device,
297 );
298 (logits, targets, targets_logits)
299 }};
300 }
301
302 #[test]
303 fn test_cross_entropy_loss_with_weights() {
304 let (logits, targets, targets_logits) = setup!();
305 let weights = vec![1.0, 2., 3., 4., 5.];
306 let device = Default::default();
307 let loss_1 = CrossEntropyLossConfig::new()
308 .with_weights(Some(weights.clone()))
309 .init(&device)
310 .forward(logits.clone(), targets);
311 let tensor = log_softmax(logits, 1);
312 let loss_2 = tensor
313 * targets_logits
314 * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
315 .unsqueeze()
316 .repeat_dim(0, 4);
317 let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
318 loss_1
319 .into_data()
320 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
321 }
322
323 #[test]
324 fn test_label_smoothing_with_weights_and_alpha_zero() {
325 let (logits, targets, _) = setup!();
326 let device = Default::default();
327 let weights = vec![1.0, 2., 3., 4., 5.];
328 let loss_1 = CrossEntropyLossConfig::new()
329 .with_weights(Some(weights.clone()))
330 .init(&device)
331 .forward(logits.clone(), targets.clone());
332 let loss_2 = CrossEntropyLossConfig::new()
333 .with_weights(Some(weights.clone()))
334 .with_smoothing(Some(0.))
335 .init(&device)
336 .forward(logits.clone(), targets);
337 loss_1
338 .into_data()
339 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
340 }
341
342 #[test]
343 fn test_cross_entropy_loss() {
344 let (logits, targets, targets_logits) = setup!();
345 let device = Default::default();
346 let loss_1 = CrossEntropyLossConfig::new()
347 .init(&device)
348 .forward(logits.clone(), targets);
349 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
350
351 loss_1
352 .into_data()
353 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
354 }
355
356 #[test]
357 fn test_label_smoothing_alpha_equal_zero() {
358 let (logits, targets, _) = setup!();
359 let device = Default::default();
360 let loss_1 = CrossEntropyLossConfig::new()
361 .init(&device)
362 .forward(logits.clone(), targets.clone());
363 let loss_2 = CrossEntropyLossConfig::new()
364 .with_smoothing(Some(0.))
365 .init(&device)
366 .forward(logits, targets);
367
368 loss_1
369 .into_data()
370 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
371 }
372
373 #[test]
374 fn test_cross_entropy_loss_with_pad_token() {
375 let (logits, targets, targets_logits) = setup_padded!();
376 let pad_index = 1;
377
378 let loss_1 = CrossEntropyLossConfig::new()
379 .with_pad_tokens(Some(vec![pad_index, 2]))
380 .init(&logits.device())
381 .forward(logits.clone(), targets);
382 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
383
384 loss_1
385 .into_data()
386 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
387 }
388
389 #[test]
390 fn test_label_smoothing_with_zero_alpha_and_pad_token() {
391 let (logits, targets, _) = setup_padded!();
392 let pad_index = 1;
393
394 let loss_1 = CrossEntropyLossConfig::new()
395 .with_pad_tokens(Some(vec![pad_index, 2]))
396 .init(&logits.device())
397 .forward(logits.clone(), targets.clone());
398 let loss_2 = CrossEntropyLossConfig::new()
399 .with_pad_tokens(Some(vec![pad_index, 2]))
400 .with_smoothing(Some(0.))
401 .init(&logits.device())
402 .forward(logits.clone(), targets);
403
404 loss_1
405 .into_data()
406 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
407 }
408
409 #[test]
410 fn test_label_smoothing_target_conversion() {
411 let (logits, targets, _) = setup!();
412 let smoothed_targets =
413 CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
414 let targets_logits = Tensor::<TestBackend, 2>::from_data(
415 TensorData::from([
416 [0.01, 0.01, 0.96, 0.01, 0.01],
417 [0.96, 0.01, 0.01, 0.01, 0.01],
418 [0.01, 0.01, 0.01, 0.01, 0.96],
419 [0.01, 0.96, 0.01, 0.01, 0.01],
420 ]),
421 &Default::default(),
422 );
423 smoothed_targets
424 .into_data()
425 .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
426 }
427
428 #[test]
429 fn test_label_smoothing() {
430 let (logits, targets, _) = setup!();
431 let device = Default::default();
432 let loss_1 = CrossEntropyLossConfig::new()
433 .with_smoothing(Some(0.05))
434 .init(&device)
435 .forward(logits.clone(), targets);
436 let targets_logits = Tensor::<TestBackend, 2>::from_data(
437 TensorData::from([
438 [0.01, 0.01, 0.96, 0.01, 0.01],
439 [0.96, 0.01, 0.01, 0.01, 0.01],
440 [0.01, 0.01, 0.01, 0.01, 0.96],
441 [0.01, 0.96, 0.01, 0.01, 0.01],
442 ]),
443 &device,
444 );
445
446 let x = log_softmax(logits, 1);
447 let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
448
449 loss_1
450 .into_data()
451 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
452 }
453
454 #[test]
455 fn display() {
456 let config = CrossEntropyLossConfig::new()
457 .with_weights(Some(alloc::vec![3., 7., 0.9]))
458 .with_smoothing(Some(0.5));
459 let loss = config.init::<TestBackend>(&Default::default());
460
461 assert_eq!(
462 alloc::format!("{loss}"),
463 "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
464 );
465 }
466}