Skip to main content

entrenar/transformer/
encoder_block.rs

1//! Encoder transformer block (BERT/RoBERTa architecture)
2//!
3//! Unlike decoder blocks (pre-norm RMSNorm + SwiGLU FFN), encoder blocks use:
4//! - Post-norm architecture: residual + LayerNorm (not pre-norm)
5//! - Bidirectional attention (no causal mask — same as existing `MultiHeadAttention`)
6//! - GELU FFN with 2 projections (not SwiGLU with 3)
7//! - LayerNorm with bias (not RMSNorm)
8//!
9//! # Contract (CLF-001)
10//! - Output shape == input shape (seq_len × hidden_size)
11//! - No NaN or Inf in output for finite input
12
13use crate::autograd::add;
14use crate::Tensor;
15use std::collections::HashMap;
16
17use super::attention::MultiHeadAttention;
18use super::config::TransformerConfig;
19use super::feedforward::EncoderFeedForward;
20use super::norm::LayerNorm;
21
22/// Encoder transformer block (BERT/RoBERTa).
23///
24/// Architecture: x → Attn(x) + x → LayerNorm → FFN + residual → LayerNorm
25/// (post-norm, matching HuggingFace BERT implementation)
26pub struct EncoderBlock {
27    /// Layer index
28    layer_idx: usize,
29    /// Self-attention (bidirectional — no causal mask)
30    pub self_attn: MultiHeadAttention,
31    /// Post-attention LayerNorm
32    pub attn_layernorm: LayerNorm,
33    /// Feed-forward network (GELU, 2 projections)
34    pub ffn: EncoderFeedForward,
35    /// Post-FFN LayerNorm
36    pub ffn_layernorm: LayerNorm,
37    /// Hidden size for batched operations
38    hidden_size: usize,
39}
40
41impl EncoderBlock {
42    /// Create new encoder block with default initialization
43    pub fn new(config: &TransformerConfig, layer_idx: usize) -> Self {
44        let eps = config.rms_norm_eps; // reuse epsilon field
45        Self {
46            layer_idx,
47            self_attn: MultiHeadAttention::new(config),
48            attn_layernorm: LayerNorm::new(config.hidden_size, eps),
49            ffn: EncoderFeedForward::new(config),
50            ffn_layernorm: LayerNorm::new(config.hidden_size, eps),
51            hidden_size: config.hidden_size,
52        }
53    }
54
55    /// Create encoder block from pre-trained parameters.
56    ///
57    /// Expected weight names (after RoBERTa mapping):
58    /// - `encoder.layers.{i}.self_attn.{q,k,v,o}_proj.weight`
59    /// - `encoder.layers.{i}.input_layernorm.{weight,bias}`
60    /// - `encoder.layers.{i}.mlp.intermediate.dense.{weight,bias}`
61    /// - `encoder.layers.{i}.mlp.output.dense.{weight,bias}`
62    /// - `encoder.layers.{i}.post_attention_layernorm.{weight,bias}`
63    pub fn from_params(
64        config: &TransformerConfig,
65        params: &HashMap<String, Tensor>,
66        layer_idx: usize,
67    ) -> Option<Self> {
68        let prefix = format!("encoder.layers.{layer_idx}");
69        let eps = config.rms_norm_eps;
70
71        let self_attn =
72            MultiHeadAttention::from_params(config, params, &format!("{prefix}.self_attn"))?;
73
74        let attn_layernorm = LayerNorm::from_params(
75            params,
76            &format!("{prefix}.input_layernorm"),
77            eps,
78            config.hidden_size,
79        )?;
80
81        let ffn = EncoderFeedForward::from_params(config, params, &format!("{prefix}.mlp"))?;
82
83        let ffn_layernorm = LayerNorm::from_params(
84            params,
85            &format!("{prefix}.post_attention_layernorm"),
86            eps,
87            config.hidden_size,
88        )?;
89
90        Some(Self {
91            layer_idx,
92            self_attn,
93            attn_layernorm,
94            ffn,
95            ffn_layernorm,
96            hidden_size: config.hidden_size,
97        })
98    }
99
100    /// Forward pass (post-norm encoder architecture).
101    ///
102    /// ```text
103    /// h = LayerNorm(x + Attention(x))
104    /// out = LayerNorm(h + FFN(h))
105    /// ```
106    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
107        let h = self.hidden_size;
108
109        // Self-attention + residual + LayerNorm
110        let attn_out = self.self_attn.forward(x, seq_len);
111        let residual1 = add(x, &attn_out);
112        let norm1 = self.attn_layernorm.forward_batched(&residual1, seq_len, h);
113
114        // FFN + residual + LayerNorm
115        let ffn_out = self.ffn.forward(&norm1, seq_len);
116        let residual2 = add(&norm1, &ffn_out);
117        self.ffn_layernorm.forward_batched(&residual2, seq_len, h)
118    }
119
120    /// Get layer index
121    pub fn layer_idx(&self) -> usize {
122        self.layer_idx
123    }
124
125    /// Get all parameters (immutable)
126    pub fn parameters(&self) -> Vec<&Tensor> {
127        let mut params = Vec::new();
128        params.extend(self.self_attn.parameters());
129        params.push(&self.attn_layernorm.weight);
130        params.push(&self.attn_layernorm.bias);
131        params.extend(self.ffn.parameters());
132        params.push(&self.ffn_layernorm.weight);
133        params.push(&self.ffn_layernorm.bias);
134        params
135    }
136}
137
138#[cfg(test)]
139#[allow(clippy::unwrap_used)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn clf_001_encoder_block_output_shape() {
145        let config = TransformerConfig::codebert();
146        let block = EncoderBlock::new(&config, 0);
147        let seq_len = 4;
148        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
149        let output = block.forward(&x, seq_len);
150        assert_eq!(output.len(), seq_len * config.hidden_size);
151    }
152
153    #[test]
154    fn clf_001_encoder_block_output_finite() {
155        let config = TransformerConfig::codebert();
156        let block = EncoderBlock::new(&config, 0);
157        let seq_len = 3;
158        let x = Tensor::from_vec(vec![0.5; seq_len * config.hidden_size], true);
159        let output = block.forward(&x, seq_len);
160        let data = output.data();
161        let slice = data.as_slice().unwrap();
162        assert!(slice.iter().all(|v| v.is_finite()), "encoder block output must be finite");
163    }
164
165    #[test]
166    fn clf_001_encoder_block_layer_idx() {
167        let config = TransformerConfig::codebert();
168        let block = EncoderBlock::new(&config, 7);
169        assert_eq!(block.layer_idx(), 7);
170    }
171
172    #[test]
173    fn clf_001_encoder_block_parameters_count() {
174        let config = TransformerConfig::codebert();
175        let block = EncoderBlock::new(&config, 0);
176        let params = block.parameters();
177        // self_attn: 4 (Q,K,V,O) + attn_layernorm: 2 (w,b)
178        // ffn: 4 (w_up,b_up,w_down,b_down) + ffn_layernorm: 2 (w,b)
179        assert_eq!(params.len(), 12);
180    }
181
182    #[test]
183    fn test_encoder_block_different_layer_indices() {
184        let config = TransformerConfig::codebert();
185        for idx in [0, 1, 5, 11] {
186            let block = EncoderBlock::new(&config, idx);
187            assert_eq!(block.layer_idx(), idx);
188        }
189    }
190
191    #[test]
192    fn test_encoder_block_forward_preserves_shape() {
193        let config = TransformerConfig::codebert();
194        let block = EncoderBlock::new(&config, 0);
195        for seq_len in [1, 2, 4, 8] {
196            let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
197            let output = block.forward(&x, seq_len);
198            assert_eq!(
199                output.len(),
200                seq_len * config.hidden_size,
201                "Shape mismatch for seq_len={seq_len}"
202            );
203        }
204    }
205
206    #[test]
207    fn test_encoder_block_deterministic() {
208        let config = TransformerConfig::codebert();
209        let block = EncoderBlock::new(&config, 0);
210        let seq_len = 3;
211        let x = Tensor::from_vec(vec![0.3; seq_len * config.hidden_size], true);
212
213        let out1 = block.forward(&x, seq_len);
214        let out2 = block.forward(&x, seq_len);
215
216        let d1 = out1.data();
217        let d2 = out2.data();
218        let s1 = d1.as_slice().unwrap();
219        let s2 = d2.as_slice().unwrap();
220        assert_eq!(s1, s2, "Encoder block should be deterministic");
221    }
222
223    #[test]
224    fn test_encoder_block_from_params_missing() {
225        let config = TransformerConfig::codebert();
226        let empty_params: HashMap<String, Tensor> = HashMap::new();
227        let result = EncoderBlock::from_params(&config, &empty_params, 0);
228        assert!(result.is_none());
229    }
230
231    #[test]
232    fn test_encoder_block_hidden_size() {
233        let config = TransformerConfig::codebert();
234        let block = EncoderBlock::new(&config, 0);
235        assert_eq!(block.hidden_size, config.hidden_size);
236    }
237
238    #[test]
239    fn test_encoder_block_parameters_nonzero_length() {
240        let config = TransformerConfig::codebert();
241        let block = EncoderBlock::new(&config, 0);
242        let params = block.parameters();
243        for (i, p) in params.iter().enumerate() {
244            assert!(!p.is_empty(), "Parameter {i} should have non-zero length");
245        }
246    }
247
248    #[test]
249    fn test_encoder_block_single_token() {
250        let config = TransformerConfig::codebert();
251        let block = EncoderBlock::new(&config, 3);
252        let x = Tensor::from_vec(vec![0.2; config.hidden_size], true);
253        let output = block.forward(&x, 1);
254        assert_eq!(output.len(), config.hidden_size);
255        let data = output.data();
256        let slice = data.as_slice().unwrap();
257        assert!(slice.iter().all(|v| v.is_finite()));
258    }
259}