burn_nn/loss/
cross_entropy.rs

1use burn_core as burn;
2use burn_core::tensor::IndexingUpdateOp;
3
4use alloc::string::ToString;
5use alloc::vec;
6use alloc::vec::Vec;
7use burn::module::{Content, DisplaySettings, ModuleDisplay};
8use burn::tensor::activation::log_softmax;
9use burn::tensor::{Bool, Int, Tensor, backend::Backend};
10use burn::{config::Config, module::Module};
11
12/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
13#[derive(Config, Debug)]
14pub struct CrossEntropyLossConfig {
15    /// Create padded cross entropy.
16    ///
17    /// Prevents pad tokens from impacting loss calculation.
18    pub pad_tokens: Option<Vec<usize>>,
19
20    /// Create weighted cross-entropy.
21    ///
22    /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
23    ///
24    /// # Pre-conditions
25    ///   - The order of the weight vector should correspond to the label integer assignment.
26    ///   - Targets assigned negative Int's will not be allowed.
27    pub weights: Option<Vec<f32>>,
28
29    /// Create cross-entropy with label smoothing.
30    ///
31    /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
32    /// Alpha = 0 would be the same as default.
33    pub smoothing: Option<f32>,
34
35    /// Create cross-entropy with probabilities as input instead of logits.
36    ///
37    #[config(default = true)]
38    pub logits: bool,
39}
40
41impl CrossEntropyLossConfig {
42    /// Initialize [Cross-entropy loss](CrossEntropyLoss).
43    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
44        self.assertions();
45        CrossEntropyLoss {
46            pad_tokens: self.pad_tokens.clone(),
47            weights: self
48                .weights
49                .as_ref()
50                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
51            smoothing: self.smoothing,
52            logits: self.logits,
53        }
54    }
55
56    fn assertions(&self) {
57        if let Some(alpha) = self.smoothing {
58            assert!(
59                (0.0..=1.).contains(&alpha),
60                "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
61            );
62        };
63        if let Some(weights) = self.weights.as_ref() {
64            assert!(
65                weights.iter().all(|e| e > &0.),
66                "Weights of cross-entropy have to be positive."
67            );
68        }
69    }
70}
71
72/// Calculate the cross entropy loss from the input logits and the targets.
73///
74/// Should be created using [CrossEntropyLossConfig]
75#[derive(Module, Debug)]
76#[module(custom_display)]
77pub struct CrossEntropyLoss<B: Backend> {
78    /// Pad tokens to ignore in the loss calculation.
79    pub pad_tokens: Option<Vec<usize>>,
80    /// Weights for cross-entropy.
81    pub weights: Option<Tensor<B, 1>>,
82    /// Label smoothing factor.
83    pub smoothing: Option<f32>,
84    /// Use logits as input.
85    pub logits: bool,
86}
87
88impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
89    fn custom_settings(&self) -> Option<DisplaySettings> {
90        DisplaySettings::new()
91            .with_new_line_after_attribute(false)
92            .optional()
93    }
94
95    fn custom_content(&self, content: Content) -> Option<Content> {
96        let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
97            alloc::format!("Vec<0..{}>", pad_tokens.len())
98        } else {
99            "None".to_string()
100        };
101
102        content
103            .add("pad_tokens", &pad_tokens)
104            .add("weights", &self.weights)
105            .add("smoothing", &self.smoothing)
106            .add("logits", &self.logits)
107            .optional()
108    }
109}
110
111impl<B: Backend> CrossEntropyLoss<B> {
112    /// For backward compatibility.
113    pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
114        CrossEntropyLossConfig::new()
115            .with_pad_tokens(pad_index.map(|e| vec![e]))
116            .init(device)
117    }
118
119    /// Compute the criterion on the input tensor.
120    ///
121    /// # Shapes
122    ///
123    /// - logits: `[batch_size, num_targets]`
124    /// - targets: `[batch_size]`
125    pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
126        Self::assertions(logits.clone(), targets.clone());
127        match self.smoothing {
128            Some(alpha) => self.forward_smoothed(logits, targets, alpha),
129            _ => self.forward_default(logits, targets),
130        }
131    }
132
133    fn forward_smoothed(
134        &self,
135        logits: Tensor<B, 2>,
136        targets: Tensor<B, 1, Int>,
137        alpha: f32,
138    ) -> Tensor<B, 1> {
139        let mask = self.padding_mask(&targets);
140        let tensor = if self.logits {
141            log_softmax(logits, 1)
142        } else {
143            logits.log()
144        };
145        let [batch_size, nr_classes] = tensor.dims();
146        let tensor = tensor
147            * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
148
149        match &self.weights {
150            Some(weights) => {
151                let tensor = tensor
152                    * weights
153                        .clone()
154                        .reshape([1, nr_classes])
155                        .repeat_dim(0, batch_size);
156                let weights = weights.clone().gather(0, targets);
157                let tensor = Self::apply_mask_2d(tensor, mask);
158                tensor.sum().neg() / weights.sum()
159            }
160            None => {
161                let tensor = Self::apply_mask_2d(tensor, mask);
162                tensor.sum_dim(1).mean().neg()
163            }
164        }
165    }
166
167    fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
168        let [batch_size] = targets.dims();
169
170        let mask = self.padding_mask(&targets);
171        let tensor = log_softmax(logits, 1);
172        let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
173
174        match &self.weights {
175            Some(weights) => {
176                let weights = weights.clone().gather(0, targets);
177                let tensor = tensor.reshape([batch_size]) * weights.clone();
178                let tensor = Self::apply_mask_1d(tensor, mask);
179                tensor.sum().neg() / weights.sum()
180            }
181            None => {
182                let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
183                tensor.mean().neg()
184            }
185        }
186    }
187
188    fn compute_smoothed_targets(
189        shape: [usize; 2],
190        targets: Tensor<B, 1, Int>,
191        alpha: f32,
192    ) -> Tensor<B, 2> {
193        let [batch_size, nr_classes] = shape;
194        let device = &targets.device();
195        let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
196            1,
197            targets.reshape([batch_size, 1]),
198            Tensor::ones([batch_size, 1], device),
199            IndexingUpdateOp::Add,
200        );
201        targets_matrix * (1. - alpha) + alpha / nr_classes as f32
202    }
203
204    fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
205        let mut mask = None;
206        if let Some(pad_tokens) = &self.pad_tokens {
207            let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
208            for x in pad_tokens {
209                res = res + targets.clone().equal_elem(*x as i64).int();
210            }
211            mask = Some(res.greater_elem(0));
212        }
213
214        mask
215    }
216
217    fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
218        if let Some(mask) = mask {
219            tensor = tensor.mask_fill(mask, 0);
220        }
221
222        tensor
223    }
224
225    fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
226        if let Some(mask) = mask {
227            let [batch_size, nr_classes] = tensor.dims();
228            tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
229        }
230
231        tensor
232    }
233
234    fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
235        let [logits_height, _] = logits.dims();
236        let [targets_height] = targets.dims();
237        assert!(
238            logits_height == targets_height,
239            "Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})."
240        );
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::TestBackend;
248    use burn::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
249    use burn::tensor::{Tolerance, ops::FloatElem};
250    type FT = FloatElem<TestBackend>;
251
252    macro_rules! setup {
253        () => {{
254            let [batch_size, num_targets] = [4, 5];
255            let device = Default::default();
256            let logits = Tensor::<TestBackend, 2>::random(
257                [batch_size, num_targets],
258                Distribution::Normal(0., 1.0),
259                &device,
260            );
261            let targets =
262                Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
263            let targets_logits = Tensor::<TestBackend, 2>::from_data(
264                TensorData::from([
265                    [0.0, 0.0, 1.0, 0.0, 0.0],
266                    [1.0, 0.0, 0.0, 0.0, 0.0],
267                    [0.0, 0.0, 0.0, 0.0, 1.0],
268                    [0.0, 1.0, 0.0, 0.0, 0.0],
269                ]),
270                &device,
271            );
272            (logits, targets, targets_logits)
273        }};
274    }
275
276    macro_rules! setup_padded {
277        () => {{
278            let [batch_size, num_targets, pad_index] = [4, 5, 1];
279            let device = Default::default();
280            let logits = Tensor::<TestBackend, 2>::random(
281                [batch_size, num_targets],
282                Distribution::Normal(0., 1.0),
283                &device,
284            );
285            let targets = Tensor::<TestBackend, 1, Int>::from_data(
286                TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
287                &device,
288            );
289            let targets_logits = Tensor::<TestBackend, 2>::from_data(
290                TensorData::from([
291                    [0.0, 0.0, 0.0, 0.0, 0.0],
292                    [1.0, 0.0, 0.0, 0.0, 0.0],
293                    [0.0, 0.0, 0.0, 0.0, 1.0],
294                    [0.0, 0.0, 0.0, 0.0, 0.0],
295                ]),
296                &device,
297            );
298            (logits, targets, targets_logits)
299        }};
300    }
301
302    #[test]
303    fn test_cross_entropy_loss_with_weights() {
304        let (logits, targets, targets_logits) = setup!();
305        let weights = vec![1.0, 2., 3., 4., 5.];
306        let device = Default::default();
307        let loss_1 = CrossEntropyLossConfig::new()
308            .with_weights(Some(weights.clone()))
309            .init(&device)
310            .forward(logits.clone(), targets);
311        let tensor = log_softmax(logits, 1);
312        let loss_2 = tensor
313            * targets_logits
314            * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
315                .unsqueeze()
316                .repeat_dim(0, 4);
317        let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
318        loss_1
319            .into_data()
320            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
321    }
322
323    #[test]
324    fn test_label_smoothing_with_weights_and_alpha_zero() {
325        let (logits, targets, _) = setup!();
326        let device = Default::default();
327        let weights = vec![1.0, 2., 3., 4., 5.];
328        let loss_1 = CrossEntropyLossConfig::new()
329            .with_weights(Some(weights.clone()))
330            .init(&device)
331            .forward(logits.clone(), targets.clone());
332        let loss_2 = CrossEntropyLossConfig::new()
333            .with_weights(Some(weights.clone()))
334            .with_smoothing(Some(0.))
335            .init(&device)
336            .forward(logits.clone(), targets);
337        loss_1
338            .into_data()
339            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
340    }
341
342    #[test]
343    fn test_cross_entropy_loss() {
344        let (logits, targets, targets_logits) = setup!();
345        let device = Default::default();
346        let loss_1 = CrossEntropyLossConfig::new()
347            .init(&device)
348            .forward(logits.clone(), targets);
349        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
350
351        loss_1
352            .into_data()
353            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
354    }
355
356    #[test]
357    fn test_label_smoothing_alpha_equal_zero() {
358        let (logits, targets, _) = setup!();
359        let device = Default::default();
360        let loss_1 = CrossEntropyLossConfig::new()
361            .init(&device)
362            .forward(logits.clone(), targets.clone());
363        let loss_2 = CrossEntropyLossConfig::new()
364            .with_smoothing(Some(0.))
365            .init(&device)
366            .forward(logits, targets);
367
368        loss_1
369            .into_data()
370            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
371    }
372
373    #[test]
374    fn test_cross_entropy_loss_with_pad_token() {
375        let (logits, targets, targets_logits) = setup_padded!();
376        let pad_index = 1;
377
378        let loss_1 = CrossEntropyLossConfig::new()
379            .with_pad_tokens(Some(vec![pad_index, 2]))
380            .init(&logits.device())
381            .forward(logits.clone(), targets);
382        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
383
384        loss_1
385            .into_data()
386            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
387    }
388
389    #[test]
390    fn test_label_smoothing_with_zero_alpha_and_pad_token() {
391        let (logits, targets, _) = setup_padded!();
392        let pad_index = 1;
393
394        let loss_1 = CrossEntropyLossConfig::new()
395            .with_pad_tokens(Some(vec![pad_index, 2]))
396            .init(&logits.device())
397            .forward(logits.clone(), targets.clone());
398        let loss_2 = CrossEntropyLossConfig::new()
399            .with_pad_tokens(Some(vec![pad_index, 2]))
400            .with_smoothing(Some(0.))
401            .init(&logits.device())
402            .forward(logits.clone(), targets);
403
404        loss_1
405            .into_data()
406            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
407    }
408
409    #[test]
410    fn test_label_smoothing_target_conversion() {
411        let (logits, targets, _) = setup!();
412        let smoothed_targets =
413            CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
414        let targets_logits = Tensor::<TestBackend, 2>::from_data(
415            TensorData::from([
416                [0.01, 0.01, 0.96, 0.01, 0.01],
417                [0.96, 0.01, 0.01, 0.01, 0.01],
418                [0.01, 0.01, 0.01, 0.01, 0.96],
419                [0.01, 0.96, 0.01, 0.01, 0.01],
420            ]),
421            &Default::default(),
422        );
423        smoothed_targets
424            .into_data()
425            .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
426    }
427
428    #[test]
429    fn test_label_smoothing() {
430        let (logits, targets, _) = setup!();
431        let device = Default::default();
432        let loss_1 = CrossEntropyLossConfig::new()
433            .with_smoothing(Some(0.05))
434            .init(&device)
435            .forward(logits.clone(), targets);
436        let targets_logits = Tensor::<TestBackend, 2>::from_data(
437            TensorData::from([
438                [0.01, 0.01, 0.96, 0.01, 0.01],
439                [0.96, 0.01, 0.01, 0.01, 0.01],
440                [0.01, 0.01, 0.01, 0.01, 0.96],
441                [0.01, 0.96, 0.01, 0.01, 0.01],
442            ]),
443            &device,
444        );
445
446        let x = log_softmax(logits, 1);
447        let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
448
449        loss_1
450            .into_data()
451            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
452    }
453
454    #[test]
455    fn display() {
456        let config = CrossEntropyLossConfig::new()
457            .with_weights(Some(alloc::vec![3., 7., 0.9]))
458            .with_smoothing(Some(0.5));
459        let loss = config.init::<TestBackend>(&Default::default());
460
461        assert_eq!(
462            alloc::format!("{loss}"),
463            "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
464        );
465    }
466}