use crate::graph::{Graph, NodeId};
pub struct Gemma4Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: u32,
pub num_key_value_heads: u32,
pub intermediate_size: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub sliding_window_size: u32,
pub global_attn_period: usize,
pub global_attn_offset: usize,
pub use_qk_norm: bool,
}
impl Gemma4Config {
pub fn gemma4_1b() -> Self {
Self {
vocab_size: 262144,
hidden_size: 1856,
num_hidden_layers: 26,
num_attention_heads: 14,
num_key_value_heads: 7,
intermediate_size: 7424,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
sliding_window_size: 4096,
global_attn_period: 4,
global_attn_offset: 3,
use_qk_norm: true,
}
}
pub fn gemma4_4b() -> Self {
Self {
vocab_size: 262144,
hidden_size: 2560,
num_hidden_layers: 34,
num_attention_heads: 20,
num_key_value_heads: 10,
intermediate_size: 10240,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
sliding_window_size: 4096,
global_attn_period: 4,
global_attn_offset: 3,
use_qk_norm: true,
}
}
pub fn gemma4_12b() -> Self {
Self {
vocab_size: 262144,
hidden_size: 3840,
num_hidden_layers: 48,
num_attention_heads: 16,
num_key_value_heads: 8,
intermediate_size: 15360,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
sliding_window_size: 4096,
global_attn_period: 4,
global_attn_offset: 3,
use_qk_norm: true,
}
}
pub fn gemma4_27b() -> Self {
Self {
vocab_size: 262144,
hidden_size: 4608,
num_hidden_layers: 62,
num_attention_heads: 32,
num_key_value_heads: 16,
intermediate_size: 18432,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
sliding_window_size: 4096,
global_attn_period: 4,
global_attn_offset: 3,
use_qk_norm: true,
}
}
pub fn small_test() -> Self {
Self {
vocab_size: 64,
hidden_size: 32,
num_hidden_layers: 4,
num_attention_heads: 2,
num_key_value_heads: 2,
intermediate_size: 64,
rms_norm_eps: 1e-6,
rope_theta: 10000.0,
sliding_window_size: 8,
global_attn_period: 4,
global_attn_offset: 3,
use_qk_norm: true,
}
}
pub fn head_dim(&self) -> u32 {
self.hidden_size as u32 / self.num_attention_heads
}
pub fn kv_dim(&self) -> usize {
self.num_key_value_heads as usize * self.head_dim() as usize
}
fn is_global_layer(&self, layer: usize) -> bool {
layer >= self.global_attn_offset
&& (layer - self.global_attn_offset).is_multiple_of(self.global_attn_period)
}
}
pub fn build_graph(g: &mut Graph, config: &Gemma4Config, seq_len: usize) -> NodeId {
let hidden = config.hidden_size;
let kv_dim = config.kv_dim();
let ffn = config.intermediate_size;
let eps = config.rms_norm_eps;
let theta = config.rope_theta;
let head_dim = config.head_dim();
let token_ids = g.input_u32("token_ids", &[seq_len]);
let embed_weight = g.parameter("model.embed_tokens.weight", &[config.vocab_size, hidden]);
let mut x = g.embedding(token_ids, embed_weight);
for i in 0..config.num_hidden_layers {
let prefix = format!("model.layers.{}", i);
let is_global = config.is_global_layer(i);
let ln1_w = g.parameter(&format!("{}.input_layernorm.weight", prefix), &[hidden]);
let h = g.rms_norm(x, ln1_w, eps);
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[hidden, hidden],
);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[hidden, kv_dim],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[hidden, kv_dim],
);
let q = g.matmul(h, wq); let k = g.matmul(h, wk); let v = g.matmul(h, wv);
let (q, k) = if config.use_qk_norm {
let qn_w = g.parameter(&format!("{}.self_attn.q_norm.weight", prefix), &[hidden]);
let kn_w = g.parameter(&format!("{}.self_attn.k_norm.weight", prefix), &[kv_dim]);
(g.rms_norm(q, qn_w, eps), g.rms_norm(k, kn_w, eps))
} else {
(q, k)
};
let q = g.rope(q, theta, head_dim);
let k = g.rope(k, theta, head_dim);
let attn = if is_global {
g.causal_attention(
q,
k,
v,
config.num_attention_heads,
config.num_key_value_heads,
head_dim,
)
} else {
g.sliding_window_attention(
q,
k,
v,
config.num_attention_heads,
config.num_key_value_heads,
head_dim,
config.sliding_window_size,
)
};
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[hidden, hidden],
);
let attn_out = g.matmul(attn, wo);
x = g.add(x, attn_out);
let ln2_w = g.parameter(
&format!("{}.post_attention_layernorm.weight", prefix),
&[hidden],
);
let h = g.rms_norm(x, ln2_w, eps);
let w_gate = g.parameter(&format!("{}.mlp.gate_proj.weight", prefix), &[hidden, ffn]);
let w_up = g.parameter(&format!("{}.mlp.up_proj.weight", prefix), &[hidden, ffn]);
let w_down = g.parameter(&format!("{}.mlp.down_proj.weight", prefix), &[ffn, hidden]);
let gate = g.matmul(h, w_gate); let up = g.matmul(h, w_up); let ffn_out = g.swiglu(gate, up);
let ffn_out = g.matmul(ffn_out, w_down);
x = g.add(x, ffn_out);
}
let final_ln_w = g.parameter("model.norm.weight", &[hidden]);
x = g.rms_norm(x, final_ln_w, eps);
let lm_head = g.parameter("lm_head.weight", &[hidden, config.vocab_size]);
g.matmul(x, lm_head) }
pub fn weight_names(config: &Gemma4Config) -> Vec<String> {
let mut names = Vec::new();
names.push("model.embed_tokens.weight".to_string());
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{}", i);
names.push(format!("{}.input_layernorm.weight", p));
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.o_proj.weight", p));
if config.use_qk_norm {
names.push(format!("{}.self_attn.q_norm.weight", p));
names.push(format!("{}.self_attn.k_norm.weight", p));
}
names.push(format!("{}.post_attention_layernorm.weight", p));
names.push(format!("{}.mlp.gate_proj.weight", p));
names.push(format!("{}.mlp.up_proj.weight", p));
names.push(format!("{}.mlp.down_proj.weight", p));
}
names.push("model.norm.weight".to_string());
names.push("lm_head.weight".to_string());
names
}
pub fn transposed_weight_names(config: &Gemma4Config) -> Vec<String> {
let mut names = Vec::new();
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{}", i);
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.o_proj.weight", p));
names.push(format!("{}.mlp.gate_proj.weight", p));
names.push(format!("{}.mlp.up_proj.weight", p));
names.push(format!("{}.mlp.down_proj.weight", p));
}
names.push("lm_head.weight".to_string());
names
}