burn_core/nn/loss/
cross_entropy.rs

1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, ModuleDisplay};
4use crate::tensor::activation::log_softmax;
5use crate::tensor::{Bool, Int, Tensor, backend::Backend};
6use crate::{config::Config, module::Module};
7use alloc::string::ToString;
8use alloc::vec;
9use alloc::vec::Vec;
10
11/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
12#[derive(Config, Debug)]
13pub struct CrossEntropyLossConfig {
14    /// Create padded cross entropy.
15    ///
16    /// Prevents pad tokens from impacting loss calculation.
17    pub pad_tokens: Option<Vec<usize>>,
18
19    /// Create weighted cross-entropy.
20    ///
21    /// The loss of a specific sample will simply be given by: weight * log(p(x)) * 1,
22    ///
23    /// # Pre-conditions
24    ///   - The order of the weight vector should correspond to the label integer assignment.
25    ///   - Targets assigned negative Int's will not be allowed.
26    pub weights: Option<Vec<f32>>,
27
28    /// Create cross-entropy with label smoothing.
29    ///
30    /// Hard labels {0, 1} will be changed to y_smoothed = y(1 - a) + a / nr_classes.
31    /// Alpha = 0 would be the same as default.
32    pub smoothing: Option<f32>,
33
34    /// Create cross-entropy with probabilities as input instead of logits.
35    ///
36    #[config(default = true)]
37    pub logits: bool,
38}
39
40impl CrossEntropyLossConfig {
41    /// Initialize [Cross-entropy loss](CrossEntropyLoss).
42    pub fn init<B: Backend>(&self, device: &B::Device) -> CrossEntropyLoss<B> {
43        self.assertions();
44        CrossEntropyLoss {
45            pad_tokens: self.pad_tokens.clone(),
46            weights: self
47                .weights
48                .as_ref()
49                .map(|e| Tensor::<B, 1>::from_floats(e.as_slice(), device)),
50            smoothing: self.smoothing,
51            logits: self.logits,
52        }
53    }
54
55    fn assertions(&self) {
56        if let Some(alpha) = self.smoothing {
57            assert!(
58                (0.0..=1.).contains(&alpha),
59                "Alpha of Cross-entropy loss with smoothed labels should be in interval [0, 1]. Got {alpha}"
60            );
61        };
62        if let Some(weights) = self.weights.as_ref() {
63            assert!(
64                weights.iter().all(|e| e > &0.),
65                "Weights of cross-entropy have to be positive."
66            );
67        }
68    }
69}
70
71/// Calculate the cross entropy loss from the input logits and the targets.
72///
73/// Should be created using [CrossEntropyLossConfig]
74#[derive(Module, Debug)]
75#[module(custom_display)]
76pub struct CrossEntropyLoss<B: Backend> {
77    /// Pad tokens to ignore in the loss calculation.
78    pub pad_tokens: Option<Vec<usize>>,
79    /// Weights for cross-entropy.
80    pub weights: Option<Tensor<B, 1>>,
81    /// Label smoothing factor.
82    pub smoothing: Option<f32>,
83    /// Use logits as input.
84    pub logits: bool,
85}
86
87impl<B: Backend> ModuleDisplay for CrossEntropyLoss<B> {
88    fn custom_settings(&self) -> Option<DisplaySettings> {
89        DisplaySettings::new()
90            .with_new_line_after_attribute(false)
91            .optional()
92    }
93
94    fn custom_content(&self, content: Content) -> Option<Content> {
95        let pad_tokens = if let Some(pad_tokens) = &self.pad_tokens {
96            alloc::format!("Vec<0..{}>", pad_tokens.len())
97        } else {
98            "None".to_string()
99        };
100
101        content
102            .add("pad_tokens", &pad_tokens)
103            .add("weights", &self.weights)
104            .add("smoothing", &self.smoothing)
105            .add("logits", &self.logits)
106            .optional()
107    }
108}
109
110impl<B: Backend> CrossEntropyLoss<B> {
111    /// For backward compatibility.
112    pub fn new(pad_index: Option<usize>, device: &B::Device) -> Self {
113        CrossEntropyLossConfig::new()
114            .with_pad_tokens(pad_index.map(|e| vec![e]))
115            .init(device)
116    }
117
118    /// Compute the criterion on the input tensor.
119    ///
120    /// # Shapes
121    ///
122    /// - logits: `[batch_size, num_targets]`
123    /// - targets: `[batch_size]`
124    pub fn forward(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
125        Self::assertions(logits.clone(), targets.clone());
126        match self.smoothing {
127            Some(alpha) => self.forward_smoothed(logits, targets, alpha),
128            _ => self.forward_default(logits, targets),
129        }
130    }
131
132    fn forward_smoothed(
133        &self,
134        logits: Tensor<B, 2>,
135        targets: Tensor<B, 1, Int>,
136        alpha: f32,
137    ) -> Tensor<B, 1> {
138        let mask = self.padding_mask(&targets);
139        let tensor = if self.logits {
140            log_softmax(logits, 1)
141        } else {
142            logits.log()
143        };
144        let [batch_size, nr_classes] = tensor.dims();
145        let tensor = tensor
146            * Self::compute_smoothed_targets([batch_size, nr_classes], targets.clone(), alpha);
147
148        match &self.weights {
149            Some(weights) => {
150                let tensor = tensor
151                    * weights
152                        .clone()
153                        .reshape([1, nr_classes])
154                        .repeat_dim(0, batch_size);
155                let weights = weights.clone().gather(0, targets);
156                let tensor = Self::apply_mask_2d(tensor, mask);
157                tensor.sum().neg() / weights.sum()
158            }
159            None => {
160                let tensor = Self::apply_mask_2d(tensor, mask);
161                tensor.sum_dim(1).mean().neg()
162            }
163        }
164    }
165
166    fn forward_default(&self, logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) -> Tensor<B, 1> {
167        let [batch_size] = targets.dims();
168
169        let mask = self.padding_mask(&targets);
170        let tensor = log_softmax(logits, 1);
171        let tensor = tensor.gather(1, targets.clone().reshape([batch_size, 1]));
172
173        match &self.weights {
174            Some(weights) => {
175                let weights = weights.clone().gather(0, targets);
176                let tensor = tensor.reshape([batch_size]) * weights.clone();
177                let tensor = Self::apply_mask_1d(tensor, mask);
178                tensor.sum().neg() / weights.sum()
179            }
180            None => {
181                let tensor = Self::apply_mask_1d(tensor.reshape([batch_size]), mask);
182                tensor.mean().neg()
183            }
184        }
185    }
186
187    fn compute_smoothed_targets(
188        shape: [usize; 2],
189        targets: Tensor<B, 1, Int>,
190        alpha: f32,
191    ) -> Tensor<B, 2> {
192        let [batch_size, nr_classes] = shape;
193        let device = &targets.device();
194        let targets_matrix = Tensor::<B, 2>::zeros(shape, device).scatter(
195            1,
196            targets.reshape([batch_size, 1]),
197            Tensor::ones([batch_size, 1], device),
198        );
199        targets_matrix * (1. - alpha) + alpha / nr_classes as f32
200    }
201
202    fn padding_mask(&self, targets: &Tensor<B, 1, Int>) -> Option<Tensor<B, 1, Bool>> {
203        let mut mask = None;
204        if let Some(pad_tokens) = &self.pad_tokens {
205            let mut res = targets.clone().equal_elem(pad_tokens[0] as i64).int();
206            for x in pad_tokens {
207                res = res + targets.clone().equal_elem(*x as i64).int();
208            }
209            mask = Some(res.greater_elem(0));
210        }
211
212        mask
213    }
214
215    fn apply_mask_1d(mut tensor: Tensor<B, 1>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 1> {
216        if let Some(mask) = mask {
217            tensor = tensor.mask_fill(mask, 0);
218        }
219
220        tensor
221    }
222
223    fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
224        if let Some(mask) = mask {
225            let [batch_size, nr_classes] = tensor.dims();
226            tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
227        }
228
229        tensor
230    }
231
232    fn assertions(logits: Tensor<B, 2>, targets: Tensor<B, 1, Int>) {
233        let [logits_height, _] = logits.dims();
234        let [targets_height] = targets.dims();
235        assert!(
236            logits_height == targets_height,
237            "Shape of targets ({targets_height}) should correspond to outer shape of logits ({logits_height})."
238        );
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::TestBackend;
246    use crate::tensor::{Distribution, TensorData, loss::cross_entropy_with_logits, ops::IntElem};
247    use burn_tensor::{Tolerance, ops::FloatElem};
248    type FT = FloatElem<TestBackend>;
249
250    macro_rules! setup {
251        () => {{
252            let [batch_size, num_targets] = [4, 5];
253            let device = Default::default();
254            let logits = Tensor::<TestBackend, 2>::random(
255                [batch_size, num_targets],
256                Distribution::Normal(0., 1.0),
257                &device,
258            );
259            let targets =
260                Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([2, 0, 4, 1]), &device);
261            let targets_logits = Tensor::<TestBackend, 2>::from_data(
262                TensorData::from([
263                    [0.0, 0.0, 1.0, 0.0, 0.0],
264                    [1.0, 0.0, 0.0, 0.0, 0.0],
265                    [0.0, 0.0, 0.0, 0.0, 1.0],
266                    [0.0, 1.0, 0.0, 0.0, 0.0],
267                ]),
268                &device,
269            );
270            (logits, targets, targets_logits)
271        }};
272    }
273
274    macro_rules! setup_padded {
275        () => {{
276            let [batch_size, num_targets, pad_index] = [4, 5, 1];
277            let device = Default::default();
278            let logits = Tensor::<TestBackend, 2>::random(
279                [batch_size, num_targets],
280                Distribution::Normal(0., 1.0),
281                &device,
282            );
283            let targets = Tensor::<TestBackend, 1, Int>::from_data(
284                TensorData::from([2, 0, 4, pad_index as i64]).convert::<IntElem<TestBackend>>(),
285                &device,
286            );
287            let targets_logits = Tensor::<TestBackend, 2>::from_data(
288                TensorData::from([
289                    [0.0, 0.0, 0.0, 0.0, 0.0],
290                    [1.0, 0.0, 0.0, 0.0, 0.0],
291                    [0.0, 0.0, 0.0, 0.0, 1.0],
292                    [0.0, 0.0, 0.0, 0.0, 0.0],
293                ]),
294                &device,
295            );
296            (logits, targets, targets_logits)
297        }};
298    }
299
300    #[test]
301    fn test_cross_entropy_loss_with_weights() {
302        let (logits, targets, targets_logits) = setup!();
303        let weights = vec![1.0, 2., 3., 4., 5.];
304        let device = Default::default();
305        let loss_1 = CrossEntropyLossConfig::new()
306            .with_weights(Some(weights.clone()))
307            .init(&device)
308            .forward(logits.clone(), targets);
309        let tensor = log_softmax(logits, 1);
310        let loss_2 = tensor
311            * targets_logits
312            * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
313                .unsqueeze()
314                .repeat_dim(0, 4);
315        let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
316        loss_1
317            .into_data()
318            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
319    }
320
321    #[test]
322    fn test_label_smoothing_with_weights_and_alpha_zero() {
323        let (logits, targets, _) = setup!();
324        let device = Default::default();
325        let weights = vec![1.0, 2., 3., 4., 5.];
326        let loss_1 = CrossEntropyLossConfig::new()
327            .with_weights(Some(weights.clone()))
328            .init(&device)
329            .forward(logits.clone(), targets.clone());
330        let loss_2 = CrossEntropyLossConfig::new()
331            .with_weights(Some(weights.clone()))
332            .with_smoothing(Some(0.))
333            .init(&device)
334            .forward(logits.clone(), targets);
335        loss_1
336            .into_data()
337            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
338    }
339
340    #[test]
341    fn test_cross_entropy_loss() {
342        let (logits, targets, targets_logits) = setup!();
343        let device = Default::default();
344        let loss_1 = CrossEntropyLossConfig::new()
345            .init(&device)
346            .forward(logits.clone(), targets);
347        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
348
349        loss_1
350            .into_data()
351            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
352    }
353
354    #[test]
355    fn test_label_smoothing_alpha_equal_zero() {
356        let (logits, targets, _) = setup!();
357        let device = Default::default();
358        let loss_1 = CrossEntropyLossConfig::new()
359            .init(&device)
360            .forward(logits.clone(), targets.clone());
361        let loss_2 = CrossEntropyLossConfig::new()
362            .with_smoothing(Some(0.))
363            .init(&device)
364            .forward(logits, targets);
365
366        loss_1
367            .into_data()
368            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
369    }
370
371    #[test]
372    fn test_cross_entropy_loss_with_pad_token() {
373        let (logits, targets, targets_logits) = setup_padded!();
374        let pad_index = 1;
375
376        let loss_1 = CrossEntropyLossConfig::new()
377            .with_pad_tokens(Some(vec![pad_index, 2]))
378            .init(&logits.device())
379            .forward(logits.clone(), targets);
380        let loss_2 = cross_entropy_with_logits(logits, targets_logits);
381
382        loss_1
383            .into_data()
384            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
385    }
386
387    #[test]
388    fn test_label_smoothing_with_zero_alpha_and_pad_token() {
389        let (logits, targets, _) = setup_padded!();
390        let pad_index = 1;
391
392        let loss_1 = CrossEntropyLossConfig::new()
393            .with_pad_tokens(Some(vec![pad_index, 2]))
394            .init(&logits.device())
395            .forward(logits.clone(), targets.clone());
396        let loss_2 = CrossEntropyLossConfig::new()
397            .with_pad_tokens(Some(vec![pad_index, 2]))
398            .with_smoothing(Some(0.))
399            .init(&logits.device())
400            .forward(logits.clone(), targets);
401
402        loss_1
403            .into_data()
404            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
405    }
406
407    #[test]
408    fn test_label_smoothing_target_conversion() {
409        let (logits, targets, _) = setup!();
410        let smoothed_targets =
411            CrossEntropyLoss::compute_smoothed_targets(logits.dims(), targets, 0.05);
412        let targets_logits = Tensor::<TestBackend, 2>::from_data(
413            TensorData::from([
414                [0.01, 0.01, 0.96, 0.01, 0.01],
415                [0.96, 0.01, 0.01, 0.01, 0.01],
416                [0.01, 0.01, 0.01, 0.01, 0.96],
417                [0.01, 0.96, 0.01, 0.01, 0.01],
418            ]),
419            &Default::default(),
420        );
421        smoothed_targets
422            .into_data()
423            .assert_approx_eq::<FT>(&targets_logits.into_data(), Tolerance::default());
424    }
425
426    #[test]
427    fn test_label_smoothing() {
428        let (logits, targets, _) = setup!();
429        let device = Default::default();
430        let loss_1 = CrossEntropyLossConfig::new()
431            .with_smoothing(Some(0.05))
432            .init(&device)
433            .forward(logits.clone(), targets);
434        let targets_logits = Tensor::<TestBackend, 2>::from_data(
435            TensorData::from([
436                [0.01, 0.01, 0.96, 0.01, 0.01],
437                [0.96, 0.01, 0.01, 0.01, 0.01],
438                [0.01, 0.01, 0.01, 0.01, 0.96],
439                [0.01, 0.96, 0.01, 0.01, 0.01],
440            ]),
441            &device,
442        );
443
444        let x = log_softmax(logits, 1);
445        let loss_2 = (x * targets_logits).sum_dim(1).mean().neg();
446
447        loss_1
448            .into_data()
449            .assert_approx_eq::<FT>(&loss_2.into_data(), Tolerance::default());
450    }
451
452    #[test]
453    fn display() {
454        let config = CrossEntropyLossConfig::new()
455            .with_weights(Some(alloc::vec![3., 7., 0.9]))
456            .with_smoothing(Some(0.5));
457        let loss = config.init::<TestBackend>(&Default::default());
458
459        assert_eq!(
460            alloc::format!("{loss}"),
461            "CrossEntropyLoss {pad_tokens: None, weights: Tensor {rank: 1, shape: [3]}, smoothing: 0.5, logits: true}"
462        );
463    }
464}