1use crate::autograd::add;
14use crate::Tensor;
15use std::collections::HashMap;
16
17use super::attention::MultiHeadAttention;
18use super::config::TransformerConfig;
19use super::feedforward::EncoderFeedForward;
20use super::norm::LayerNorm;
21
22pub struct EncoderBlock {
27 layer_idx: usize,
29 pub self_attn: MultiHeadAttention,
31 pub attn_layernorm: LayerNorm,
33 pub ffn: EncoderFeedForward,
35 pub ffn_layernorm: LayerNorm,
37 hidden_size: usize,
39}
40
41impl EncoderBlock {
42 pub fn new(config: &TransformerConfig, layer_idx: usize) -> Self {
44 let eps = config.rms_norm_eps; Self {
46 layer_idx,
47 self_attn: MultiHeadAttention::new(config),
48 attn_layernorm: LayerNorm::new(config.hidden_size, eps),
49 ffn: EncoderFeedForward::new(config),
50 ffn_layernorm: LayerNorm::new(config.hidden_size, eps),
51 hidden_size: config.hidden_size,
52 }
53 }
54
55 pub fn from_params(
64 config: &TransformerConfig,
65 params: &HashMap<String, Tensor>,
66 layer_idx: usize,
67 ) -> Option<Self> {
68 let prefix = format!("encoder.layers.{layer_idx}");
69 let eps = config.rms_norm_eps;
70
71 let self_attn =
72 MultiHeadAttention::from_params(config, params, &format!("{prefix}.self_attn"))?;
73
74 let attn_layernorm = LayerNorm::from_params(
75 params,
76 &format!("{prefix}.input_layernorm"),
77 eps,
78 config.hidden_size,
79 )?;
80
81 let ffn = EncoderFeedForward::from_params(config, params, &format!("{prefix}.mlp"))?;
82
83 let ffn_layernorm = LayerNorm::from_params(
84 params,
85 &format!("{prefix}.post_attention_layernorm"),
86 eps,
87 config.hidden_size,
88 )?;
89
90 Some(Self {
91 layer_idx,
92 self_attn,
93 attn_layernorm,
94 ffn,
95 ffn_layernorm,
96 hidden_size: config.hidden_size,
97 })
98 }
99
100 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
107 let h = self.hidden_size;
108
109 let attn_out = self.self_attn.forward(x, seq_len);
111 let residual1 = add(x, &attn_out);
112 let norm1 = self.attn_layernorm.forward_batched(&residual1, seq_len, h);
113
114 let ffn_out = self.ffn.forward(&norm1, seq_len);
116 let residual2 = add(&norm1, &ffn_out);
117 self.ffn_layernorm.forward_batched(&residual2, seq_len, h)
118 }
119
120 pub fn layer_idx(&self) -> usize {
122 self.layer_idx
123 }
124
125 pub fn parameters(&self) -> Vec<&Tensor> {
127 let mut params = Vec::new();
128 params.extend(self.self_attn.parameters());
129 params.push(&self.attn_layernorm.weight);
130 params.push(&self.attn_layernorm.bias);
131 params.extend(self.ffn.parameters());
132 params.push(&self.ffn_layernorm.weight);
133 params.push(&self.ffn_layernorm.bias);
134 params
135 }
136}
137
138#[cfg(test)]
139#[allow(clippy::unwrap_used)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn clf_001_encoder_block_output_shape() {
145 let config = TransformerConfig::codebert();
146 let block = EncoderBlock::new(&config, 0);
147 let seq_len = 4;
148 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
149 let output = block.forward(&x, seq_len);
150 assert_eq!(output.len(), seq_len * config.hidden_size);
151 }
152
153 #[test]
154 fn clf_001_encoder_block_output_finite() {
155 let config = TransformerConfig::codebert();
156 let block = EncoderBlock::new(&config, 0);
157 let seq_len = 3;
158 let x = Tensor::from_vec(vec![0.5; seq_len * config.hidden_size], true);
159 let output = block.forward(&x, seq_len);
160 let data = output.data();
161 let slice = data.as_slice().unwrap();
162 assert!(slice.iter().all(|v| v.is_finite()), "encoder block output must be finite");
163 }
164
165 #[test]
166 fn clf_001_encoder_block_layer_idx() {
167 let config = TransformerConfig::codebert();
168 let block = EncoderBlock::new(&config, 7);
169 assert_eq!(block.layer_idx(), 7);
170 }
171
172 #[test]
173 fn clf_001_encoder_block_parameters_count() {
174 let config = TransformerConfig::codebert();
175 let block = EncoderBlock::new(&config, 0);
176 let params = block.parameters();
177 assert_eq!(params.len(), 15);
186 }
187
188 #[test]
189 fn test_encoder_block_different_layer_indices() {
190 let config = TransformerConfig::codebert();
191 for idx in [0, 1, 5, 11] {
192 let block = EncoderBlock::new(&config, idx);
193 assert_eq!(block.layer_idx(), idx);
194 }
195 }
196
197 #[test]
198 fn test_encoder_block_forward_preserves_shape() {
199 let config = TransformerConfig::codebert();
200 let block = EncoderBlock::new(&config, 0);
201 for seq_len in [1, 2, 4, 8] {
202 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
203 let output = block.forward(&x, seq_len);
204 assert_eq!(
205 output.len(),
206 seq_len * config.hidden_size,
207 "Shape mismatch for seq_len={seq_len}"
208 );
209 }
210 }
211
212 #[test]
213 fn test_encoder_block_deterministic() {
214 let config = TransformerConfig::codebert();
215 let block = EncoderBlock::new(&config, 0);
216 let seq_len = 3;
217 let x = Tensor::from_vec(vec![0.3; seq_len * config.hidden_size], true);
218
219 let out1 = block.forward(&x, seq_len);
220 let out2 = block.forward(&x, seq_len);
221
222 let d1 = out1.data();
223 let d2 = out2.data();
224 let s1 = d1.as_slice().unwrap();
225 let s2 = d2.as_slice().unwrap();
226 assert_eq!(s1, s2, "Encoder block should be deterministic");
227 }
228
229 #[test]
230 fn test_encoder_block_from_params_missing() {
231 let config = TransformerConfig::codebert();
232 let empty_params: HashMap<String, Tensor> = HashMap::new();
233 let result = EncoderBlock::from_params(&config, &empty_params, 0);
234 assert!(result.is_none());
235 }
236
237 #[test]
238 fn test_encoder_block_hidden_size() {
239 let config = TransformerConfig::codebert();
240 let block = EncoderBlock::new(&config, 0);
241 assert_eq!(block.hidden_size, config.hidden_size);
242 }
243
244 #[test]
245 fn test_encoder_block_parameters_nonzero_length() {
246 let config = TransformerConfig::codebert();
247 let block = EncoderBlock::new(&config, 0);
248 let params = block.parameters();
249 for (i, p) in params.iter().enumerate() {
250 assert!(!p.is_empty(), "Parameter {i} should have non-zero length");
251 }
252 }
253
254 #[test]
255 fn test_encoder_block_single_token() {
256 let config = TransformerConfig::codebert();
257 let block = EncoderBlock::new(&config, 3);
258 let x = Tensor::from_vec(vec![0.2; config.hidden_size], true);
259 let output = block.forward(&x, 1);
260 assert_eq!(output.len(), config.hidden_size);
261 let data = output.data();
262 let slice = data.as_slice().unwrap();
263 assert!(slice.iter().all(|v| v.is_finite()));
264 }
265}