use burn::prelude::*;
use burn::nn::Linear;
use crate::model::linear_zeros;
use crate::model::norm::AdaRMSNorm;
use crate::model::conditioner::FourierConditioner;
use crate::model::decoder_block::DecoderBlock;
use crate::model::rope::RotaryEmbedding;
#[derive(Module, Debug)]
pub struct DecoderTransformer<B: Backend> {
pub tok_embeddings: Linear<B>,
pub t_embedder: FourierConditioner<B>,
pub encoder_proj: Linear<B>,
pub layers: Vec<DecoderBlock<B>>,
pub norm: AdaRMSNorm<B>,
pub output: Linear<B>,
}
impl<B: Backend> DecoderTransformer<B> {
pub fn new(
input_dim: usize, encoder_dim: usize, dim: usize, t_dim: usize, n_layers: usize, head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
hidden_dim: usize,
norm_eps: f64,
device: &B::Device,
) -> Self {
let layers = (0..n_layers)
.map(|_| DecoderBlock::new(
dim, t_dim, head_dim, n_heads, n_kv_heads, hidden_dim, norm_eps, device,
))
.collect();
Self {
tok_embeddings: linear_zeros(input_dim, dim, true, device),
t_embedder: FourierConditioner::new(t_dim, device),
encoder_proj: linear_zeros(encoder_dim, dim, true, device),
layers,
norm: AdaRMSNorm::new(t_dim, dim, norm_eps, device),
output: linear_zeros(dim, input_dim, false, device),
}
}
pub fn forward(
&self,
z: Tensor<B, 3>,
enc_out: Tensor<B, 3>,
time_t: Tensor<B, 3>,
tok_idx: Tensor<B, 2, Int>,
rope: &RotaryEmbedding<B>,
) -> Tensor<B, 3> {
let mut h = self.tok_embeddings.forward(z); let t = self.t_embedder.forward(time_t); let y = self.encoder_proj.forward(enc_out);
let freqs = rope.build_freqs_4d(tok_idx);
for layer in &self.layers {
h = layer.forward(h, y.clone(), t.clone(), freqs.clone(), freqs.clone());
}
self.output.forward(self.norm.forward(h, t))
}
}