Skip to main content

deep_delta_learning/
transformer.rs

1use burn::module::{Ignored, Module};
2use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig, RmsNorm, RmsNormConfig};
3use burn::prelude::*;
4
5use crate::attention::{MultiHeadAttention, MultiHeadAttentionConfig};
6use crate::compressor::{ChannelConvCompressor, TokenConvCompressor};
7use crate::config::{CompressionVariant, DdlConfig};
8use crate::delta_res_block::DeltaResBlock;
9use crate::embed_expander::EmbeddingExpander;
10use crate::generation::{
11    AutoregressiveModel, GenerationConfig, GenerationError, GenerationResult, generate_tokens,
12};
13use crate::mlp::{SwiGluMlp, SwiGluMlpConfig};
14use crate::spectral::{LayerDiagnostics, ModelDiagnostics};
15use crate::utils::create_causal_mask;
16
17#[cfg(feature = "spectral")]
18use crate::spectral::{LayerTrace, SpectralDiagnostics, summarize_layer_traces};
19
20#[derive(Module, Debug)]
21pub struct DdlTransformerBlock<B: Backend> {
22    attn_norm: RmsNorm<B>,
23    attn: MultiHeadAttention<B>,
24    attn_delta: DeltaResBlock<B>,
25    mlp_norm: RmsNorm<B>,
26    mlp: SwiGluMlp<B>,
27    mlp_delta: DeltaResBlock<B>,
28    token_compressor: Option<TokenConvCompressor<B>>,
29    channel_compressor: Option<ChannelConvCompressor<B>>,
30}
31
32impl<B: Backend> DdlTransformerBlock<B> {
33    pub fn new(config: &DdlConfig, device: &B::Device) -> Self {
34        let attn = MultiHeadAttentionConfig::new(
35            config.d_model,
36            config.num_heads,
37            config.effective_head_dim(),
38            config.max_seq_len,
39        )
40        .with_rope_theta(config.rope_theta)
41        .init(device);
42        let mlp = SwiGluMlpConfig::new(config.d_model, config.effective_mlp_hidden()).init(device);
43        let attn_norm = RmsNormConfig::new(config.d_model)
44            .with_epsilon(1e-6)
45            .init(device);
46        let mlp_norm = RmsNormConfig::new(config.d_model)
47            .with_epsilon(1e-6)
48            .init(device);
49        let token_compressor = (config.uses_matrix_state()
50            && matches!(config.compression, CompressionVariant::TokenConv))
51        .then(|| {
52            TokenConvCompressor::<B>::new(
53                config.d_model,
54                config.d_value,
55                config.shortconv_kernel_size,
56                device,
57            )
58        });
59        let channel_compressor = (config.uses_matrix_state()
60            && matches!(config.compression, CompressionVariant::ChannelConv))
61        .then(|| ChannelConvCompressor::<B>::new(config.d_model, config.d_value, device));
62
63        Self {
64            attn_norm,
65            attn,
66            attn_delta: DeltaResBlock::new(config, device),
67            mlp_norm,
68            mlp,
69            mlp_delta: DeltaResBlock::new(config, device),
70            token_compressor,
71            channel_compressor,
72        }
73    }
74
75    fn compress(&self, state: Tensor<B, 4>) -> Tensor<B, 3> {
76        match (&self.token_compressor, &self.channel_compressor) {
77            (Some(compressor), None) => compressor.forward(state),
78            (None, Some(compressor)) => compressor.forward(state),
79            _ => panic!("matrix-state block requires exactly one compressor"),
80        }
81    }
82
83    pub fn forward_vector(
84        &self,
85        x: Tensor<B, 3>,
86        mask: Option<&Tensor<B, 3>>,
87    ) -> (Tensor<B, 3>, LayerDiagnostics) {
88        let residual = x.clone();
89        let x_ctx = self.attn_norm.forward(residual.clone());
90        let attn_out = self.attn.forward(x_ctx.clone(), mask);
91        let (x, attention) = self
92            .attn_delta
93            .forward_vector(residual, attn_out, x_ctx, x.clone());
94
95        let residual = x.clone();
96        let x_ctx = self.mlp_norm.forward(residual.clone());
97        let mlp_out = self.mlp.forward(x_ctx.clone());
98        let (x, mlp) = self
99            .mlp_delta
100            .forward_vector(residual, mlp_out, x_ctx, x.clone());
101
102        (x, LayerDiagnostics { attention, mlp })
103    }
104
105    fn forward_vector_hidden(&self, x: Tensor<B, 3>, mask: Option<&Tensor<B, 3>>) -> Tensor<B, 3> {
106        let residual = x.clone();
107        let x_ctx = self.attn_norm.forward(residual.clone());
108        let attn_out = self.attn.forward(x_ctx.clone(), mask);
109        let x = self
110            .attn_delta
111            .forward_vector_hidden(residual, attn_out, x_ctx, x.clone());
112
113        let residual = x.clone();
114        let x_ctx = self.mlp_norm.forward(residual.clone());
115        let mlp_out = self.mlp.forward(x_ctx.clone());
116        self.mlp_delta
117            .forward_vector_hidden(residual, mlp_out, x_ctx, x)
118    }
119
120    #[cfg(feature = "spectral")]
121    fn forward_vector_traced(
122        &self,
123        x: Tensor<B, 3>,
124        mask: Option<&Tensor<B, 3>>,
125    ) -> (Tensor<B, 3>, LayerTrace) {
126        let residual = x.clone();
127        let x_ctx = self.attn_norm.forward(residual.clone());
128        let attn_out = self.attn.forward(x_ctx.clone(), mask);
129        let (x, attention) =
130            self.attn_delta
131                .forward_vector_traced(residual, attn_out, x_ctx, x.clone());
132
133        let residual = x.clone();
134        let x_ctx = self.mlp_norm.forward(residual.clone());
135        let mlp_out = self.mlp.forward(x_ctx.clone());
136        let (x, mlp) = self
137            .mlp_delta
138            .forward_vector_traced(residual, mlp_out, x_ctx, x.clone());
139
140        (x, LayerTrace { attention, mlp })
141    }
142
143    pub fn forward_matrix(
144        &self,
145        state: Tensor<B, 4>,
146        mask: Option<&Tensor<B, 3>>,
147    ) -> (Tensor<B, 4>, LayerDiagnostics) {
148        let x_lin = self.compress(state.clone());
149        let x_ctx = self.attn_norm.forward(x_lin.clone());
150        let attn_out = self.attn.forward(x_ctx.clone(), mask);
151        let (state, attention) = self
152            .attn_delta
153            .forward_matrix(state, attn_out, x_ctx, x_lin);
154
155        let x_lin = self.compress(state.clone());
156        let x_ctx = self.mlp_norm.forward(x_lin.clone());
157        let mlp_out = self.mlp.forward(x_ctx.clone());
158        let (state, mlp) = self.mlp_delta.forward_matrix(state, mlp_out, x_ctx, x_lin);
159
160        (state, LayerDiagnostics { attention, mlp })
161    }
162
163    fn forward_matrix_hidden(
164        &self,
165        state: Tensor<B, 4>,
166        mask: Option<&Tensor<B, 3>>,
167    ) -> Tensor<B, 4> {
168        let x_lin = self.compress(state.clone());
169        let x_ctx = self.attn_norm.forward(x_lin.clone());
170        let attn_out = self.attn.forward(x_ctx.clone(), mask);
171        let state = self
172            .attn_delta
173            .forward_matrix_hidden(state, attn_out, x_ctx, x_lin);
174
175        let x_lin = self.compress(state.clone());
176        let x_ctx = self.mlp_norm.forward(x_lin.clone());
177        let mlp_out = self.mlp.forward(x_ctx.clone());
178        self.mlp_delta
179            .forward_matrix_hidden(state, mlp_out, x_ctx, x_lin)
180    }
181
182    #[cfg(feature = "spectral")]
183    fn forward_matrix_traced(
184        &self,
185        state: Tensor<B, 4>,
186        mask: Option<&Tensor<B, 3>>,
187    ) -> (Tensor<B, 4>, LayerTrace) {
188        let x_lin = self.compress(state.clone());
189        let x_ctx = self.attn_norm.forward(x_lin.clone());
190        let attn_out = self.attn.forward(x_ctx.clone(), mask);
191        let (state, attention) = self
192            .attn_delta
193            .forward_matrix_traced(state, attn_out, x_ctx, x_lin);
194
195        let x_lin = self.compress(state.clone());
196        let x_ctx = self.mlp_norm.forward(x_lin.clone());
197        let mlp_out = self.mlp.forward(x_ctx.clone());
198        let (state, mlp) = self
199            .mlp_delta
200            .forward_matrix_traced(state, mlp_out, x_ctx, x_lin);
201
202        (state, LayerTrace { attention, mlp })
203    }
204}
205
206#[derive(Module, Debug)]
207pub struct DdlTransformer<B: Backend> {
208    embedding: Embedding<B>,
209    embed_expander: EmbeddingExpander<B>,
210    blocks: Vec<DdlTransformerBlock<B>>,
211    final_token_compressor: Option<TokenConvCompressor<B>>,
212    final_channel_compressor: Option<ChannelConvCompressor<B>>,
213    final_norm: RmsNorm<B>,
214    lm_head: Linear<B>,
215    config: Ignored<DdlConfig>,
216}
217
218impl<B: Backend> DdlTransformer<B> {
219    pub fn new(config: DdlConfig, device: &B::Device) -> Self {
220        let blocks = (0..config.num_layers)
221            .map(|_| DdlTransformerBlock::new(&config, device))
222            .collect();
223        let final_token_compressor = (config.uses_matrix_state()
224            && matches!(config.compression, CompressionVariant::TokenConv))
225        .then(|| {
226            TokenConvCompressor::<B>::new(
227                config.d_model,
228                config.d_value,
229                config.shortconv_kernel_size,
230                device,
231            )
232        });
233        let final_channel_compressor = (config.uses_matrix_state()
234            && matches!(config.compression, CompressionVariant::ChannelConv))
235        .then(|| ChannelConvCompressor::<B>::new(config.d_model, config.d_value, device));
236        Self {
237            embedding: EmbeddingConfig::new(config.vocab_size, config.d_model).init(device),
238            embed_expander: EmbeddingExpander::new(
239                config.d_model,
240                config.d_value,
241                config.embed_conv,
242                config.shortconv_kernel_size,
243                device,
244            ),
245            blocks,
246            final_token_compressor,
247            final_channel_compressor,
248            final_norm: RmsNormConfig::new(config.d_model)
249                .with_epsilon(1e-6)
250                .init(device),
251            lm_head: LinearConfig::new(config.d_model, config.vocab_size)
252                .with_bias(false)
253                .init(device),
254            config: Ignored(config),
255        }
256    }
257
258    fn compress_final(&self, state: Tensor<B, 4>) -> Tensor<B, 3> {
259        match (&self.final_token_compressor, &self.final_channel_compressor) {
260            (Some(compressor), None) => compressor.forward(state),
261            (None, Some(compressor)) => compressor.forward(state),
262            _ => panic!("matrix-state model requires exactly one final compressor"),
263        }
264    }
265
266    fn resolve_mask(
267        &self,
268        input_ids: &Tensor<B, 2, Int>,
269        mask: Option<&Tensor<B, 3>>,
270    ) -> Tensor<B, 3> {
271        match mask {
272            Some(mask) => mask.clone(),
273            None => {
274                let [batch_size, seq_len] = input_ids.dims();
275                create_causal_mask::<B>(batch_size, seq_len, &input_ids.device())
276            }
277        }
278    }
279
280    pub fn forward_logits(
281        &self,
282        input_ids: Tensor<B, 2, Int>,
283        mask: Option<&Tensor<B, 3>>,
284    ) -> Tensor<B, 3> {
285        let mask = self.resolve_mask(&input_ids, mask);
286
287        if self.config.d_value == 1 {
288            let mut hidden = self.embedding.forward(input_ids);
289            for block in &self.blocks {
290                hidden = block.forward_vector_hidden(hidden, Some(&mask));
291            }
292            return self.lm_head.forward(self.final_norm.forward(hidden));
293        }
294
295        let mut state = self
296            .embed_expander
297            .forward(self.embedding.forward(input_ids));
298        for block in &self.blocks {
299            state = block.forward_matrix_hidden(state, Some(&mask));
300        }
301        let hidden = self.compress_final(state);
302        self.lm_head.forward(self.final_norm.forward(hidden))
303    }
304
305    pub fn forward(
306        &self,
307        input_ids: Tensor<B, 2, Int>,
308        mask: Option<&Tensor<B, 3>>,
309    ) -> (Tensor<B, 3>, ModelDiagnostics) {
310        let mask = self.resolve_mask(&input_ids, mask);
311        let mut diagnostics = Vec::with_capacity(self.blocks.len());
312
313        if self.config.d_value == 1 {
314            let mut hidden = self.embedding.forward(input_ids);
315            for block in &self.blocks {
316                let (next_hidden, layer_diag) = block.forward_vector(hidden, Some(&mask));
317                hidden = next_hidden;
318                diagnostics.push(layer_diag);
319            }
320            let logits = self.lm_head.forward(self.final_norm.forward(hidden));
321            return (
322                logits,
323                ModelDiagnostics {
324                    layers: diagnostics,
325                },
326            );
327        }
328
329        let mut state = self
330            .embed_expander
331            .forward(self.embedding.forward(input_ids));
332        for block in &self.blocks {
333            let (next_state, layer_diag) = block.forward_matrix(state, Some(&mask));
334            state = next_state;
335            diagnostics.push(layer_diag);
336        }
337        let hidden = self.compress_final(state);
338        let logits = self.lm_head.forward(self.final_norm.forward(hidden));
339        (
340            logits,
341            ModelDiagnostics {
342                layers: diagnostics,
343            },
344        )
345    }
346
347    #[cfg(feature = "spectral")]
348    pub fn forward_with_spectral_diagnostics(
349        &self,
350        input_ids: Tensor<B, 2, Int>,
351        mask: Option<&Tensor<B, 3>>,
352    ) -> (Tensor<B, 3>, ModelDiagnostics, SpectralDiagnostics) {
353        let mask = self.resolve_mask(&input_ids, mask);
354        let mut layer_traces = Vec::with_capacity(self.blocks.len());
355
356        if self.config.d_value == 1 {
357            let mut hidden = self.embedding.forward(input_ids);
358            for block in &self.blocks {
359                let (next_hidden, layer_trace) = block.forward_vector_traced(hidden, Some(&mask));
360                hidden = next_hidden;
361                layer_traces.push(layer_trace);
362            }
363            let logits = self.lm_head.forward(self.final_norm.forward(hidden));
364            let (diagnostics, spectral) = summarize_layer_traces(layer_traces);
365            return (logits, diagnostics, spectral);
366        }
367
368        let mut state = self
369            .embed_expander
370            .forward(self.embedding.forward(input_ids));
371        for block in &self.blocks {
372            let (next_state, layer_trace) = block.forward_matrix_traced(state, Some(&mask));
373            state = next_state;
374            layer_traces.push(layer_trace);
375        }
376        let hidden = self.compress_final(state);
377        let logits = self.lm_head.forward(self.final_norm.forward(hidden));
378        let (diagnostics, spectral) = summarize_layer_traces(layer_traces);
379        (logits, diagnostics, spectral)
380    }
381
382    pub fn max_seq_len(&self) -> usize {
383        self.config.max_seq_len
384    }
385
386    pub fn generate(
387        &self,
388        prompt_tokens: &[usize],
389        generation_config: &GenerationConfig,
390        device: &B::Device,
391    ) -> Result<GenerationResult, GenerationError> {
392        generate_tokens(self, prompt_tokens, generation_config, device)
393    }
394}
395
396impl<B: Backend> AutoregressiveModel<B> for DdlTransformer<B> {
397    fn forward_logits(
398        &self,
399        input_ids: Tensor<B, 2, Int>,
400        mask: Option<&Tensor<B, 3>>,
401    ) -> Tensor<B, 3> {
402        DdlTransformer::forward_logits(self, input_ids, mask)
403    }
404
405    fn max_seq_len(&self) -> usize {
406        DdlTransformer::max_seq_len(self)
407    }
408}