use crate::attention::{AttentionConfig, MultiHeadAttention, MultiHeadOutput};
use crate::error::Result;
use crate::feed_forward::{FeedForward, FeedForwardConfig};
use crate::normalize::LayerNorm;
use crate::residual;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct TransformerBlock {
pub attention: MultiHeadAttention,
pub feed_forward: FeedForward,
pub norm1: LayerNorm,
pub norm2: LayerNorm,
}
#[derive(Debug, Clone)]
pub struct BlockOutput {
pub hidden: Tensor,
pub attention_weights: Vec<Tensor>,
}
#[derive(Debug, Clone)]
pub struct TransformerConfig {
pub model_dim: usize,
pub num_heads: usize,
pub ffn_inner_dim: usize,
pub num_layers: usize,
pub vocab_size: usize,
pub max_seq_len: usize,
}
impl TransformerConfig {
pub fn new(
model_dim: usize,
num_heads: usize,
num_layers: usize,
vocab_size: usize,
max_seq_len: usize,
) -> Self {
Self {
model_dim,
num_heads,
ffn_inner_dim: 4 * model_dim,
num_layers,
vocab_size,
max_seq_len,
}
}
}
impl TransformerBlock {
pub fn new(model_dim: usize, num_heads: usize, rng: &mut impl rand::Rng) -> Result<Self> {
Self::with_ffn_dim(model_dim, num_heads, 4 * model_dim, rng)
}
pub fn with_ffn_dim(
model_dim: usize,
num_heads: usize,
ffn_inner_dim: usize,
rng: &mut impl rand::Rng,
) -> Result<Self> {
let attn_config = AttentionConfig::new(model_dim, num_heads)?;
let ffn_config = FeedForwardConfig::custom(model_dim, ffn_inner_dim);
Ok(Self {
attention: MultiHeadAttention::new(attn_config, rng)?,
feed_forward: FeedForward::new(ffn_config, rng)?,
norm1: LayerNorm::new(model_dim),
norm2: LayerNorm::new(model_dim),
})
}
pub fn forward(&self, x: &Tensor, causal: bool) -> Result<BlockOutput> {
let mut attn_output_holder: Option<MultiHeadOutput> = None;
let after_attn = residual::pre_norm_residual(x, &self.norm1, |normed| {
let mha_out = self.attention.forward(normed, causal)?;
let output = mha_out.output.clone();
attn_output_holder = Some(mha_out);
Ok(output)
})?;
let after_ffn = residual::pre_norm_residual(&after_attn, &self.norm2, |normed| {
self.feed_forward.forward(normed)
})?;
let attention_weights = attn_output_holder
.map(|o| o.head_weights)
.unwrap_or_default();
Ok(BlockOutput {
hidden: after_ffn,
attention_weights,
})
}
}
#[derive(Debug, Clone)]
pub struct TransformerStack {
pub blocks: Vec<TransformerBlock>,
pub final_norm: LayerNorm,
}
#[derive(Debug, Clone)]
pub struct StackOutput {
pub hidden: Tensor,
pub all_attention_weights: Vec<Vec<Tensor>>,
}
impl TransformerStack {
pub fn new(
num_layers: usize,
model_dim: usize,
num_heads: usize,
rng: &mut impl rand::Rng,
) -> Result<Self> {
Self::with_ffn_dim(num_layers, model_dim, num_heads, 4 * model_dim, rng)
}
pub fn with_ffn_dim(
num_layers: usize,
model_dim: usize,
num_heads: usize,
ffn_inner_dim: usize,
rng: &mut impl rand::Rng,
) -> Result<Self> {
let mut blocks = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
blocks.push(TransformerBlock::with_ffn_dim(
model_dim,
num_heads,
ffn_inner_dim,
rng,
)?);
}
Ok(Self {
blocks,
final_norm: LayerNorm::new(model_dim),
})
}
pub fn forward(&self, x: &Tensor, causal: bool) -> Result<StackOutput> {
let mut hidden = x.clone();
let mut all_weights = Vec::with_capacity(self.blocks.len());
for block in &self.blocks {
let out = block.forward(&hidden, causal)?;
hidden = out.hidden;
all_weights.push(out.attention_weights);
}
hidden = self.final_norm.forward(&hidden)?;
Ok(StackOutput {
hidden,
all_attention_weights: all_weights,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_block() {
let mut rng = rand::rng();
let block = TransformerBlock::new(8, 2, &mut rng).unwrap();
let x = Tensor::randn(&[4, 8], &mut rng);
let out = block.forward(&x, true).unwrap();
assert_eq!(out.hidden.shape(), &[4, 8]);
assert_eq!(out.attention_weights.len(), 2);
}
#[test]
fn test_custom_ffn_dim() {
let mut rng = rand::rng();
let block = TransformerBlock::with_ffn_dim(8, 2, 16, &mut rng).unwrap();
assert_eq!(block.feed_forward.config.inner_dim, 16);
let x = Tensor::randn(&[4, 8], &mut rng);
let out = block.forward(&x, true).unwrap();
assert_eq!(out.hidden.shape(), &[4, 8]);
}
#[test]
fn test_stack() {
let mut rng = rand::rng();
let stack = TransformerStack::new(2, 8, 2, &mut rng).unwrap();
let x = Tensor::randn(&[3, 8], &mut rng);
let out = stack.forward(&x, true).unwrap();
assert_eq!(out.hidden.shape(), &[3, 8]);
assert_eq!(out.all_attention_weights.len(), 2);
}
}