use burn::prelude::*;
use burn::module::{Param, ParamId};
use burn::nn::Linear;
use crate::model::linear_zeros;
use crate::model::norm::RMSNorm;
use crate::model::encoder_block::EncoderBlock;
use crate::model::rope::RotaryEmbedding;
#[derive(Module, Debug)]
pub struct EncoderTransformer<B: Backend> {
pub tok_embeddings: Linear<B>,
pub registers: Param<Tensor<B, 2>>,
pub layers: Vec<EncoderBlock<B>>,
pub norm: RMSNorm<B>,
pub output: Linear<B>,
pub downsample_factor: usize,
}
impl<B: Backend> EncoderTransformer<B> {
pub fn new(
input_dim: usize, output_dim: usize, dim: usize, n_layers: usize, head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
hidden_dim: usize,
norm_eps: f64,
downsample_factor: usize, device: &B::Device,
) -> Self {
let layers = (0..n_layers)
.map(|_| EncoderBlock::new(
dim, head_dim, n_heads, n_kv_heads, hidden_dim, norm_eps, device,
))
.collect();
Self {
tok_embeddings: linear_zeros(input_dim, dim, true, device),
registers: Param::initialized(
ParamId::new(),
Tensor::zeros([1, input_dim], device),
),
layers,
norm: RMSNorm::new(dim, norm_eps, device),
output: linear_zeros(dim, output_dim, false, device),
downsample_factor,
}
}
pub fn forward(
&self,
token_values: Tensor<B, 3>,
tok_idx: Tensor<B, 2, Int>,
rope: &RotaryEmbedding<B>,
) -> Tensor<B, 3> {
let [b, s, d] = token_values.dims();
let df = self.downsample_factor;
let regs = self.registers
.val() .unsqueeze_dim::<3>(0) .expand([b, s, d]);
let interleaved = Tensor::stack::<4>(vec![regs, token_values], 2)
.reshape([b, s * (df + 1), d]);
let mut h = self.tok_embeddings.forward(interleaved);
let tok_idx_2x = repeat_interleave_rows(tok_idx, 2);
let freqs = rope.build_freqs_4d(tok_idx_2x);
for layer in &self.layers {
h = layer.forward(h, freqs.clone());
}
let hdim = h.dims()[2];
let registers = h
.reshape([b, s, df + 1, hdim]) .narrow(2, 0, 1) .reshape([b, s, hdim]);
self.output.forward(self.norm.forward(registers)) }
}
fn repeat_interleave_rows<B: Backend>(
t: Tensor<B, 2, Int>,
repeats: usize,
) -> Tensor<B, 2, Int> {
let [s, c] = t.dims();
t.unsqueeze_dim::<3>(1)
.expand([s, repeats, c])
.reshape([s * repeats, c])
}