use crate::autograd::add;
use crate::Tensor;
use std::collections::HashMap;
use super::attention::MultiHeadAttention;
use super::config::TransformerConfig;
use super::feedforward::EncoderFeedForward;
use super::norm::LayerNorm;
pub struct EncoderBlock {
layer_idx: usize,
pub self_attn: MultiHeadAttention,
pub attn_layernorm: LayerNorm,
pub ffn: EncoderFeedForward,
pub ffn_layernorm: LayerNorm,
hidden_size: usize,
}
impl EncoderBlock {
pub fn new(config: &TransformerConfig, layer_idx: usize) -> Self {
let eps = config.rms_norm_eps; Self {
layer_idx,
self_attn: MultiHeadAttention::new(config),
attn_layernorm: LayerNorm::new(config.hidden_size, eps),
ffn: EncoderFeedForward::new(config),
ffn_layernorm: LayerNorm::new(config.hidden_size, eps),
hidden_size: config.hidden_size,
}
}
pub fn from_params(
config: &TransformerConfig,
params: &HashMap<String, Tensor>,
layer_idx: usize,
) -> Option<Self> {
let prefix = format!("encoder.layers.{layer_idx}");
let eps = config.rms_norm_eps;
let self_attn =
MultiHeadAttention::from_params(config, params, &format!("{prefix}.self_attn"))?;
let attn_layernorm = LayerNorm::from_params(
params,
&format!("{prefix}.input_layernorm"),
eps,
config.hidden_size,
)?;
let ffn = EncoderFeedForward::from_params(config, params, &format!("{prefix}.mlp"))?;
let ffn_layernorm = LayerNorm::from_params(
params,
&format!("{prefix}.post_attention_layernorm"),
eps,
config.hidden_size,
)?;
Some(Self {
layer_idx,
self_attn,
attn_layernorm,
ffn,
ffn_layernorm,
hidden_size: config.hidden_size,
})
}
pub fn forward(&self, x: &Tensor, seq_len: usize) -> Tensor {
let h = self.hidden_size;
let attn_out = self.self_attn.forward(x, seq_len);
let residual1 = add(x, &attn_out);
let norm1 = self.attn_layernorm.forward_batched(&residual1, seq_len, h);
let ffn_out = self.ffn.forward(&norm1, seq_len);
let residual2 = add(&norm1, &ffn_out);
self.ffn_layernorm.forward_batched(&residual2, seq_len, h)
}
pub fn layer_idx(&self) -> usize {
self.layer_idx
}
pub fn parameters(&self) -> Vec<&Tensor> {
let mut params = Vec::new();
params.extend(self.self_attn.parameters());
params.push(&self.attn_layernorm.weight);
params.push(&self.attn_layernorm.bias);
params.extend(self.ffn.parameters());
params.push(&self.ffn_layernorm.weight);
params.push(&self.ffn_layernorm.bias);
params
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn clf_001_encoder_block_output_shape() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 0);
let seq_len = 4;
let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
let output = block.forward(&x, seq_len);
assert_eq!(output.len(), seq_len * config.hidden_size);
}
#[test]
fn clf_001_encoder_block_output_finite() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 0);
let seq_len = 3;
let x = Tensor::from_vec(vec![0.5; seq_len * config.hidden_size], true);
let output = block.forward(&x, seq_len);
let data = output.data();
let slice = data.as_slice().unwrap();
assert!(slice.iter().all(|v| v.is_finite()), "encoder block output must be finite");
}
#[test]
fn clf_001_encoder_block_layer_idx() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 7);
assert_eq!(block.layer_idx(), 7);
}
#[test]
fn clf_001_encoder_block_parameters_count() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 0);
let params = block.parameters();
assert_eq!(params.len(), 12);
}
#[test]
fn test_encoder_block_different_layer_indices() {
let config = TransformerConfig::codebert();
for idx in [0, 1, 5, 11] {
let block = EncoderBlock::new(&config, idx);
assert_eq!(block.layer_idx(), idx);
}
}
#[test]
fn test_encoder_block_forward_preserves_shape() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 0);
for seq_len in [1, 2, 4, 8] {
let x = Tensor::from_vec(vec![0.1; seq_len * config.hidden_size], true);
let output = block.forward(&x, seq_len);
assert_eq!(
output.len(),
seq_len * config.hidden_size,
"Shape mismatch for seq_len={seq_len}"
);
}
}
#[test]
fn test_encoder_block_deterministic() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 0);
let seq_len = 3;
let x = Tensor::from_vec(vec![0.3; seq_len * config.hidden_size], true);
let out1 = block.forward(&x, seq_len);
let out2 = block.forward(&x, seq_len);
let d1 = out1.data();
let d2 = out2.data();
let s1 = d1.as_slice().unwrap();
let s2 = d2.as_slice().unwrap();
assert_eq!(s1, s2, "Encoder block should be deterministic");
}
#[test]
fn test_encoder_block_from_params_missing() {
let config = TransformerConfig::codebert();
let empty_params: HashMap<String, Tensor> = HashMap::new();
let result = EncoderBlock::from_params(&config, &empty_params, 0);
assert!(result.is_none());
}
#[test]
fn test_encoder_block_hidden_size() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 0);
assert_eq!(block.hidden_size, config.hidden_size);
}
#[test]
fn test_encoder_block_parameters_nonzero_length() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 0);
let params = block.parameters();
for (i, p) in params.iter().enumerate() {
assert!(!p.is_empty(), "Parameter {i} should have non-zero length");
}
}
#[test]
fn test_encoder_block_single_token() {
let config = TransformerConfig::codebert();
let block = EncoderBlock::new(&config, 3);
let x = Tensor::from_vec(vec![0.2; config.hidden_size], true);
let output = block.forward(&x, 1);
assert_eq!(output.len(), config.hidden_size);
let data = output.data();
let slice = data.as_slice().unwrap();
assert!(slice.iter().all(|v| v.is_finite()));
}
}