Skip to main content

axonml_llm/
transformer.rs

1//! Transformer Building Blocks
2//!
3//! Transformer encoder and decoder layers, blocks, and stacks.
4
5use axonml_autograd::Variable;
6use axonml_nn::{Module, Linear, Dropout, Parameter};
7use axonml_tensor::Tensor;
8use axonml_tensor::creation::{zeros, ones};
9
10use crate::attention::{MultiHeadSelfAttention, CausalSelfAttention};
11
12/// Layer normalization.
13#[derive(Debug)]
14pub struct LayerNorm {
15    /// Scale parameter
16    pub weight: Parameter,
17    /// Bias parameter
18    pub bias: Parameter,
19    /// Epsilon for numerical stability
20    pub eps: f32,
21    /// Normalized dimension
22    pub dim: usize,
23}
24
25impl LayerNorm {
26    /// Creates a new layer normalization.
27    pub fn new(dim: usize, eps: f32) -> Self {
28        Self {
29            weight: Parameter::new(ones::<f32>(&[dim]), true),
30            bias: Parameter::new(zeros::<f32>(&[dim]), true),
31            eps,
32            dim,
33        }
34    }
35}
36
37impl Module for LayerNorm {
38    fn forward(&self, input: &Variable) -> Variable {
39        // Normalize over last dimension
40        let mean = input.mean_dim(-1, true);
41        let variance = input.var_dim(-1, true);
42
43        let x_normalized = input.sub(&mean).div(&variance.add_scalar(self.eps).sqrt());
44
45        // Scale and shift
46        let weight_var = Variable::from_tensor_with_grad(self.weight.data().clone(), self.weight.requires_grad());
47        let bias_var = Variable::from_tensor_with_grad(self.bias.data().clone(), self.bias.requires_grad());
48
49        x_normalized.mul(&weight_var).add(&bias_var)
50    }
51
52    fn parameters(&self) -> Vec<Parameter> {
53        vec![self.weight.clone(), self.bias.clone()]
54    }
55}
56
57/// Feed-forward network (MLP) used in transformers.
58#[derive(Debug)]
59pub struct FeedForward {
60    /// First linear layer
61    pub fc1: Linear,
62    /// Second linear layer
63    pub fc2: Linear,
64    /// Dropout
65    pub dropout: Dropout,
66    /// Activation function type
67    pub activation: String,
68}
69
70impl FeedForward {
71    /// Creates a new feed-forward network.
72    pub fn new(hidden_size: usize, intermediate_size: usize, dropout: f32, activation: &str) -> Self {
73        Self {
74            fc1: Linear::new(hidden_size, intermediate_size),
75            fc2: Linear::new(intermediate_size, hidden_size),
76            dropout: Dropout::new(dropout),
77            activation: activation.to_string(),
78        }
79    }
80
81    /// Applies the activation function.
82    fn activate(&self, x: &Variable) -> Variable {
83        match self.activation.as_str() {
84            "gelu" => x.gelu(),
85            "relu" => x.relu(),
86            "silu" | "swish" => x.silu(),
87            "tanh" => x.tanh(),
88            _ => x.gelu(), // Default to GELU
89        }
90    }
91}
92
93impl Module for FeedForward {
94    fn forward(&self, input: &Variable) -> Variable {
95        let x = self.fc1.forward(input);
96        let x = self.activate(&x);
97        let x = self.dropout.forward(&x);
98        self.fc2.forward(&x)
99    }
100
101    fn parameters(&self) -> Vec<Parameter> {
102        let mut params = Vec::new();
103        params.extend(self.fc1.parameters());
104        params.extend(self.fc2.parameters());
105        params
106    }
107
108    fn train(&mut self) {
109        self.dropout.train();
110    }
111
112    fn eval(&mut self) {
113        self.dropout.eval();
114    }
115}
116
117/// Transformer encoder block (BERT-style).
118#[derive(Debug)]
119pub struct TransformerEncoderBlock {
120    /// Self-attention layer
121    pub attention: MultiHeadSelfAttention,
122    /// First layer norm (pre-attention or post-attention)
123    pub ln1: LayerNorm,
124    /// Feed-forward network
125    pub ffn: FeedForward,
126    /// Second layer norm (pre-FFN or post-FFN)
127    pub ln2: LayerNorm,
128    /// Residual dropout
129    pub dropout: Dropout,
130    /// Whether to use pre-norm (like GPT) or post-norm (like original BERT)
131    pub pre_norm: bool,
132}
133
134impl TransformerEncoderBlock {
135    /// Creates a new transformer encoder block.
136    pub fn new(
137        hidden_size: usize,
138        num_heads: usize,
139        intermediate_size: usize,
140        dropout: f32,
141        layer_norm_eps: f32,
142        activation: &str,
143        pre_norm: bool,
144    ) -> Self {
145        Self {
146            attention: MultiHeadSelfAttention::new(hidden_size, num_heads, dropout),
147            ln1: LayerNorm::new(hidden_size, layer_norm_eps),
148            ffn: FeedForward::new(hidden_size, intermediate_size, dropout, activation),
149            ln2: LayerNorm::new(hidden_size, layer_norm_eps),
150            dropout: Dropout::new(dropout),
151            pre_norm,
152        }
153    }
154
155    /// Forward pass with optional attention mask.
156    pub fn forward_with_mask(
157        &self,
158        hidden_states: &Variable,
159        attention_mask: Option<&Tensor<f32>>,
160    ) -> Variable {
161        if self.pre_norm {
162            // Pre-norm: LN -> Attention -> Residual, LN -> FFN -> Residual
163            let residual = hidden_states.clone();
164            let x = self.ln1.forward(hidden_states);
165            let x = self.attention.forward_with_mask(&x, attention_mask);
166            let x = self.dropout.forward(&x);
167            let x = x.add(&residual);
168
169            let residual = x.clone();
170            let x = self.ln2.forward(&x);
171            let x = self.ffn.forward(&x);
172            let x = self.dropout.forward(&x);
173            x.add(&residual)
174        } else {
175            // Post-norm: Attention -> Residual -> LN, FFN -> Residual -> LN
176            let residual = hidden_states.clone();
177            let x = self.attention.forward_with_mask(hidden_states, attention_mask);
178            let x = self.dropout.forward(&x);
179            let x = self.ln1.forward(&x.add(&residual));
180
181            let residual = x.clone();
182            let x = self.ffn.forward(&x);
183            let x = self.dropout.forward(&x);
184            self.ln2.forward(&x.add(&residual))
185        }
186    }
187}
188
189impl Module for TransformerEncoderBlock {
190    fn forward(&self, input: &Variable) -> Variable {
191        self.forward_with_mask(input, None)
192    }
193
194    fn parameters(&self) -> Vec<Parameter> {
195        let mut params = Vec::new();
196        params.extend(self.attention.parameters());
197        params.extend(self.ln1.parameters());
198        params.extend(self.ffn.parameters());
199        params.extend(self.ln2.parameters());
200        params
201    }
202
203    fn train(&mut self) {
204        self.attention.train();
205        self.ffn.train();
206        self.dropout.train();
207    }
208
209    fn eval(&mut self) {
210        self.attention.eval();
211        self.ffn.eval();
212        self.dropout.eval();
213    }
214}
215
216/// Transformer decoder block (GPT-style with causal attention).
217#[derive(Debug)]
218pub struct TransformerDecoderBlock {
219    /// Causal self-attention
220    pub attention: CausalSelfAttention,
221    /// First layer norm
222    pub ln1: LayerNorm,
223    /// Feed-forward network
224    pub ffn: FeedForward,
225    /// Second layer norm
226    pub ln2: LayerNorm,
227}
228
229impl TransformerDecoderBlock {
230    /// Creates a new transformer decoder block.
231    pub fn new(
232        n_embd: usize,
233        n_head: usize,
234        max_seq_len: usize,
235        dropout: f32,
236        layer_norm_eps: f32,
237        activation: &str,
238    ) -> Self {
239        Self {
240            attention: CausalSelfAttention::new(n_embd, n_head, max_seq_len, dropout),
241            ln1: LayerNorm::new(n_embd, layer_norm_eps),
242            ffn: FeedForward::new(n_embd, 4 * n_embd, dropout, activation),
243            ln2: LayerNorm::new(n_embd, layer_norm_eps),
244        }
245    }
246}
247
248impl Module for TransformerDecoderBlock {
249    fn forward(&self, input: &Variable) -> Variable {
250        // GPT-2 style: Pre-norm with residual connections
251        let x = input.clone();
252
253        // Attention block
254        let residual = x.clone();
255        let x = self.ln1.forward(&x);
256        let x = self.attention.forward(&x);
257        let x = x.add(&residual);
258
259        // FFN block
260        let residual = x.clone();
261        let x = self.ln2.forward(&x);
262        let x = self.ffn.forward(&x);
263        x.add(&residual)
264    }
265
266    fn parameters(&self) -> Vec<Parameter> {
267        let mut params = Vec::new();
268        params.extend(self.attention.parameters());
269        params.extend(self.ln1.parameters());
270        params.extend(self.ffn.parameters());
271        params.extend(self.ln2.parameters());
272        params
273    }
274
275    fn train(&mut self) {
276        self.attention.train();
277        self.ffn.train();
278    }
279
280    fn eval(&mut self) {
281        self.attention.eval();
282        self.ffn.eval();
283    }
284}
285
286/// Stack of transformer encoder blocks.
287#[derive(Debug)]
288pub struct TransformerEncoder {
289    /// Encoder layers
290    pub layers: Vec<TransformerEncoderBlock>,
291}
292
293impl TransformerEncoder {
294    /// Creates a new transformer encoder stack.
295    pub fn new(
296        num_layers: usize,
297        hidden_size: usize,
298        num_heads: usize,
299        intermediate_size: usize,
300        dropout: f32,
301        layer_norm_eps: f32,
302        activation: &str,
303        pre_norm: bool,
304    ) -> Self {
305        let layers = (0..num_layers)
306            .map(|_| {
307                TransformerEncoderBlock::new(
308                    hidden_size,
309                    num_heads,
310                    intermediate_size,
311                    dropout,
312                    layer_norm_eps,
313                    activation,
314                    pre_norm,
315                )
316            })
317            .collect();
318
319        Self { layers }
320    }
321
322    /// Forward pass with optional attention mask.
323    pub fn forward_with_mask(
324        &self,
325        hidden_states: &Variable,
326        attention_mask: Option<&Tensor<f32>>,
327    ) -> Variable {
328        let mut output = hidden_states.clone();
329        for layer in &self.layers {
330            output = layer.forward_with_mask(&output, attention_mask);
331        }
332        output
333    }
334}
335
336impl Module for TransformerEncoder {
337    fn forward(&self, input: &Variable) -> Variable {
338        self.forward_with_mask(input, None)
339    }
340
341    fn parameters(&self) -> Vec<Parameter> {
342        self.layers.iter().flat_map(|l| l.parameters()).collect()
343    }
344
345    fn train(&mut self) {
346        for layer in &mut self.layers {
347            layer.train();
348        }
349    }
350
351    fn eval(&mut self) {
352        for layer in &mut self.layers {
353            layer.eval();
354        }
355    }
356}
357
358/// Stack of transformer decoder blocks.
359#[derive(Debug)]
360pub struct TransformerDecoder {
361    /// Decoder layers
362    pub layers: Vec<TransformerDecoderBlock>,
363    /// Final layer norm
364    pub ln_f: LayerNorm,
365}
366
367impl TransformerDecoder {
368    /// Creates a new transformer decoder stack.
369    pub fn new(
370        num_layers: usize,
371        n_embd: usize,
372        n_head: usize,
373        max_seq_len: usize,
374        dropout: f32,
375        layer_norm_eps: f32,
376        activation: &str,
377    ) -> Self {
378        let layers = (0..num_layers)
379            .map(|_| {
380                TransformerDecoderBlock::new(
381                    n_embd,
382                    n_head,
383                    max_seq_len,
384                    dropout,
385                    layer_norm_eps,
386                    activation,
387                )
388            })
389            .collect();
390
391        Self {
392            layers,
393            ln_f: LayerNorm::new(n_embd, layer_norm_eps),
394        }
395    }
396}
397
398impl Module for TransformerDecoder {
399    fn forward(&self, input: &Variable) -> Variable {
400        let mut output = input.clone();
401        for layer in &self.layers {
402            output = layer.forward(&output);
403        }
404        self.ln_f.forward(&output)
405    }
406
407    fn parameters(&self) -> Vec<Parameter> {
408        let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
409        params.extend(self.ln_f.parameters());
410        params
411    }
412
413    fn train(&mut self) {
414        for layer in &mut self.layers {
415            layer.train();
416        }
417    }
418
419    fn eval(&mut self) {
420        for layer in &mut self.layers {
421            layer.eval();
422        }
423    }
424}
425
426/// Generic transformer block type selection.
427#[derive(Debug)]
428pub enum TransformerBlock {
429    /// Encoder block (bidirectional attention)
430    Encoder(TransformerEncoderBlock),
431    /// Decoder block (causal attention)
432    Decoder(TransformerDecoderBlock),
433}
434
435impl Module for TransformerBlock {
436    fn forward(&self, input: &Variable) -> Variable {
437        match self {
438            TransformerBlock::Encoder(block) => block.forward(input),
439            TransformerBlock::Decoder(block) => block.forward(input),
440        }
441    }
442
443    fn parameters(&self) -> Vec<Parameter> {
444        match self {
445            TransformerBlock::Encoder(block) => block.parameters(),
446            TransformerBlock::Decoder(block) => block.parameters(),
447        }
448    }
449
450    fn train(&mut self) {
451        match self {
452            TransformerBlock::Encoder(block) => block.train(),
453            TransformerBlock::Decoder(block) => block.train(),
454        }
455    }
456
457    fn eval(&mut self) {
458        match self {
459            TransformerBlock::Encoder(block) => block.eval(),
460            TransformerBlock::Decoder(block) => block.eval(),
461        }
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_layer_norm() {
471        let ln = LayerNorm::new(64, 1e-5);
472        let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
473        let output = ln.forward(&input);
474
475        assert_eq!(output.data().shape(), &[2, 8, 64]);
476    }
477
478    #[test]
479    fn test_feed_forward() {
480        let ffn = FeedForward::new(64, 256, 0.0, "gelu");
481        let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
482        let output = ffn.forward(&input);
483
484        assert_eq!(output.data().shape(), &[2, 8, 64]);
485    }
486
487    #[test]
488    fn test_encoder_block() {
489        let block = TransformerEncoderBlock::new(64, 4, 256, 0.0, 1e-5, "gelu", false);
490        let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
491        let output = block.forward(&input);
492
493        assert_eq!(output.data().shape(), &[2, 8, 64]);
494    }
495
496    #[test]
497    fn test_decoder_block() {
498        let block = TransformerDecoderBlock::new(64, 4, 128, 0.0, 1e-5, "gelu");
499        let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
500        let output = block.forward(&input);
501
502        assert_eq!(output.data().shape(), &[2, 8, 64]);
503    }
504
505    #[test]
506    fn test_transformer_encoder() {
507        let encoder = TransformerEncoder::new(2, 64, 4, 256, 0.0, 1e-5, "gelu", false);
508        let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
509        let output = encoder.forward(&input);
510
511        assert_eq!(output.data().shape(), &[2, 8, 64]);
512    }
513
514    #[test]
515    fn test_transformer_decoder() {
516        let decoder = TransformerDecoder::new(2, 64, 4, 128, 0.0, 1e-5, "gelu");
517        let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
518        let output = decoder.forward(&input);
519
520        assert_eq!(output.data().shape(), &[2, 8, 64]);
521    }
522}