1use crate::autograd::add;
6use crate::Tensor;
7use std::collections::HashMap;
8
9use super::attention::MultiHeadAttention;
10use super::config::TransformerConfig;
11use super::feedforward::FeedForward;
12use super::norm::RMSNorm;
13
14pub struct TransformerBlock {
16 config: TransformerConfig,
18 layer_idx: usize,
20 pub input_norm: RMSNorm,
22 pub self_attn: MultiHeadAttention,
24 pub post_attn_norm: RMSNorm,
26 pub ffn: FeedForward,
28}
29
30impl TransformerBlock {
31 pub fn new(config: &TransformerConfig, layer_idx: usize) -> Self {
33 Self {
34 config: config.clone(),
35 layer_idx,
36 input_norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
37 self_attn: MultiHeadAttention::new(config),
38 post_attn_norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
39 ffn: FeedForward::new(config),
40 }
41 }
42
43 pub fn from_params(
51 config: &TransformerConfig,
52 params: &HashMap<String, Tensor>,
53 layer_idx: usize,
54 ) -> Option<Self> {
55 let prefix = format!("model.layers.{layer_idx}");
56
57 let input_norm = RMSNorm::from_params(
58 params,
59 &format!("{prefix}.input_layernorm"),
60 config.rms_norm_eps,
61 config.hidden_size,
62 )?;
63
64 let self_attn =
65 MultiHeadAttention::from_params(config, params, &format!("{prefix}.self_attn"))?;
66
67 let post_attn_norm = RMSNorm::from_params(
68 params,
69 &format!("{prefix}.post_attention_layernorm"),
70 config.rms_norm_eps,
71 config.hidden_size,
72 )?;
73
74 let ffn = FeedForward::from_params(config, params, &format!("{prefix}.mlp"))?;
75
76 Some(Self { config: config.clone(), layer_idx, input_norm, self_attn, post_attn_norm, ffn })
77 }
78
79 pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
88 let hidden_size = self.config.hidden_size;
89
90 let norm1 = self.input_norm.forward_batched(x, seq_len, hidden_size);
92 let attn_out = self.self_attn.forward(&norm1, seq_len);
93 let residual1 = add(x, &attn_out);
94
95 let norm2 = self.post_attn_norm.forward_batched(&residual1, seq_len, hidden_size);
97 let ffn_out = self.ffn.forward(&norm2, seq_len);
98 add(&residual1, &ffn_out)
99 }
100
101 pub fn layer_idx(&self) -> usize {
103 self.layer_idx
104 }
105
106 pub fn parameters(&self) -> Vec<&Tensor> {
108 let mut params = vec![&self.input_norm.weight, &self.post_attn_norm.weight];
109 params.extend(self.self_attn.parameters());
110 params.extend(self.ffn.parameters());
111 params
112 }
113
114 pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
116 let mut params: Vec<&mut Tensor> = Vec::new();
117 params.push(&mut self.input_norm.weight);
118 params.push(&mut self.post_attn_norm.weight);
119 params.extend(self.self_attn.parameters_mut());
120 params.extend(self.ffn.parameters_mut());
121 params
122 }
123
124 pub fn named_parameters(&self) -> Vec<(String, &Tensor)> {
127 let prefix = format!("model.layers.{}", self.layer_idx);
128 let mut params = vec![
129 (format!("{prefix}.input_layernorm.weight"), &self.input_norm.weight),
130 (format!("{prefix}.post_attention_layernorm.weight"), &self.post_attn_norm.weight),
131 ];
132 params.extend(self.self_attn.named_parameters(&format!("{prefix}.self_attn")));
133 params.push((format!("{prefix}.mlp.gate_proj.weight"), &self.ffn.w_gate));
134 params.push((format!("{prefix}.mlp.up_proj.weight"), &self.ffn.w_up));
135 params.push((format!("{prefix}.mlp.down_proj.weight"), &self.ffn.w_down));
136 params
137 }
138
139 pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool {
141 match suffix {
142 "input_layernorm.weight" => {
143 self.input_norm.weight = value;
144 true
145 }
146 "post_attention_layernorm.weight" => {
147 self.post_attn_norm.weight = value;
148 true
149 }
150 "mlp.gate_proj.weight" => {
151 self.ffn.w_gate = value;
152 true
153 }
154 "mlp.up_proj.weight" => {
155 self.ffn.w_up = value;
156 true
157 }
158 "mlp.down_proj.weight" => {
159 self.ffn.w_down = value;
160 true
161 }
162 _ => self.self_attn.set_named_parameter(suffix, value),
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_transformer_block_tiny() {
173 let config = TransformerConfig::tiny();
174 let block = TransformerBlock::new(&config, 0);
175 let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
176 let output = block.forward(&x, 2);
177 assert_eq!(output.len(), 2 * config.hidden_size);
178 }
179
180 #[test]
181 fn test_transformer_block_layer_idx() {
182 let config = TransformerConfig::tiny();
183 let block = TransformerBlock::new(&config, 5);
184 assert_eq!(block.layer_idx(), 5);
185 }
186
187 #[test]
188 fn test_transformer_block_parameters() {
189 let config = TransformerConfig::tiny();
190 let block = TransformerBlock::new(&config, 0);
191 let params = block.parameters();
192 assert_eq!(params.len(), 9);
194 }
195
196 #[test]
197 fn test_transformer_block_from_params_success() {
198 let config = TransformerConfig::tiny();
199 let hidden_size = config.hidden_size;
200 let q_dim = config.q_dim();
201 let kv_hidden_size = config.num_kv_heads * config.head_dim();
202 let intermediate_size = config.intermediate_size;
203
204 let mut params = HashMap::new();
205
206 params.insert(
208 "model.layers.0.input_layernorm.weight".to_string(),
209 Tensor::from_vec(vec![1.0; hidden_size], true),
210 );
211
212 params.insert(
214 "model.layers.0.self_attn.q_proj.weight".to_string(),
215 Tensor::from_vec(vec![0.1; q_dim * hidden_size], true),
216 );
217 params.insert(
218 "model.layers.0.self_attn.k_proj.weight".to_string(),
219 Tensor::from_vec(vec![0.1; kv_hidden_size * hidden_size], true),
220 );
221 params.insert(
222 "model.layers.0.self_attn.v_proj.weight".to_string(),
223 Tensor::from_vec(vec![0.1; kv_hidden_size * hidden_size], true),
224 );
225 params.insert(
226 "model.layers.0.self_attn.o_proj.weight".to_string(),
227 Tensor::from_vec(vec![0.1; hidden_size * q_dim], true),
228 );
229
230 params.insert(
232 "model.layers.0.post_attention_layernorm.weight".to_string(),
233 Tensor::from_vec(vec![1.0; hidden_size], true),
234 );
235
236 params.insert(
238 "model.layers.0.mlp.gate_proj.weight".to_string(),
239 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
240 );
241 params.insert(
242 "model.layers.0.mlp.up_proj.weight".to_string(),
243 Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
244 );
245 params.insert(
246 "model.layers.0.mlp.down_proj.weight".to_string(),
247 Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
248 );
249
250 let block = TransformerBlock::from_params(&config, ¶ms, 0);
251 assert!(block.is_some());
252 let block = block.expect("operation should succeed");
253 assert_eq!(block.layer_idx(), 0);
254 }
255
256 #[test]
257 fn test_transformer_block_from_params_missing_norm() {
258 let config = TransformerConfig::tiny();
259 let params = HashMap::new();
260 let block = TransformerBlock::from_params(&config, ¶ms, 0);
263 assert!(block.is_none());
264 }
265}