1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::activation::log_softmax;
5use crate::tensor::{Bool, Int, Tensor, backend::Backend};
6use crate::{config::Config, module::Module};
7use alloc::string::ToString;
8use alloc::vec;
9use alloc::vec::Vec;
10
11#[derive(Config, Debug)]
13pub struct CrossEntropyLossConfig {
14 pub pad_tokens: Option<Vec<usize>>,
18
19 pub weights: Option<Vec<f32>>,
27
28 pub smoothing: Option<f32>,
33
34 #[config(default = true)]
37 pub logits: bool,
38}
39
40impl CrossEntropyLossConfig {
41 pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
43 self.assertions();
44 CrossEntropyLoss {
45 pad_tokens: self.pad_tokens.clone(),
46 weights: self
47 .weights
48 .as_ref()
49 .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
50 smoothing: self.smoothing,
51 logits: self.logits,
52 }
53 }
54
55 fn assertions(&self) {
56 if let Some(alpha) = self.smoothing {
57 assert!(
58 (0.0..=1.).contains(&alpha),
59 "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
60 alpha
61 );
62 };
63 if let Some(weights) = self.weights.as_ref() {
64 assert!(
65 weights.iter().all(|e| e > &0.),
66 "Weights of cross-entropy have to be positive."
67 );
68 }
69 }
70}
71
72#[derive(Module, Debug)]
76#[module(custom_display)]
77pub struct CrossEntropyLoss<B: Backend> {
78 pub pad_tokens: Option<Vec<usize>>,
80 pub weights: Option<Tensor<B, 1>>,
82 pub smoothing: Option<f32>,
84 pub logits: bool,
86}
87
88impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
89 fn custom_settings(&self) -> Option<DisplaySettings> {
90 DisplaySettings::new()
91 .with_new_line_after_attribute(false)
92 .optional()
93 }
94
95 fn custom_content(&self, content: Content) -> Option<Content> {
96 let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
97 alloc::format!("Vec<0..{}>", pad_tokens.len())
98 } else {
99 "None".to_string()
100 };
101
102 content
103 .add("pad_tokens", &pad_tokens)
104 .add("weights", &self.weights)
105 .add("smoothing", &self.smoothing)
106 .add("logits", &self.logits)
107 .optional()
108 }
109}
110
111impl<B: Backend> CrossEntropyLoss<B> {
112 pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
114 CrossEntropyLossConfig::new()
115 .with_pad_tokens(pad_index.map(|e| vec![e]))
116 .init(device)
117 }
118
119 pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
126 Self::assertions(logits.clone(), targets.clone());
127 match self.smoothing {
128 Some(alpha) => self.forward_smoothed(logits, targets, alpha),
129 _ => self.forward_default(logits, targets),
130 }
131 }
132
133 fn forward_smoothed(
134 &self,
135 logits: Tensor<B, 2>,
136 targets: Tensor<B, 1, Int>,
137 alpha: f32,
138 ) -> Tensor<B, 1> {
139 let mask = self.padding_mask(&targets);
140 let tensor = if self.logits {
141 log_softmax(logits, 1)
142 } else {
143 logits.log()
144 };
145 let [batch_size, nr_classes] = tensor.dims();
146 let tensor = tensor
147 * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
148
149 match &self.weights {
150 Some(weights) => {
151 let tensor = tensor
152 * weights
153 .clone()
154 .reshape([1, nr_classes])
155 .repeat_dim(0, batch_size);
156 let weights = weights.clone().gather(0, targets);
157 let tensor = Self::apply_mask_2d(tensor, mask);
158 tensor.sum().neg() / weights.sum()
159 }
160 None => {
161 let tensor = Self::apply_mask_2d(tensor, mask);
162 tensor.sum_dim(1).mean().neg()
163 }
164 }
165 }
166
167 fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
168 let [batch_size] = targets.dims();
169
170 let mask = self.padding_mask(&targets);
171 let tensor = log_softmax(logits, 1);
172 let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
173
174 match &self.weights {
175 Some(weights) => {
176 let weights = weights.clone().gather(0, targets);
177 let tensor = tensor.reshape([batch_size]) * weights.clone();
178 let tensor = Self::apply_mask_1d(tensor, mask);
179 tensor.sum().neg() / weights.sum()
180 }
181 None => {
182 let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
183 tensor.mean().neg()
184 }
185 }
186 }
187
188 fn compute_smoothed_targets(
189 shape: [usize; 2],
190 targets: Tensor<B, 1, Int>,
191 alpha: f32,
192 ) -> Tensor<B, 2> {
193 let [batch_size, nr_classes] = shape;
194 let device = &targets.device();
195 let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
196 1,
197 targets.reshape([batch_size, 1]),
198 Tensor::ones([batch_size, 1], device),
199 );
200 targets_matrix * (1. - alpha) + alpha / nr_classes as f32
201 }
202
203 fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
204 let mut mask = None;
205 if let Some(pad_tokens) = &self.pad_tokens {
206 let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
207 for x in pad_tokens {
208 res = res + targets.clone().equal_elem(*x as i64).int();
209 }
210 mask = Some(res.greater_elem(0));
211 }
212
213 mask
214 }
215
216 fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
217 if let Some(mask) = mask {
218 tensor = tensor.mask_fill(mask, 0);
219 }
220
221 tensor
222 }
223
224 fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
225 if let Some(mask) = mask {
226 let [batch_size, nr_classes] = tensor.dims();
227 tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
228 }
229
230 tensor
231 }
232
233 fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
234 let [logits_height, _] = logits.dims();
235 let [targets_height] = targets.dims();
236 assert!(
237 logits_height == targets_height,
238 "Shape of targets ({}) should correspond to outer shape of logits ({}).",
239 targets_height,
240 logits_height
241 );
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::TestBackend;
249 use crate::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
250 use burn_tensor::{Tolerance, ops::FloatElem};
251 type FT = FloatElem<TestBackend>;
252
253 macro_rules! setup {
254 () => {{
255 let [batch_size, num_targets] = [4, 5];
256 let device = Default::default();
257 let logits = Tensor::<TestBackend, 2>::random(
258 [batch_size, num_targets],
259 Distribution::Normal(0., 1.0),
260 &device,
261 );
262 let targets =
263 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
264 let targets_logits = Tensor::<TestBackend, 2>::from_data(
265 TensorData::from([
266 [0.0, 0.0, 1.0, 0.0, 0.0],
267 [1.0, 0.0, 0.0, 0.0, 0.0],
268 [0.0, 0.0, 0.0, 0.0, 1.0],
269 [0.0, 1.0, 0.0, 0.0, 0.0],
270 ]),
271 &device,
272 );
273 (logits, targets, targets_logits)
274 }};
275 }
276
277 macro_rules! setup_padded {
278 () => {{
279 let [batch_size, num_targets, pad_index] = [4, 5, 1];
280 let device = Default::default();
281 let logits = Tensor::<TestBackend, 2>::random(
282 [batch_size, num_targets],
283 Distribution::Normal(0., 1.0),
284 &device,
285 );
286 let targets = Tensor::<TestBackend, 1, Int>::from_data(
287 TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
288 &device,
289 );
290 let targets_logits = Tensor::<TestBackend, 2>::from_data(
291 TensorData::from([
292 [0.0, 0.0, 0.0, 0.0, 0.0],
293 [1.0, 0.0, 0.0, 0.0, 0.0],
294 [0.0, 0.0, 0.0, 0.0, 1.0],
295 [0.0, 0.0, 0.0, 0.0, 0.0],
296 ]),
297 &device,
298 );
299 (logits, targets, targets_logits)
300 }};
301 }
302
303 #[test]
304 fn test_cross_entropy_loss_with_weights() {
305 let (logits, targets, targets_logits) = setup!();
306 let weights = vec![1.0, 2., 3., 4., 5.];
307 let device = Default::default();
308 let loss_1 = CrossEntropyLossConfig::new()
309 .with_weights(Some(weights.clone()))
310 .init(&device)
311 .forward(logits.clone(), targets);
312 let tensor = log_softmax(logits, 1);
313 let loss_2 = tensor
314 * targets_logits
315 * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
316 .unsqueeze()
317 .repeat_dim(0, 4);
318 let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
319 loss_1
320 .into_data()
321 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
322 }
323
324 #[test]
325 fn test_label_smoothing_with_weights_and_alpha_zero() {
326 let (logits, targets, _) = setup!();
327 let device = Default::default();
328 let weights = vec![1.0, 2., 3., 4., 5.];
329 let loss_1 = CrossEntropyLossConfig::new()
330 .with_weights(Some(weights.clone()))
331 .init(&device)
332 .forward(logits.clone(), targets.clone());
333 let loss_2 = CrossEntropyLossConfig::new()
334 .with_weights(Some(weights.clone()))
335 .with_smoothing(Some(0.))
336 .init(&device)
337 .forward(logits.clone(), targets);
338 loss_1
339 .into_data()
340 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
341 }
342
343 #[test]
344 fn test_cross_entropy_loss() {
345 let (logits, targets, targets_logits) = setup!();
346 let device = Default::default();
347 let loss_1 = CrossEntropyLossConfig::new()
348 .init(&device)
349 .forward(logits.clone(), targets);
350 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
351
352 loss_1
353 .into_data()
354 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
355 }
356
357 #[test]
358 fn test_label_smoothing_alpha_equal_zero() {
359 let (logits, targets, _) = setup!();
360 let device = Default::default();
361 let loss_1 = CrossEntropyLossConfig::new()
362 .init(&device)
363 .forward(logits.clone(), targets.clone());
364 let loss_2 = CrossEntropyLossConfig::new()
365 .with_smoothing(Some(0.))
366 .init(&device)
367 .forward(logits, targets);
368
369 loss_1
370 .into_data()
371 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
372 }
373
374 #[test]
375 fn test_cross_entropy_loss_with_pad_token() {
376 let (logits, targets, targets_logits) = setup_padded!();
377 let pad_index = 1;
378
379 let loss_1 = CrossEntropyLossConfig::new()
380 .with_pad_tokens(Some(vec![pad_index, 2]))
381 .init(&logits.device())
382 .forward(logits.clone(), targets);
383 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
384
385 loss_1
386 .into_data()
387 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
388 }
389
390 #[test]
391 fn test_label_smoothing_with_zero_alpha_and_pad_token() {
392 let (logits, targets, _) = setup_padded!();
393 let pad_index = 1;
394
395 let loss_1 = CrossEntropyLossConfig::new()
396 .with_pad_tokens(Some(vec![pad_index, 2]))
397 .init(&logits.device())
398 .forward(logits.clone(), targets.clone());
399 let loss_2 = CrossEntropyLossConfig::new()
400 .with_pad_tokens(Some(vec![pad_index, 2]))
401 .with_smoothing(Some(0.))
402 .init(&logits.device())
403 .forward(logits.clone(), targets);
404
405 loss_1
406 .into_data()
407 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
408 }
409
410 #[test]
411 fn test_label_smoothing_target_conversion() {
412 let (logits, targets, _) = setup!();
413 let smoothed_targets =
414 CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
415 let targets_logits = Tensor::<TestBackend, 2>::from_data(
416 TensorData::from([
417 [0.01, 0.01, 0.96, 0.01, 0.01],
418 [0.96, 0.01, 0.01, 0.01, 0.01],
419 [0.01, 0.01, 0.01, 0.01, 0.96],
420 [0.01, 0.96, 0.01, 0.01, 0.01],
421 ]),
422 &Default::default(),
423 );
424 smoothed_targets
425 .into_data()
426 .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
427 }
428
429 #[test]
430 fn test_label_smoothing() {
431 let (logits, targets, _) = setup!();
432 let device = Default::default();
433 let loss_1 = CrossEntropyLossConfig::new()
434 .with_smoothing(Some(0.05))
435 .init(&device)
436 .forward(logits.clone(), targets);
437 let targets_logits = Tensor::<TestBackend, 2>::from_data(
438 TensorData::from([
439 [0.01, 0.01, 0.96, 0.01, 0.01],
440 [0.96, 0.01, 0.01, 0.01, 0.01],
441 [0.01, 0.01, 0.01, 0.01, 0.96],
442 [0.01, 0.96, 0.01, 0.01, 0.01],
443 ]),
444 &device,
445 );
446
447 let x = log_softmax(logits, 1);
448 let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
449
450 loss_1
451 .into_data()
452 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
453 }
454
455 #[test]
456 fn display() {
457 let config = CrossEntropyLossConfig::new()
458 .with_weights(Some(alloc::vec![3., 7., 0.9]))
459 .with_smoothing(Some(0.5));
460 let loss = config.init::<TestBackend>(&Default::default());
461
462 assert_eq!(
463 alloc::format!("{}", loss),
464 "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
465 );
466 }
467}