1use burn::module::{Ignored, Module};
2use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig, RmsNorm, RmsNormConfig};
3use burn::prelude::*;
4
5use crate::attention::{MultiHeadAttention, MultiHeadAttentionConfig};
6use crate::config::DdlConfig;
7use crate::error::BaselineConfigError;
8use crate::generation::{
9 AutoregressiveModel, GenerationConfig, GenerationError, GenerationResult, generate_tokens,
10};
11use crate::mlp::{SwiGluMlp, SwiGluMlpConfig};
12use crate::utils::create_causal_mask;
13
14#[derive(Module, Debug)]
15pub struct BaselineTransformerBlock<B: Backend> {
16 attn_norm: RmsNorm<B>,
17 attn: MultiHeadAttention<B>,
18 mlp_norm: RmsNorm<B>,
19 mlp: SwiGluMlp<B>,
20}
21
22impl<B: Backend> BaselineTransformerBlock<B> {
23 pub fn new(config: &DdlConfig, device: &B::Device) -> Self {
24 let attn = MultiHeadAttentionConfig::new(
25 config.d_model,
26 config.num_heads,
27 config.effective_head_dim(),
28 config.max_seq_len,
29 )
30 .with_rope_theta(config.rope_theta)
31 .init(device);
32 let mlp = SwiGluMlpConfig::new(config.d_model, config.effective_mlp_hidden()).init(device);
33
34 Self {
35 attn_norm: RmsNormConfig::new(config.d_model)
36 .with_epsilon(1e-6)
37 .init(device),
38 attn,
39 mlp_norm: RmsNormConfig::new(config.d_model)
40 .with_epsilon(1e-6)
41 .init(device),
42 mlp,
43 }
44 }
45
46 pub fn forward_hidden(&self, x: Tensor<B, 3>, mask: Option<&Tensor<B, 3>>) -> Tensor<B, 3> {
47 let residual = x.clone();
48 let x_ctx = self.attn_norm.forward(residual.clone());
49 let x = residual + self.attn.forward(x_ctx, mask);
50
51 let residual = x.clone();
52 let x_ctx = self.mlp_norm.forward(residual.clone());
53 residual + self.mlp.forward(x_ctx)
54 }
55}
56
57#[derive(Module, Debug)]
58pub struct BaselineTransformer<B: Backend> {
59 embedding: Embedding<B>,
60 blocks: Vec<BaselineTransformerBlock<B>>,
61 final_norm: RmsNorm<B>,
62 lm_head: Linear<B>,
63 config: Ignored<DdlConfig>,
64}
65
66impl<B: Backend> BaselineTransformer<B> {
67 pub fn new(config: &DdlConfig, device: &B::Device) -> Self {
68 Self::try_new(config, device)
69 .unwrap_or_else(|error| panic!("invalid baseline transformer configuration: {error}"))
70 }
71
72 pub fn try_new(config: &DdlConfig, device: &B::Device) -> Result<Self, BaselineConfigError> {
73 config.validate()?;
74 if config.d_value != 1 {
75 return Err(BaselineConfigError::UnsupportedValueDimension {
76 d_value: config.d_value,
77 });
78 }
79
80 let blocks = (0..config.num_layers)
81 .map(|_| BaselineTransformerBlock::new(config, device))
82 .collect();
83
84 Ok(Self {
85 embedding: EmbeddingConfig::new(config.vocab_size, config.d_model).init(device),
86 blocks,
87 final_norm: RmsNormConfig::new(config.d_model)
88 .with_epsilon(1e-6)
89 .init(device),
90 lm_head: LinearConfig::new(config.d_model, config.vocab_size)
91 .with_bias(false)
92 .init(device),
93 config: Ignored(config.clone()),
94 })
95 }
96
97 fn resolve_mask(
98 &self,
99 input_ids: &Tensor<B, 2, Int>,
100 mask: Option<&Tensor<B, 3>>,
101 ) -> Tensor<B, 3> {
102 match mask {
103 Some(mask) => mask.clone(),
104 None => {
105 let [batch_size, seq_len] = input_ids.dims();
106 create_causal_mask::<B>(batch_size, seq_len, &input_ids.device())
107 }
108 }
109 }
110
111 pub fn forward_hidden(
112 &self,
113 input_ids: Tensor<B, 2, Int>,
114 mask: Option<&Tensor<B, 3>>,
115 ) -> Tensor<B, 3> {
116 let mask = self.resolve_mask(&input_ids, mask);
117 let mut hidden = self.embedding.forward(input_ids);
118 for block in &self.blocks {
119 hidden = block.forward_hidden(hidden, Some(&mask));
120 }
121 self.final_norm.forward(hidden)
122 }
123
124 pub fn forward_logits(
125 &self,
126 input_ids: Tensor<B, 2, Int>,
127 mask: Option<&Tensor<B, 3>>,
128 ) -> Tensor<B, 3> {
129 self.lm_head.forward(self.forward_hidden(input_ids, mask))
130 }
131
132 pub fn forward(
133 &self,
134 input_ids: Tensor<B, 2, Int>,
135 mask: Option<&Tensor<B, 3>>,
136 ) -> Tensor<B, 3> {
137 self.forward_logits(input_ids, mask)
138 }
139
140 pub fn max_seq_len(&self) -> usize {
141 self.config.max_seq_len
142 }
143
144 pub fn generate(
145 &self,
146 prompt_tokens: &[usize],
147 generation_config: &GenerationConfig,
148 device: &B::Device,
149 ) -> Result<GenerationResult, GenerationError> {
150 generate_tokens(self, prompt_tokens, generation_config, device)
151 }
152}
153
154impl<B: Backend> AutoregressiveModel<B> for BaselineTransformer<B> {
155 fn forward_logits(
156 &self,
157 input_ids: Tensor<B, 2, Int>,
158 mask: Option<&Tensor<B, 3>>,
159 ) -> Tensor<B, 3> {
160 BaselineTransformer::forward_logits(self, input_ids, mask)
161 }
162
163 fn max_seq_len(&self) -> usize {
164 BaselineTransformer::max_seq_len(self)
165 }
166}