1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::activation::log_softmax;
5use crate::tensor::{backend::Backend, Bool, Int, Tensor};
6use crate::{config::Config, module::Module};
7use alloc::string::ToString;
8use alloc::vec;
9use alloc::vec::Vec;
10
11#[derive(Config, Debug)]
13pub struct CrossEntropyLossConfig {
14 pub pad_tokens: Option<Vec<usize>>,
18
19 pub weights: Option<Vec<f32>>,
27
28 pub smoothing: Option<f32>,
33
34 #[config(default = true)]
37 pub logits: bool,
38}
39
40impl CrossEntropyLossConfig {
41 pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
43 self.assertions();
44 CrossEntropyLoss {
45 pad_tokens: self.pad_tokens.clone(),
46 weights: self
47 .weights
48 .as_ref()
49 .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
50 smoothing: self.smoothing,
51 logits: self.logits,
52 }
53 }
54
55 fn assertions(&self) {
56 if let Some(alpha) = self.smoothing {
57 assert!(
58 (0.0..=1.).contains(&alpha),
59 "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
60 alpha
61 );
62 };
63 if let Some(weights) = self.weights.as_ref() {
64 assert!(
65 weights.iter().all(|e| e > &0.),
66 "Weights of cross-entropy have to be positive."
67 );
68 }
69 }
70}
71
72#[derive(Module, Debug)]
76#[module(custom_display)]
77pub struct CrossEntropyLoss<B: Backend> {
78 pub pad_tokens: Option<Vec<usize>>,
80 pub weights: Option<Tensor<B, 1>>,
82 pub smoothing: Option<f32>,
84 pub logits: bool,
86}
87
88impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
89 fn custom_settings(&self) -> Option<DisplaySettings> {
90 DisplaySettings::new()
91 .with_new_line_after_attribute(false)
92 .optional()
93 }
94
95 fn custom_content(&self, content: Content) -> Option<Content> {
96 let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
97 alloc::format!("Vec<0..{}>", pad_tokens.len())
98 } else {
99 "None".to_string()
100 };
101
102 content
103 .add("pad_tokens", &pad_tokens)
104 .add("weights", &self.weights)
105 .add("smoothing", &self.smoothing)
106 .add("logits", &self.logits)
107 .optional()
108 }
109}
110
111impl<B: Backend> CrossEntropyLoss<B> {
112 pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
114 CrossEntropyLossConfig::new()
115 .with_pad_tokens(pad_index.map(|e| vec![e]))
116 .init(device)
117 }
118
119 pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
126 Self::assertions(logits.clone(), targets.clone());
127 match self.smoothing {
128 Some(alpha) => self.forward_smoothed(logits, targets, alpha),
129 _ => self.forward_default(logits, targets),
130 }
131 }
132
133 fn forward_smoothed(
134 &self,
135 logits: Tensor<B, 2>,
136 targets: Tensor<B, 1, Int>,
137 alpha: f32,
138 ) -> Tensor<B, 1> {
139 let mask = self.padding_mask(&targets);
140 let tensor = if self.logits {
141 log_softmax(logits, 1)
142 } else {
143 logits.log()
144 };
145 let [batch_size, nr_classes] = tensor.dims();
146 let tensor = tensor
147 * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
148
149 match &self.weights {
150 Some(weights) => {
151 let tensor = tensor
152 * weights
153 .clone()
154 .reshape([1, nr_classes])
155 .repeat_dim(0, batch_size);
156 let weights = weights.clone().gather(0, targets);
157 let tensor = Self::apply_mask_2d(tensor, mask);
158 tensor.sum().neg() / weights.sum()
159 }
160 None => {
161 let tensor = Self::apply_mask_2d(tensor, mask);
162 tensor.sum_dim(1).mean().neg()
163 }
164 }
165 }
166
167 fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
168 let [batch_size] = targets.dims();
169
170 let mask = self.padding_mask(&targets);
171 let tensor = log_softmax(logits, 1);
172 let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
173
174 match &self.weights {
175 Some(weights) => {
176 let weights = weights.clone().gather(0, targets);
177 let tensor = tensor.reshape([batch_size]) * weights.clone();
178 let tensor = Self::apply_mask_1d(tensor, mask);
179 tensor.sum().neg() / weights.sum()
180 }
181 None => {
182 let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
183 tensor.mean().neg()
184 }
185 }
186 }
187
188 fn compute_smoothed_targets(
189 shape: [usize; 2],
190 targets: Tensor<B, 1, Int>,
191 alpha: f32,
192 ) -> Tensor<B, 2> {
193 let [batch_size, nr_classes] = shape;
194 let device = &targets.device();
195 let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
196 1,
197 targets.reshape([batch_size, 1]),
198 Tensor::ones([batch_size, 1], device),
199 );
200 targets_matrix * (1. - alpha) + alpha / nr_classes as f32
201 }
202
203 fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
204 let mut mask = None;
205 if let Some(pad_tokens) = &self.pad_tokens {
206 let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
207 for x in pad_tokens {
208 res = res + targets.clone().equal_elem(*x as i64).int();
209 }
210 mask = Some(res.greater_elem(0));
211 }
212
213 mask
214 }
215
216 fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
217 if let Some(mask) = mask {
218 tensor = tensor.mask_fill(mask, 0);
219 }
220
221 tensor
222 }
223
224 fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
225 if let Some(mask) = mask {
226 let [batch_size, nr_classes] = tensor.dims();
227 tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
228 }
229
230 tensor
231 }
232
233 fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
234 let [logits_height, _] = logits.dims();
235 let [targets_height] = targets.dims();
236 assert!(
237 logits_height == targets_height,
238 "Shape of targets ({}) should correspond to outer shape of logits ({}).",
239 targets_height,
240 logits_height
241 );
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::tensor::{loss::cross_entropy_with_logits, ops::IntElem, Distribution, TensorData};
249 use crate::TestBackend;
250
251 macro_rules! setup {
252 () => {{
253 let [batch_size, num_targets] = [4, 5];
254 let device = Default::default();
255 let logits = Tensor::<TestBackend, 2>::random(
256 [batch_size, num_targets],
257 Distribution::Normal(0., 1.0),
258 &device,
259 );
260 let targets =
261 Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
262 let targets_logits = Tensor::<TestBackend, 2>::from_data(
263 TensorData::from([
264 [0.0, 0.0, 1.0, 0.0, 0.0],
265 [1.0, 0.0, 0.0, 0.0, 0.0],
266 [0.0, 0.0, 0.0, 0.0, 1.0],
267 [0.0, 1.0, 0.0, 0.0, 0.0],
268 ]),
269 &device,
270 );
271 (logits, targets, targets_logits)
272 }};
273 }
274
275 macro_rules! setup_padded {
276 () => {{
277 let [batch_size, num_targets, pad_index] = [4, 5, 1];
278 let device = Default::default();
279 let logits = Tensor::<TestBackend, 2>::random(
280 [batch_size, num_targets],
281 Distribution::Normal(0., 1.0),
282 &device,
283 );
284 let targets = Tensor::<TestBackend, 1, Int>::from_data(
285 TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
286 &device,
287 );
288 let targets_logits = Tensor::<TestBackend, 2>::from_data(
289 TensorData::from([
290 [0.0, 0.0, 0.0, 0.0, 0.0],
291 [1.0, 0.0, 0.0, 0.0, 0.0],
292 [0.0, 0.0, 0.0, 0.0, 1.0],
293 [0.0, 0.0, 0.0, 0.0, 0.0],
294 ]),
295 &device,
296 );
297 (logits, targets, targets_logits)
298 }};
299 }
300
301 #[test]
302 fn test_cross_entropy_loss_with_weights() {
303 let (logits, targets, targets_logits) = setup!();
304 let weights = vec![1.0, 2., 3., 4., 5.];
305 let device = Default::default();
306 let loss_1 = CrossEntropyLossConfig::new()
307 .with_weights(Some(weights.clone()))
308 .init(&device)
309 .forward(logits.clone(), targets);
310 let tensor = log_softmax(logits, 1);
311 let loss_2 = tensor
312 * targets_logits
313 * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
314 .unsqueeze()
315 .repeat_dim(0, 4);
316 let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
317 loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
318 }
319
320 #[test]
321 fn test_label_smoothing_with_weights_and_alpha_zero() {
322 let (logits, targets, _) = setup!();
323 let device = Default::default();
324 let weights = vec![1.0, 2., 3., 4., 5.];
325 let loss_1 = CrossEntropyLossConfig::new()
326 .with_weights(Some(weights.clone()))
327 .init(&device)
328 .forward(logits.clone(), targets.clone());
329 let loss_2 = CrossEntropyLossConfig::new()
330 .with_weights(Some(weights.clone()))
331 .with_smoothing(Some(0.))
332 .init(&device)
333 .forward(logits.clone(), targets);
334 loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
335 }
336
337 #[test]
338 fn test_cross_entropy_loss() {
339 let (logits, targets, targets_logits) = setup!();
340 let device = Default::default();
341 let loss_1 = CrossEntropyLossConfig::new()
342 .init(&device)
343 .forward(logits.clone(), targets);
344 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
345
346 loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
347 }
348
349 #[test]
350 fn test_label_smoothing_alpha_equal_zero() {
351 let (logits, targets, _) = setup!();
352 let device = Default::default();
353 let loss_1 = CrossEntropyLossConfig::new()
354 .init(&device)
355 .forward(logits.clone(), targets.clone());
356 let loss_2 = CrossEntropyLossConfig::new()
357 .with_smoothing(Some(0.))
358 .init(&device)
359 .forward(logits, targets);
360
361 loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
362 }
363
364 #[test]
365 fn test_cross_entropy_loss_with_pad_token() {
366 let (logits, targets, targets_logits) = setup_padded!();
367 let pad_index = 1;
368
369 let loss_1 = CrossEntropyLossConfig::new()
370 .with_pad_tokens(Some(vec![pad_index, 2]))
371 .init(&logits.device())
372 .forward(logits.clone(), targets);
373 let loss_2 = cross_entropy_with_logits(logits, targets_logits);
374
375 loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
376 }
377
378 #[test]
379 fn test_label_smoothing_with_zero_alpha_and_pad_token() {
380 let (logits, targets, _) = setup_padded!();
381 let pad_index = 1;
382
383 let loss_1 = CrossEntropyLossConfig::new()
384 .with_pad_tokens(Some(vec![pad_index, 2]))
385 .init(&logits.device())
386 .forward(logits.clone(), targets.clone());
387 let loss_2 = CrossEntropyLossConfig::new()
388 .with_pad_tokens(Some(vec![pad_index, 2]))
389 .with_smoothing(Some(0.))
390 .init(&logits.device())
391 .forward(logits.clone(), targets);
392
393 loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
394 }
395
396 #[test]
397 fn test_label_smoothing_target_conversion() {
398 let (logits, targets, _) = setup!();
399 let smoothed_targets =
400 CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
401 let targets_logits = Tensor::<TestBackend, 2>::from_data(
402 TensorData::from([
403 [0.01, 0.01, 0.96, 0.01, 0.01],
404 [0.96, 0.01, 0.01, 0.01, 0.01],
405 [0.01, 0.01, 0.01, 0.01, 0.96],
406 [0.01, 0.96, 0.01, 0.01, 0.01],
407 ]),
408 &Default::default(),
409 );
410 smoothed_targets
411 .into_data()
412 .assert_approx_eq(&targets_logits.into_data(), 3);
413 }
414
415 #[test]
416 fn test_label_smoothing() {
417 let (logits, targets, _) = setup!();
418 let device = Default::default();
419 let loss_1 = CrossEntropyLossConfig::new()
420 .with_smoothing(Some(0.05))
421 .init(&device)
422 .forward(logits.clone(), targets);
423 let targets_logits = Tensor::<TestBackend, 2>::from_data(
424 TensorData::from([
425 [0.01, 0.01, 0.96, 0.01, 0.01],
426 [0.96, 0.01, 0.01, 0.01, 0.01],
427 [0.01, 0.01, 0.01, 0.01, 0.96],
428 [0.01, 0.96, 0.01, 0.01, 0.01],
429 ]),
430 &device,
431 );
432
433 let x = log_softmax(logits, 1);
434 let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
435
436 loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
437 }
438
439 #[test]
440 fn display() {
441 let config = CrossEntropyLossConfig::new()
442 .with_weights(Some(alloc::vec![3., 7., 0.9]))
443 .with_smoothing(Some(0.5));
444 let loss = config.init::<TestBackend>(&Default::default());
445
446 assert_eq!(
447 alloc::format!("{}", loss),
448 "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
449 );
450 }
451}