use crate::autograd::add;
use crate::Tensor;
use std::collections::HashMap;
use super::attention::MultiHeadAttention;
use super::config::TransformerConfig;
use super::feedforward::FeedForward;
use super::norm::RMSNorm;
pub struct TransformerBlock {
config: TransformerConfig,
layer_idx: usize,
pub input_norm: RMSNorm,
pub self_attn: MultiHeadAttention,
pub post_attn_norm: RMSNorm,
pub ffn: FeedForward,
}
impl TransformerBlock {
pub fn new(config: &TransformerConfig, layer_idx: usize) -> Self {
Self {
config: config.clone(),
layer_idx,
input_norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
self_attn: MultiHeadAttention::new(config),
post_attn_norm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
ffn: FeedForward::new(config),
}
}
pub fn from_params(
config: &TransformerConfig,
params: &HashMap<String, Tensor>,
layer_idx: usize,
) -> Option<Self> {
let prefix = format!("model.layers.{layer_idx}");
let input_norm = RMSNorm::from_params(
params,
&format!("{prefix}.input_layernorm"),
config.rms_norm_eps,
config.hidden_size,
)?;
let self_attn =
MultiHeadAttention::from_params(config, params, &format!("{prefix}.self_attn"))?;
let post_attn_norm = RMSNorm::from_params(
params,
&format!("{prefix}.post_attention_layernorm"),
config.rms_norm_eps,
config.hidden_size,
)?;
let ffn = FeedForward::from_params(config, params, &format!("{prefix}.mlp"))?;
Some(Self { config: config.clone(), layer_idx, input_norm, self_attn, post_attn_norm, ffn })
}
pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
let hidden_size = self.config.hidden_size;
let norm1 = self.input_norm.forward_batched(x, seq_len, hidden_size);
let attn_out = self.self_attn.forward(&norm1, seq_len);
let residual1 = add(x, &attn_out);
let norm2 = self.post_attn_norm.forward_batched(&residual1, seq_len, hidden_size);
let ffn_out = self.ffn.forward(&norm2, seq_len);
add(&residual1, &ffn_out)
}
pub fn layer_idx(&self) -> usize {
self.layer_idx
}
pub fn parameters(&self) -> Vec<&Tensor> {
let mut params = vec![&self.input_norm.weight, &self.post_attn_norm.weight];
params.extend(self.self_attn.parameters());
params.extend(self.ffn.parameters());
params
}
pub fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut params: Vec<&mut Tensor> = Vec::new();
params.push(&mut self.input_norm.weight);
params.push(&mut self.post_attn_norm.weight);
params.extend(self.self_attn.parameters_mut());
params.extend(self.ffn.parameters_mut());
params
}
pub fn named_parameters(&self) -> Vec<(String, &Tensor)> {
let prefix = format!("model.layers.{}", self.layer_idx);
let mut params = vec![
(format!("{prefix}.input_layernorm.weight"), &self.input_norm.weight),
(format!("{prefix}.post_attention_layernorm.weight"), &self.post_attn_norm.weight),
];
params.extend(self.self_attn.named_parameters(&format!("{prefix}.self_attn")));
params.push((format!("{prefix}.mlp.gate_proj.weight"), &self.ffn.w_gate));
params.push((format!("{prefix}.mlp.up_proj.weight"), &self.ffn.w_up));
params.push((format!("{prefix}.mlp.down_proj.weight"), &self.ffn.w_down));
params
}
pub fn set_named_parameter(&mut self, suffix: &str, value: Tensor) -> bool {
match suffix {
"input_layernorm.weight" => {
self.input_norm.weight = value;
true
}
"post_attention_layernorm.weight" => {
self.post_attn_norm.weight = value;
true
}
"mlp.gate_proj.weight" => {
self.ffn.w_gate = value;
true
}
"mlp.up_proj.weight" => {
self.ffn.w_up = value;
true
}
"mlp.down_proj.weight" => {
self.ffn.w_down = value;
true
}
_ => self.self_attn.set_named_parameter(suffix, value),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transformer_block_tiny() {
let config = TransformerConfig::tiny();
let block = TransformerBlock::new(&config, 0);
let x = Tensor::from_vec(vec![0.1; 2 * config.hidden_size], true);
let output = block.forward(&x, 2);
assert_eq!(output.len(), 2 * config.hidden_size);
}
#[test]
fn test_transformer_block_layer_idx() {
let config = TransformerConfig::tiny();
let block = TransformerBlock::new(&config, 5);
assert_eq!(block.layer_idx(), 5);
}
#[test]
fn test_transformer_block_parameters() {
let config = TransformerConfig::tiny();
let block = TransformerBlock::new(&config, 0);
let params = block.parameters();
assert_eq!(params.len(), 9);
}
#[test]
fn test_transformer_block_from_params_success() {
let config = TransformerConfig::tiny();
let hidden_size = config.hidden_size;
let q_dim = config.q_dim();
let kv_hidden_size = config.num_kv_heads * config.head_dim();
let intermediate_size = config.intermediate_size;
let mut params = HashMap::new();
params.insert(
"model.layers.0.input_layernorm.weight".to_string(),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; q_dim * hidden_size], true),
);
params.insert(
"model.layers.0.self_attn.k_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; kv_hidden_size * hidden_size], true),
);
params.insert(
"model.layers.0.self_attn.v_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; kv_hidden_size * hidden_size], true),
);
params.insert(
"model.layers.0.self_attn.o_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * q_dim], true),
);
params.insert(
"model.layers.0.post_attention_layernorm.weight".to_string(),
Tensor::from_vec(vec![1.0; hidden_size], true),
);
params.insert(
"model.layers.0.mlp.gate_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
);
params.insert(
"model.layers.0.mlp.up_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; hidden_size * intermediate_size], true),
);
params.insert(
"model.layers.0.mlp.down_proj.weight".to_string(),
Tensor::from_vec(vec![0.1; intermediate_size * hidden_size], true),
);
let block = TransformerBlock::from_params(&config, ¶ms, 0);
assert!(block.is_some());
let block = block.expect("operation should succeed");
assert_eq!(block.layer_idx(), 0);
}
#[test]
fn test_transformer_block_from_params_missing_norm() {
let config = TransformerConfig::tiny();
let params = HashMap::new();
let block = TransformerBlock::from_params(&config, ¶ms, 0);
assert!(block.is_none());
}
}