Skip to main content

entrenar/transformer/
block.rs

1//! Transformer block module
2//!
3//! This module provides complete transformer blocks combining attention, FFN, and normalization.
4
5use crate::autograd::add;
6use crate::Tensor;
7use std::collections::HashMap;
8
9use super::attention::MultiHeadAttention;
10use super::config::TransformerConfig;
11use super::feedforward::FeedForward;
12use super::norm::RMSNorm;
13
14/// Complete transformer block
15pub struct TransformerBlock {
16    /// Configuration
17    config: TransformerConfig,
18    /// Layer index
19    layer_idx: usize,
20    /// Input layer normalization
21    pub input_norm: RMSNorm,
22    /// Self-attention
23    pub self_attn: MultiHeadAttention,
24    /// Post-attention layer normalization
25    pub post_attn_norm: RMSNorm,
26    /// Feed-forward network
27    pub ffn: FeedForward,
28}
29
30impl TransformerBlock {
31    /// Create new transformer block with initialized weights
32    pub fn new(config: &TransformerConfig, layer_idx: usize) -> Self {
33        Self {
34            config: config.clone(),
35            layer_idx,
36            input_norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
37            self_attn: MultiHeadAttention::new(config),
38            post_attn_norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
39            ffn: FeedForward::new(config),
40        }
41    }
42
43    /// Create transformer block from parameter map
44    ///
45    /// Expected parameter names (following HuggingFace LLaMA convention):
46    /// - `model.layers.{layer_idx}.input_layernorm.weight`
47    /// - `model.layers.{layer_idx}.self_attn.*`
48    /// - `model.layers.{layer_idx}.post_attention_layernorm.weight`
49    /// - `model.layers.{layer_idx}.mlp.*`
50    pub fn from_params(
51        config: &TransformerConfig,
52        params: &HashMap<String, Tensor>,
53        layer_idx: usize,
54    ) -> Option<Self> {
55        let prefix = format!("model.layers.{layer_idx}");
56
57        let input_norm = RMSNorm::from_params(
58            params,
59            &format!("{prefix}.input_layernorm"),
60            config.rms_norm_eps,
61            config.hidden_size,
62        )?;
63
64        let self_attn =
65            MultiHeadAttention::from_params(config, params, &format!("{prefix}.self_attn"))?;
66
67        let post_attn_norm = RMSNorm::from_params(
68            params,
69            &format!("{prefix}.post_attention_layernorm"),
70            config.rms_norm_eps,
71            config.hidden_size,
72        )?;
73
74        let ffn = FeedForward::from_params(config, params, &format!("{prefix}.mlp"))?;
75
76        Some(Self { config: config.clone(), layer_idx, input_norm, self_attn, post_attn_norm, ffn })
77    }
78
79    /// Forward pass
80    ///
81    /// # Arguments
82    /// * `x` - Input tensor (seq_len * hidden_size, flattened)
83    /// * `seq_len` - Sequence length
84    ///
85    /// # Returns
86    /// Output tensor (seq_len * hidden_size, flattened)
87    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
88        let hidden_size = self.config.hidden_size;
89
90        // Pre-norm + attention + residual
91        let norm1 = self.input_norm.forward_batched(x, seq_len, hidden_size);
92        let attn_out = self.self_attn.forward(&norm1, seq_len);
93        let residual1 = add(x, &attn_out);
94
95        // Pre-norm + FFN + residual
96        let norm2 = self.post_attn_norm.forward_batched(&residual1, seq_len, hidden_size);
97        let ffn_out = self.ffn.forward(&norm2, seq_len);
98        add(&residual1, &ffn_out)
99    }
100
101    /// Get layer index
102    pub fn layer_idx(&self) -> usize {
103        self.layer_idx
104    }
105
106    /// Get all parameters as a vector
107    pub fn parameters(&self) -> Vec<&Tensor> {
108        let mut params = vec![&self.input_norm.weight, &self.post_attn_norm.weight];
109        params.extend(self.self_attn.parameters());
110        params.extend(self.ffn.parameters());
111        params
112    }
113
114    /// Get all parameters as mutable references for optimizer
115    pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
116        let mut params: Vec<&mut Tensor> = Vec::new();
117        params.push(&mut self.input_norm.weight);
118        params.push(&mut self.post_attn_norm.weight);
119        params.extend(self.self_attn.parameters_mut());
120        params.extend(self.ffn.parameters_mut());
121        params
122    }
123
124    /// Get named parameters for checkpoint serialization.
125    /// Returns (name, tensor) pairs matching HuggingFace weight conventions.
126    pub fn named_parameters(&self) -> Vec<(String, &Tensor)> {
127        let prefix = format!("model.layers.{}", self.layer_idx);
128        let mut params = vec![
129            (format!("{prefix}.input_layernorm.weight"), &self.input_norm.weight),
130            (format!("{prefix}.post_attention_layernorm.weight"), &self.post_attn_norm.weight),
131        ];
132        params.extend(self.self_attn.named_parameters(&format!("{prefix}.self_attn")));
133        params.push((format!("{prefix}.mlp.gate_proj.weight"), &self.ffn.w_gate));
134        params.push((format!("{prefix}.mlp.up_proj.weight"), &self.ffn.w_up));
135        params.push((format!("{prefix}.mlp.down_proj.weight"), &self.ffn.w_down));
136        params
137    }
138
139    /// ENT-282: Set a named parameter by suffix (after "model.layers.{idx}.").
140    pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool {
141        match suffix {
142            "input_layernorm.weight" => {
143                self.input_norm.weight = value;
144                true
145            }
146            "post_attention_layernorm.weight" => {
147                self.post_attn_norm.weight = value;
148                true
149            }
150            "mlp.gate_proj.weight" => {
151                self.ffn.w_gate = value;
152                true
153            }
154            "mlp.up_proj.weight" => {
155                self.ffn.w_up = value;
156                true
157            }
158            "mlp.down_proj.weight" => {
159                self.ffn.w_down = value;
160                true
161            }
162            _ => self.self_attn.set_named_parameter(suffix, value),
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_transformer_block_tiny() {
173        let config = TransformerConfig::tiny();
174        let block = TransformerBlock::new(&config, 0);
175        let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
176        let output = block.forward(&x, 2);
177        assert_eq!(output.len(), 2 * config.hidden_size);
178    }
179
180    #[test]
181    fn test_transformer_block_layer_idx() {
182        let config = TransformerConfig::tiny();
183        let block = TransformerBlock::new(&config, 5);
184        assert_eq!(block.layer_idx(), 5);
185    }
186
187    #[test]
188    fn test_transformer_block_parameters() {
189        let config = TransformerConfig::tiny();
190        let block = TransformerBlock::new(&config, 0);
191        let params = block.parameters();
192        // input_norm + post_attn_norm + 4 attn weights + 3 ffn = 9 (no biases by default)
193        assert_eq!(params.len(), 9);
194    }
195
196    #[test]
197    fn test_transformer_block_from_params_success() {
198        let config = TransformerConfig::tiny();
199        let hidden_size = config.hidden_size;
200        let q_dim = config.q_dim();
201        let kv_hidden_size = config.num_kv_heads * config.head_dim();
202        let intermediate_size = config.intermediate_size;
203
204        let mut params = HashMap::new();
205
206        // Input norm
207        params.insert(
208            "model.layers.0.input_layernorm.weight".to_string(),
209            Tensor::from_vec(vec![1.0; hidden_size], true),
210        );
211
212        // Self-attention weights
213        params.insert(
214            "model.layers.0.self_attn.q_proj.weight".to_string(),
215            Tensor::from_vec(vec![0.1; q_dim * hidden_size], true),
216        );
217        params.insert(
218            "model.layers.0.self_attn.k_proj.weight".to_string(),
219            Tensor::from_vec(vec![0.1; kv_hidden_size * hidden_size], true),
220        );
221        params.insert(
222            "model.layers.0.self_attn.v_proj.weight".to_string(),
223            Tensor::from_vec(vec![0.1; kv_hidden_size * hidden_size], true),
224        );
225        params.insert(
226            "model.layers.0.self_attn.o_proj.weight".to_string(),
227            Tensor::from_vec(vec![0.1; hidden_size * q_dim], true),
228        );
229
230        // Post-attention norm
231        params.insert(
232            "model.layers.0.post_attention_layernorm.weight".to_string(),
233            Tensor::from_vec(vec![1.0; hidden_size], true),
234        );
235
236        // MLP weights
237        params.insert(
238            "model.layers.0.mlp.gate_proj.weight".to_string(),
239            Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
240        );
241        params.insert(
242            "model.layers.0.mlp.up_proj.weight".to_string(),
243            Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
244        );
245        params.insert(
246            "model.layers.0.mlp.down_proj.weight".to_string(),
247            Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
248        );
249
250        let block = TransformerBlock::from_params(&config, &params, 0);
251        assert!(block.is_some());
252        let block = block.expect("operation should succeed");
253        assert_eq!(block.layer_idx(), 0);
254    }
255
256    #[test]
257    fn test_transformer_block_from_params_missing_norm() {
258        let config = TransformerConfig::tiny();
259        let params = HashMap::new();
260        // Empty params - should fail
261
262        let block = TransformerBlock::from_params(&config, &params, 0);
263        assert!(block.is_none());
264    }
265}