burn_train/learner/
sequence.rs1use crate::metric::{AccuracyInput, PerplexityInput, TopKAccuracyInput};
2use crate::metric::{Adaptor, CerInput, LossInput, WerInput, processor::ItemLazy};
3use burn_core::tensor::backend::Backend;
4use burn_core::tensor::{Int, Tensor, Transaction};
5use burn_flex::Flex;
6
7#[derive(new)]
17pub struct SequenceOutput<B: Backend> {
18 pub loss: Tensor<B, 1>,
20
21 pub logits: Tensor<B, 3>,
23
24 pub predictions: Option<Tensor<B, 2, Int>>,
27
28 pub targets: Tensor<B, 2, Int>,
30}
31
32impl<B: Backend> SequenceOutput<B> {
33 fn predicted_tokens(&self) -> Tensor<B, 2, Int> {
34 match &self.predictions {
35 Some(preds) => preds.clone(),
36 None => self.logits.clone().argmax(2).squeeze_dim::<2>(2),
37 }
38 }
39
40 fn flat_logits(&self) -> Tensor<B, 2> {
41 let [batch_size, seq_len, vocab_size] = self.logits.dims();
42 self.logits
43 .clone()
44 .reshape([batch_size * seq_len, vocab_size])
45 }
46
47 fn flat_targets(&self) -> Tensor<B, 1, Int> {
48 let [batch_size, seq_len] = self.targets.dims();
49 self.targets.clone().reshape([batch_size * seq_len])
50 }
51}
52
53impl<B: Backend> ItemLazy for SequenceOutput<B> {
54 type ItemSync = SequenceOutput<Flex>;
56
57 fn sync(self) -> Self::ItemSync {
58 let device = &Default::default();
59
60 match self.predictions {
61 Some(preds) => {
62 let [logits, loss, targets, predictions] = Transaction::default()
63 .register(self.logits)
64 .register(self.loss)
65 .register(self.targets)
66 .register(preds)
67 .execute()
68 .try_into()
69 .expect("Correct amount of tensor data");
70
71 SequenceOutput {
72 logits: Tensor::from_data(logits, device),
73 loss: Tensor::from_data(loss, device),
74 targets: Tensor::from_data(targets, device),
75 predictions: Some(Tensor::from_data(predictions, device)),
76 }
77 }
78 None => {
79 let [logits, loss, targets] = Transaction::default()
80 .register(self.logits)
81 .register(self.loss)
82 .register(self.targets)
83 .execute()
84 .try_into()
85 .expect("Correct amount of tensor data");
86
87 SequenceOutput {
88 logits: Tensor::from_data(logits, device),
89 loss: Tensor::from_data(loss, device),
90 targets: Tensor::from_data(targets, device),
91 predictions: None,
92 }
93 }
94 }
95 }
96}
97
98impl<B: Backend> Adaptor<LossInput<B>> for SequenceOutput<B> {
99 fn adapt(&self) -> LossInput<B> {
100 LossInput::new(self.loss.clone())
101 }
102}
103
104impl<B: Backend> Adaptor<CerInput<B>> for SequenceOutput<B> {
105 fn adapt(&self) -> CerInput<B> {
106 CerInput::new(self.predicted_tokens(), self.targets.clone())
107 }
108}
109
110impl<B: Backend> Adaptor<WerInput<B>> for SequenceOutput<B> {
111 fn adapt(&self) -> WerInput<B> {
112 WerInput::new(self.predicted_tokens(), self.targets.clone())
113 }
114}
115
116impl<B: Backend> Adaptor<AccuracyInput<B>> for SequenceOutput<B> {
117 fn adapt(&self) -> AccuracyInput<B> {
118 AccuracyInput::new(self.flat_logits(), self.flat_targets())
119 }
120}
121
122impl<B: Backend> Adaptor<TopKAccuracyInput<B>> for SequenceOutput<B> {
123 fn adapt(&self) -> TopKAccuracyInput<B> {
124 TopKAccuracyInput::new(self.flat_logits(), self.flat_targets())
125 }
126}
127
128impl<B: Backend> Adaptor<PerplexityInput<B>> for SequenceOutput<B> {
129 fn adapt(&self) -> PerplexityInput<B> {
130 PerplexityInput::new(self.flat_logits(), self.flat_targets())
131 }
132}