burn_core/nn/loss/
cross_entropy.rs

1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::activation::log_softmax;
5use crate::tensor::{Bool, Int, Tensor, backend::Backend};
6use crate::{config::Config, module::Module};
7use alloc::string::ToString;
8use alloc::vec;
9use alloc::vec::Vec;
10
11/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
12#[derive(Config, Debug)]
13pub struct CrossEntropyLossConfig {
14    /// Create padded cross entropy.
15    ///
16    /// Prevents pad tokens from impacting loss calculation.
17    pub pad_tokens: Option<Vec<usize>>,
18
19    /// Create weighted cross-entropy.
20    ///
21    /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
22    ///
23    /// # Pre-conditions
24    ///   - The order of the weight vector should correspond to the label integer assignment.
25    ///   - Targets assigned negative Int's will not be allowed.
26    pub weights: Option<Vec<f32>>,
27
28    /// Create cross-entropy with label smoothing.
29    ///
30    /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
31    /// Alpha = 0 would be the same as default.
32    pub smoothing: Option<f32>,
33
34    /// Create cross-entropy with probabilities as input instead of logits.
35    ///
36    #[config(default = true)]
37    pub logits: bool,
38}
39
40impl CrossEntropyLossConfig {
41    /// Initialize [Cross-entropy loss](CrossEntropyLoss).
42    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
43        self.assertions();
44        CrossEntropyLoss {
45            pad_tokens: self.pad_tokens.clone(),
46            weights: self
47                .weights
48                .as_ref()
49                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
50            smoothing: self.smoothing,
51            logits: self.logits,
52        }
53    }
54
55    fn assertions(&self) {
56        if let Some(alpha) = self.smoothing {
57            assert!(
58                (0.0..=1.).contains(&alpha),
59                "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
60                alpha
61            );
62        };
63        if let Some(weights) = self.weights.as_ref() {
64            assert!(
65                weights.iter().all(|e| e > &0.),
66                "Weights of cross-entropy have to be positive."
67            );
68        }
69    }
70}
71
72/// Calculate the cross entropy loss from the input logits and the targets.
73///
74/// Should be created using [CrossEntropyLossConfig]
75#[derive(Module, Debug)]
76#[module(custom_display)]
77pub struct CrossEntropyLoss<B: Backend> {
78    /// Pad tokens to ignore in the loss calculation.
79    pub pad_tokens: Option<Vec<usize>>,
80    /// Weights for cross-entropy.
81    pub weights: Option<Tensor<B, 1>>,
82    /// Label smoothing factor.
83    pub smoothing: Option<f32>,
84    /// Use logits as input.
85    pub logits: bool,
86}
87
88impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
89    fn custom_settings(&self) -> Option<DisplaySettings> {
90        DisplaySettings::new()
91            .with_new_line_after_attribute(false)
92            .optional()
93    }
94
95    fn custom_content(&self, content: Content) -> Option<Content> {
96        let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
97            alloc::format!("Vec<0..{}>", pad_tokens.len())
98        } else {
99            "None".to_string()
100        };
101
102        content
103            .add("pad_tokens", &pad_tokens)
104            .add("weights", &self.weights)
105            .add("smoothing", &self.smoothing)
106            .add("logits", &self.logits)
107            .optional()
108    }
109}
110
111impl<B: Backend> CrossEntropyLoss<B> {
112    /// For backward compatibility.
113    pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
114        CrossEntropyLossConfig::new()
115            .with_pad_tokens(pad_index.map(|e| vec![e]))
116            .init(device)
117    }
118
119    /// Compute the criterion on the input tensor.
120    ///
121    /// # Shapes
122    ///
123    /// - logits: `[batch_size, num_targets]`
124    /// - targets: `[batch_size]`
125    pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
126        Self::assertions(logits.clone(), targets.clone());
127        match self.smoothing {
128            Some(alpha) => self.forward_smoothed(logits, targets, alpha),
129            _ => self.forward_default(logits, targets),
130        }
131    }
132
133    fn forward_smoothed(
134        &self,
135        logits: Tensor<B, 2>,
136        targets: Tensor<B, 1, Int>,
137        alpha: f32,
138    ) -> Tensor<B, 1> {
139        let mask = self.padding_mask(&targets);
140        let tensor = if self.logits {
141            log_softmax(logits, 1)
142        } else {
143            logits.log()
144        };
145        let [batch_size, nr_classes] = tensor.dims();
146        let tensor = tensor
147            * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
148
149        match &self.weights {
150            Some(weights) => {
151                let tensor = tensor
152                    * weights
153                        .clone()
154                        .reshape([1, nr_classes])
155                        .repeat_dim(0, batch_size);
156                let weights = weights.clone().gather(0, targets);
157                let tensor = Self::apply_mask_2d(tensor, mask);
158                tensor.sum().neg() / weights.sum()
159            }
160            None => {
161                let tensor = Self::apply_mask_2d(tensor, mask);
162                tensor.sum_dim(1).mean().neg()
163            }
164        }
165    }
166
167    fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
168        let [batch_size] = targets.dims();
169
170        let mask = self.padding_mask(&targets);
171        let tensor = log_softmax(logits, 1);
172        let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
173
174        match &self.weights {
175            Some(weights) => {
176                let weights = weights.clone().gather(0, targets);
177                let tensor = tensor.reshape([batch_size]) * weights.clone();
178                let tensor = Self::apply_mask_1d(tensor, mask);
179                tensor.sum().neg() / weights.sum()
180            }
181            None => {
182                let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
183                tensor.mean().neg()
184            }
185        }
186    }
187
188    fn compute_smoothed_targets(
189        shape: [usize; 2],
190        targets: Tensor<B, 1, Int>,
191        alpha: f32,
192    ) -> Tensor<B, 2> {
193        let [batch_size, nr_classes] = shape;
194        let device = &targets.device();
195        let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
196            1,
197            targets.reshape([batch_size, 1]),
198            Tensor::ones([batch_size, 1], device),
199        );
200        targets_matrix * (1. - alpha) + alpha / nr_classes as f32
201    }
202
203    fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
204        let mut mask = None;
205        if let Some(pad_tokens) = &self.pad_tokens {
206            let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
207            for x in pad_tokens {
208                res = res + targets.clone().equal_elem(*x as i64).int();
209            }
210            mask = Some(res.greater_elem(0));
211        }
212
213        mask
214    }
215
216    fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
217        if let Some(mask) = mask {
218            tensor = tensor.mask_fill(mask, 0);
219        }
220
221        tensor
222    }
223
224    fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
225        if let Some(mask) = mask {
226            let [batch_size, nr_classes] = tensor.dims();
227            tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
228        }
229
230        tensor
231    }
232
233    fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
234        let [logits_height, _] = logits.dims();
235        let [targets_height] = targets.dims();
236        assert!(
237            logits_height == targets_height,
238            "Shape of targets ({}) should correspond to outer shape of logits ({}).",
239            targets_height,
240            logits_height
241        );
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::TestBackend;
249    use crate::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
250    use burn_tensor::{Tolerance, ops::FloatElem};
251    type FT = FloatElem<TestBackend>;
252
253    macro_rules! setup {
254        () => {{
255            let [batch_size, num_targets] = [4, 5];
256            let device = Default::default();
257            let logits = Tensor::<TestBackend, 2>::random(
258                [batch_size, num_targets],
259                Distribution::Normal(0., 1.0),
260                &device,
261            );
262            let targets =
263                Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
264            let targets_logits = Tensor::<TestBackend, 2>::from_data(
265                TensorData::from([
266                    [0.0, 0.0, 1.0, 0.0, 0.0],
267                    [1.0, 0.0, 0.0, 0.0, 0.0],
268                    [0.0, 0.0, 0.0, 0.0, 1.0],
269                    [0.0, 1.0, 0.0, 0.0, 0.0],
270                ]),
271                &device,
272            );
273            (logits, targets, targets_logits)
274        }};
275    }
276
277    macro_rules! setup_padded {
278        () => {{
279            let [batch_size, num_targets, pad_index] = [4, 5, 1];
280            let device = Default::default();
281            let logits = Tensor::<TestBackend, 2>::random(
282                [batch_size, num_targets],
283                Distribution::Normal(0., 1.0),
284                &device,
285            );
286            let targets = Tensor::<TestBackend, 1, Int>::from_data(
287                TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
288                &device,
289            );
290            let targets_logits = Tensor::<TestBackend, 2>::from_data(
291                TensorData::from([
292                    [0.0, 0.0, 0.0, 0.0, 0.0],
293                    [1.0, 0.0, 0.0, 0.0, 0.0],
294                    [0.0, 0.0, 0.0, 0.0, 1.0],
295                    [0.0, 0.0, 0.0, 0.0, 0.0],
296                ]),
297                &device,
298            );
299            (logits, targets, targets_logits)
300        }};
301    }
302
303    #[test]
304    fn test_cross_entropy_loss_with_weights() {
305        let (logits, targets, targets_logits) = setup!();
306        let weights = vec![1.0, 2., 3., 4., 5.];
307        let device = Default::default();
308        let loss_1 = CrossEntropyLossConfig::new()
309            .with_weights(Some(weights.clone()))
310            .init(&device)
311            .forward(logits.clone(), targets);
312        let tensor = log_softmax(logits, 1);
313        let loss_2 = tensor
314            * targets_logits
315            * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
316                .unsqueeze()
317                .repeat_dim(0, 4);
318        let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
319        loss_1
320            .into_data()
321            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
322    }
323
324    #[test]
325    fn test_label_smoothing_with_weights_and_alpha_zero() {
326        let (logits, targets, _) = setup!();
327        let device = Default::default();
328        let weights = vec![1.0, 2., 3., 4., 5.];
329        let loss_1 = CrossEntropyLossConfig::new()
330            .with_weights(Some(weights.clone()))
331            .init(&device)
332            .forward(logits.clone(), targets.clone());
333        let loss_2 = CrossEntropyLossConfig::new()
334            .with_weights(Some(weights.clone()))
335            .with_smoothing(Some(0.))
336            .init(&device)
337            .forward(logits.clone(), targets);
338        loss_1
339            .into_data()
340            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
341    }
342
343    #[test]
344    fn test_cross_entropy_loss() {
345        let (logits, targets, targets_logits) = setup!();
346        let device = Default::default();
347        let loss_1 = CrossEntropyLossConfig::new()
348            .init(&device)
349            .forward(logits.clone(), targets);
350        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
351
352        loss_1
353            .into_data()
354            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
355    }
356
357    #[test]
358    fn test_label_smoothing_alpha_equal_zero() {
359        let (logits, targets, _) = setup!();
360        let device = Default::default();
361        let loss_1 = CrossEntropyLossConfig::new()
362            .init(&device)
363            .forward(logits.clone(), targets.clone());
364        let loss_2 = CrossEntropyLossConfig::new()
365            .with_smoothing(Some(0.))
366            .init(&device)
367            .forward(logits, targets);
368
369        loss_1
370            .into_data()
371            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
372    }
373
374    #[test]
375    fn test_cross_entropy_loss_with_pad_token() {
376        let (logits, targets, targets_logits) = setup_padded!();
377        let pad_index = 1;
378
379        let loss_1 = CrossEntropyLossConfig::new()
380            .with_pad_tokens(Some(vec![pad_index, 2]))
381            .init(&logits.device())
382            .forward(logits.clone(), targets);
383        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
384
385        loss_1
386            .into_data()
387            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
388    }
389
390    #[test]
391    fn test_label_smoothing_with_zero_alpha_and_pad_token() {
392        let (logits, targets, _) = setup_padded!();
393        let pad_index = 1;
394
395        let loss_1 = CrossEntropyLossConfig::new()
396            .with_pad_tokens(Some(vec![pad_index, 2]))
397            .init(&logits.device())
398            .forward(logits.clone(), targets.clone());
399        let loss_2 = CrossEntropyLossConfig::new()
400            .with_pad_tokens(Some(vec![pad_index, 2]))
401            .with_smoothing(Some(0.))
402            .init(&logits.device())
403            .forward(logits.clone(), targets);
404
405        loss_1
406            .into_data()
407            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
408    }
409
410    #[test]
411    fn test_label_smoothing_target_conversion() {
412        let (logits, targets, _) = setup!();
413        let smoothed_targets =
414            CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
415        let targets_logits = Tensor::<TestBackend, 2>::from_data(
416            TensorData::from([
417                [0.01, 0.01, 0.96, 0.01, 0.01],
418                [0.96, 0.01, 0.01, 0.01, 0.01],
419                [0.01, 0.01, 0.01, 0.01, 0.96],
420                [0.01, 0.96, 0.01, 0.01, 0.01],
421            ]),
422            &Default::default(),
423        );
424        smoothed_targets
425            .into_data()
426            .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
427    }
428
429    #[test]
430    fn test_label_smoothing() {
431        let (logits, targets, _) = setup!();
432        let device = Default::default();
433        let loss_1 = CrossEntropyLossConfig::new()
434            .with_smoothing(Some(0.05))
435            .init(&device)
436            .forward(logits.clone(), targets);
437        let targets_logits = Tensor::<TestBackend, 2>::from_data(
438            TensorData::from([
439                [0.01, 0.01, 0.96, 0.01, 0.01],
440                [0.96, 0.01, 0.01, 0.01, 0.01],
441                [0.01, 0.01, 0.01, 0.01, 0.96],
442                [0.01, 0.96, 0.01, 0.01, 0.01],
443            ]),
444            &device,
445        );
446
447        let x = log_softmax(logits, 1);
448        let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
449
450        loss_1
451            .into_data()
452            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
453    }
454
455    #[test]
456    fn display() {
457        let config = CrossEntropyLossConfig::new()
458            .with_weights(Some(alloc::vec![3., 7., 0.9]))
459            .with_smoothing(Some(0.5));
460        let loss = config.init::<TestBackend>(&Default::default());
461
462        assert_eq!(
463            alloc::format!("{}", loss),
464            "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
465        );
466    }
467}