Skip to main content

burn_train/learner/
sequence.rs

1use crate::metric::{AccuracyInput, PerplexityInput, TopKAccuracyInput};
2use crate::metric::{Adaptor, CerInput, LossInput, WerInput, processor::ItemLazy};
3use burn_core::tensor::backend::Backend;
4use burn_core::tensor::{Int, Tensor, Transaction};
5use burn_flex::Flex;
6
7/// Sequence prediction output adapted for multiple metrics.
8///
9/// Supported metrics:
10/// - Accuracy
11/// - TopKAccuracy
12/// - Perplexity
13/// - Loss
14/// - CER
15/// - WER
16#[derive(new)]
17pub struct SequenceOutput<B: Backend> {
18    /// The loss.
19    pub loss: Tensor<B, 1>,
20
21    /// Raw logits. Shape: `[batch_size, seq_len, vocab_size]`
22    pub logits: Tensor<B, 3>,
23
24    /// Optional predicted token indices. Shape: `[batch_size, seq_length]`.
25    /// If not provided, predictions default to argmax of `logits` along the last dimension.
26    pub predictions: Option<Tensor<B, 2, Int>>,
27
28    /// The target token indices. Shape: `[batch_size, seq_length]`
29    pub targets: Tensor<B, 2, Int>,
30}
31
32impl<B: Backend> SequenceOutput<B> {
33    fn predicted_tokens(&self) -> Tensor<B, 2, Int> {
34        match &self.predictions {
35            Some(preds) => preds.clone(),
36            None => self.logits.clone().argmax(2).squeeze_dim::<2>(2),
37        }
38    }
39
40    fn flat_logits(&self) -> Tensor<B, 2> {
41        let [batch_size, seq_len, vocab_size] = self.logits.dims();
42        self.logits
43            .clone()
44            .reshape([batch_size * seq_len, vocab_size])
45    }
46
47    fn flat_targets(&self) -> Tensor<B, 1, Int> {
48        let [batch_size, seq_len] = self.targets.dims();
49        self.targets.clone().reshape([batch_size * seq_len])
50    }
51}
52
53impl<B: Backend> ItemLazy for SequenceOutput<B> {
54    // Flex's IntElem is i32; token IDs > i32::MAX would truncate on sync.
55    type ItemSync = SequenceOutput<Flex>;
56
57    fn sync(self) -> Self::ItemSync {
58        let device = &Default::default();
59
60        match self.predictions {
61            Some(preds) => {
62                let [logits, loss, targets, predictions] = Transaction::default()
63                    .register(self.logits)
64                    .register(self.loss)
65                    .register(self.targets)
66                    .register(preds)
67                    .execute()
68                    .try_into()
69                    .expect("Correct amount of tensor data");
70
71                SequenceOutput {
72                    logits: Tensor::from_data(logits, device),
73                    loss: Tensor::from_data(loss, device),
74                    targets: Tensor::from_data(targets, device),
75                    predictions: Some(Tensor::from_data(predictions, device)),
76                }
77            }
78            None => {
79                let [logits, loss, targets] = Transaction::default()
80                    .register(self.logits)
81                    .register(self.loss)
82                    .register(self.targets)
83                    .execute()
84                    .try_into()
85                    .expect("Correct amount of tensor data");
86
87                SequenceOutput {
88                    logits: Tensor::from_data(logits, device),
89                    loss: Tensor::from_data(loss, device),
90                    targets: Tensor::from_data(targets, device),
91                    predictions: None,
92                }
93            }
94        }
95    }
96}
97
98impl<B: Backend> Adaptor<LossInput<B>> for SequenceOutput<B> {
99    fn adapt(&self) -> LossInput<B> {
100        LossInput::new(self.loss.clone())
101    }
102}
103
104impl<B: Backend> Adaptor<CerInput<B>> for SequenceOutput<B> {
105    fn adapt(&self) -> CerInput<B> {
106        CerInput::new(self.predicted_tokens(), self.targets.clone())
107    }
108}
109
110impl<B: Backend> Adaptor<WerInput<B>> for SequenceOutput<B> {
111    fn adapt(&self) -> WerInput<B> {
112        WerInput::new(self.predicted_tokens(), self.targets.clone())
113    }
114}
115
116impl<B: Backend> Adaptor<AccuracyInput<B>> for SequenceOutput<B> {
117    fn adapt(&self) -> AccuracyInput<B> {
118        AccuracyInput::new(self.flat_logits(), self.flat_targets())
119    }
120}
121
122impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for SequenceOutput<B> {
123    fn adapt(&self) -> TopKAccuracyInput<B> {
124        TopKAccuracyInput::new(self.flat_logits(), self.flat_targets())
125    }
126}
127
128impl<B: Backend> Adaptor<PerplexityInput<B>> for SequenceOutput<B> {
129    fn adapt(&self) -> PerplexityInput<B> {
130        PerplexityInput::new(self.flat_logits(), self.flat_targets())
131    }
132}