use burn::module::Module;
use burn::nn::{
Embedding, EmbeddingConfig, Linear, LinearConfig, RmsNorm, RmsNormConfig, RotaryEncoding,
RotaryEncodingConfig,
};
use burn::tensor::{activation, backend::Backend, Int, Tensor};
const ROPE_CACHE_LEN: usize = 4096;
#[derive(Clone, Debug)]
pub struct GemmaConfig {
pub vocab_size: usize, pub hidden_size: usize, pub intermediate_size: usize, pub num_layers: usize, pub num_heads: usize, pub num_kv_heads: usize, pub head_dim: usize, pub rope_theta: f32, pub rope_local_base_freq: f32, pub sliding_window: usize, pub sliding_window_pattern: usize, pub query_pre_attn_scalar: f64, pub rms_norm_eps: f64, }
impl GemmaConfig {
pub fn gemma_3_270m() -> Self {
Self {
vocab_size: 262_144,
hidden_size: 640,
intermediate_size: 2048,
num_layers: 18,
num_heads: 4,
num_kv_heads: 1,
head_dim: 256,
rope_theta: 1_000_000.0,
rope_local_base_freq: 10_000.0,
sliding_window: 512,
sliding_window_pattern: 6,
query_pre_attn_scalar: 256.0,
rms_norm_eps: 1e-6,
}
}
pub fn is_full_attention(&self, layer: usize) -> bool {
(layer + 1) % self.sliding_window_pattern == 0
}
fn attn_scale(&self) -> f64 {
1.0 / self.query_pre_attn_scalar.sqrt()
}
}
#[derive(Module, Debug)]
pub struct Mlp<B: Backend> {
gate_proj: Linear<B>,
up_proj: Linear<B>,
down_proj: Linear<B>,
}
impl<B: Backend> Mlp<B> {
fn init(cfg: &GemmaConfig, device: &B::Device) -> Self {
let lin = |i, o| LinearConfig::new(i, o).with_bias(false).init(device);
Self {
gate_proj: lin(cfg.hidden_size, cfg.intermediate_size),
up_proj: lin(cfg.hidden_size, cfg.intermediate_size),
down_proj: lin(cfg.intermediate_size, cfg.hidden_size),
}
}
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let gate = activation::gelu(self.gate_proj.forward(x.clone()));
let up = self.up_proj.forward(x);
self.down_proj.forward(gate * up)
}
}
#[derive(Module, Debug)]
pub struct Attention<B: Backend> {
q_proj: Linear<B>,
k_proj: Linear<B>,
v_proj: Linear<B>,
o_proj: Linear<B>,
q_norm: RmsNorm<B>,
k_norm: RmsNorm<B>,
}
impl<B: Backend> Attention<B> {
fn init(cfg: &GemmaConfig, device: &B::Device) -> Self {
let q_out = cfg.num_heads * cfg.head_dim;
let kv_out = cfg.num_kv_heads * cfg.head_dim;
let lin = |i, o| LinearConfig::new(i, o).with_bias(false).init(device);
let norm = || RmsNormConfig::new(cfg.head_dim).with_epsilon(cfg.rms_norm_eps).init(device);
Self {
q_proj: lin(cfg.hidden_size, q_out),
k_proj: lin(cfg.hidden_size, kv_out),
v_proj: lin(cfg.hidden_size, kv_out),
o_proj: lin(q_out, cfg.hidden_size),
q_norm: norm(),
k_norm: norm(),
}
}
fn forward(
&self,
x: Tensor<B, 3>,
cfg: &GemmaConfig,
rope: &RotaryEncoding<B>,
mask: Tensor<B, 4>,
) -> Tensor<B, 3> {
let [batch, seq, _] = x.dims();
let (h, kv, hd) = (cfg.num_heads, cfg.num_kv_heads, cfg.head_dim);
let q = self.q_proj.forward(x.clone()).reshape([batch, seq, h, hd]);
let k = self.k_proj.forward(x.clone()).reshape([batch, seq, kv, hd]);
let v = self.v_proj.forward(x).reshape([batch, seq, kv, hd]);
let q = self.q_norm.forward(q).swap_dims(1, 2); let k = self.k_norm.forward(k).swap_dims(1, 2); let v = v.swap_dims(1, 2);
let q = rope.forward(q);
let k = rope.forward(k);
let k = k.repeat_dim(1, h / kv);
let v = v.repeat_dim(1, h / kv);
let scores = q
.matmul(k.swap_dims(2, 3))
.mul_scalar(cfg.attn_scale())
+ mask;
let probs = activation::softmax(scores, 3);
let ctx = probs.matmul(v);
let ctx = ctx.swap_dims(1, 2).reshape([batch, seq, h * hd]);
self.o_proj.forward(ctx)
}
}
#[derive(Module, Debug)]
pub struct DecoderLayer<B: Backend> {
input_layernorm: RmsNorm<B>,
self_attn: Attention<B>,
post_attention_layernorm: RmsNorm<B>,
pre_feedforward_layernorm: RmsNorm<B>,
mlp: Mlp<B>,
post_feedforward_layernorm: RmsNorm<B>,
}
impl<B: Backend> DecoderLayer<B> {
fn init(cfg: &GemmaConfig, device: &B::Device) -> Self {
let norm = || {
RmsNormConfig::new(cfg.hidden_size)
.with_epsilon(cfg.rms_norm_eps)
.init(device)
};
Self {
input_layernorm: norm(),
self_attn: Attention::init(cfg, device),
post_attention_layernorm: norm(),
pre_feedforward_layernorm: norm(),
mlp: Mlp::init(cfg, device),
post_feedforward_layernorm: norm(),
}
}
fn forward(
&self,
x: Tensor<B, 3>,
cfg: &GemmaConfig,
rope: &RotaryEncoding<B>,
mask: Tensor<B, 4>,
) -> Tensor<B, 3> {
let normed = self.input_layernorm.forward(x.clone());
let attn = self.self_attn.forward(normed, cfg, rope, mask);
let h = x + self.post_attention_layernorm.forward(attn);
let normed = self.pre_feedforward_layernorm.forward(h.clone());
let ff = self.mlp.forward(normed);
h + self.post_feedforward_layernorm.forward(ff)
}
}
#[derive(Module, Debug)]
pub struct GemmaModel<B: Backend> {
embed: Embedding<B>,
layers: Vec<DecoderLayer<B>>,
norm: RmsNorm<B>,
rope_global: RotaryEncoding<B>,
rope_local: RotaryEncoding<B>,
#[module(skip)]
config: GemmaConfig,
}
impl<B: Backend> GemmaModel<B> {
pub fn init(cfg: GemmaConfig, device: &B::Device) -> Self {
let layers = (0..cfg.num_layers)
.map(|_| DecoderLayer::init(&cfg, device))
.collect();
let rope = |theta| {
RotaryEncodingConfig::new(ROPE_CACHE_LEN, cfg.head_dim)
.with_theta(theta)
.init(device)
};
Self {
embed: EmbeddingConfig::new(cfg.vocab_size, cfg.hidden_size).init(device),
layers,
norm: RmsNormConfig::new(cfg.hidden_size)
.with_epsilon(cfg.rms_norm_eps)
.init(device),
rope_global: rope(cfg.rope_theta),
rope_local: rope(cfg.rope_local_base_freq),
config: cfg,
}
}
pub fn forward(&self, tokens: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let cfg = &self.config;
let [batch, seq] = tokens.dims();
let device = tokens.device();
let scale = (cfg.hidden_size as f64).sqrt();
let mut x = self.embed.forward(tokens).mul_scalar(scale);
let q_idx = Tensor::<B, 1, Int>::arange(0..seq as i64, &device).reshape([seq, 1]);
let k_idx = Tensor::<B, 1, Int>::arange(0..seq as i64, &device).reshape([1, seq]);
let future = k_idx.greater(q_idx); let mask = Tensor::<B, 2>::zeros([seq, seq], &device)
.mask_fill(future, f32::NEG_INFINITY)
.reshape([1, 1, seq, seq]);
let _ = batch;
for (i, layer) in self.layers.iter().enumerate() {
let rope = if cfg.is_full_attention(i) {
&self.rope_global
} else {
&self.rope_local
};
x = layer.forward(x, cfg, rope, mask.clone());
}
let x = self.norm.forward(x);
let embed_t = self.embed.weight.val().transpose(); x.matmul(embed_t.unsqueeze::<3>())
}
}