use burn::prelude::*;
use burn::nn::{Linear, LinearConfig, LayerNorm, LayerNormConfig};
use burn::tensor::activation::gelu;
use crate::model::attention::Attention;
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
pub norm1: LayerNorm<B>,
pub attn: Attention<B>,
pub norm2: LayerNorm<B>,
pub mlp_fc1: Linear<B>,
pub mlp_fc2: Linear<B>,
}
impl<B: Backend> TransformerBlock<B> {
pub fn new(
dim: usize, n_heads: usize, mlp_ratio: f64,
qkv_bias: bool, eps: f64, device: &B::Device,
) -> Self {
let mlp_hidden = (dim as f64 * mlp_ratio) as usize;
Self {
norm1: LayerNormConfig::new(dim).with_epsilon(eps).init(device),
attn: Attention::new(dim, n_heads, qkv_bias, device),
norm2: LayerNormConfig::new(dim).with_epsilon(eps).init(device),
mlp_fc1: LinearConfig::new(dim, mlp_hidden).with_bias(true).init(device),
mlp_fc2: LinearConfig::new(mlp_hidden, dim).with_bias(true).init(device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let h = self.attn.forward(self.norm1.forward(x.clone()));
let x = x + h;
let h = self.mlp_fc2.forward(gelu(self.mlp_fc1.forward(self.norm2.forward(x.clone()))));
x + h
}
}