Skip to main content

entrenar/transformer/
encoder_block.rs

1//! Encoder transformer block (BERT/RoBERTa architecture)
2//!
3//! Unlike decoder blocks (pre-norm RMSNorm + SwiGLU FFN), encoder blocks use:
4//! - Post-norm architecture: residual + LayerNorm (not pre-norm)
5//! - Bidirectional attention (no causal mask — same as existing `MultiHeadAttention`)
6//! - GELU FFN with 2 projections (not SwiGLU with 3)
7//! - LayerNorm with bias (not RMSNorm)
8//!
9//! # Contract (CLF-001)
10//! - Output shape == input shape (seq_len × hidden_size)
11//! - No NaN or Inf in output for finite input
12
13use crate::autograd::add;
14use crate::Tensor;
15use std::collections::HashMap;
16
17use super::attention::MultiHeadAttention;
18use super::config::TransformerConfig;
19use super::feedforward::EncoderFeedForward;
20use super::norm::LayerNorm;
21
22/// Encoder transformer block (BERT/RoBERTa).
23///
24/// Architecture: x → Attn(x) + x → LayerNorm → FFN + residual → LayerNorm
25/// (post-norm, matching HuggingFace BERT implementation)
26pub struct EncoderBlock {
27    /// Layer index
28    layer_idx: usize,
29    /// Self-attention (bidirectional — no causal mask)
30    pub self_attn: MultiHeadAttention,
31    /// Post-attention LayerNorm
32    pub attn_layernorm: LayerNorm,
33    /// Feed-forward network (GELU, 2 projections)
34    pub ffn: EncoderFeedForward,
35    /// Post-FFN LayerNorm
36    pub ffn_layernorm: LayerNorm,
37    /// Hidden size for batched operations
38    hidden_size: usize,
39}
40
41impl EncoderBlock {
42    /// Create new encoder block with default initialization
43    pub fn new(config: &TransformerConfig, layer_idx: usize) -> Self {
44        let eps = config.rms_norm_eps; // reuse epsilon field
45        Self {
46            layer_idx,
47            self_attn: MultiHeadAttention::new(config),
48            attn_layernorm: LayerNorm::new(config.hidden_size, eps),
49            ffn: EncoderFeedForward::new(config),
50            ffn_layernorm: LayerNorm::new(config.hidden_size, eps),
51            hidden_size: config.hidden_size,
52        }
53    }
54
55    /// Create encoder block from pre-trained parameters.
56    ///
57    /// Expected weight names (after RoBERTa mapping):
58    /// - `encoder.layers.{i}.self_attn.{q,k,v,o}_proj.weight`
59    /// - `encoder.layers.{i}.input_layernorm.{weight,bias}`
60    /// - `encoder.layers.{i}.mlp.intermediate.dense.{weight,bias}`
61    /// - `encoder.layers.{i}.mlp.output.dense.{weight,bias}`
62    /// - `encoder.layers.{i}.post_attention_layernorm.{weight,bias}`
63    pub fn from_params(
64        config: &TransformerConfig,
65        params: &HashMap<String, Tensor>,
66        layer_idx: usize,
67    ) -> Option<Self> {
68        let prefix = format!("encoder.layers.{layer_idx}");
69        let eps = config.rms_norm_eps;
70
71        let self_attn =
72            MultiHeadAttention::from_params(config, params, &format!("{prefix}.self_attn"))?;
73
74        let attn_layernorm = LayerNorm::from_params(
75            params,
76            &format!("{prefix}.input_layernorm"),
77            eps,
78            config.hidden_size,
79        )?;
80
81        let ffn = EncoderFeedForward::from_params(config, params, &format!("{prefix}.mlp"))?;
82
83        let ffn_layernorm = LayerNorm::from_params(
84            params,
85            &format!("{prefix}.post_attention_layernorm"),
86            eps,
87            config.hidden_size,
88        )?;
89
90        Some(Self {
91            layer_idx,
92            self_attn,
93            attn_layernorm,
94            ffn,
95            ffn_layernorm,
96            hidden_size: config.hidden_size,
97        })
98    }
99
100    /// Forward pass (post-norm encoder architecture).
101    ///
102    /// ```text
103    /// h = LayerNorm(x + Attention(x))
104    /// out = LayerNorm(h + FFN(h))
105    /// ```
106    pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
107        let h = self.hidden_size;
108
109        // Self-attention + residual + LayerNorm
110        let attn_out = self.self_attn.forward(x, seq_len);
111        let residual1 = add(x, &attn_out);
112        let norm1 = self.attn_layernorm.forward_batched(&residual1, seq_len, h);
113
114        // FFN + residual + LayerNorm
115        let ffn_out = self.ffn.forward(&norm1, seq_len);
116        let residual2 = add(&norm1, &ffn_out);
117        self.ffn_layernorm.forward_batched(&residual2, seq_len, h)
118    }
119
120    /// Get layer index
121    pub fn layer_idx(&self) -> usize {
122        self.layer_idx
123    }
124
125    /// Get all parameters (immutable)
126    pub fn parameters(&self) -> Vec<&Tensor> {
127        let mut params = Vec::new();
128        params.extend(self.self_attn.parameters());
129        params.push(&self.attn_layernorm.weight);
130        params.push(&self.attn_layernorm.bias);
131        params.extend(self.ffn.parameters());
132        params.push(&self.ffn_layernorm.weight);
133        params.push(&self.ffn_layernorm.bias);
134        params
135    }
136}
137
138#[cfg(test)]
139#[allow(clippy::unwrap_used)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn clf_001_encoder_block_output_shape() {
145        let config = TransformerConfig::codebert();
146        let block = EncoderBlock::new(&config, 0);
147        let seq_len = 4;
148        let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
149        let output = block.forward(&x, seq_len);
150        assert_eq!(output.len(), seq_len * config.hidden_size);
151    }
152
153    #[test]
154    fn clf_001_encoder_block_output_finite() {
155        let config = TransformerConfig::codebert();
156        let block = EncoderBlock::new(&config, 0);
157        let seq_len = 3;
158        let x = Tensor::from_vec(vec![0.5; seq_len * config.hidden_size], true);
159        let output = block.forward(&x, seq_len);
160        let data = output.data();
161        let slice = data.as_slice().unwrap();
162        assert!(slice.iter().all(|v| v.is_finite()), "encoder block output must be finite");
163    }
164
165    #[test]
166    fn clf_001_encoder_block_layer_idx() {
167        let config = TransformerConfig::codebert();
168        let block = EncoderBlock::new(&config, 7);
169        assert_eq!(block.layer_idx(), 7);
170    }
171
172    #[test]
173    fn clf_001_encoder_block_parameters_count() {
174        let config = TransformerConfig::codebert();
175        let block = EncoderBlock::new(&config, 0);
176        let params = block.parameters();
177        // Updated 2026-05-09 alongside FALSIFY-APR-PRETRAIN-INIT-POPULATE-COVERAGE-001:
178        // codebert config has `use_bias: true`, and `MultiHeadAttention::new`
179        // now correctly allocates Q/K/V projection biases when use_bias=true
180        // (previously hardcoded to None, masking 3 params per encoder layer).
181        // self_attn: 4 (Q,K,V,O weights) + 3 (Q,K,V biases) + attn_layernorm: 2 (w,b)
182        // ffn: 4 (w_up,b_up,w_down,b_down) + ffn_layernorm: 2 (w,b)
183        // Pre-fix this test asserted 12 (wrong; the 3 missing biases would have
184        // been silently dropped during populate from any HF BERT/RoBERTa init).
185        assert_eq!(params.len(), 15);
186    }
187
188    #[test]
189    fn test_encoder_block_different_layer_indices() {
190        let config = TransformerConfig::codebert();
191        for idx in [0, 1, 5, 11] {
192            let block = EncoderBlock::new(&config, idx);
193            assert_eq!(block.layer_idx(), idx);
194        }
195    }
196
197    #[test]
198    fn test_encoder_block_forward_preserves_shape() {
199        let config = TransformerConfig::codebert();
200        let block = EncoderBlock::new(&config, 0);
201        for seq_len in [1, 2, 4, 8] {
202            let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
203            let output = block.forward(&x, seq_len);
204            assert_eq!(
205                output.len(),
206                seq_len * config.hidden_size,
207                "Shape mismatch for seq_len={seq_len}"
208            );
209        }
210    }
211
212    #[test]
213    fn test_encoder_block_deterministic() {
214        let config = TransformerConfig::codebert();
215        let block = EncoderBlock::new(&config, 0);
216        let seq_len = 3;
217        let x = Tensor::from_vec(vec![0.3; seq_len * config.hidden_size], true);
218
219        let out1 = block.forward(&x, seq_len);
220        let out2 = block.forward(&x, seq_len);
221
222        let d1 = out1.data();
223        let d2 = out2.data();
224        let s1 = d1.as_slice().unwrap();
225        let s2 = d2.as_slice().unwrap();
226        assert_eq!(s1, s2, "Encoder block should be deterministic");
227    }
228
229    #[test]
230    fn test_encoder_block_from_params_missing() {
231        let config = TransformerConfig::codebert();
232        let empty_params: HashMap<String, Tensor> = HashMap::new();
233        let result = EncoderBlock::from_params(&config, &empty_params, 0);
234        assert!(result.is_none());
235    }
236
237    #[test]
238    fn test_encoder_block_hidden_size() {
239        let config = TransformerConfig::codebert();
240        let block = EncoderBlock::new(&config, 0);
241        assert_eq!(block.hidden_size, config.hidden_size);
242    }
243
244    #[test]
245    fn test_encoder_block_parameters_nonzero_length() {
246        let config = TransformerConfig::codebert();
247        let block = EncoderBlock::new(&config, 0);
248        let params = block.parameters();
249        for (i, p) in params.iter().enumerate() {
250            assert!(!p.is_empty(), "Parameter {i} should have non-zero length");
251        }
252    }
253
254    #[test]
255    fn test_encoder_block_single_token() {
256        let config = TransformerConfig::codebert();
257        let block = EncoderBlock::new(&config, 3);
258        let x = Tensor::from_vec(vec![0.2; config.hidden_size], true);
259        let output = block.forward(&x, 1);
260        assert_eq!(output.len(), config.hidden_size);
261        let data = output.data();
262        let slice = data.as_slice().unwrap();
263        assert!(slice.iter().all(|v| v.is_finite()));
264    }
265}