Skip to main content

deep_delta_learning/
baseline.rs

1use burn::module::{Ignored, Module};
2use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig, RmsNorm, RmsNormConfig};
3use burn::prelude::*;
4
5use crate::attention::{MultiHeadAttention, MultiHeadAttentionConfig};
6use crate::config::DdlConfig;
7use crate::error::BaselineConfigError;
8use crate::generation::{
9    AutoregressiveModel, GenerationConfig, GenerationError, GenerationResult, generate_tokens,
10};
11use crate::mlp::{SwiGluMlp, SwiGluMlpConfig};
12use crate::utils::create_causal_mask;
13
14#[derive(Module, Debug)]
15pub struct BaselineTransformerBlock<B: Backend> {
16    attn_norm: RmsNorm<B>,
17    attn: MultiHeadAttention<B>,
18    mlp_norm: RmsNorm<B>,
19    mlp: SwiGluMlp<B>,
20}
21
22impl<B: Backend> BaselineTransformerBlock<B> {
23    pub fn new(config: &DdlConfig, device: &B::Device) -> Self {
24        let attn = MultiHeadAttentionConfig::new(
25            config.d_model,
26            config.num_heads,
27            config.effective_head_dim(),
28            config.max_seq_len,
29        )
30        .with_rope_theta(config.rope_theta)
31        .init(device);
32        let mlp = SwiGluMlpConfig::new(config.d_model, config.effective_mlp_hidden()).init(device);
33
34        Self {
35            attn_norm: RmsNormConfig::new(config.d_model)
36                .with_epsilon(1e-6)
37                .init(device),
38            attn,
39            mlp_norm: RmsNormConfig::new(config.d_model)
40                .with_epsilon(1e-6)
41                .init(device),
42            mlp,
43        }
44    }
45
46    pub fn forward_hidden(&self, x: Tensor<B, 3>, mask: Option<&Tensor<B, 3>>) -> Tensor<B, 3> {
47        let residual = x.clone();
48        let x_ctx = self.attn_norm.forward(residual.clone());
49        let x = residual + self.attn.forward(x_ctx, mask);
50
51        let residual = x.clone();
52        let x_ctx = self.mlp_norm.forward(residual.clone());
53        residual + self.mlp.forward(x_ctx)
54    }
55}
56
57#[derive(Module, Debug)]
58pub struct BaselineTransformer<B: Backend> {
59    embedding: Embedding<B>,
60    blocks: Vec<BaselineTransformerBlock<B>>,
61    final_norm: RmsNorm<B>,
62    lm_head: Linear<B>,
63    config: Ignored<DdlConfig>,
64}
65
66impl<B: Backend> BaselineTransformer<B> {
67    pub fn new(config: &DdlConfig, device: &B::Device) -> Self {
68        Self::try_new(config, device)
69            .unwrap_or_else(|error| panic!("invalid baseline transformer configuration: {error}"))
70    }
71
72    pub fn try_new(config: &DdlConfig, device: &B::Device) -> Result<Self, BaselineConfigError> {
73        config.validate()?;
74        if config.d_value != 1 {
75            return Err(BaselineConfigError::UnsupportedValueDimension {
76                d_value: config.d_value,
77            });
78        }
79
80        let blocks = (0..config.num_layers)
81            .map(|_| BaselineTransformerBlock::new(config, device))
82            .collect();
83
84        Ok(Self {
85            embedding: EmbeddingConfig::new(config.vocab_size, config.d_model).init(device),
86            blocks,
87            final_norm: RmsNormConfig::new(config.d_model)
88                .with_epsilon(1e-6)
89                .init(device),
90            lm_head: LinearConfig::new(config.d_model, config.vocab_size)
91                .with_bias(false)
92                .init(device),
93            config: Ignored(config.clone()),
94        })
95    }
96
97    fn resolve_mask(
98        &self,
99        input_ids: &Tensor<B, 2, Int>,
100        mask: Option<&Tensor<B, 3>>,
101    ) -> Tensor<B, 3> {
102        match mask {
103            Some(mask) => mask.clone(),
104            None => {
105                let [batch_size, seq_len] = input_ids.dims();
106                create_causal_mask::<B>(batch_size, seq_len, &input_ids.device())
107            }
108        }
109    }
110
111    pub fn forward_hidden(
112        &self,
113        input_ids: Tensor<B, 2, Int>,
114        mask: Option<&Tensor<B, 3>>,
115    ) -> Tensor<B, 3> {
116        let mask = self.resolve_mask(&input_ids, mask);
117        let mut hidden = self.embedding.forward(input_ids);
118        for block in &self.blocks {
119            hidden = block.forward_hidden(hidden, Some(&mask));
120        }
121        self.final_norm.forward(hidden)
122    }
123
124    pub fn forward_logits(
125        &self,
126        input_ids: Tensor<B, 2, Int>,
127        mask: Option<&Tensor<B, 3>>,
128    ) -> Tensor<B, 3> {
129        self.lm_head.forward(self.forward_hidden(input_ids, mask))
130    }
131
132    pub fn forward(
133        &self,
134        input_ids: Tensor<B, 2, Int>,
135        mask: Option<&Tensor<B, 3>>,
136    ) -> Tensor<B, 3> {
137        self.forward_logits(input_ids, mask)
138    }
139
140    pub fn max_seq_len(&self) -> usize {
141        self.config.max_seq_len
142    }
143
144    pub fn generate(
145        &self,
146        prompt_tokens: &[usize],
147        generation_config: &GenerationConfig,
148        device: &B::Device,
149    ) -> Result<GenerationResult, GenerationError> {
150        generate_tokens(self, prompt_tokens, generation_config, device)
151    }
152}
153
154impl<B: Backend> AutoregressiveModel<B> for BaselineTransformer<B> {
155    fn forward_logits(
156        &self,
157        input_ids: Tensor<B, 2, Int>,
158        mask: Option<&Tensor<B, 3>>,
159    ) -> Tensor<B, 3> {
160        BaselineTransformer::forward_logits(self, input_ids, mask)
161    }
162
163    fn max_seq_len(&self) -> usize {
164        BaselineTransformer::max_seq_len(self)
165    }
166}