1use burn_core as burn;
2
3use alloc::vec::Vec;
4
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::{Bool, Tensor, backend::Backend};
8
9use crate::cache::TensorCache;
10use crate::{
11 Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
12 attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
13};
14
15use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
16
17#[derive(Config, Debug)]
19pub struct TransformerDecoderConfig {
20 pub d_model: usize,
22 pub d_ff: usize,
24 pub n_heads: usize,
26 pub n_layers: usize,
28 #[config(default = 0.1)]
30 pub dropout: f64,
31 #[config(default = false)]
33 pub norm_first: bool,
34 #[config(default = false)]
41 pub quiet_softmax: bool,
42 #[config(
44 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
45 )]
46 pub initializer: Initializer,
47}
48
49#[derive(Module, Debug)]
57#[module(custom_display)]
58pub struct TransformerDecoder<B: Backend> {
59 pub layers: Vec<TransformerDecoderLayer<B>>,
61
62 pub d_model: usize,
64
65 pub d_ff: usize,
67
68 pub n_heads: usize,
70
71 pub n_layers: usize,
73
74 pub dropout: f64,
76
77 pub norm_first: bool,
79
80 pub quiet_softmax: bool,
82}
83
84impl<B: Backend> ModuleDisplay for TransformerDecoder<B> {
85 fn custom_settings(&self) -> Option<DisplaySettings> {
86 DisplaySettings::new()
87 .with_new_line_after_attribute(false)
88 .optional()
89 }
90
91 fn custom_content(&self, content: Content) -> Option<Content> {
92 content
93 .add("d_model", &self.d_model)
94 .add("d_ff", &self.d_ff)
95 .add("n_heads", &self.n_heads)
96 .add("n_layers", &self.n_layers)
97 .add("dropout", &self.dropout)
98 .add("norm_first", &self.norm_first)
99 .add("quiet_softmax", &self.quiet_softmax)
100 .optional()
101 }
102}
103
104impl TransformerDecoderConfig {
105 pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerDecoder<B> {
107 let layers = (0..self.n_layers)
108 .map(|_| TransformerDecoderLayer::new(self, device))
109 .collect::<Vec<_>>();
110
111 TransformerDecoder {
112 layers,
113 d_model: self.d_model,
114 d_ff: self.d_ff,
115 n_heads: self.n_heads,
116 n_layers: self.n_layers,
117 dropout: self.dropout,
118 norm_first: self.norm_first,
119 quiet_softmax: self.quiet_softmax,
120 }
121 }
122}
123
124#[derive(Debug)]
126pub struct TransformerDecoderInput<B: Backend> {
127 target: Tensor<B, 3>,
128 target_mask_pad: Option<Tensor<B, 2, Bool>>,
129 target_mask_attn: Option<Tensor<B, 3, Bool>>,
130 memory: Tensor<B, 3>,
131 memory_mask_pad: Option<Tensor<B, 2, Bool>>,
132 memory_mask_attn: Option<Tensor<B, 3, Bool>>,
133}
134
135impl<B: Backend> TransformerDecoderInput<B> {
136 pub fn new(target: Tensor<B, 3>, memory: Tensor<B, 3>) -> Self {
138 Self {
139 target,
140 target_mask_pad: None,
141 target_mask_attn: None,
142 memory,
143 memory_mask_pad: None,
144 memory_mask_attn: None,
145 }
146 }
147
148 pub fn memory_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
150 self.memory_mask_pad = Some(mask_pad);
151 self
152 }
153
154 pub fn memory_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
156 self.memory_mask_attn = Some(mask_attn);
157 self
158 }
159
160 pub fn target_mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
162 self.target_mask_pad = Some(mask_pad);
163 self
164 }
165
166 pub fn target_mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
168 self.target_mask_attn = Some(mask_attn);
169 self
170 }
171}
172
173#[derive(Module, Debug)]
175pub struct TransformerDecoderLayer<B: Backend> {
176 cross_attn: MultiHeadAttention<B>,
177 self_attn: MultiHeadAttention<B>,
178 pwff: PositionWiseFeedForward<B>,
179 norm_1: LayerNorm<B>,
180 norm_2: LayerNorm<B>,
181 norm_3: LayerNorm<B>,
182 dropout: Dropout,
183 norm_first: bool,
184}
185
186struct TransformerDecoderLayerAutoregressiveCache<B: Backend> {
187 cross_attn: MhaCache<B>,
188 self_attn: MhaCache<B>,
189 pwff: TensorCache<B, 3>,
190 norm_1: TensorCache<B, 3>,
191 norm_2: TensorCache<B, 3>,
192 norm_3: TensorCache<B, 3>,
193}
194
195impl<B: Backend> TransformerDecoderLayerAutoregressiveCache<B> {
196 fn empty() -> Self {
197 Self {
198 cross_attn: MhaCache::autoregressive_cross_attention(),
199 self_attn: MhaCache::autoregressive(),
200 pwff: TensorCache::empty(),
201 norm_1: TensorCache::empty(),
202 norm_2: TensorCache::empty(),
203 norm_3: TensorCache::empty(),
204 }
205 }
206}
207
208pub struct TransformerDecoderAutoregressiveCache<B: Backend> {
212 layers: Vec<TransformerDecoderLayerAutoregressiveCache<B>>,
213}
214
215impl<B: Backend> TransformerDecoderAutoregressiveCache<B> {
216 fn empty(num_layers: usize) -> Self {
217 Self {
218 layers: (0..num_layers)
219 .map(|_| TransformerDecoderLayerAutoregressiveCache::empty())
220 .collect(),
221 }
222 }
223}
224
225impl<B: Backend> TransformerDecoderLayer<B> {
226 fn new(config: &TransformerDecoderConfig, device: &B::Device) -> Self {
227 let self_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
228 .with_initializer(config.initializer.clone())
229 .with_dropout(config.dropout)
230 .with_quiet_softmax(config.quiet_softmax)
231 .init(device);
232
233 let cross_attn = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
234 .with_initializer(config.initializer.clone())
235 .with_dropout(config.dropout)
236 .with_quiet_softmax(config.quiet_softmax)
237 .init(device);
238 let norm_1 = LayerNormConfig::new(config.d_model).init(device);
239 let norm_2 = LayerNormConfig::new(config.d_model).init(device);
240 let norm_3 = LayerNormConfig::new(config.d_model).init(device);
241 let dropout = DropoutConfig::new(config.dropout).init();
242 let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
243 .with_dropout(config.dropout)
244 .init(device);
245
246 Self {
247 cross_attn,
248 self_attn,
249 norm_1,
250 norm_2,
251 norm_3,
252 pwff,
253 dropout,
254 norm_first: config.norm_first,
255 }
256 }
257
258 fn forward(&self, mut input: TransformerDecoderInput<B>) -> TransformerDecoderInput<B> {
260 let x = input.target;
262 let mut residual_path = x.clone();
263
264 if self.norm_first {
266 residual_path = self.norm_3.forward(residual_path);
267 }
268
269 let mut self_attn_input = MhaInput::self_attn(residual_path);
271 if let Some(mask_pad) = &input.target_mask_pad {
272 self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
273 }
274 if let Some(mask_attn) = &input.target_mask_attn {
275 self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
276 }
277 let residual_path = self.self_attn.forward(self_attn_input).context;
278
279 let residual_path = self.dropout.forward(residual_path);
280 let mut x = x + residual_path;
281
282 let residual_path = if self.norm_first {
285 self.norm_1.forward(x.clone())
286 } else {
287 x = self.norm_1.forward(x);
288 x.clone()
289 };
290
291 let mut cross_attn_input =
293 MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
294 if let Some(mask_pad) = &input.memory_mask_pad {
295 cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
296 }
297 if let Some(mask_attn) = &input.memory_mask_attn {
298 cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
299 }
300 let residual_path = self.cross_attn.forward(cross_attn_input).context;
301
302 let residual_path = self.dropout.forward(residual_path);
303 let mut x = x + residual_path;
304
305 let residual_path = if self.norm_first {
308 self.norm_2.forward(x.clone())
309 } else {
310 x = self.norm_2.forward(x);
311 x.clone()
312 };
313
314 let residual_path = self.pwff.forward(residual_path);
315 let residual_path = self.dropout.forward(residual_path);
316 let mut x = x + residual_path;
317
318 if !self.norm_first {
321 x = self.norm_3.forward(x)
322 }
323
324 input.target = x;
325 input
326 }
327
328 fn forward_autoregressive_inference(
329 &self,
330 mut input: TransformerDecoderInput<B>,
331 cache: &mut TransformerDecoderLayerAutoregressiveCache<B>,
332 ) -> TransformerDecoderInput<B> {
333 let x = input.target;
335 let mut residual_path = x.clone();
336
337 if self.norm_first {
339 residual_path = cache
340 .norm_3
341 .forward_autoregressive(residual_path, 1, |x| self.norm_3.forward(x));
342 }
343
344 let mut self_attn_input = MhaInput::self_attn(residual_path);
346 if let Some(mask_pad) = &input.target_mask_pad {
347 self_attn_input = self_attn_input.mask_pad(mask_pad.clone());
348 }
349 if let Some(mask_attn) = &input.target_mask_attn {
350 self_attn_input = self_attn_input.mask_attn(mask_attn.clone());
351 }
352 let residual_path = self
353 .self_attn
354 .forward_cache(self_attn_input, &mut cache.self_attn)
355 .context;
356
357 let residual_path = self.dropout.forward(residual_path);
358 let mut x = x + residual_path;
359
360 let residual_path = if self.norm_first {
363 cache
364 .norm_1
365 .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
366 } else {
367 x = cache
368 .norm_1
369 .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
370 x.clone()
371 };
372
373 let mut cross_attn_input =
375 MhaInput::new(residual_path, input.memory.clone(), input.memory.clone());
376 if let Some(mask_pad) = &input.memory_mask_pad {
377 cross_attn_input = cross_attn_input.mask_pad(mask_pad.clone());
378 }
379 if let Some(mask_attn) = &input.memory_mask_attn {
380 cross_attn_input = cross_attn_input.mask_attn(mask_attn.clone());
381 }
382 let residual_path = self
383 .cross_attn
384 .forward_cache(cross_attn_input, &mut cache.cross_attn)
385 .context;
386
387 let residual_path = self.dropout.forward(residual_path);
388 let mut x = x + residual_path;
389
390 let residual_path = if self.norm_first {
393 cache
394 .norm_2
395 .forward_autoregressive(x.clone(), 1, |x| self.norm_2.forward(x))
396 } else {
397 x = cache
398 .norm_2
399 .forward_autoregressive(x, 1, |x| self.norm_2.forward(x));
400 x.clone()
401 };
402
403 let residual_path = cache
404 .pwff
405 .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
406 let residual_path = self.dropout.forward(residual_path);
407 let mut x = x + residual_path;
408
409 if !self.norm_first {
412 x = cache
413 .norm_3
414 .forward_autoregressive(x, 1, |x| self.norm_3.forward(x))
415 }
416
417 input.target = x;
418 input
419 }
420}
421
422impl<B: Backend> TransformerDecoder<B> {
423 pub fn forward(&self, mut input: TransformerDecoderInput<B>) -> Tensor<B, 3> {
425 for layer in self.layers.iter() {
426 input = layer.forward(input);
427 }
428
429 input.target
430 }
431
432 pub fn forward_autoregressive_inference(
434 &self,
435 mut input: TransformerDecoderInput<B>,
436 cache: &mut TransformerDecoderAutoregressiveCache<B>,
437 ) -> Tensor<B, 3> {
438 for i in 0..self.layers.len() {
439 let layer = self.layers.get(i).unwrap();
440 let cache = cache.layers.get_mut(i).unwrap();
441
442 input = layer.forward_autoregressive_inference(input, cache);
443 }
444
445 input.target
446 }
447 pub fn new_autoregressive_cache(&self) -> TransformerDecoderAutoregressiveCache<B> {
449 TransformerDecoderAutoregressiveCache::empty(self.layers.len())
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use burn::tensor::Device;
456
457 use super::*;
458 use crate::{TestBackend, attention::generate_autoregressive_mask};
459
460 use burn::tensor::{Tolerance, ops::FloatElem};
461 type FT = FloatElem<TestBackend>;
462
463 #[test]
464 fn test_autoregressive_norm_last() {
465 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
466 let device = Default::default();
467 TestBackend::seed(&device, 0);
468
469 test_autoregressive(
470 TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers)
471 .with_norm_first(false),
472 )
473 }
474
475 #[test]
476 fn test_autoregressive_norm_first() {
477 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
478 let device = Default::default();
479 TestBackend::seed(&device, 0);
480
481 test_autoregressive(
482 TransformerDecoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
483 )
484 }
485
486 fn test_autoregressive(config: TransformerDecoderConfig) {
487 let device: Device<TestBackend> = Default::default();
488 let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
489 let transformer = config.init::<TestBackend>(&device);
490
491 let memory = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
492 .float()
493 .reshape([batch_size, seq_length, d_model]);
494 let target = Tensor::arange(0..(batch_size * seq_length * d_model) as i64, &device)
495 .float()
496 .reshape([batch_size, seq_length, d_model]);
497 let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &target.device());
498 let input = TransformerDecoderInput::new(target.clone(), memory.clone())
499 .target_mask_attn(mask_attn);
500
501 let output_1 = transformer.forward(input);
503
504 let mut output_2 = Vec::new();
506 let mut cache = transformer.new_autoregressive_cache();
507
508 for i in 1..seq_length + 1 {
509 let target = target.clone().slice([0..batch_size, 0..i, 0..d_model]);
510
511 let mask_attn = generate_autoregressive_mask(batch_size, i, &target.device());
512 let input = TransformerDecoderInput::new(target.clone(), memory.clone())
513 .target_mask_attn(mask_attn);
514 let next_tok = transformer .forward_autoregressive_inference(input, &mut cache)
516 .slice([0..batch_size, i - 1..i, 0..d_model]);
517 output_2.push(next_tok);
518 }
519
520 let output_2 = Tensor::cat(output_2, 1);
521
522 let tolerance = Tolerance::rel_abs(5e-3, 1e-4);
524 output_1
525 .into_data()
526 .assert_approx_eq::<FT>(&output_2.into_data(), tolerance);
527 }
528
529 #[test]
530 fn display() {
531 let config = TransformerDecoderConfig::new(2, 4, 2, 3);
532 let transformer = config.init::<TestBackend>(&Default::default());
533
534 assert_eq!(
535 alloc::format!("{transformer}"),
536 "TransformerDecoder {d_model: 2, d_ff: 4, n_heads: 2, n_layers: 3, \
537 dropout: 0.1, norm_first: false, quiet_softmax: false, params: 246}"
538 );
539 }
540}