1use crate::tensor::Bool;
2use alloc::vec::Vec;
3
4use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
5use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
6use crate::{
7 self as burn,
8 nn::{Initializer, attention::MhaCache, cache::TensorCache},
9};
10use crate::{
11 config::Config,
12 nn::{
13 Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
14 attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
15 },
16 tensor::{Tensor, backend::Backend},
17};
18
19#[derive(Config)]
21pub struct TransformerEncoderConfig {
22 pub d_model: usize,
24 pub d_ff: usize,
26 pub n_heads: usize,
28 pub n_layers: usize,
30 #[config(default = 0.1)]
32 pub dropout: f64,
33 #[config(default = false)]
35 pub norm_first: bool,
36 #[config(default = false)]
43 pub quiet_softmax: bool,
44 #[config(
46 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
47 )]
48 pub initializer: Initializer,
49}
50
51#[derive(Module, Debug)]
59#[module(custom_display)]
60pub struct TransformerEncoder<B: Backend> {
61 pub layers: Vec<TransformerEncoderLayer<B>>,
63
64 pub d_model: usize,
66
67 pub d_ff: usize,
69
70 pub n_heads: usize,
72
73 pub n_layers: usize,
75
76 pub dropout: f64,
78
79 pub norm_first: bool,
81
82 pub quiet_softmax: bool,
84}
85
86impl<B: Backend> ModuleDisplay for TransformerEncoder<B> {
87 fn custom_settings(&self) -> Option<DisplaySettings> {
88 DisplaySettings::new()
89 .with_new_line_after_attribute(false)
90 .optional()
91 }
92
93 fn custom_content(&self, content: Content) -> Option<Content> {
94 content
95 .add("d_model", &self.d_model)
96 .add("d_ff", &self.d_ff)
97 .add("n_heads", &self.n_heads)
98 .add("n_layers", &self.n_layers)
99 .add("dropout", &self.dropout)
100 .add("norm_first", &self.norm_first)
101 .add("quiet_softmax", &self.quiet_softmax)
102 .optional()
103 }
104}
105
106#[derive(Debug)]
108pub struct TransformerEncoderInput<B: Backend> {
109 tensor: Tensor<B, 3>,
110 mask_pad: Option<Tensor<B, 2, Bool>>,
111 mask_attn: Option<Tensor<B, 3, Bool>>,
112}
113
114impl<B: Backend> TransformerEncoderInput<B> {
115 pub fn new(tensor: Tensor<B, 3>) -> Self {
117 Self {
118 tensor,
119 mask_pad: None,
120 mask_attn: None,
121 }
122 }
123
124 pub fn mask_pad(mut self, mask_pad: Tensor<B, 2, Bool>) -> Self {
126 self.mask_pad = Some(mask_pad);
127 self
128 }
129
130 pub fn mask_attn(mut self, mask_attn: Tensor<B, 3, Bool>) -> Self {
132 self.mask_attn = Some(mask_attn);
133 self
134 }
135}
136impl TransformerEncoderConfig {
137 pub fn init<B: Backend>(&self, device: &B::Device) -> TransformerEncoder<B> {
139 let layers = (0..self.n_layers)
140 .map(|_| TransformerEncoderLayer::new(self, device))
141 .collect::<Vec<_>>();
142
143 TransformerEncoder {
144 layers,
145 d_model: self.d_model,
146 d_ff: self.d_ff,
147 n_heads: self.n_heads,
148 n_layers: self.n_layers,
149 dropout: self.dropout,
150 norm_first: self.norm_first,
151 quiet_softmax: self.quiet_softmax,
152 }
153 }
154}
155
156impl<B: Backend> TransformerEncoder<B> {
157 pub fn forward(&self, input: TransformerEncoderInput<B>) -> Tensor<B, 3> {
164 let mut x = input.tensor;
165
166 for layer in self.layers.iter() {
167 x = layer.forward(x, input.mask_pad.clone(), input.mask_attn.clone());
168 }
169
170 x
171 }
172 pub fn forward_autoregressive_inference(
179 &self,
180 input: TransformerEncoderInput<B>,
181 cache: &mut TransformerEncoderAutoregressiveCache<B>,
182 ) -> Tensor<B, 3> {
183 let mut x = input.tensor;
184
185 for i in 0..self.layers.len() {
186 let layer = self.layers.get(i).unwrap();
187 let cache = cache.layers.get_mut(i).unwrap();
188
189 x = layer.forward_autoregressive_inference(
190 x,
191 input.mask_pad.clone(),
192 input.mask_attn.clone(),
193 cache,
194 );
195 }
196
197 x
198 }
199
200 pub fn new_autoregressive_cache(&self) -> TransformerEncoderAutoregressiveCache<B> {
202 TransformerEncoderAutoregressiveCache::empty(self.layers.len())
203 }
204}
205
206#[derive(Module, Debug)]
208pub struct TransformerEncoderLayer<B: Backend> {
209 mha: MultiHeadAttention<B>,
210 pwff: PositionWiseFeedForward<B>,
211 norm_1: LayerNorm<B>,
212 norm_2: LayerNorm<B>,
213 dropout: Dropout,
214 norm_first: bool,
215}
216
217impl<B: Backend> TransformerEncoderLayer<B> {
218 fn new(config: &TransformerEncoderConfig, device: &B::Device) -> Self {
219 let mha = MultiHeadAttentionConfig::new(config.d_model, config.n_heads)
220 .with_initializer(config.initializer.clone())
221 .with_dropout(config.dropout)
222 .with_quiet_softmax(config.quiet_softmax)
223 .init(device);
224 let norm_1 = LayerNormConfig::new(config.d_model).init(device);
225 let norm_2 = LayerNormConfig::new(config.d_model).init(device);
226 let dropout = DropoutConfig::new(config.dropout).init();
227 let pwff = PositionWiseFeedForwardConfig::new(config.d_model, config.d_ff)
228 .with_initializer(config.initializer.clone())
229 .with_dropout(config.dropout)
230 .init(device);
231
232 Self {
233 mha,
234 norm_1,
235 norm_2,
236 pwff,
237 dropout,
238 norm_first: config.norm_first,
239 }
240 }
241
242 fn forward(
243 &self,
244 input: Tensor<B, 3>,
245 mask_pad: Option<Tensor<B, 2, Bool>>,
246 mask_attn: Option<Tensor<B, 3, Bool>>,
247 ) -> Tensor<B, 3> {
248 let x = input;
250 let mut residual_path = x.clone();
251
252 if self.norm_first {
254 residual_path = self.norm_2.forward(residual_path)
255 }
256
257 let mut input_mhs = MhaInput::self_attn(residual_path);
259 if let Some(mask_pad) = mask_pad {
260 input_mhs = input_mhs.mask_pad(mask_pad);
261 }
262 if let Some(mask_attn) = mask_attn {
263 input_mhs = input_mhs.mask_attn(mask_attn);
264 }
265 let residual_path = self.mha.forward(input_mhs).context;
266
267 let residual_path = self.dropout.forward(residual_path);
268 let mut x = x + residual_path;
269
270 let residual_path = if self.norm_first {
273 self.norm_1.forward(x.clone())
274 } else {
275 x = self.norm_1.forward(x);
276 x.clone()
277 };
278
279 let residual_path = self.pwff.forward(residual_path);
281 let residual_path = self.dropout.forward(residual_path);
282 let mut x = x + residual_path;
283
284 if !self.norm_first {
287 x = self.norm_2.forward(x)
288 }
289
290 x
291 }
292
293 fn forward_autoregressive_inference(
294 &self,
295 input: Tensor<B, 3>,
296 mask_pad: Option<Tensor<B, 2, Bool>>,
297 mask_attn: Option<Tensor<B, 3, Bool>>,
298 cache: &mut TransformerEncoderLayerAutoregressiveCache<B>,
299 ) -> Tensor<B, 3> {
300 let x = input;
302 let mut residual_path = x.clone();
303
304 if self.norm_first {
306 residual_path = cache
307 .norm_2
308 .forward_autoregressive(residual_path, 1, |x| self.norm_2.forward(x))
309 }
310
311 let mut input_mhs = MhaInput::self_attn(residual_path);
313 if let Some(mask_pad) = mask_pad {
314 input_mhs = input_mhs.mask_pad(mask_pad);
315 }
316 if let Some(mask_attn) = mask_attn {
317 input_mhs = input_mhs.mask_attn(mask_attn);
318 }
319 let residual_path = self.mha.forward_cache(input_mhs, &mut cache.mha).context;
320
321 let residual_path = self.dropout.forward(residual_path);
322 let mut x = x + residual_path;
323
324 let residual_path = if self.norm_first {
327 cache
328 .norm_1
329 .forward_autoregressive(x.clone(), 1, |x| self.norm_1.forward(x))
330 } else {
331 x = cache
332 .norm_1
333 .forward_autoregressive(x, 1, |x| self.norm_1.forward(x));
334 x.clone()
335 };
336
337 let residual_path = cache
339 .pwff
340 .forward_autoregressive(residual_path, 1, |x| self.pwff.forward(x));
341 let residual_path = self.dropout.forward(residual_path);
342 let mut x = x + residual_path;
343
344 if !self.norm_first {
347 x = cache
348 .norm_2
349 .forward_autoregressive(x, 1, |x| self.norm_2.forward(x))
350 }
351
352 x
353 }
354}
355
356struct TransformerEncoderLayerAutoregressiveCache<B: Backend> {
357 mha: MhaCache<B>,
358 pwff: TensorCache<B, 3>,
359 norm_1: TensorCache<B, 3>,
360 norm_2: TensorCache<B, 3>,
361}
362
363impl<B: Backend> TransformerEncoderLayerAutoregressiveCache<B> {
364 fn empty() -> Self {
365 Self {
366 mha: MhaCache::autoregressive(),
367 pwff: TensorCache::empty(),
368 norm_1: TensorCache::empty(),
369 norm_2: TensorCache::empty(),
370 }
371 }
372}
373
374pub struct TransformerEncoderAutoregressiveCache<B: Backend> {
378 layers: Vec<TransformerEncoderLayerAutoregressiveCache<B>>,
379}
380
381impl<B: Backend> TransformerEncoderAutoregressiveCache<B> {
382 fn empty(num_layers: usize) -> Self {
383 Self {
384 layers: (0..num_layers)
385 .map(|_| TransformerEncoderLayerAutoregressiveCache::empty())
386 .collect(),
387 }
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use crate::tensor::Distribution;
395 use crate::{TestBackend, nn::attention::generate_autoregressive_mask};
396 use burn_tensor::{Tolerance, ops::FloatElem};
397 type FT = FloatElem<TestBackend>;
398
399 #[test]
400 fn test_autoregressive_norm_last() {
401 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
402 test_autoregressive(
403 TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers)
404 .with_norm_first(false),
405 )
406 }
407
408 #[test]
409 fn test_autoregressive_norm_first() {
410 let [d_model, d_ff, n_heads, num_layers] = [12, 24, 2, 3];
411 test_autoregressive(
412 TransformerEncoderConfig::new(d_model, d_ff, n_heads, num_layers).with_norm_first(true),
413 )
414 }
415
416 fn test_autoregressive(config: TransformerEncoderConfig) {
417 let [batch_size, seq_length, d_model] = [3, 4, config.d_model];
418 let device = Default::default();
419 let transformer = config.init(&device);
420
421 let tensor = Tensor::<TestBackend, 3>::random(
422 [batch_size, seq_length, d_model],
423 Distribution::Default,
424 &device,
425 );
426 let mask_attn = generate_autoregressive_mask(batch_size, seq_length, &tensor.device());
427 let input = TransformerEncoderInput::new(tensor.clone()).mask_attn(mask_attn);
428
429 let output_1 = transformer.forward(input);
430 let mut output_2 = Vec::new();
431 let mut cache = transformer.new_autoregressive_cache();
432
433 for i in 1..seq_length + 1 {
434 let tensor = tensor.clone().slice([0..batch_size, 0..i, 0..d_model]);
435 let input = TransformerEncoderInput::new(tensor.clone());
436 let next_tok = transformer
437 .forward_autoregressive_inference(input, &mut cache)
438 .slice([0..batch_size, i - 1..i, 0..d_model]);
439 output_2.push(next_tok);
440 }
441
442 let output_2 = Tensor::cat(output_2, 1);
443
444 output_1
445 .into_data()
446 .assert_approx_eq::<FT>(&output_2.into_data(), Tolerance::rel_abs(1e-4, 1e-4));
447 }
448
449 #[test]
450 fn display() {
451 let config = TransformerEncoderConfig::new(2, 4, 2, 3);
452 let transformer = config.init::<TestBackend>(&Default::default());
453
454 assert_eq!(
455 alloc::format!("{}", transformer),
456 "TransformerEncoder {d_model: 2, d_ff: 4, n_heads: 2, \
457 n_layers: 3, dropout: 0.1, norm_first: false, quiet_softmax: false, params: 162}"
458 );
459 }
460}