1use alloc::vec::Vec;
2
3use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
4
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::tensor::Bool;
7use crate::{
8 self as burn,
9 nn::{Initializer, attention::MhaCache, cache::TensorCache},
10};
11use crate::{
12 config::Config,
13 nn::{
14 Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
15 attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
16 },
17 tensor::{Tensor, backend::Backend},
18};
19
20#[derive(Config)]
22pub struct TransformerDecoderConfig {
23 pub d_model: usize,
25 pub d_ff: usize,
27 pub n_heads: usize,
29 pub n_layers: usize,
31 #[config(default = 0.1)]
33 pub dropout: f64,
34 #[config(default = false)]
36 pub norm_first: bool,
37 #[config(default = false)]
44 pub quiet_softmax: bool,
45 #[config(
47 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
48 )]
49 pub initializer: Initializer,
50}
51
52#[derive(Module, Debug)]
60#[module(custom_display)]
61pub struct TransformerDecoder<B: Backend> {
62 pub layers: Vec<TransformerDecoderLayer<B>>,
64
65 pub d_model: usize,
67
68 pub d_ff: usize,
70
71 pub n_heads: usize,
73
74 pub n_layers: usize,
76
77 pub dropout: f64,
79
80 pub norm_first: bool,
82
83 pub quiet_softmax: bool,
85}
86
87impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
88 fn custom_settings(&self) -> Option<DisplaySettings> {
89 DisplaySettings::new()
90 .with_new_line_after_attribute(false)
91 .optional()
92 }
93
94 fn custom_content(&self, content: Content) -> Option<Content> {
95 content
96 .add("d_model", &self.d_model)
97 .add("d_ff", &self.d_ff)
98 .add("n_heads", &self.n_heads)
99 .add("n_layers", &self.n_layers)
100 .add("dropout", &self.dropout)
101 .add("norm_first", &self.norm_first)
102 .add("quiet_softmax", &self.quiet_softmax)
103 .optional()
104 }
105}
106
107impl TransformerDecoderConfig {
108 pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
110 let layers = (0..self.n_layers)
111 .map(|_| TransformerDecoderLayer::new(self, device))
112 .collect::<Vec<_>>();
113
114 TransformerDecoder {
115 layers,
116 d_model: self.d_model,
117 d_ff: self.d_ff,
118 n_heads: self.n_heads,
119 n_layers: self.n_layers,
120 dropout: self.dropout,
121 norm_first: self.norm_first,
122 quiet_softmax: self.quiet_softmax,
123 }
124 }
125}
126
127#[derive(Debug)]
129pub struct TransformerDecoderInput<B: Backend> {
130 target: Tensor<B, 3>,
131 target_mask_pad: Option<Tensor<B, 2, Bool>>,
132 target_mask_attn: Option<Tensor<B, 3, Bool>>,
133 memory: Tensor<B, 3>,
134 memory_mask_pad: Option<Tensor<B, 2, Bool>>,
135 memory_mask_attn: Option<Tensor<B, 3, Bool>>,
136}
137
138impl<B: Backend> TransformerDecoderInput<B> {
139 pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
141 Self {
142 target,
143 target_mask_pad: None,
144 target_mask_attn: None,
145 memory,
146 memory_mask_pad: None,
147 memory_mask_attn: None,
148 }
149 }
150
151 pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
153 self.memory_mask_pad = Some(mask_pad);
154 self
155 }
156
157 pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
159 self.memory_mask_attn = Some(mask_attn);
160 self
161 }
162
163 pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
165 self.target_mask_pad = Some(mask_pad);
166 self
167 }
168
169 pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
171 self.target_mask_attn = Some(mask_attn);
172 self
173 }
174}
175
176#[derive(Module, Debug)]
178pub struct TransformerDecoderLayer<B: Backend> {
179 cross_attn: MultiHeadAttention<B>,
180 self_attn: MultiHeadAttention<B>,
181 pwff: PositionWiseFeedForward<B>,
182 norm_1: LayerNorm<B>,
183 norm_2: LayerNorm<B>,
184 norm_3: LayerNorm<B>,
185 dropout: Dropout,
186 norm_first: bool,
187}
188
189struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
190 cross_attn: MhaCache<B>,
191 self_attn: MhaCache<B>,
192 pwff: TensorCache<B, 3>,
193 norm_1: TensorCache<B, 3>,
194 norm_2: TensorCache<B, 3>,
195 norm_3: TensorCache<B, 3>,
196}
197
198impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
199 fn empty() -> Self {
200 Self {
201 cross_attn: MhaCache::autoregressive_cross_attention(),
202 self_attn: MhaCache::autoregressive(),
203 pwff: TensorCache::empty(),
204 norm_1: TensorCache::empty(),
205 norm_2: TensorCache::empty(),
206 norm_3: TensorCache::empty(),
207 }
208 }
209}
210
211pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
215 layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
216}
217
218impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
219 fn empty(num_layers: usize) -> Self {
220 Self {
221 layers: (0..num_layers)
222 .map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
223 .collect(),
224 }
225 }
226}
227
228impl<B: Backend> TransformerDecoderLayer<B> {
229 fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
230 let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
231 .with_initializer(config.initializer.clone())
232 .with_dropout(config.dropout)
233 .with_quiet_softmax(config.quiet_softmax)
234 .init(device);
235
236 let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
237 .with_initializer(config.initializer.clone())
238 .with_dropout(config.dropout)
239 .with_quiet_softmax(config.quiet_softmax)
240 .init(device);
241 let norm_1 = LayerNormConfig::new(config.d_model).init(device);
242 let norm_2 = LayerNormConfig::new(config.d_model).init(device);
243 let norm_3 = LayerNormConfig::new(config.d_model).init(device);
244 let dropout = DropoutConfig::new(config.dropout).init();
245 let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
246 .with_dropout(config.dropout)
247 .init(device);
248
249 Self {
250 cross_attn,
251 self_attn,
252 norm_1,
253 norm_2,
254 norm_3,
255 pwff,
256 dropout,
257 norm_first: config.norm_first,
258 }
259 }
260
261 fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
263 let x = input.target;
265 let mut residual_path = x.clone();
266
267 if self.norm_first {
269 residual_path = self.norm_3.forward(residual_path);
270 }
271
272 let mut self_attn_input = MhaInput::self_attn(residual_path);
274 if let Some(mask_pad) = &input.target_mask_pad {
275 self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
276 }
277 if let Some(mask_attn) = &input.target_mask_attn {
278 self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
279 }
280 let residual_path = self.self_attn.forward(self_attn_input).context;
281
282 let residual_path = self.dropout.forward(residual_path);
283 let mut x = x + residual_path;
284
285 let residual_path = if self.norm_first {
288 self.norm_1.forward(x.clone())
289 } else {
290 x = self.norm_1.forward(x);
291 x.clone()
292 };
293
294 let mut cross_attn_input =
296 MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
297 if let Some(mask_pad) = &input.memory_mask_pad {
298 cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
299 }
300 if let Some(mask_attn) = &input.memory_mask_attn {
301 cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
302 }
303 let residual_path = self.cross_attn.forward(cross_attn_input).context;
304
305 let residual_path = self.dropout.forward(residual_path);
306 let mut x = x + residual_path;
307
308 let residual_path = if self.norm_first {
311 self.norm_2.forward(x.clone())
312 } else {
313 x = self.norm_2.forward(x);
314 x.clone()
315 };
316
317 let residual_path = self.pwff.forward(residual_path);
318 let residual_path = self.dropout.forward(residual_path);
319 let mut x = x + residual_path;
320
321 if !self.norm_first {
324 x = self.norm_3.forward(x)
325 }
326
327 input.target = x;
328 input
329 }
330
331 fn forward_autoregressive_inference(
332 &self,
333 mut input: TransformerDecoderInput<B>,
334 cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
335 ) -> TransformerDecoderInput<B> {
336 let x = input.target;
338 let mut residual_path = x.clone();
339
340 if self.norm_first {
342 residual_path = cache
343 .norm_3
344 .forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
345 }
346
347 let mut self_attn_input = MhaInput::self_attn(residual_path);
349 if let Some(mask_pad) = &input.target_mask_pad {
350 self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
351 }
352 if let Some(mask_attn) = &input.target_mask_attn {
353 self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
354 }
355 let residual_path = self
356 .self_attn
357 .forward_cache(self_attn_input, &mut cache.self_attn)
358 .context;
359
360 let residual_path = self.dropout.forward(residual_path);
361 let mut x = x + residual_path;
362
363 let residual_path = if self.norm_first {
366 cache
367 .norm_1
368 .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
369 } else {
370 x = cache
371 .norm_1
372 .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
373 x.clone()
374 };
375
376 let mut cross_attn_input =
378 MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
379 if let Some(mask_pad) = &input.memory_mask_pad {
380 cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
381 }
382 if let Some(mask_attn) = &input.memory_mask_attn {
383 cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
384 }
385 let residual_path = self
386 .cross_attn
387 .forward_cache(cross_attn_input, &mut cache.cross_attn)
388 .context;
389
390 let residual_path = self.dropout.forward(residual_path);
391 let mut x = x + residual_path;
392
393 let residual_path = if self.norm_first {
396 cache
397 .norm_2
398 .forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
399 } else {
400 x = cache
401 .norm_2
402 .forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
403 x.clone()
404 };
405
406 let residual_path = cache
407 .pwff
408 .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
409 let residual_path = self.dropout.forward(residual_path);
410 let mut x = x + residual_path;
411
412 if !self.norm_first {
415 x = cache
416 .norm_3
417 .forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
418 }
419
420 input.target = x;
421 input
422 }
423}
424
425impl<B: Backend> TransformerDecoder<B> {
426 pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
428 for layer in self.layers.iter() {
429 input = layer.forward(input);
430 }
431
432 input.target
433 }
434
435 pub fn forward_autoregressive_inference(
437 &self,
438 mut input: TransformerDecoderInput<B>,
439 cache: &mut TransformerDecoderAutoregressiveCache<B>,
440 ) -> Tensor<B, 3> {
441 for i in 0..self.layers.len() {
442 let layer = self.layers.get(i).unwrap();
443 let cache = cache.layers.get_mut(i).unwrap();
444
445 input = layer.forward_autoregressive_inference(input, cache);
446 }
447
448 input.target
449 }
450 pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
452 TransformerDecoderAutoregressiveCache::empty(self.layers.len())
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use burn_tensor::Device;
459
460 use super::*;
461 use crate::{TestBackend, nn::attention::generate_autoregressive_mask};
462
463 use burn_tensor::{Tolerance, ops::FloatElem};
464 type FT = FloatElem<TestBackend>;
465
466 #[test]
467 fn test_autoregressive_norm_last() {
468 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
469 TestBackend::seed(0);
470
471 test_autoregressive(
472 TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
473 .with_norm_first(false),
474 )
475 }
476
477 #[test]
478 fn test_autoregressive_norm_first() {
479 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
480 TestBackend::seed(0);
481
482 test_autoregressive(
483 TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
484 )
485 }
486
487 fn test_autoregressive(config: TransformerDecoderConfig) {
488 let device: Device<TestBackend> = Default::default();
489 let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
490 let transformer = config.init::<TestBackend>(&device);
491
492 let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
493 .float()
494 .reshape([batch_size, seq_length, d_model]);
495 let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
496 .float()
497 .reshape([batch_size, seq_length, d_model]);
498 let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
499 let input = TransformerDecoderInput::new(target.clone(), memory.clone())
500 .target_mask_attn(mask_attn);
501
502 let output_1 = transformer.forward(input);
504
505 let mut output_2 = Vec::new();
507 let mut cache = transformer.new_autoregressive_cache();
508
509 for i in 1..seq_length + 1 {
510 let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
511
512 let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
513 let input = TransformerDecoderInput::new(target.clone(), memory.clone())
514 .target_mask_attn(mask_attn);
515 let next_tok = transformer .forward_autoregressive_inference(input, &mut cache)
517 .slice([0..batch_size, i - 1..i, 0..d_model]);
518 output_2.push(next_tok);
519 }
520
521 let output_2 = Tensor::cat(output_2, 1);
522
523 output_1
525 .into_data()
526 .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::default());
527 }
528
529 #[test]
530 fn display() {
531 let config = TransformerDecoderConfig::new(2, 4, 2, 3);
532 let transformer = config.init::<TestBackend>(&Default::default());
533
534 assert_eq!(
535 alloc::format!("{transformer}"),
536 "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
537 dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
538 );
539 }
540}