use burn::prelude::*;
use crate::model::norm::OsfLayerNorm;
use crate::model::attention::Attention;
use crate::model::feedforward::FeedForward;
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
pub attn_norm: OsfLayerNorm<B>,
pub attn: Attention<B>,
pub ff_norm: OsfLayerNorm<B>,
pub ff: FeedForward<B>,
}
impl<B: Backend> TransformerBlock<B> {
pub fn new(
input_dim: usize,
output_dim: usize,
hidden_dim: usize,
heads: usize,
dim_head: usize,
qkv_bias: bool,
device: &B::Device,
) -> Self {
Self {
attn_norm: OsfLayerNorm::new(input_dim, 1e-5, device),
attn: Attention::new(input_dim, output_dim, heads, dim_head, qkv_bias, device),
ff_norm: OsfLayerNorm::new(output_dim, 1e-5, device),
ff: FeedForward::new(output_dim, output_dim, hidden_dim, device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let h = x.clone() + self.attn.forward(self.attn_norm.forward(x.clone()));
h.clone() + self.ff.forward(self.ff_norm.forward(h))
}
}