1use burn_core as burn;
2
3use alloc::vec::Vec;
4
5use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
6use crate::{
7 Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
8 activation::ActivationConfig,
9 attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
10 cache::TensorCache,
11};
12use burn::config::Config;
13use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
14use burn::tensor::{Bool, Tensor, backend::Backend};
15
16#[derive(Config, Debug)]
18pub struct TransformerEncoderConfig {
19 pub d_model: usize,
21 pub d_ff: usize,
23 pub n_heads: usize,
25 pub n_layers: usize,
27 #[config(default = 0.1)]
29 pub dropout: f64,
30 #[config(default = false)]
32 pub norm_first: bool,
33 #[config(default = false)]
40 pub quiet_softmax: bool,
41 #[config(
43 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
44 )]
45 pub initializer: Initializer,
46 #[config(default = "ActivationConfig::Gelu")]
48 pub activation: ActivationConfig,
49 #[config(default = 1e-5)]
51 pub layer_norm_eps: f64,
52}
53
54#[derive(Module, Debug)]
62#[module(custom_display)]
63pub struct TransformerEncoder<B: Backend> {
64 pub layers: Vec<TransformerEncoderLayer<B>>,
66
67 pub d_model: usize,
69
70 pub d_ff: usize,
72
73 pub n_heads: usize,
75
76 pub n_layers: usize,
78
79 pub dropout: f64,
81
82 pub norm_first: bool,
84
85 pub quiet_softmax: bool,
87}
88
89impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
90 fn custom_settings(&self) -> Option<DisplaySettings> {
91 DisplaySettings::new()
92 .with_new_line_after_attribute(false)
93 .optional()
94 }
95
96 fn custom_content(&self, content: Content) -> Option<Content> {
97 content
98 .add("d_model", &self.d_model)
99 .add("d_ff", &self.d_ff)
100 .add("n_heads", &self.n_heads)
101 .add("n_layers", &self.n_layers)
102 .add("dropout", &self.dropout)
103 .add("norm_first", &self.norm_first)
104 .add("quiet_softmax", &self.quiet_softmax)
105 .optional()
106 }
107}
108
109#[derive(Debug)]
111pub struct TransformerEncoderInput<B: Backend> {
112 tensor: Tensor<B, 3>,
113 mask_pad: Option<Tensor<B, 2, Bool>>,
114 mask_attn: Option<Tensor<B, 3, Bool>>,
115}
116
117impl<B: Backend> TransformerEncoderInput<B> {
118 pub fn new(tensor: Tensor<B, 3>) -> Self {
120 Self {
121 tensor,
122 mask_pad: None,
123 mask_attn: None,
124 }
125 }
126
127 pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
129 self.mask_pad = Some(mask_pad);
130 self
131 }
132
133 pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
135 self.mask_attn = Some(mask_attn);
136 self
137 }
138}
139impl TransformerEncoderConfig {
140 pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {
142 let layers = (0..self.n_layers)
143 .map(|_| TransformerEncoderLayer::new(self, device))
144 .collect::<Vec<_>>();
145
146 TransformerEncoder {
147 layers,
148 d_model: self.d_model,
149 d_ff: self.d_ff,
150 n_heads: self.n_heads,
151 n_layers: self.n_layers,
152 dropout: self.dropout,
153 norm_first: self.norm_first,
154 quiet_softmax: self.quiet_softmax,
155 }
156 }
157}
158
159impl<B: Backend> TransformerEncoder<B> {
160 pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {
167 let mut x = input.tensor;
168
169 for layer in self.layers.iter() {
170 x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());
171 }
172
173 x
174 }
175 pub fn forward_autoregressive_inference(
182 &self,
183 input: TransformerEncoderInput<B>,
184 cache: &mut TransformerEncoderAutoregressiveCache<B>,
185 ) -> Tensor<B, 3> {
186 let mut x = input.tensor;
187
188 for i in 0..self.layers.len() {
189 let layer = self.layers.get(i).unwrap();
190 let cache = cache.layers.get_mut(i).unwrap();
191
192 x = layer.forward_autoregressive_inference(
193 x,
194 input.mask_pad.clone(),
195 input.mask_attn.clone(),
196 cache,
197 );
198 }
199
200 x
201 }
202
203 pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
205 TransformerEncoderAutoregressiveCache::empty(self.layers.len())
206 }
207}
208
209#[derive(Module, Debug)]
211pub struct TransformerEncoderLayer<B: Backend> {
212 pub mha: MultiHeadAttention<B>,
214 pub pwff: PositionWiseFeedForward<B>,
216 pub norm_1: LayerNorm<B>,
218 pub norm_2: LayerNorm<B>,
220 pub dropout: Dropout,
222 pub norm_first: bool,
225}
226
227impl<B: Backend> TransformerEncoderLayer<B> {
228 pub fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {
230 let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
231 .with_initializer(config.initializer.clone())
232 .with_dropout(config.dropout)
233 .with_quiet_softmax(config.quiet_softmax)
234 .init(device);
235 let norm_1 = LayerNormConfig::new(config.d_model)
236 .with_epsilon(config.layer_norm_eps)
237 .init(device);
238 let norm_2 = LayerNormConfig::new(config.d_model)
239 .with_epsilon(config.layer_norm_eps)
240 .init(device);
241 let dropout = DropoutConfig::new(config.dropout).init();
242 let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
243 .with_initializer(config.initializer.clone())
244 .with_dropout(config.dropout)
245 .with_activation(config.activation.clone())
246 .init(device);
247
248 Self {
249 mha,
250 norm_1,
251 norm_2,
252 pwff,
253 dropout,
254 norm_first: config.norm_first,
255 }
256 }
257
258 pub fn forward(
265 &self,
266 input: Tensor<B, 3>,
267 mask_pad: Option<Tensor<B, 2, Bool>>,
268 mask_attn: Option<Tensor<B, 3, Bool>>,
269 ) -> Tensor<B, 3> {
270 let x = input;
272 let mut residual_path = x.clone();
273
274 if self.norm_first {
276 residual_path = self.norm_2.forward(residual_path)
277 }
278
279 let mut input_mhs = MhaInput::self_attn(residual_path);
281 if let Some(mask_pad) = mask_pad {
282 input_mhs = input_mhs.mask_pad(mask_pad);
283 }
284 if let Some(mask_attn) = mask_attn {
285 input_mhs = input_mhs.mask_attn(mask_attn);
286 }
287 let residual_path = self.mha.forward(input_mhs).context;
288
289 let residual_path = self.dropout.forward(residual_path);
290 let mut x = x + residual_path;
291
292 let residual_path = if self.norm_first {
295 self.norm_1.forward(x.clone())
296 } else {
297 x = self.norm_1.forward(x);
298 x.clone()
299 };
300
301 let residual_path = self.pwff.forward(residual_path);
303 let residual_path = self.dropout.forward(residual_path);
304 let mut x = x + residual_path;
305
306 if !self.norm_first {
309 x = self.norm_2.forward(x)
310 }
311
312 x
313 }
314
315 pub fn forward_autoregressive_inference(
317 &self,
318 input: Tensor<B, 3>,
319 mask_pad: Option<Tensor<B, 2, Bool>>,
320 mask_attn: Option<Tensor<B, 3, Bool>>,
321 cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
322 ) -> Tensor<B, 3> {
323 let x = input;
325 let mut residual_path = x.clone();
326
327 if self.norm_first {
329 residual_path = cache
330 .norm_2
331 .forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))
332 }
333
334 let mut input_mhs = MhaInput::self_attn(residual_path);
336 if let Some(mask_pad) = mask_pad {
337 input_mhs = input_mhs.mask_pad(mask_pad);
338 }
339 if let Some(mask_attn) = mask_attn {
340 input_mhs = input_mhs.mask_attn(mask_attn);
341 }
342 let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;
343
344 let residual_path = self.dropout.forward(residual_path);
345 let mut x = x + residual_path;
346
347 let residual_path = if self.norm_first {
350 cache
351 .norm_1
352 .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
353 } else {
354 x = cache
355 .norm_1
356 .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
357 x.clone()
358 };
359
360 let residual_path = cache
362 .pwff
363 .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
364 let residual_path = self.dropout.forward(residual_path);
365 let mut x = x + residual_path;
366
367 if !self.norm_first {
370 x = cache
371 .norm_2
372 .forward_autoregressive(x, 1, |x| self.norm_2.forward(x))
373 }
374
375 x
376 }
377}
378
379pub struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
381 pub mha: MhaCache<B>,
383 pub pwff: TensorCache<B, 3>,
385 pub norm_1: TensorCache<B, 3>,
387 pub norm_2: TensorCache<B, 3>,
389}
390
391impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
392 pub fn empty() -> Self {
394 Self {
395 mha: MhaCache::autoregressive(),
396 pwff: TensorCache::empty(),
397 norm_1: TensorCache::empty(),
398 norm_2: TensorCache::empty(),
399 }
400 }
401}
402
403pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
407 layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
408}
409
410impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
411 fn empty(num_layers: usize) -> Self {
412 Self {
413 layers: (0..num_layers)
414 .map(|_| TransformerEncoderLayerAutoregressiveCache::empty())
415 .collect(),
416 }
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use crate::{TestBackend, attention::generate_autoregressive_mask};
424 use burn::tensor::Distribution;
425 use burn::tensor::{Tolerance, ops::FloatElem};
426 type FT = FloatElem<TestBackend>;
427
428 #[test]
429 fn test_autoregressive_norm_last() {
430 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
431 test_autoregressive(
432 TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)
433 .with_norm_first(false),
434 )
435 }
436
437 #[test]
438 fn test_autoregressive_norm_first() {
439 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
440 test_autoregressive(
441 TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
442 )
443 }
444
445 fn test_autoregressive(config: TransformerEncoderConfig) {
446 let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
447 let device = Default::default();
448 let transformer = config.init(&device);
449
450 let tensor = Tensor::<TestBackend, 3>::random(
451 [batch_size, seq_length, d_model],
452 Distribution::Default,
453 &device,
454 );
455 let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
456 let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
457
458 let output_1 = transformer.forward(input);
459 let mut output_2 = Vec::new();
460 let mut cache = transformer.new_autoregressive_cache();
461
462 for i in 1..seq_length + 1 {
463 let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
464 let input = TransformerEncoderInput::new(tensor.clone());
465 let next_tok = transformer
466 .forward_autoregressive_inference(input, &mut cache)
467 .slice([0..batch_size, i - 1..i, 0..d_model]);
468 output_2.push(next_tok);
469 }
470
471 let output_2 = Tensor::cat(output_2, 1);
472
473 output_1
474 .into_data()
475 .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::permissive());
476 }
477
478 #[test]
479 fn display() {
480 let config = TransformerEncoderConfig::new(2, 4, 2, 3);
481 let transformer = config.init::<TestBackend>(&Default::default());
482
483 assert_eq!(
484 alloc::format!("{transformer}"),
485 "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
486 n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
487 );
488 }
489}