burn_core/nn/loss/
cross_entropy.rs

1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::activation::log_softmax;
5use crate::tensor::{backend::Backend, Bool, Int, Tensor};
6use crate::{config::Config, module::Module};
7use alloc::string::ToString;
8use alloc::vec;
9use alloc::vec::Vec;
10
11/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
12#[derive(Config, Debug)]
13pub struct CrossEntropyLossConfig {
14    /// Create padded cross entropy.
15    ///
16    /// Prevents pad tokens from impacting loss calculation.
17    pub pad_tokens: Option<Vec<usize>>,
18
19    /// Create weighted cross-entropy.
20    ///
21    /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
22    ///
23    /// # Pre-conditions
24    ///   - The order of the weight vector should correspond to the label integer assignment.
25    ///   - Targets assigned negative Int's will not be allowed.
26    pub weights: Option<Vec<f32>>,
27
28    /// Create cross-entropy with label smoothing.
29    ///
30    /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
31    /// Alpha = 0 would be the same as default.
32    pub smoothing: Option<f32>,
33
34    /// Create cross-entropy with probabilities as input instead of logits.
35    ///
36    #[config(default = true)]
37    pub logits: bool,
38}
39
40impl CrossEntropyLossConfig {
41    /// Initialize [Cross-entropy loss](CrossEntropyLoss).
42    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
43        self.assertions();
44        CrossEntropyLoss {
45            pad_tokens: self.pad_tokens.clone(),
46            weights: self
47                .weights
48                .as_ref()
49                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
50            smoothing: self.smoothing,
51            logits: self.logits,
52        }
53    }
54
55    fn assertions(&self) {
56        if let Some(alpha) = self.smoothing {
57            assert!(
58                (0.0..=1.).contains(&alpha),
59                "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {}",
60                alpha
61            );
62        };
63        if let Some(weights) = self.weights.as_ref() {
64            assert!(
65                weights.iter().all(|e| e > &0.),
66                "Weights of cross-entropy have to be positive."
67            );
68        }
69    }
70}
71
72/// Calculate the cross entropy loss from the input logits and the targets.
73///
74/// Should be created using [CrossEntropyLossConfig]
75#[derive(Module, Debug)]
76#[module(custom_display)]
77pub struct CrossEntropyLoss<B: Backend> {
78    /// Pad tokens to ignore in the loss calculation.
79    pub pad_tokens: Option<Vec<usize>>,
80    /// Weights for cross-entropy.
81    pub weights: Option<Tensor<B, 1>>,
82    /// Label smoothing factor.
83    pub smoothing: Option<f32>,
84    /// Use logits as input.
85    pub logits: bool,
86}
87
88impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
89    fn custom_settings(&self) -> Option<DisplaySettings> {
90        DisplaySettings::new()
91            .with_new_line_after_attribute(false)
92            .optional()
93    }
94
95    fn custom_content(&self, content: Content) -> Option<Content> {
96        let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
97            alloc::format!("Vec<0..{}>", pad_tokens.len())
98        } else {
99            "None".to_string()
100        };
101
102        content
103            .add("pad_tokens", &pad_tokens)
104            .add("weights", &self.weights)
105            .add("smoothing", &self.smoothing)
106            .add("logits", &self.logits)
107            .optional()
108    }
109}
110
111impl<B: Backend> CrossEntropyLoss<B> {
112    /// For backward compatibility.
113    pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
114        CrossEntropyLossConfig::new()
115            .with_pad_tokens(pad_index.map(|e| vec![e]))
116            .init(device)
117    }
118
119    /// Compute the criterion on the input tensor.
120    ///
121    /// # Shapes
122    ///
123    /// - logits: `[batch_size, num_targets]`
124    /// - targets: `[batch_size]`
125    pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
126        Self::assertions(logits.clone(), targets.clone());
127        match self.smoothing {
128            Some(alpha) => self.forward_smoothed(logits, targets, alpha),
129            _ => self.forward_default(logits, targets),
130        }
131    }
132
133    fn forward_smoothed(
134        &self,
135        logits: Tensor<B, 2>,
136        targets: Tensor<B, 1, Int>,
137        alpha: f32,
138    ) -> Tensor<B, 1> {
139        let mask = self.padding_mask(&targets);
140        let tensor = if self.logits {
141            log_softmax(logits, 1)
142        } else {
143            logits.log()
144        };
145        let [batch_size, nr_classes] = tensor.dims();
146        let tensor = tensor
147            * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
148
149        match &self.weights {
150            Some(weights) => {
151                let tensor = tensor
152                    * weights
153                        .clone()
154                        .reshape([1, nr_classes])
155                        .repeat_dim(0, batch_size);
156                let weights = weights.clone().gather(0, targets);
157                let tensor = Self::apply_mask_2d(tensor, mask);
158                tensor.sum().neg() / weights.sum()
159            }
160            None => {
161                let tensor = Self::apply_mask_2d(tensor, mask);
162                tensor.sum_dim(1).mean().neg()
163            }
164        }
165    }
166
167    fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
168        let [batch_size] = targets.dims();
169
170        let mask = self.padding_mask(&targets);
171        let tensor = log_softmax(logits, 1);
172        let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
173
174        match &self.weights {
175            Some(weights) => {
176                let weights = weights.clone().gather(0, targets);
177                let tensor = tensor.reshape([batch_size]) * weights.clone();
178                let tensor = Self::apply_mask_1d(tensor, mask);
179                tensor.sum().neg() / weights.sum()
180            }
181            None => {
182                let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
183                tensor.mean().neg()
184            }
185        }
186    }
187
188    fn compute_smoothed_targets(
189        shape: [usize; 2],
190        targets: Tensor<B, 1, Int>,
191        alpha: f32,
192    ) -> Tensor<B, 2> {
193        let [batch_size, nr_classes] = shape;
194        let device = &targets.device();
195        let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
196            1,
197            targets.reshape([batch_size, 1]),
198            Tensor::ones([batch_size, 1], device),
199        );
200        targets_matrix * (1. - alpha) + alpha / nr_classes as f32
201    }
202
203    fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
204        let mut mask = None;
205        if let Some(pad_tokens) = &self.pad_tokens {
206            let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
207            for x in pad_tokens {
208                res = res + targets.clone().equal_elem(*x as i64).int();
209            }
210            mask = Some(res.greater_elem(0));
211        }
212
213        mask
214    }
215
216    fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
217        if let Some(mask) = mask {
218            tensor = tensor.mask_fill(mask, 0);
219        }
220
221        tensor
222    }
223
224    fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
225        if let Some(mask) = mask {
226            let [batch_size, nr_classes] = tensor.dims();
227            tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
228        }
229
230        tensor
231    }
232
233    fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
234        let [logits_height, _] = logits.dims();
235        let [targets_height] = targets.dims();
236        assert!(
237            logits_height == targets_height,
238            "Shape of targets ({}) should correspond to outer shape of logits ({}).",
239            targets_height,
240            logits_height
241        );
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::tensor::{loss::cross_entropy_with_logits, ops::IntElem, Distribution, TensorData};
249    use crate::TestBackend;
250
251    macro_rules! setup {
252        () => {{
253            let [batch_size, num_targets] = [4, 5];
254            let device = Default::default();
255            let logits = Tensor::<TestBackend, 2>::random(
256                [batch_size, num_targets],
257                Distribution::Normal(0., 1.0),
258                &device,
259            );
260            let targets =
261                Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
262            let targets_logits = Tensor::<TestBackend, 2>::from_data(
263                TensorData::from([
264                    [0.0, 0.0, 1.0, 0.0, 0.0],
265                    [1.0, 0.0, 0.0, 0.0, 0.0],
266                    [0.0, 0.0, 0.0, 0.0, 1.0],
267                    [0.0, 1.0, 0.0, 0.0, 0.0],
268                ]),
269                &device,
270            );
271            (logits, targets, targets_logits)
272        }};
273    }
274
275    macro_rules! setup_padded {
276        () => {{
277            let [batch_size, num_targets, pad_index] = [4, 5, 1];
278            let device = Default::default();
279            let logits = Tensor::<TestBackend, 2>::random(
280                [batch_size, num_targets],
281                Distribution::Normal(0., 1.0),
282                &device,
283            );
284            let targets = Tensor::<TestBackend, 1, Int>::from_data(
285                TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
286                &device,
287            );
288            let targets_logits = Tensor::<TestBackend, 2>::from_data(
289                TensorData::from([
290                    [0.0, 0.0, 0.0, 0.0, 0.0],
291                    [1.0, 0.0, 0.0, 0.0, 0.0],
292                    [0.0, 0.0, 0.0, 0.0, 1.0],
293                    [0.0, 0.0, 0.0, 0.0, 0.0],
294                ]),
295                &device,
296            );
297            (logits, targets, targets_logits)
298        }};
299    }
300
301    #[test]
302    fn test_cross_entropy_loss_with_weights() {
303        let (logits, targets, targets_logits) = setup!();
304        let weights = vec![1.0, 2., 3., 4., 5.];
305        let device = Default::default();
306        let loss_1 = CrossEntropyLossConfig::new()
307            .with_weights(Some(weights.clone()))
308            .init(&device)
309            .forward(logits.clone(), targets);
310        let tensor = log_softmax(logits, 1);
311        let loss_2 = tensor
312            * targets_logits
313            * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
314                .unsqueeze()
315                .repeat_dim(0, 4);
316        let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
317        loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
318    }
319
320    #[test]
321    fn test_label_smoothing_with_weights_and_alpha_zero() {
322        let (logits, targets, _) = setup!();
323        let device = Default::default();
324        let weights = vec![1.0, 2., 3., 4., 5.];
325        let loss_1 = CrossEntropyLossConfig::new()
326            .with_weights(Some(weights.clone()))
327            .init(&device)
328            .forward(logits.clone(), targets.clone());
329        let loss_2 = CrossEntropyLossConfig::new()
330            .with_weights(Some(weights.clone()))
331            .with_smoothing(Some(0.))
332            .init(&device)
333            .forward(logits.clone(), targets);
334        loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
335    }
336
337    #[test]
338    fn test_cross_entropy_loss() {
339        let (logits, targets, targets_logits) = setup!();
340        let device = Default::default();
341        let loss_1 = CrossEntropyLossConfig::new()
342            .init(&device)
343            .forward(logits.clone(), targets);
344        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
345
346        loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
347    }
348
349    #[test]
350    fn test_label_smoothing_alpha_equal_zero() {
351        let (logits, targets, _) = setup!();
352        let device = Default::default();
353        let loss_1 = CrossEntropyLossConfig::new()
354            .init(&device)
355            .forward(logits.clone(), targets.clone());
356        let loss_2 = CrossEntropyLossConfig::new()
357            .with_smoothing(Some(0.))
358            .init(&device)
359            .forward(logits, targets);
360
361        loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
362    }
363
364    #[test]
365    fn test_cross_entropy_loss_with_pad_token() {
366        let (logits, targets, targets_logits) = setup_padded!();
367        let pad_index = 1;
368
369        let loss_1 = CrossEntropyLossConfig::new()
370            .with_pad_tokens(Some(vec![pad_index, 2]))
371            .init(&logits.device())
372            .forward(logits.clone(), targets);
373        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
374
375        loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
376    }
377
378    #[test]
379    fn test_label_smoothing_with_zero_alpha_and_pad_token() {
380        let (logits, targets, _) = setup_padded!();
381        let pad_index = 1;
382
383        let loss_1 = CrossEntropyLossConfig::new()
384            .with_pad_tokens(Some(vec![pad_index, 2]))
385            .init(&logits.device())
386            .forward(logits.clone(), targets.clone());
387        let loss_2 = CrossEntropyLossConfig::new()
388            .with_pad_tokens(Some(vec![pad_index, 2]))
389            .with_smoothing(Some(0.))
390            .init(&logits.device())
391            .forward(logits.clone(), targets);
392
393        loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
394    }
395
396    #[test]
397    fn test_label_smoothing_target_conversion() {
398        let (logits, targets, _) = setup!();
399        let smoothed_targets =
400            CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
401        let targets_logits = Tensor::<TestBackend, 2>::from_data(
402            TensorData::from([
403                [0.01, 0.01, 0.96, 0.01, 0.01],
404                [0.96, 0.01, 0.01, 0.01, 0.01],
405                [0.01, 0.01, 0.01, 0.01, 0.96],
406                [0.01, 0.96, 0.01, 0.01, 0.01],
407            ]),
408            &Default::default(),
409        );
410        smoothed_targets
411            .into_data()
412            .assert_approx_eq(&targets_logits.into_data(), 3);
413    }
414
415    #[test]
416    fn test_label_smoothing() {
417        let (logits, targets, _) = setup!();
418        let device = Default::default();
419        let loss_1 = CrossEntropyLossConfig::new()
420            .with_smoothing(Some(0.05))
421            .init(&device)
422            .forward(logits.clone(), targets);
423        let targets_logits = Tensor::<TestBackend, 2>::from_data(
424            TensorData::from([
425                [0.01, 0.01, 0.96, 0.01, 0.01],
426                [0.96, 0.01, 0.01, 0.01, 0.01],
427                [0.01, 0.01, 0.01, 0.01, 0.96],
428                [0.01, 0.96, 0.01, 0.01, 0.01],
429            ]),
430            &device,
431        );
432
433        let x = log_softmax(logits, 1);
434        let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
435
436        loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
437    }
438
439    #[test]
440    fn display() {
441        let config = CrossEntropyLossConfig::new()
442            .with_weights(Some(alloc::vec![3., 7., 0.9]))
443            .with_smoothing(Some(0.5));
444        let loss = config.init::<TestBackend>(&Default::default());
445
446        assert_eq!(
447            alloc::format!("{}", loss),
448            "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
449        );
450    }
451}