use burn::prelude::*;
use crate::model::attention::Attention;
use crate::model::feedforward::FeedForward;
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
pub attn: Attention<B>,
pub ff: FeedForward<B>,
}
impl<B: Backend> TransformerBlock<B> {
pub fn new(
dim: usize,
heads: usize,
head_dim: usize,
mlp_dim: usize,
use_geglu: bool,
device: &B::Device,
) -> Self {
Self {
attn: Attention::new(dim, heads, head_dim, device),
ff: FeedForward::new(dim, mlp_dim, use_geglu, device),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.attn.forward(x.clone()) + x;
let x_res = x.clone();
self.ff.forward(x) + x_res
}
}
#[derive(Module, Debug)]
pub struct TransformerBackbone<B: Backend> {
pub layers: Vec<TransformerBlock<B>>,
}
impl<B: Backend> TransformerBackbone<B> {
pub fn new(
dim: usize,
depth: usize,
heads: usize,
head_dim: usize,
mlp_dim: usize,
use_geglu: bool,
device: &B::Device,
) -> Self {
let layers = (0..depth)
.map(|_| TransformerBlock::new(dim, heads, head_dim, mlp_dim, use_geglu, device))
.collect();
Self { layers }
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let mut x = x;
for block in &self.layers {
x = block.forward(x);
}
x
}
pub fn forward_with_layers(&self, x: Tensor<B, 3>) -> Vec<Tensor<B, 3>> {
let mut out = vec![x.clone()];
let mut x = x;
for block in &self.layers {
x = block.forward(x);
out.push(x.clone());
}
out
}
}