use burn::prelude::*;
use burn::nn::{LayerNorm, LayerNormConfig};
use crate::config::ModelConfig;
use super::attention::Attention;
use super::mlp::Mlp;
#[derive(Module, Debug)]
pub struct TransformerEncoderLayer<B: Backend> {
pub norm1: LayerNorm<B>,
pub attn: Attention<B>,
pub norm2: LayerNorm<B>,
pub mlp: Mlp<B>,
}
impl<B: Backend> TransformerEncoderLayer<B> {
pub fn new(cfg: &ModelConfig, device: &B::Device) -> Self {
let d = cfg.feature_size;
Self {
norm1: LayerNormConfig::new(d)
.with_epsilon(cfg.layer_norm_eps)
.init(device),
attn: Attention::new(d, cfg.num_heads, device),
norm2: LayerNormConfig::new(d)
.with_epsilon(cfg.layer_norm_eps)
.init(device),
mlp: Mlp::new(d, cfg.dim_feedforward, device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let h = x.clone() + self.attn.forward(self.norm1.forward(x));
h.clone() + self.mlp.forward(self.norm2.forward(h))
}
}