1use burn::module::{Ignored, Module};
2use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig, RmsNorm, RmsNormConfig};
3use burn::prelude::*;
4
5use crate::attention::{MultiHeadAttention, MultiHeadAttentionConfig};
6use crate::compressor::{ChannelConvCompressor, TokenConvCompressor};
7use crate::config::{CompressionVariant, DdlConfig};
8use crate::delta_res_block::DeltaResBlock;
9use crate::embed_expander::EmbeddingExpander;
10use crate::generation::{
11 AutoregressiveModel, GenerationConfig, GenerationError, GenerationResult, generate_tokens,
12};
13use crate::mlp::{SwiGluMlp, SwiGluMlpConfig};
14use crate::spectral::{LayerDiagnostics, ModelDiagnostics};
15use crate::utils::create_causal_mask;
16
17#[cfg(feature = "spectral")]
18use crate::spectral::{LayerTrace, SpectralDiagnostics, summarize_layer_traces};
19
20#[derive(Module, Debug)]
21pub struct DdlTransformerBlock<B: Backend> {
22 attn_norm: RmsNorm<B>,
23 attn: MultiHeadAttention<B>,
24 attn_delta: DeltaResBlock<B>,
25 mlp_norm: RmsNorm<B>,
26 mlp: SwiGluMlp<B>,
27 mlp_delta: DeltaResBlock<B>,
28 token_compressor: Option<TokenConvCompressor<B>>,
29 channel_compressor: Option<ChannelConvCompressor<B>>,
30}
31
32impl<B: Backend> DdlTransformerBlock<B> {
33 pub fn new(config: &DdlConfig, device: &B::Device) -> Self {
34 let attn = MultiHeadAttentionConfig::new(
35 config.d_model,
36 config.num_heads,
37 config.effective_head_dim(),
38 config.max_seq_len,
39 )
40 .with_rope_theta(config.rope_theta)
41 .init(device);
42 let mlp = SwiGluMlpConfig::new(config.d_model, config.effective_mlp_hidden()).init(device);
43 let attn_norm = RmsNormConfig::new(config.d_model)
44 .with_epsilon(1e-6)
45 .init(device);
46 let mlp_norm = RmsNormConfig::new(config.d_model)
47 .with_epsilon(1e-6)
48 .init(device);
49 let token_compressor = (config.uses_matrix_state()
50 && matches!(config.compression, CompressionVariant::TokenConv))
51 .then(|| {
52 TokenConvCompressor::<B>::new(
53 config.d_model,
54 config.d_value,
55 config.shortconv_kernel_size,
56 device,
57 )
58 });
59 let channel_compressor = (config.uses_matrix_state()
60 && matches!(config.compression, CompressionVariant::ChannelConv))
61 .then(|| ChannelConvCompressor::<B>::new(config.d_model, config.d_value, device));
62
63 Self {
64 attn_norm,
65 attn,
66 attn_delta: DeltaResBlock::new(config, device),
67 mlp_norm,
68 mlp,
69 mlp_delta: DeltaResBlock::new(config, device),
70 token_compressor,
71 channel_compressor,
72 }
73 }
74
75 fn compress(&self, state: Tensor<B, 4>) -> Tensor<B, 3> {
76 match (&self.token_compressor, &self.channel_compressor) {
77 (Some(compressor), None) => compressor.forward(state),
78 (None, Some(compressor)) => compressor.forward(state),
79 _ => panic!("matrix-state block requires exactly one compressor"),
80 }
81 }
82
83 pub fn forward_vector(
84 &self,
85 x: Tensor<B, 3>,
86 mask: Option<&Tensor<B, 3>>,
87 ) -> (Tensor<B, 3>, LayerDiagnostics) {
88 let residual = x.clone();
89 let x_ctx = self.attn_norm.forward(residual.clone());
90 let attn_out = self.attn.forward(x_ctx.clone(), mask);
91 let (x, attention) = self
92 .attn_delta
93 .forward_vector(residual, attn_out, x_ctx, x.clone());
94
95 let residual = x.clone();
96 let x_ctx = self.mlp_norm.forward(residual.clone());
97 let mlp_out = self.mlp.forward(x_ctx.clone());
98 let (x, mlp) = self
99 .mlp_delta
100 .forward_vector(residual, mlp_out, x_ctx, x.clone());
101
102 (x, LayerDiagnostics { attention, mlp })
103 }
104
105 fn forward_vector_hidden(&self, x: Tensor<B, 3>, mask: Option<&Tensor<B, 3>>) -> Tensor<B, 3> {
106 let residual = x.clone();
107 let x_ctx = self.attn_norm.forward(residual.clone());
108 let attn_out = self.attn.forward(x_ctx.clone(), mask);
109 let x = self
110 .attn_delta
111 .forward_vector_hidden(residual, attn_out, x_ctx, x.clone());
112
113 let residual = x.clone();
114 let x_ctx = self.mlp_norm.forward(residual.clone());
115 let mlp_out = self.mlp.forward(x_ctx.clone());
116 self.mlp_delta
117 .forward_vector_hidden(residual, mlp_out, x_ctx, x)
118 }
119
120 #[cfg(feature = "spectral")]
121 fn forward_vector_traced(
122 &self,
123 x: Tensor<B, 3>,
124 mask: Option<&Tensor<B, 3>>,
125 ) -> (Tensor<B, 3>, LayerTrace) {
126 let residual = x.clone();
127 let x_ctx = self.attn_norm.forward(residual.clone());
128 let attn_out = self.attn.forward(x_ctx.clone(), mask);
129 let (x, attention) =
130 self.attn_delta
131 .forward_vector_traced(residual, attn_out, x_ctx, x.clone());
132
133 let residual = x.clone();
134 let x_ctx = self.mlp_norm.forward(residual.clone());
135 let mlp_out = self.mlp.forward(x_ctx.clone());
136 let (x, mlp) = self
137 .mlp_delta
138 .forward_vector_traced(residual, mlp_out, x_ctx, x.clone());
139
140 (x, LayerTrace { attention, mlp })
141 }
142
143 pub fn forward_matrix(
144 &self,
145 state: Tensor<B, 4>,
146 mask: Option<&Tensor<B, 3>>,
147 ) -> (Tensor<B, 4>, LayerDiagnostics) {
148 let x_lin = self.compress(state.clone());
149 let x_ctx = self.attn_norm.forward(x_lin.clone());
150 let attn_out = self.attn.forward(x_ctx.clone(), mask);
151 let (state, attention) = self
152 .attn_delta
153 .forward_matrix(state, attn_out, x_ctx, x_lin);
154
155 let x_lin = self.compress(state.clone());
156 let x_ctx = self.mlp_norm.forward(x_lin.clone());
157 let mlp_out = self.mlp.forward(x_ctx.clone());
158 let (state, mlp) = self.mlp_delta.forward_matrix(state, mlp_out, x_ctx, x_lin);
159
160 (state, LayerDiagnostics { attention, mlp })
161 }
162
163 fn forward_matrix_hidden(
164 &self,
165 state: Tensor<B, 4>,
166 mask: Option<&Tensor<B, 3>>,
167 ) -> Tensor<B, 4> {
168 let x_lin = self.compress(state.clone());
169 let x_ctx = self.attn_norm.forward(x_lin.clone());
170 let attn_out = self.attn.forward(x_ctx.clone(), mask);
171 let state = self
172 .attn_delta
173 .forward_matrix_hidden(state, attn_out, x_ctx, x_lin);
174
175 let x_lin = self.compress(state.clone());
176 let x_ctx = self.mlp_norm.forward(x_lin.clone());
177 let mlp_out = self.mlp.forward(x_ctx.clone());
178 self.mlp_delta
179 .forward_matrix_hidden(state, mlp_out, x_ctx, x_lin)
180 }
181
182 #[cfg(feature = "spectral")]
183 fn forward_matrix_traced(
184 &self,
185 state: Tensor<B, 4>,
186 mask: Option<&Tensor<B, 3>>,
187 ) -> (Tensor<B, 4>, LayerTrace) {
188 let x_lin = self.compress(state.clone());
189 let x_ctx = self.attn_norm.forward(x_lin.clone());
190 let attn_out = self.attn.forward(x_ctx.clone(), mask);
191 let (state, attention) = self
192 .attn_delta
193 .forward_matrix_traced(state, attn_out, x_ctx, x_lin);
194
195 let x_lin = self.compress(state.clone());
196 let x_ctx = self.mlp_norm.forward(x_lin.clone());
197 let mlp_out = self.mlp.forward(x_ctx.clone());
198 let (state, mlp) = self
199 .mlp_delta
200 .forward_matrix_traced(state, mlp_out, x_ctx, x_lin);
201
202 (state, LayerTrace { attention, mlp })
203 }
204}
205
206#[derive(Module, Debug)]
207pub struct DdlTransformer<B: Backend> {
208 embedding: Embedding<B>,
209 embed_expander: EmbeddingExpander<B>,
210 blocks: Vec<DdlTransformerBlock<B>>,
211 final_token_compressor: Option<TokenConvCompressor<B>>,
212 final_channel_compressor: Option<ChannelConvCompressor<B>>,
213 final_norm: RmsNorm<B>,
214 lm_head: Linear<B>,
215 config: Ignored<DdlConfig>,
216}
217
218impl<B: Backend> DdlTransformer<B> {
219 pub fn new(config: DdlConfig, device: &B::Device) -> Self {
220 let blocks = (0..config.num_layers)
221 .map(|_| DdlTransformerBlock::new(&config, device))
222 .collect();
223 let final_token_compressor = (config.uses_matrix_state()
224 && matches!(config.compression, CompressionVariant::TokenConv))
225 .then(|| {
226 TokenConvCompressor::<B>::new(
227 config.d_model,
228 config.d_value,
229 config.shortconv_kernel_size,
230 device,
231 )
232 });
233 let final_channel_compressor = (config.uses_matrix_state()
234 && matches!(config.compression, CompressionVariant::ChannelConv))
235 .then(|| ChannelConvCompressor::<B>::new(config.d_model, config.d_value, device));
236 Self {
237 embedding: EmbeddingConfig::new(config.vocab_size, config.d_model).init(device),
238 embed_expander: EmbeddingExpander::new(
239 config.d_model,
240 config.d_value,
241 config.embed_conv,
242 config.shortconv_kernel_size,
243 device,
244 ),
245 blocks,
246 final_token_compressor,
247 final_channel_compressor,
248 final_norm: RmsNormConfig::new(config.d_model)
249 .with_epsilon(1e-6)
250 .init(device),
251 lm_head: LinearConfig::new(config.d_model, config.vocab_size)
252 .with_bias(false)
253 .init(device),
254 config: Ignored(config),
255 }
256 }
257
258 fn compress_final(&self, state: Tensor<B, 4>) -> Tensor<B, 3> {
259 match (&self.final_token_compressor, &self.final_channel_compressor) {
260 (Some(compressor), None) => compressor.forward(state),
261 (None, Some(compressor)) => compressor.forward(state),
262 _ => panic!("matrix-state model requires exactly one final compressor"),
263 }
264 }
265
266 fn resolve_mask(
267 &self,
268 input_ids: &Tensor<B, 2, Int>,
269 mask: Option<&Tensor<B, 3>>,
270 ) -> Tensor<B, 3> {
271 match mask {
272 Some(mask) => mask.clone(),
273 None => {
274 let [batch_size, seq_len] = input_ids.dims();
275 create_causal_mask::<B>(batch_size, seq_len, &input_ids.device())
276 }
277 }
278 }
279
280 pub fn forward_logits(
281 &self,
282 input_ids: Tensor<B, 2, Int>,
283 mask: Option<&Tensor<B, 3>>,
284 ) -> Tensor<B, 3> {
285 let mask = self.resolve_mask(&input_ids, mask);
286
287 if self.config.d_value == 1 {
288 let mut hidden = self.embedding.forward(input_ids);
289 for block in &self.blocks {
290 hidden = block.forward_vector_hidden(hidden, Some(&mask));
291 }
292 return self.lm_head.forward(self.final_norm.forward(hidden));
293 }
294
295 let mut state = self
296 .embed_expander
297 .forward(self.embedding.forward(input_ids));
298 for block in &self.blocks {
299 state = block.forward_matrix_hidden(state, Some(&mask));
300 }
301 let hidden = self.compress_final(state);
302 self.lm_head.forward(self.final_norm.forward(hidden))
303 }
304
305 pub fn forward(
306 &self,
307 input_ids: Tensor<B, 2, Int>,
308 mask: Option<&Tensor<B, 3>>,
309 ) -> (Tensor<B, 3>, ModelDiagnostics) {
310 let mask = self.resolve_mask(&input_ids, mask);
311 let mut diagnostics = Vec::with_capacity(self.blocks.len());
312
313 if self.config.d_value == 1 {
314 let mut hidden = self.embedding.forward(input_ids);
315 for block in &self.blocks {
316 let (next_hidden, layer_diag) = block.forward_vector(hidden, Some(&mask));
317 hidden = next_hidden;
318 diagnostics.push(layer_diag);
319 }
320 let logits = self.lm_head.forward(self.final_norm.forward(hidden));
321 return (
322 logits,
323 ModelDiagnostics {
324 layers: diagnostics,
325 },
326 );
327 }
328
329 let mut state = self
330 .embed_expander
331 .forward(self.embedding.forward(input_ids));
332 for block in &self.blocks {
333 let (next_state, layer_diag) = block.forward_matrix(state, Some(&mask));
334 state = next_state;
335 diagnostics.push(layer_diag);
336 }
337 let hidden = self.compress_final(state);
338 let logits = self.lm_head.forward(self.final_norm.forward(hidden));
339 (
340 logits,
341 ModelDiagnostics {
342 layers: diagnostics,
343 },
344 )
345 }
346
347 #[cfg(feature = "spectral")]
348 pub fn forward_with_spectral_diagnostics(
349 &self,
350 input_ids: Tensor<B, 2, Int>,
351 mask: Option<&Tensor<B, 3>>,
352 ) -> (Tensor<B, 3>, ModelDiagnostics, SpectralDiagnostics) {
353 let mask = self.resolve_mask(&input_ids, mask);
354 let mut layer_traces = Vec::with_capacity(self.blocks.len());
355
356 if self.config.d_value == 1 {
357 let mut hidden = self.embedding.forward(input_ids);
358 for block in &self.blocks {
359 let (next_hidden, layer_trace) = block.forward_vector_traced(hidden, Some(&mask));
360 hidden = next_hidden;
361 layer_traces.push(layer_trace);
362 }
363 let logits = self.lm_head.forward(self.final_norm.forward(hidden));
364 let (diagnostics, spectral) = summarize_layer_traces(layer_traces);
365 return (logits, diagnostics, spectral);
366 }
367
368 let mut state = self
369 .embed_expander
370 .forward(self.embedding.forward(input_ids));
371 for block in &self.blocks {
372 let (next_state, layer_trace) = block.forward_matrix_traced(state, Some(&mask));
373 state = next_state;
374 layer_traces.push(layer_trace);
375 }
376 let hidden = self.compress_final(state);
377 let logits = self.lm_head.forward(self.final_norm.forward(hidden));
378 let (diagnostics, spectral) = summarize_layer_traces(layer_traces);
379 (logits, diagnostics, spectral)
380 }
381
382 pub fn max_seq_len(&self) -> usize {
383 self.config.max_seq_len
384 }
385
386 pub fn generate(
387 &self,
388 prompt_tokens: &[usize],
389 generation_config: &GenerationConfig,
390 device: &B::Device,
391 ) -> Result<GenerationResult, GenerationError> {
392 generate_tokens(self, prompt_tokens, generation_config, device)
393 }
394}
395
396impl<B: Backend> AutoregressiveModel<B> for DdlTransformer<B> {
397 fn forward_logits(
398 &self,
399 input_ids: Tensor<B, 2, Int>,
400 mask: Option<&Tensor<B, 3>>,
401 ) -> Tensor<B, 3> {
402 DdlTransformer::forward_logits(self, input_ids, mask)
403 }
404
405 fn max_seq_len(&self) -> usize {
406 DdlTransformer::max_seq_len(self)
407 }
408}