1use burn_core as burn;
2
3use alloc::vec::Vec;
4
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::{Bool, Tensor, backend::Backend};
8
9use crate::activation::ActivationConfig;
10use crate::cache::TensorCache;
11use crate::{
12 Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
13 attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
14};
15
16use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
17
18#[derive(Config, Debug)]
20pub struct TransformerDecoderConfig {
21 pub d_model: usize,
23 pub d_ff: usize,
25 pub n_heads: usize,
27 pub n_layers: usize,
29 #[config(default = 0.1)]
31 pub dropout: f64,
32 #[config(default = false)]
34 pub norm_first: bool,
35 #[config(default = false)]
42 pub quiet_softmax: bool,
43 #[config(
45 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
46 )]
47 pub initializer: Initializer,
48 #[config(default = "ActivationConfig::Gelu")]
50 pub activation: ActivationConfig,
51 #[config(default = 1e-5)]
53 pub layer_norm_eps: f64,
54}
55
56#[derive(Module, Debug)]
64#[module(custom_display)]
65pub struct TransformerDecoder<B: Backend> {
66 pub layers: Vec<TransformerDecoderLayer<B>>,
68
69 pub d_model: usize,
71
72 pub d_ff: usize,
74
75 pub n_heads: usize,
77
78 pub n_layers: usize,
80
81 pub dropout: f64,
83
84 pub norm_first: bool,
86
87 pub quiet_softmax: bool,
89}
90
91impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
92 fn custom_settings(&self) -> Option<DisplaySettings> {
93 DisplaySettings::new()
94 .with_new_line_after_attribute(false)
95 .optional()
96 }
97
98 fn custom_content(&self, content: Content) -> Option<Content> {
99 content
100 .add("d_model", &self.d_model)
101 .add("d_ff", &self.d_ff)
102 .add("n_heads", &self.n_heads)
103 .add("n_layers", &self.n_layers)
104 .add("dropout", &self.dropout)
105 .add("norm_first", &self.norm_first)
106 .add("quiet_softmax", &self.quiet_softmax)
107 .optional()
108 }
109}
110
111impl TransformerDecoderConfig {
112 pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
114 let layers = (0..self.n_layers)
115 .map(|_| TransformerDecoderLayer::new(self, device))
116 .collect::<Vec<_>>();
117
118 TransformerDecoder {
119 layers,
120 d_model: self.d_model,
121 d_ff: self.d_ff,
122 n_heads: self.n_heads,
123 n_layers: self.n_layers,
124 dropout: self.dropout,
125 norm_first: self.norm_first,
126 quiet_softmax: self.quiet_softmax,
127 }
128 }
129}
130
131#[derive(Debug)]
133pub struct TransformerDecoderInput<B: Backend> {
134 target: Tensor<B, 3>,
135 target_mask_pad: Option<Tensor<B, 2, Bool>>,
136 target_mask_attn: Option<Tensor<B, 3, Bool>>,
137 memory: Tensor<B, 3>,
138 memory_mask_pad: Option<Tensor<B, 2, Bool>>,
139 memory_mask_attn: Option<Tensor<B, 3, Bool>>,
140}
141
142impl<B: Backend> TransformerDecoderInput<B> {
143 pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
145 Self {
146 target,
147 target_mask_pad: None,
148 target_mask_attn: None,
149 memory,
150 memory_mask_pad: None,
151 memory_mask_attn: None,
152 }
153 }
154
155 pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
157 self.memory_mask_pad = Some(mask_pad);
158 self
159 }
160
161 pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
163 self.memory_mask_attn = Some(mask_attn);
164 self
165 }
166
167 pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
169 self.target_mask_pad = Some(mask_pad);
170 self
171 }
172
173 pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
175 self.target_mask_attn = Some(mask_attn);
176 self
177 }
178}
179
180#[derive(Module, Debug)]
182pub struct TransformerDecoderLayer<B: Backend> {
183 pub cross_attn: MultiHeadAttention<B>,
185 pub self_attn: MultiHeadAttention<B>,
187 pub pwff: PositionWiseFeedForward<B>,
189 pub norm_1: LayerNorm<B>,
191 pub norm_2: LayerNorm<B>,
193 pub norm_3: LayerNorm<B>,
195 pub dropout: Dropout,
197 pub norm_first: bool,
199}
200
201pub struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
203 pub cross_attn: MhaCache<B>,
205 pub self_attn: MhaCache<B>,
207 pub pwff: TensorCache<B, 3>,
209 pub norm_1: TensorCache<B, 3>,
211 pub norm_2: TensorCache<B, 3>,
213 pub norm_3: TensorCache<B, 3>,
215}
216
217impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
218 pub fn empty() -> Self {
220 Self {
221 cross_attn: MhaCache::autoregressive_cross_attention(),
222 self_attn: MhaCache::autoregressive(),
223 pwff: TensorCache::empty(),
224 norm_1: TensorCache::empty(),
225 norm_2: TensorCache::empty(),
226 norm_3: TensorCache::empty(),
227 }
228 }
229}
230
231pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
235 layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
236}
237
238impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
239 fn empty(num_layers: usize) -> Self {
240 Self {
241 layers: (0..num_layers)
242 .map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
243 .collect(),
244 }
245 }
246}
247
248impl<B: Backend> TransformerDecoderLayer<B> {
249 pub fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
251 let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
252 .with_initializer(config.initializer.clone())
253 .with_dropout(config.dropout)
254 .with_quiet_softmax(config.quiet_softmax)
255 .init(device);
256
257 let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
258 .with_initializer(config.initializer.clone())
259 .with_dropout(config.dropout)
260 .with_quiet_softmax(config.quiet_softmax)
261 .init(device);
262 let norm_1 = LayerNormConfig::new(config.d_model)
263 .with_epsilon(config.layer_norm_eps)
264 .init(device);
265 let norm_2 = LayerNormConfig::new(config.d_model)
266 .with_epsilon(config.layer_norm_eps)
267 .init(device);
268 let norm_3 = LayerNormConfig::new(config.d_model)
269 .with_epsilon(config.layer_norm_eps)
270 .init(device);
271 let dropout = DropoutConfig::new(config.dropout).init();
272 let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
273 .with_initializer(config.initializer.clone())
274 .with_dropout(config.dropout)
275 .with_activation(config.activation.clone())
276 .init(device);
277
278 Self {
279 cross_attn,
280 self_attn,
281 norm_1,
282 norm_2,
283 norm_3,
284 pwff,
285 dropout,
286 norm_first: config.norm_first,
287 }
288 }
289
290 pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
292 let x = input.target;
294 let mut residual_path = x.clone();
295
296 if self.norm_first {
298 residual_path = self.norm_3.forward(residual_path);
299 }
300
301 let mut self_attn_input = MhaInput::self_attn(residual_path);
303 if let Some(mask_pad) = &input.target_mask_pad {
304 self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
305 }
306 if let Some(mask_attn) = &input.target_mask_attn {
307 self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
308 }
309 let residual_path = self.self_attn.forward(self_attn_input).context;
310
311 let residual_path = self.dropout.forward(residual_path);
312 let mut x = x + residual_path;
313
314 let residual_path = if self.norm_first {
317 self.norm_1.forward(x.clone())
318 } else {
319 x = self.norm_1.forward(x);
320 x.clone()
321 };
322
323 let mut cross_attn_input =
325 MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
326 if let Some(mask_pad) = &input.memory_mask_pad {
327 cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
328 }
329 if let Some(mask_attn) = &input.memory_mask_attn {
330 cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
331 }
332 let residual_path = self.cross_attn.forward(cross_attn_input).context;
333
334 let residual_path = self.dropout.forward(residual_path);
335 let mut x = x + residual_path;
336
337 let residual_path = if self.norm_first {
340 self.norm_2.forward(x.clone())
341 } else {
342 x = self.norm_2.forward(x);
343 x.clone()
344 };
345
346 let residual_path = self.pwff.forward(residual_path);
347 let residual_path = self.dropout.forward(residual_path);
348 let mut x = x + residual_path;
349
350 if !self.norm_first {
353 x = self.norm_3.forward(x)
354 }
355
356 input.target = x;
357 input
358 }
359
360 pub fn forward_autoregressive_inference(
362 &self,
363 mut input: TransformerDecoderInput<B>,
364 cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
365 ) -> TransformerDecoderInput<B> {
366 let x = input.target;
368 let mut residual_path = x.clone();
369
370 if self.norm_first {
372 residual_path = cache
373 .norm_3
374 .forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
375 }
376
377 let mut self_attn_input = MhaInput::self_attn(residual_path);
379 if let Some(mask_pad) = &input.target_mask_pad {
380 self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
381 }
382 if let Some(mask_attn) = &input.target_mask_attn {
383 self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
384 }
385 let residual_path = self
386 .self_attn
387 .forward_cache(self_attn_input, &mut cache.self_attn)
388 .context;
389
390 let residual_path = self.dropout.forward(residual_path);
391 let mut x = x + residual_path;
392
393 let residual_path = if self.norm_first {
396 cache
397 .norm_1
398 .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
399 } else {
400 x = cache
401 .norm_1
402 .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
403 x.clone()
404 };
405
406 let mut cross_attn_input =
408 MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
409 if let Some(mask_pad) = &input.memory_mask_pad {
410 cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
411 }
412 if let Some(mask_attn) = &input.memory_mask_attn {
413 cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
414 }
415 let residual_path = self
416 .cross_attn
417 .forward_cache(cross_attn_input, &mut cache.cross_attn)
418 .context;
419
420 let residual_path = self.dropout.forward(residual_path);
421 let mut x = x + residual_path;
422
423 let residual_path = if self.norm_first {
426 cache
427 .norm_2
428 .forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
429 } else {
430 x = cache
431 .norm_2
432 .forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
433 x.clone()
434 };
435
436 let residual_path = cache
437 .pwff
438 .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
439 let residual_path = self.dropout.forward(residual_path);
440 let mut x = x + residual_path;
441
442 if !self.norm_first {
445 x = cache
446 .norm_3
447 .forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
448 }
449
450 input.target = x;
451 input
452 }
453}
454
455impl<B: Backend> TransformerDecoder<B> {
456 pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
458 for layer in self.layers.iter() {
459 input = layer.forward(input);
460 }
461
462 input.target
463 }
464
465 pub fn forward_autoregressive_inference(
467 &self,
468 mut input: TransformerDecoderInput<B>,
469 cache: &mut TransformerDecoderAutoregressiveCache<B>,
470 ) -> Tensor<B, 3> {
471 for i in 0..self.layers.len() {
472 let layer = self.layers.get(i).unwrap();
473 let cache = cache.layers.get_mut(i).unwrap();
474
475 input = layer.forward_autoregressive_inference(input, cache);
476 }
477
478 input.target
479 }
480 pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
482 TransformerDecoderAutoregressiveCache::empty(self.layers.len())
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use burn::tensor::Device;
489
490 use super::*;
491 use crate::{TestBackend, attention::generate_autoregressive_mask};
492
493 use burn::tensor::{Tolerance, ops::FloatElem};
494 type FT = FloatElem<TestBackend>;
495
496 #[test]
497 fn test_autoregressive_norm_last() {
498 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
499 let device = Default::default();
500 TestBackend::seed(&device, 0);
501
502 test_autoregressive(
503 TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
504 .with_norm_first(false),
505 )
506 }
507
508 #[test]
509 fn test_autoregressive_norm_first() {
510 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
511 let device = Default::default();
512 TestBackend::seed(&device, 0);
513
514 test_autoregressive(
515 TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
516 )
517 }
518
519 fn test_autoregressive(config: TransformerDecoderConfig) {
520 let device: Device<TestBackend> = Default::default();
521 let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
522 let transformer = config.init::<TestBackend>(&device);
523
524 let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
525 .float()
526 .reshape([batch_size, seq_length, d_model]);
527 let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
528 .float()
529 .reshape([batch_size, seq_length, d_model]);
530 let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
531 let input = TransformerDecoderInput::new(target.clone(), memory.clone())
532 .target_mask_attn(mask_attn);
533
534 let output_1 = transformer.forward(input);
536
537 let mut output_2 = Vec::new();
539 let mut cache = transformer.new_autoregressive_cache();
540
541 for i in 1..seq_length + 1 {
542 let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
543
544 let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
545 let input = TransformerDecoderInput::new(target.clone(), memory.clone())
546 .target_mask_attn(mask_attn);
547 let next_tok = transformer .forward_autoregressive_inference(input, &mut cache)
549 .slice([0..batch_size, i - 1..i, 0..d_model]);
550 output_2.push(next_tok);
551 }
552
553 let output_2 = Tensor::cat(output_2, 1);
554
555 let tolerance = Tolerance::rel_abs(5e-3, 1e-4);
557 output_1
558 .into_data()
559 .assert_approx_eq::<FT>(&output_2.into_data(), tolerance);
560 }
561
562 #[test]
563 fn display() {
564 let config = TransformerDecoderConfig::new(2, 4, 2, 3);
565 let transformer = config.init::<TestBackend>(&Default::default());
566
567 assert_eq!(
568 alloc::format!("{transformer}"),
569 "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
570 dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
571 );
572 }
573}