Skip to main content

burn_nn/loss/
cross_entropy.rs

1use burn_core as burn;
2use burn_core::tensor::IndexingUpdateOp;
3
4use alloc::string::ToString;
5use alloc::vec;
6use alloc::vec::Vec;
7use burn::module::{Content, DisplaySettings, ModuleDisplay};
8use burn::tensor::activation::log_softmax;
9use burn::tensor::{Bool, Int, Tensor, backend::Backend};
10use burn::{config::Config, module::Module};
11
12#[cfg(not(feature = "std"))]
13#[allow(unused_imports)]
14use num_traits::Float;
15
16/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
17#[derive(Config, Debug)]
18pub struct CrossEntropyLossConfig {
19    /// Create padded cross entropy.
20    ///
21    /// Prevents pad tokens from impacting loss calculation.
22    pub pad_tokens: Option<Vec<usize>>,
23
24    /// Create weighted cross-entropy.
25    ///
26    /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
27    ///
28    /// # Pre-conditions
29    ///   - The order of the weight vector should correspond to the label integer assignment.
30    ///   - Targets assigned negative Int's will not be allowed.
31    pub weights: Option<Vec<f32>>,
32
33    /// Create cross-entropy with label smoothing.
34    ///
35    /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
36    /// Alpha = 0 would be the same as default.
37    pub smoothing: Option<f32>,
38
39    /// Create cross-entropy with probabilities as input instead of logits.
40    ///
41    #[config(default = true)]
42    pub logits: bool,
43}
44
45impl CrossEntropyLossConfig {
46    /// Initialize [Cross-entropy loss](CrossEntropyLoss).
47    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
48        self.assertions();
49        CrossEntropyLoss {
50            pad_tokens: self.pad_tokens.clone(),
51            weights: self
52                .weights
53                .as_ref()
54                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
55            smoothing: self.smoothing,
56            logits: self.logits,
57        }
58    }
59
60    fn assertions(&self) {
61        if let Some(alpha) = self.smoothing {
62            assert!(
63                (0.0..=1.).contains(&alpha),
64                "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
65            );
66        };
67        if let Some(weights) = self.weights.as_ref() {
68            assert!(
69                weights.iter().all(|e| e > &0.),
70                "Weights of cross-entropy have to be positive."
71            );
72        }
73    }
74}
75
76/// Calculate the cross entropy loss from the input logits and the targets.
77///
78/// Should be created using [CrossEntropyLossConfig]
79#[derive(Module, Debug)]
80#[module(custom_display)]
81pub struct CrossEntropyLoss<B: Backend> {
82    /// Pad tokens to ignore in the loss calculation.
83    pub pad_tokens: Option<Vec<usize>>,
84    /// Weights for cross-entropy.
85    pub weights: Option<Tensor<B, 1>>,
86    /// Label smoothing factor.
87    pub smoothing: Option<f32>,
88    /// Use logits as input.
89    pub logits: bool,
90}
91
92impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
93    fn custom_settings(&self) -> Option<DisplaySettings> {
94        DisplaySettings::new()
95            .with_new_line_after_attribute(false)
96            .optional()
97    }
98
99    fn custom_content(&self, content: Content) -> Option<Content> {
100        let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
101            alloc::format!("Vec<0..{}>", pad_tokens.len())
102        } else {
103            "None".to_string()
104        };
105
106        content
107            .add("pad_tokens", &pad_tokens)
108            .add("weights", &self.weights)
109            .add("smoothing", &self.smoothing)
110            .add("logits", &self.logits)
111            .optional()
112    }
113}
114
115impl<B: Backend> CrossEntropyLoss<B> {
116    /// For backward compatibility.
117    pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
118        CrossEntropyLossConfig::new()
119            .with_pad_tokens(pad_index.map(|e| vec![e]))
120            .init(device)
121    }
122
123    /// Compute the criterion on the input tensor.
124    ///
125    /// # Shapes
126    ///
127    /// - logits: `[batch_size, num_targets]`
128    /// - targets: `[batch_size]`
129    pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
130        Self::assertions(logits.clone(), targets.clone());
131        match self.smoothing {
132            Some(alpha) => self.forward_smoothed(logits, targets, alpha),
133            _ => self.forward_default(logits, targets),
134        }
135    }
136
137    fn forward_smoothed(
138        &self,
139        logits: Tensor<B, 2>,
140        targets: Tensor<B, 1, Int>,
141        alpha: f32,
142    ) -> Tensor<B, 1> {
143        let mask = self.padding_mask(&targets);
144        let tensor = if self.logits {
145            log_softmax(logits, 1)
146        } else {
147            logits.log()
148        };
149        let [batch_size, nr_classes] = tensor.dims();
150        let tensor = tensor
151            * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
152
153        match &self.weights {
154            Some(weights) => {
155                let tensor = tensor
156                    * weights
157                        .clone()
158                        .reshape([1, nr_classes])
159                        .repeat_dim(0, batch_size);
160                let weights = weights.clone().gather(0, targets);
161                let tensor = Self::apply_mask_2d(tensor, mask);
162                tensor.sum().neg() / weights.sum()
163            }
164            None => {
165                let tensor = Self::apply_mask_2d(tensor, mask);
166                tensor.sum_dim(1).mean().neg()
167            }
168        }
169    }
170
171    fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
172        let [batch_size] = targets.dims();
173
174        let mask = self.padding_mask(&targets);
175        let target_indices = targets.clone().reshape([batch_size, 1]);
176        let tensor = if self.logits {
177            log_softmax(logits, 1).gather(1, target_indices)
178        } else {
179            // TODO: finfo stable eps
180            let finfo = logits.dtype().finfo().unwrap();
181            let eps = finfo.min_positive.sqrt();
182            logits.clamp_min(eps).gather(1, target_indices).log()
183        };
184
185        match &self.weights {
186            Some(weights) => {
187                let weights = weights.clone().gather(0, targets);
188                let tensor = tensor.reshape([batch_size]) * weights.clone();
189                let tensor = Self::apply_mask_1d(tensor, mask);
190                tensor.sum().neg() / weights.sum()
191            }
192            None => {
193                let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
194                tensor.mean().neg()
195            }
196        }
197    }
198
199    fn compute_smoothed_targets(
200        shape: [usize; 2],
201        targets: Tensor<B, 1, Int>,
202        alpha: f32,
203    ) -> Tensor<B, 2> {
204        let [batch_size, nr_classes] = shape;
205        let device = &targets.device();
206        let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
207            1,
208            targets.reshape([batch_size, 1]),
209            Tensor::ones([batch_size, 1], device),
210            IndexingUpdateOp::Add,
211        );
212        targets_matrix * (1. - alpha) + alpha / nr_classes as f32
213    }
214
215    fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
216        let mut mask = None;
217        if let Some(pad_tokens) = &self.pad_tokens {
218            let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
219            for x in pad_tokens {
220                res = res + targets.clone().equal_elem(*x as i64).int();
221            }
222            mask = Some(res.greater_elem(0));
223        }
224
225        mask
226    }
227
228    fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
229        if let Some(mask) = mask {
230            tensor = tensor.mask_fill(mask, 0);
231        }
232
233        tensor
234    }
235
236    fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
237        if let Some(mask) = mask {
238            let [batch_size, nr_classes] = tensor.dims();
239            tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
240        }
241
242        tensor
243    }
244
245    fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
246        let [logits_height, _] = logits.dims();
247        let [targets_height] = targets.dims();
248        assert!(
249            logits_height == targets_height,
250            "Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})."
251        );
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::TestBackend;
259    use burn::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
260    use burn::tensor::{Tolerance, ops::FloatElem};
261    type FT = FloatElem<TestBackend>;
262
263    macro_rules! setup {
264        () => {{
265            let [batch_size, num_targets] = [4, 5];
266            let device = Default::default();
267            let logits = Tensor::<TestBackend, 2>::random(
268                [batch_size, num_targets],
269                Distribution::Normal(0., 1.0),
270                &device,
271            );
272            let targets =
273                Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
274            let targets_logits = Tensor::<TestBackend, 2>::from_data(
275                TensorData::from([
276                    [0.0, 0.0, 1.0, 0.0, 0.0],
277                    [1.0, 0.0, 0.0, 0.0, 0.0],
278                    [0.0, 0.0, 0.0, 0.0, 1.0],
279                    [0.0, 1.0, 0.0, 0.0, 0.0],
280                ]),
281                &device,
282            );
283            (logits, targets, targets_logits)
284        }};
285    }
286
287    macro_rules! setup_padded {
288        () => {{
289            let [batch_size, num_targets, pad_index] = [4, 5, 1];
290            let device = Default::default();
291            let logits = Tensor::<TestBackend, 2>::random(
292                [batch_size, num_targets],
293                Distribution::Normal(0., 1.0),
294                &device,
295            );
296            let targets = Tensor::<TestBackend, 1, Int>::from_data(
297                TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
298                &device,
299            );
300            let targets_logits = Tensor::<TestBackend, 2>::from_data(
301                TensorData::from([
302                    [0.0, 0.0, 0.0, 0.0, 0.0],
303                    [1.0, 0.0, 0.0, 0.0, 0.0],
304                    [0.0, 0.0, 0.0, 0.0, 1.0],
305                    [0.0, 0.0, 0.0, 0.0, 0.0],
306                ]),
307                &device,
308            );
309            (logits, targets, targets_logits)
310        }};
311    }
312
313    #[test]
314    fn test_cross_entropy_loss_with_weights() {
315        let (logits, targets, targets_logits) = setup!();
316        let weights = vec![1.0, 2., 3., 4., 5.];
317        let device = Default::default();
318        let loss_1 = CrossEntropyLossConfig::new()
319            .with_weights(Some(weights.clone()))
320            .init(&device)
321            .forward(logits.clone(), targets);
322        let tensor = log_softmax(logits, 1);
323        let loss_2 = tensor
324            * targets_logits
325            * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
326                .unsqueeze()
327                .repeat_dim(0, 4);
328        let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
329        loss_1
330            .into_data()
331            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
332    }
333
334    #[test]
335    fn test_label_smoothing_with_weights_and_alpha_zero() {
336        let (logits, targets, _) = setup!();
337        let device = Default::default();
338        let weights = vec![1.0, 2., 3., 4., 5.];
339        let loss_1 = CrossEntropyLossConfig::new()
340            .with_weights(Some(weights.clone()))
341            .init(&device)
342            .forward(logits.clone(), targets.clone());
343        let loss_2 = CrossEntropyLossConfig::new()
344            .with_weights(Some(weights.clone()))
345            .with_smoothing(Some(0.))
346            .init(&device)
347            .forward(logits.clone(), targets);
348        loss_1
349            .into_data()
350            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
351    }
352
353    #[test]
354    fn test_cross_entropy_loss() {
355        let (logits, targets, targets_logits) = setup!();
356        let device = Default::default();
357        let loss_1 = CrossEntropyLossConfig::new()
358            .init(&device)
359            .forward(logits.clone(), targets);
360        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
361
362        loss_1
363            .into_data()
364            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
365    }
366
367    #[test]
368    fn test_label_smoothing_alpha_equal_zero() {
369        let (logits, targets, _) = setup!();
370        let device = Default::default();
371        let loss_1 = CrossEntropyLossConfig::new()
372            .init(&device)
373            .forward(logits.clone(), targets.clone());
374        let loss_2 = CrossEntropyLossConfig::new()
375            .with_smoothing(Some(0.))
376            .init(&device)
377            .forward(logits, targets);
378
379        loss_1
380            .into_data()
381            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
382    }
383
384    #[test]
385    fn test_cross_entropy_loss_with_pad_token() {
386        let (logits, targets, targets_logits) = setup_padded!();
387        let pad_index = 1;
388
389        let loss_1 = CrossEntropyLossConfig::new()
390            .with_pad_tokens(Some(vec![pad_index, 2]))
391            .init(&logits.device())
392            .forward(logits.clone(), targets);
393        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
394
395        loss_1
396            .into_data()
397            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
398    }
399
400    #[test]
401    fn test_label_smoothing_with_zero_alpha_and_pad_token() {
402        let (logits, targets, _) = setup_padded!();
403        let pad_index = 1;
404
405        let loss_1 = CrossEntropyLossConfig::new()
406            .with_pad_tokens(Some(vec![pad_index, 2]))
407            .init(&logits.device())
408            .forward(logits.clone(), targets.clone());
409        let loss_2 = CrossEntropyLossConfig::new()
410            .with_pad_tokens(Some(vec![pad_index, 2]))
411            .with_smoothing(Some(0.))
412            .init(&logits.device())
413            .forward(logits.clone(), targets);
414
415        loss_1
416            .into_data()
417            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
418    }
419
420    #[test]
421    fn test_label_smoothing_target_conversion() {
422        let (logits, targets, _) = setup!();
423        let smoothed_targets =
424            CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
425        let targets_logits = Tensor::<TestBackend, 2>::from_data(
426            TensorData::from([
427                [0.01, 0.01, 0.96, 0.01, 0.01],
428                [0.96, 0.01, 0.01, 0.01, 0.01],
429                [0.01, 0.01, 0.01, 0.01, 0.96],
430                [0.01, 0.96, 0.01, 0.01, 0.01],
431            ]),
432            &Default::default(),
433        );
434        smoothed_targets
435            .into_data()
436            .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
437    }
438
439    #[test]
440    fn test_label_smoothing() {
441        let (logits, targets, _) = setup!();
442        let device = Default::default();
443        let loss_1 = CrossEntropyLossConfig::new()
444            .with_smoothing(Some(0.05))
445            .init(&device)
446            .forward(logits.clone(), targets);
447        let targets_logits = Tensor::<TestBackend, 2>::from_data(
448            TensorData::from([
449                [0.01, 0.01, 0.96, 0.01, 0.01],
450                [0.96, 0.01, 0.01, 0.01, 0.01],
451                [0.01, 0.01, 0.01, 0.01, 0.96],
452                [0.01, 0.96, 0.01, 0.01, 0.01],
453            ]),
454            &device,
455        );
456
457        let x = log_softmax(logits, 1);
458        let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
459
460        loss_1
461            .into_data()
462            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
463    }
464
465    #[test]
466    fn test_logits_flag_affects_output() {
467        let device = Default::default();
468
469        let probs = Tensor::<TestBackend, 2>::from_data(
470            TensorData::from([
471                [0.1, 0.2, 0.7, 0.0, 0.0],
472                [0.7, 0.1, 0.1, 0.1, 0.0],
473                [0.2, 0.2, 0.2, 0.2, 0.2],
474                [0.0, 0.3, 0.3, 0.2, 0.2],
475            ]),
476            &device,
477        );
478
479        let targets =
480            Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
481
482        let loss_logits = CrossEntropyLossConfig::new()
483            .init(&device)
484            .forward(probs.clone(), targets.clone());
485
486        let loss_probs = CrossEntropyLossConfig::new()
487            .with_logits(false)
488            .init(&device)
489            .forward(probs, targets);
490
491        // They must differ if logits flag is implemented correctly
492        let loss_logits = loss_logits.into_data();
493        let loss_probs = loss_probs.into_data();
494
495        loss_logits.assert_approx_eq::<f32>(&TensorData::from([1.354197]), Tolerance::default());
496        loss_probs.assert_approx_eq::<f32>(&TensorData::from([0.88169014]), Tolerance::default());
497
498        assert_ne!(
499            loss_logits.as_slice::<f32>().unwrap(),
500            loss_probs.as_slice::<f32>().unwrap(),
501            "logits flag should change computation (log_softmax vs log)"
502        );
503    }
504
505    #[test]
506    fn display() {
507        let config = CrossEntropyLossConfig::new()
508            .with_weights(Some(alloc::vec![3., 7., 0.9]))
509            .with_smoothing(Some(0.5));
510        let loss = config.init::<TestBackend>(&Default::default());
511
512        assert_eq!(
513            alloc::format!("{loss}"),
514            "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
515        );
516    }
517
518    // TODO: backward tests
519}