1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::activation::log_softmax;
5use crate::tensor::{Bool, Int, Tensor, backend::Backend};
6use crate::{config::Config, module::Module};
7use alloc::string::ToString;
8use alloc::vec;
9use alloc::vec::Vec;
10
11#[derive(Config, Debug)]
13pub struct CrossEntropyLossConfig {
14 pub pad_tokens: Option<Vec<usize>>,
18
19 pub weights: Option<Vec<f32>>,
27
28 pub smoothing: Option<f32>,
33
34 #[config(default = true)]
37 pub logits: bool,
38}
39
40impl CrossEntropyLossConfig {
41 pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
43 self.assertions();
44 CrossEntropyLoss {
45 pad_tokens: self.pad_tokens.clone(),
46 weights: self
47 .weights
48 .as_ref()
49 .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
50 smoothing: self.smoothing,
51 logits: self.logits,
52 }
53 }
54
55 fn assertions(&self) {
56 if let Some(alpha) = self.smoothing {
57 assert!(
58 (0.0..=1.).contains(&alpha),
59 "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
60 );
61 };
62 if let Some(weights) = self.weights.as_ref() {
63 assert!(
64 weights.iter().all(|e| e > &0.),
65 "Weights of cross-entropy have to be positive."
66 );
67 }
68 }
69}
70
71#[derive(Module, Debug)]
75#[module(custom_display)]
76pub struct CrossEntropyLoss<B: Backend> {
77 pub pad_tokens: Option<Vec<usize>>,
79 pub weights: Option<Tensor<B, 1>>,
81 pub smoothing: Option<f32>,
83 pub logits: bool,
85}
86
87impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
88 fn custom_settings(&self) -> Option<DisplaySettings> {
89 DisplaySettings::new()
90 .with_new_line_after_attribute(false)
91 .optional()
92 }
93
94 fn custom_content(&self, content: Content) -> Option<Content> {
95 let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
96 alloc::format!("Vec<0..{}>", pad_tokens.len())
97 } else {
98 "None".to_string()
99 };
100
101 content
102 .add("pad_tokens", &pad_tokens)
103 .add("weights", &self.weights)
104 .add("smoothing", &self.smoothing)
105 .add("logits", &self.logits)
106 .optional()
107 }
108}
109
110impl<B: Backend> CrossEntropyLoss<B> {
111 pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
113 CrossEntropyLossConfig::new()
114 .with_pad_tokens(pad_index.map(|e| vec![e]))
115 .init(device)
116 }
117
118 pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
125 Self::assertions(logits.clone(), targets.clone());
126 match self.smoothing {
127 Some(alpha) => self.forward_smoothed(logits, targets, alpha),
128 _ => self.forward_default(logits, targets),
129 }
130 }
131
132 fn forward_smoothed(
133 &self,
134 logits: Tensor<B, 2>,
135 targets: Tensor<B, 1, Int>,
136 alpha: f32,
137 ) -> Tensor<B, 1> {
138 let mask = self.padding_mask(&targets);
139 let tensor = if self.logits {
140 log_softmax(logits, 1)
141 } else {
142 logits.log()
143 };
144 let [batch_size, nr_classes] = tensor.dims();
145 let tensor = tensor
146 * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
147
148 match &self.weights {
149 Some(weights) => {
150 let tensor = tensor
151 * weights
152 .clone()
153 .reshape([1, nr_classes])
154 .repeat_dim(0, batch_size);
155 let weights = weights.clone().gather(0, targets);
156 let tensor = Self::apply_mask_2d(tensor, mask);
157 tensor.sum().neg() / weights.sum()
158 }
159 None => {
160 let tensor = Self::apply_mask_2d(tensor, mask);
161 tensor.sum_dim(1).mean().neg()
162 }
163 }
164 }
165
166 fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
167 let [batch_size] = targets.dims();
168
169 let mask = self.padding_mask(&targets);
170 let tensor = log_softmax(logits, 1);
171 let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
172
173 match &self.weights {
174 Some(weights) => {
175 let weights = weights.clone().gather(0, targets);
176 let tensor = tensor.reshape([batch_size]) * weights.clone();
177 let tensor = Self::apply_mask_1d(tensor, mask);
178 tensor.sum().neg() / weights.sum()
179 }
180 None => {
181 let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
182 tensor.mean().neg()
183 }
184 }
185 }
186
187 fn compute_smoothed_targets(
188 shape: [usize; 2],
189 targets: Tensor<B, 1, Int>,
190 alpha: f32,
191 ) -> Tensor<B, 2> {
192 let [batch_size, nr_classes] = shape;
193 let device = &targets.device();
194 let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
195 1,
196 targets.reshape([batch_size, 1]),
197 Tensor::ones([batch_size, 1], device),
198 );
199 targets_matrix * (1. - alpha) + alpha / nr_classes as f32
200 }
201
202 fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
203 let mut mask = None;
204 if let Some(pad_tokens) = &self.pad_tokens {
205 let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
206 for x in pad_tokens {
207 res = res + targets.clone().equal_elem(*x as i64).int();
208 }
209 mask = Some(res.greater_elem(0));
210 }
211
212 mask
213 }
214
215 fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
216 if let Some(mask) = mask {
217 tensor = tensor.mask_fill(mask, 0);
218 }
219
220 tensor
221 }
222
223 fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
224 if let Some(mask) = mask {
225 let [batch_size, nr_classes] = tensor.dims();
226 tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
227 }
228
229 tensor
230 }
231
232 fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
233 let [logits_height, _] = logits.dims();
234 let [targets_height] = targets.dims();
235 assert!(
236 logits_height == targets_height,
237 "Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})."
238 );
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::TestBackend;
246 use crate::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
247 use burn_tensor::{Tolerance, ops::FloatElem};
248 type FT = FloatElem<TestBackend>;
249
250 macro_rules! setup {
251 () => {{
252 let [batch_size, num_targets] = [4, 5];
253 let device = Default::default();
254 let logits = Tensor::<TestBackend, 2>::random(
255 [batch_size, num_targets],
256 Distribution::Normal(0., 1.0),
257 &device,
258 );
259 let targets =
260 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
261 let targets_logits = Tensor::<TestBackend, 2>::from_data(
262 TensorData::from([
263 [0.0, 0.0, 1.0, 0.0, 0.0],
264 [1.0, 0.0, 0.0, 0.0, 0.0],
265 [0.0, 0.0, 0.0, 0.0, 1.0],
266 [0.0, 1.0, 0.0, 0.0, 0.0],
267 ]),
268 &device,
269 );
270 (logits, targets, targets_logits)
271 }};
272 }
273
274 macro_rules! setup_padded {
275 () => {{
276 let [batch_size, num_targets, pad_index] = [4, 5, 1];
277 let device = Default::default();
278 let logits = Tensor::<TestBackend, 2>::random(
279 [batch_size, num_targets],
280 Distribution::Normal(0., 1.0),
281 &device,
282 );
283 let targets = Tensor::<TestBackend, 1, Int>::from_data(
284 TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
285 &device,
286 );
287 let targets_logits = Tensor::<TestBackend, 2>::from_data(
288 TensorData::from([
289 [0.0, 0.0, 0.0, 0.0, 0.0],
290 [1.0, 0.0, 0.0, 0.0, 0.0],
291 [0.0, 0.0, 0.0, 0.0, 1.0],
292 [0.0, 0.0, 0.0, 0.0, 0.0],
293 ]),
294 &device,
295 );
296 (logits, targets, targets_logits)
297 }};
298 }
299
300 #[test]
301 fn test_cross_entropy_loss_with_weights() {
302 let (logits, targets, targets_logits) = setup!();
303 let weights = vec![1.0, 2., 3., 4., 5.];
304 let device = Default::default();
305 let loss_1 = CrossEntropyLossConfig::new()
306 .with_weights(Some(weights.clone()))
307 .init(&device)
308 .forward(logits.clone(), targets);
309 let tensor = log_softmax(logits, 1);
310 let loss_2 = tensor
311 * targets_logits
312 * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
313 .unsqueeze()
314 .repeat_dim(0, 4);
315 let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
316 loss_1
317 .into_data()
318 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
319 }
320
321 #[test]
322 fn test_label_smoothing_with_weights_and_alpha_zero() {
323 let (logits, targets, _) = setup!();
324 let device = Default::default();
325 let weights = vec![1.0, 2., 3., 4., 5.];
326 let loss_1 = CrossEntropyLossConfig::new()
327 .with_weights(Some(weights.clone()))
328 .init(&device)
329 .forward(logits.clone(), targets.clone());
330 let loss_2 = CrossEntropyLossConfig::new()
331 .with_weights(Some(weights.clone()))
332 .with_smoothing(Some(0.))
333 .init(&device)
334 .forward(logits.clone(), targets);
335 loss_1
336 .into_data()
337 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
338 }
339
340 #[test]
341 fn test_cross_entropy_loss() {
342 let (logits, targets, targets_logits) = setup!();
343 let device = Default::default();
344 let loss_1 = CrossEntropyLossConfig::new()
345 .init(&device)
346 .forward(logits.clone(), targets);
347 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
348
349 loss_1
350 .into_data()
351 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
352 }
353
354 #[test]
355 fn test_label_smoothing_alpha_equal_zero() {
356 let (logits, targets, _) = setup!();
357 let device = Default::default();
358 let loss_1 = CrossEntropyLossConfig::new()
359 .init(&device)
360 .forward(logits.clone(), targets.clone());
361 let loss_2 = CrossEntropyLossConfig::new()
362 .with_smoothing(Some(0.))
363 .init(&device)
364 .forward(logits, targets);
365
366 loss_1
367 .into_data()
368 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
369 }
370
371 #[test]
372 fn test_cross_entropy_loss_with_pad_token() {
373 let (logits, targets, targets_logits) = setup_padded!();
374 let pad_index = 1;
375
376 let loss_1 = CrossEntropyLossConfig::new()
377 .with_pad_tokens(Some(vec![pad_index, 2]))
378 .init(&logits.device())
379 .forward(logits.clone(), targets);
380 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
381
382 loss_1
383 .into_data()
384 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
385 }
386
387 #[test]
388 fn test_label_smoothing_with_zero_alpha_and_pad_token() {
389 let (logits, targets, _) = setup_padded!();
390 let pad_index = 1;
391
392 let loss_1 = CrossEntropyLossConfig::new()
393 .with_pad_tokens(Some(vec![pad_index, 2]))
394 .init(&logits.device())
395 .forward(logits.clone(), targets.clone());
396 let loss_2 = CrossEntropyLossConfig::new()
397 .with_pad_tokens(Some(vec![pad_index, 2]))
398 .with_smoothing(Some(0.))
399 .init(&logits.device())
400 .forward(logits.clone(), targets);
401
402 loss_1
403 .into_data()
404 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
405 }
406
407 #[test]
408 fn test_label_smoothing_target_conversion() {
409 let (logits, targets, _) = setup!();
410 let smoothed_targets =
411 CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
412 let targets_logits = Tensor::<TestBackend, 2>::from_data(
413 TensorData::from([
414 [0.01, 0.01, 0.96, 0.01, 0.01],
415 [0.96, 0.01, 0.01, 0.01, 0.01],
416 [0.01, 0.01, 0.01, 0.01, 0.96],
417 [0.01, 0.96, 0.01, 0.01, 0.01],
418 ]),
419 &Default::default(),
420 );
421 smoothed_targets
422 .into_data()
423 .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
424 }
425
426 #[test]
427 fn test_label_smoothing() {
428 let (logits, targets, _) = setup!();
429 let device = Default::default();
430 let loss_1 = CrossEntropyLossConfig::new()
431 .with_smoothing(Some(0.05))
432 .init(&device)
433 .forward(logits.clone(), targets);
434 let targets_logits = Tensor::<TestBackend, 2>::from_data(
435 TensorData::from([
436 [0.01, 0.01, 0.96, 0.01, 0.01],
437 [0.96, 0.01, 0.01, 0.01, 0.01],
438 [0.01, 0.01, 0.01, 0.01, 0.96],
439 [0.01, 0.96, 0.01, 0.01, 0.01],
440 ]),
441 &device,
442 );
443
444 let x = log_softmax(logits, 1);
445 let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
446
447 loss_1
448 .into_data()
449 .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
450 }
451
452 #[test]
453 fn display() {
454 let config = CrossEntropyLossConfig::new()
455 .with_weights(Some(alloc::vec![3., 7., 0.9]))
456 .with_smoothing(Some(0.5));
457 let loss = config.init::<TestBackend>(&Default::default());
458
459 assert_eq!(
460 alloc::format!("{loss}"),
461 "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
462 );
463 }
464}