1use burn_core as burn;
2
3use alloc::vec::Vec;
4
5use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
6use crate::{
7 Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
8 attention::{MhaCache, MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
9 cache::TensorCache,
10};
11use burn::config::Config;
12use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
13use burn::tensor::{Bool, Tensor, backend::Backend};
14
15#[derive(Config, Debug)]
17pub struct TransformerEncoderConfig {
18 pub d_model: usize,
20 pub d_ff: usize,
22 pub n_heads: usize,
24 pub n_layers: usize,
26 #[config(default = 0.1)]
28 pub dropout: f64,
29 #[config(default = false)]
31 pub norm_first: bool,
32 #[config(default = false)]
39 pub quiet_softmax: bool,
40 #[config(
42 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
43 )]
44 pub initializer: Initializer,
45}
46
47#[derive(Module, Debug)]
55#[module(custom_display)]
56pub struct TransformerEncoder<B: Backend> {
57 pub layers: Vec<TransformerEncoderLayer<B>>,
59
60 pub d_model: usize,
62
63 pub d_ff: usize,
65
66 pub n_heads: usize,
68
69 pub n_layers: usize,
71
72 pub dropout: f64,
74
75 pub norm_first: bool,
77
78 pub quiet_softmax: bool,
80}
81
82impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
83 fn custom_settings(&self) -> Option<DisplaySettings> {
84 DisplaySettings::new()
85 .with_new_line_after_attribute(false)
86 .optional()
87 }
88
89 fn custom_content(&self, content: Content) -> Option<Content> {
90 content
91 .add("d_model", &self.d_model)
92 .add("d_ff", &self.d_ff)
93 .add("n_heads", &self.n_heads)
94 .add("n_layers", &self.n_layers)
95 .add("dropout", &self.dropout)
96 .add("norm_first", &self.norm_first)
97 .add("quiet_softmax", &self.quiet_softmax)
98 .optional()
99 }
100}
101
102#[derive(Debug)]
104pub struct TransformerEncoderInput<B: Backend> {
105 tensor: Tensor<B, 3>,
106 mask_pad: Option<Tensor<B, 2, Bool>>,
107 mask_attn: Option<Tensor<B, 3, Bool>>,
108}
109
110impl<B: Backend> TransformerEncoderInput<B> {
111 pub fn new(tensor: Tensor<B, 3>) -> Self {
113 Self {
114 tensor,
115 mask_pad: None,
116 mask_attn: None,
117 }
118 }
119
120 pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
122 self.mask_pad = Some(mask_pad);
123 self
124 }
125
126 pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
128 self.mask_attn = Some(mask_attn);
129 self
130 }
131}
132impl TransformerEncoderConfig {
133 pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {
135 let layers = (0..self.n_layers)
136 .map(|_| TransformerEncoderLayer::new(self, device))
137 .collect::<Vec<_>>();
138
139 TransformerEncoder {
140 layers,
141 d_model: self.d_model,
142 d_ff: self.d_ff,
143 n_heads: self.n_heads,
144 n_layers: self.n_layers,
145 dropout: self.dropout,
146 norm_first: self.norm_first,
147 quiet_softmax: self.quiet_softmax,
148 }
149 }
150}
151
152impl<B: Backend> TransformerEncoder<B> {
153 pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {
160 let mut x = input.tensor;
161
162 for layer in self.layers.iter() {
163 x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());
164 }
165
166 x
167 }
168 pub fn forward_autoregressive_inference(
175 &self,
176 input: TransformerEncoderInput<B>,
177 cache: &mut TransformerEncoderAutoregressiveCache<B>,
178 ) -> Tensor<B, 3> {
179 let mut x = input.tensor;
180
181 for i in 0..self.layers.len() {
182 let layer = self.layers.get(i).unwrap();
183 let cache = cache.layers.get_mut(i).unwrap();
184
185 x = layer.forward_autoregressive_inference(
186 x,
187 input.mask_pad.clone(),
188 input.mask_attn.clone(),
189 cache,
190 );
191 }
192
193 x
194 }
195
196 pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
198 TransformerEncoderAutoregressiveCache::empty(self.layers.len())
199 }
200}
201
202#[derive(Module, Debug)]
204pub struct TransformerEncoderLayer<B: Backend> {
205 pub mha: MultiHeadAttention<B>,
207 pub pwff: PositionWiseFeedForward<B>,
209 pub norm_1: LayerNorm<B>,
211 pub norm_2: LayerNorm<B>,
213 pub dropout: Dropout,
215 pub norm_first: bool,
218}
219
220impl<B: Backend> TransformerEncoderLayer<B> {
221 fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {
222 let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
223 .with_initializer(config.initializer.clone())
224 .with_dropout(config.dropout)
225 .with_quiet_softmax(config.quiet_softmax)
226 .init(device);
227 let norm_1 = LayerNormConfig::new(config.d_model).init(device);
228 let norm_2 = LayerNormConfig::new(config.d_model).init(device);
229 let dropout = DropoutConfig::new(config.dropout).init();
230 let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
231 .with_initializer(config.initializer.clone())
232 .with_dropout(config.dropout)
233 .init(device);
234
235 Self {
236 mha,
237 norm_1,
238 norm_2,
239 pwff,
240 dropout,
241 norm_first: config.norm_first,
242 }
243 }
244
245 fn forward(
246 &self,
247 input: Tensor<B, 3>,
248 mask_pad: Option<Tensor<B, 2, Bool>>,
249 mask_attn: Option<Tensor<B, 3, Bool>>,
250 ) -> Tensor<B, 3> {
251 let x = input;
253 let mut residual_path = x.clone();
254
255 if self.norm_first {
257 residual_path = self.norm_2.forward(residual_path)
258 }
259
260 let mut input_mhs = MhaInput::self_attn(residual_path);
262 if let Some(mask_pad) = mask_pad {
263 input_mhs = input_mhs.mask_pad(mask_pad);
264 }
265 if let Some(mask_attn) = mask_attn {
266 input_mhs = input_mhs.mask_attn(mask_attn);
267 }
268 let residual_path = self.mha.forward(input_mhs).context;
269
270 let residual_path = self.dropout.forward(residual_path);
271 let mut x = x + residual_path;
272
273 let residual_path = if self.norm_first {
276 self.norm_1.forward(x.clone())
277 } else {
278 x = self.norm_1.forward(x);
279 x.clone()
280 };
281
282 let residual_path = self.pwff.forward(residual_path);
284 let residual_path = self.dropout.forward(residual_path);
285 let mut x = x + residual_path;
286
287 if !self.norm_first {
290 x = self.norm_2.forward(x)
291 }
292
293 x
294 }
295
296 fn forward_autoregressive_inference(
297 &self,
298 input: Tensor<B, 3>,
299 mask_pad: Option<Tensor<B, 2, Bool>>,
300 mask_attn: Option<Tensor<B, 3, Bool>>,
301 cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
302 ) -> Tensor<B, 3> {
303 let x = input;
305 let mut residual_path = x.clone();
306
307 if self.norm_first {
309 residual_path = cache
310 .norm_2
311 .forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))
312 }
313
314 let mut input_mhs = MhaInput::self_attn(residual_path);
316 if let Some(mask_pad) = mask_pad {
317 input_mhs = input_mhs.mask_pad(mask_pad);
318 }
319 if let Some(mask_attn) = mask_attn {
320 input_mhs = input_mhs.mask_attn(mask_attn);
321 }
322 let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;
323
324 let residual_path = self.dropout.forward(residual_path);
325 let mut x = x + residual_path;
326
327 let residual_path = if self.norm_first {
330 cache
331 .norm_1
332 .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
333 } else {
334 x = cache
335 .norm_1
336 .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
337 x.clone()
338 };
339
340 let residual_path = cache
342 .pwff
343 .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
344 let residual_path = self.dropout.forward(residual_path);
345 let mut x = x + residual_path;
346
347 if !self.norm_first {
350 x = cache
351 .norm_2
352 .forward_autoregressive(x, 1, |x| self.norm_2.forward(x))
353 }
354
355 x
356 }
357}
358
359struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
360 mha: MhaCache<B>,
361 pwff: TensorCache<B, 3>,
362 norm_1: TensorCache<B, 3>,
363 norm_2: TensorCache<B, 3>,
364}
365
366impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
367 fn empty() -> Self {
368 Self {
369 mha: MhaCache::autoregressive(),
370 pwff: TensorCache::empty(),
371 norm_1: TensorCache::empty(),
372 norm_2: TensorCache::empty(),
373 }
374 }
375}
376
377pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
381 layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
382}
383
384impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
385 fn empty(num_layers: usize) -> Self {
386 Self {
387 layers: (0..num_layers)
388 .map(|_| TransformerEncoderLayerAutoregressiveCache::empty())
389 .collect(),
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::{TestBackend, attention::generate_autoregressive_mask};
398 use burn::tensor::Distribution;
399 use burn::tensor::{Tolerance, ops::FloatElem};
400 type FT = FloatElem<TestBackend>;
401
402 #[test]
403 fn test_autoregressive_norm_last() {
404 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
405 test_autoregressive(
406 TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)
407 .with_norm_first(false),
408 )
409 }
410
411 #[test]
412 fn test_autoregressive_norm_first() {
413 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
414 test_autoregressive(
415 TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
416 )
417 }
418
419 fn test_autoregressive(config: TransformerEncoderConfig) {
420 let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
421 let device = Default::default();
422 let transformer = config.init(&device);
423
424 let tensor = Tensor::<TestBackend, 3>::random(
425 [batch_size, seq_length, d_model],
426 Distribution::Default,
427 &device,
428 );
429 let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
430 let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
431
432 let output_1 = transformer.forward(input);
433 let mut output_2 = Vec::new();
434 let mut cache = transformer.new_autoregressive_cache();
435
436 for i in 1..seq_length + 1 {
437 let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
438 let input = TransformerEncoderInput::new(tensor.clone());
439 let next_tok = transformer
440 .forward_autoregressive_inference(input, &mut cache)
441 .slice([0..batch_size, i - 1..i, 0..d_model]);
442 output_2.push(next_tok);
443 }
444
445 let output_2 = Tensor::cat(output_2, 1);
446
447 output_1
448 .into_data()
449 .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::permissive());
450 }
451
452 #[test]
453 fn display() {
454 let config = TransformerEncoderConfig::new(2, 4, 2, 3);
455 let transformer = config.init::<TestBackend>(&Default::default());
456
457 assert_eq!(
458 alloc::format!("{transformer}"),
459 "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
460 n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
461 );
462 }
463}