1use crate::autograd::add;
14use crate::Tensor;
15use std::collections::HashMap;
16
17use super::attention::MultiHeadAttention;
18use super::config::TransformerConfig;
19use super::feedforward::EncoderFeedForward;
20use super::norm::LayerNorm;
21
22pub struct EncoderBlock {
27 layer_idx: usize,
29 pub self_attn: MultiHeadAttention,
31 pub attn_layernorm: LayerNorm,
33 pub ffn: EncoderFeedForward,
35 pub ffn_layernorm: LayerNorm,
37 hidden_size: usize,
39}
40
41impl EncoderBlock {
42 pub fn new(config: &TransformerConfig, layer_idx: usize) -> Self {
44 let eps = config.rms_norm_eps; Self {
46 layer_idx,
47 self_attn: MultiHeadAttention::new(config),
48 attn_layernorm: LayerNorm::new(config.hidden_size, eps),
49 ffn: EncoderFeedForward::new(config),
50 ffn_layernorm: LayerNorm::new(config.hidden_size, eps),
51 hidden_size: config.hidden_size,
52 }
53 }
54
55 pub fn from_params(
64 config: &TransformerConfig,
65 params: &HashMap<String, Tensor>,
66 layer_idx: usize,
67 ) -> Option<Self> {
68 let prefix = format!("encoder.layers.{layer_idx}");
69 let eps = config.rms_norm_eps;
70
71 let self_attn =
72 MultiHeadAttention::from_params(config, params, &format!("{prefix}.self_attn"))?;
73
74 let attn_layernorm = LayerNorm::from_params(
75 params,
76 &format!("{prefix}.input_layernorm"),
77 eps,
78 config.hidden_size,
79 )?;
80
81 let ffn = EncoderFeedForward::from_params(config, params, &format!("{prefix}.mlp"))?;
82
83 let ffn_layernorm = LayerNorm::from_params(
84 params,
85 &format!("{prefix}.post_attention_layernorm"),
86 eps,
87 config.hidden_size,
88 )?;
89
90 Some(Self {
91 layer_idx,
92 self_attn,
93 attn_layernorm,
94 ffn,
95 ffn_layernorm,
96 hidden_size: config.hidden_size,
97 })
98 }
99
100 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
107 let h = self.hidden_size;
108
109 let attn_out = self.self_attn.forward(x, seq_len);
111 let residual1 = add(x, &attn_out);
112 let norm1 = self.attn_layernorm.forward_batched(&residual1, seq_len, h);
113
114 let ffn_out = self.ffn.forward(&norm1, seq_len);
116 let residual2 = add(&norm1, &ffn_out);
117 self.ffn_layernorm.forward_batched(&residual2, seq_len, h)
118 }
119
120 pub fn layer_idx(&self) -> usize {
122 self.layer_idx
123 }
124
125 pub fn parameters(&self) -> Vec<&Tensor> {
127 let mut params = Vec::new();
128 params.extend(self.self_attn.parameters());
129 params.push(&self.attn_layernorm.weight);
130 params.push(&self.attn_layernorm.bias);
131 params.extend(self.ffn.parameters());
132 params.push(&self.ffn_layernorm.weight);
133 params.push(&self.ffn_layernorm.bias);
134 params
135 }
136}
137
138#[cfg(test)]
139#[allow(clippy::unwrap_used)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn clf_001_encoder_block_output_shape() {
145 let config = TransformerConfig::codebert();
146 let block = EncoderBlock::new(&config, 0);
147 let seq_len = 4;
148 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
149 let output = block.forward(&x, seq_len);
150 assert_eq!(output.len(), seq_len * config.hidden_size);
151 }
152
153 #[test]
154 fn clf_001_encoder_block_output_finite() {
155 let config = TransformerConfig::codebert();
156 let block = EncoderBlock::new(&config, 0);
157 let seq_len = 3;
158 let x = Tensor::from_vec(vec![0.5; seq_len * config.hidden_size], true);
159 let output = block.forward(&x, seq_len);
160 let data = output.data();
161 let slice = data.as_slice().unwrap();
162 assert!(slice.iter().all(|v| v.is_finite()), "encoder block output must be finite");
163 }
164
165 #[test]
166 fn clf_001_encoder_block_layer_idx() {
167 let config = TransformerConfig::codebert();
168 let block = EncoderBlock::new(&config, 7);
169 assert_eq!(block.layer_idx(), 7);
170 }
171
172 #[test]
173 fn clf_001_encoder_block_parameters_count() {
174 let config = TransformerConfig::codebert();
175 let block = EncoderBlock::new(&config, 0);
176 let params = block.parameters();
177 assert_eq!(params.len(), 12);
180 }
181
182 #[test]
183 fn test_encoder_block_different_layer_indices() {
184 let config = TransformerConfig::codebert();
185 for idx in [0, 1, 5, 11] {
186 let block = EncoderBlock::new(&config, idx);
187 assert_eq!(block.layer_idx(), idx);
188 }
189 }
190
191 #[test]
192 fn test_encoder_block_forward_preserves_shape() {
193 let config = TransformerConfig::codebert();
194 let block = EncoderBlock::new(&config, 0);
195 for seq_len in [1, 2, 4, 8] {
196 let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
197 let output = block.forward(&x, seq_len);
198 assert_eq!(
199 output.len(),
200 seq_len * config.hidden_size,
201 "Shape mismatch for seq_len={seq_len}"
202 );
203 }
204 }
205
206 #[test]
207 fn test_encoder_block_deterministic() {
208 let config = TransformerConfig::codebert();
209 let block = EncoderBlock::new(&config, 0);
210 let seq_len = 3;
211 let x = Tensor::from_vec(vec![0.3; seq_len * config.hidden_size], true);
212
213 let out1 = block.forward(&x, seq_len);
214 let out2 = block.forward(&x, seq_len);
215
216 let d1 = out1.data();
217 let d2 = out2.data();
218 let s1 = d1.as_slice().unwrap();
219 let s2 = d2.as_slice().unwrap();
220 assert_eq!(s1, s2, "Encoder block should be deterministic");
221 }
222
223 #[test]
224 fn test_encoder_block_from_params_missing() {
225 let config = TransformerConfig::codebert();
226 let empty_params: HashMap<String, Tensor> = HashMap::new();
227 let result = EncoderBlock::from_params(&config, &empty_params, 0);
228 assert!(result.is_none());
229 }
230
231 #[test]
232 fn test_encoder_block_hidden_size() {
233 let config = TransformerConfig::codebert();
234 let block = EncoderBlock::new(&config, 0);
235 assert_eq!(block.hidden_size, config.hidden_size);
236 }
237
238 #[test]
239 fn test_encoder_block_parameters_nonzero_length() {
240 let config = TransformerConfig::codebert();
241 let block = EncoderBlock::new(&config, 0);
242 let params = block.parameters();
243 for (i, p) in params.iter().enumerate() {
244 assert!(!p.is_empty(), "Parameter {i} should have non-zero length");
245 }
246 }
247
248 #[test]
249 fn test_encoder_block_single_token() {
250 let config = TransformerConfig::codebert();
251 let block = EncoderBlock::new(&config, 3);
252 let x = Tensor::from_vec(vec![0.2; config.hidden_size], true);
253 let output = block.forward(&x, 1);
254 assert_eq!(output.len(), config.hidden_size);
255 let data = output.data();
256 let slice = data.as_slice().unwrap();
257 assert!(slice.iter().all(|v| v.is_finite()));
258 }
259}