use crate::graph::{Graph, NodeId};
pub struct SmolLM2Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: u32,
pub num_key_value_heads: u32,
pub intermediate_size: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub tie_word_embeddings: bool,
}
impl SmolLM2Config {
pub fn smollm2_135m() -> Self {
Self {
vocab_size: 49152,
hidden_size: 576,
num_hidden_layers: 30,
num_attention_heads: 9,
num_key_value_heads: 3,
intermediate_size: 1536,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
tie_word_embeddings: true,
}
}
pub fn small_test() -> Self {
Self {
vocab_size: 64,
hidden_size: 32,
num_hidden_layers: 2,
num_attention_heads: 2,
num_key_value_heads: 2,
intermediate_size: 64,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
tie_word_embeddings: true,
}
}
pub fn medium_test() -> Self {
Self {
vocab_size: 64,
hidden_size: 128,
num_hidden_layers: 8,
num_attention_heads: 2,
num_key_value_heads: 2,
intermediate_size: 256,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
tie_word_embeddings: true,
}
}
pub fn head_dim(&self) -> u32 {
self.hidden_size as u32 / self.num_attention_heads
}
pub fn kv_dim(&self) -> usize {
self.num_key_value_heads as usize * self.head_dim() as usize
}
}
pub fn build_graph(g: &mut Graph, config: &SmolLM2Config, seq_len: usize) -> NodeId {
let hidden = config.hidden_size;
let kv_dim = config.kv_dim();
let ffn = config.intermediate_size;
let eps = config.rms_norm_eps;
let theta = config.rope_theta;
let token_ids = g.input_u32("token_ids", &[seq_len]);
let embed_weight = g.parameter("model.embed_tokens.weight", &[config.vocab_size, hidden]);
let mut x = g.embedding(token_ids, embed_weight);
for i in 0..config.num_hidden_layers {
let prefix = format!("model.layers.{}", i);
let ln1_w = g.parameter(&format!("{}.input_layernorm.weight", prefix), &[hidden]);
let h = g.rms_norm(x, ln1_w, eps);
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[hidden, hidden],
);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[hidden, kv_dim],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[hidden, kv_dim],
);
let q = g.matmul(h, wq); let k = g.matmul(h, wk); let v = g.matmul(h, wv);
let q = g.rope(q, theta, config.head_dim());
let k = g.rope(k, theta, config.head_dim());
let attn = g.causal_attention(
q,
k,
v,
config.num_attention_heads,
config.num_key_value_heads,
config.head_dim(),
);
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[hidden, hidden],
);
let attn_out = g.matmul(attn, wo);
x = g.add(x, attn_out);
let ln2_w = g.parameter(
&format!("{}.post_attention_layernorm.weight", prefix),
&[hidden],
);
let h = g.rms_norm(x, ln2_w, eps);
let w_gate = g.parameter(&format!("{}.mlp.gate_proj.weight", prefix), &[hidden, ffn]);
let w_up = g.parameter(&format!("{}.mlp.up_proj.weight", prefix), &[hidden, ffn]);
let w_down = g.parameter(&format!("{}.mlp.down_proj.weight", prefix), &[ffn, hidden]);
let gate = g.matmul(h, w_gate); let up = g.matmul(h, w_up); let ffn_out = g.swiglu(gate, up);
let ffn_out = g.matmul(ffn_out, w_down);
x = g.add(x, ffn_out);
}
let final_ln_w = g.parameter("model.norm.weight", &[hidden]);
x = g.rms_norm(x, final_ln_w, eps);
if config.tie_word_embeddings {
g.matmul_bt(x, embed_weight)
} else {
let lm_head = g.parameter("lm_head.weight", &[hidden, config.vocab_size]);
g.matmul(x, lm_head) }
}
pub fn build_training_graph(config: &SmolLM2Config, seq_len: usize) -> Graph {
let mut g = Graph::new();
let logits = build_graph(&mut g, config, seq_len);
let labels = g.input("labels", &[seq_len, config.vocab_size]);
let loss = g.cross_entropy_loss(logits, labels);
g.set_outputs(vec![loss]);
g
}
pub fn build_prefill_graph(
g: &mut Graph,
config: &SmolLM2Config,
seq_len: usize,
) -> (NodeId, Vec<NodeId>, Vec<NodeId>) {
let hidden = config.hidden_size;
let kv_dim = config.kv_dim();
let ffn = config.intermediate_size;
let eps = config.rms_norm_eps;
let theta = config.rope_theta;
let token_ids = g.input_u32("token_ids", &[seq_len]);
let embed_weight = g.parameter("model.embed_tokens.weight", &[config.vocab_size, hidden]);
let mut x = g.embedding(token_ids, embed_weight);
let mut k_outputs = Vec::new();
let mut v_outputs = Vec::new();
for i in 0..config.num_hidden_layers {
let prefix = format!("model.layers.{}", i);
let ln1_w = g.parameter(&format!("{}.input_layernorm.weight", prefix), &[hidden]);
let h = g.rms_norm(x, ln1_w, eps);
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[hidden, hidden],
);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[hidden, kv_dim],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[hidden, kv_dim],
);
let q = g.matmul(h, wq);
let k = g.matmul(h, wk);
let v = g.matmul(h, wv);
let q = g.rope(q, theta, config.head_dim());
let k = g.rope(k, theta, config.head_dim());
k_outputs.push(k);
v_outputs.push(v);
let attn = g.causal_attention(
q,
k,
v,
config.num_attention_heads,
config.num_key_value_heads,
config.head_dim(),
);
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[hidden, hidden],
);
let attn_out = g.matmul(attn, wo);
x = g.add(x, attn_out);
let ln2_w = g.parameter(
&format!("{}.post_attention_layernorm.weight", prefix),
&[hidden],
);
let h = g.rms_norm(x, ln2_w, eps);
let w_gate = g.parameter(&format!("{}.mlp.gate_proj.weight", prefix), &[hidden, ffn]);
let w_up = g.parameter(&format!("{}.mlp.up_proj.weight", prefix), &[hidden, ffn]);
let w_down = g.parameter(&format!("{}.mlp.down_proj.weight", prefix), &[ffn, hidden]);
let gate = g.matmul(h, w_gate);
let up = g.matmul(h, w_up);
let ffn_out = g.swiglu(gate, up);
let ffn_out = g.matmul(ffn_out, w_down);
x = g.add(x, ffn_out);
}
let final_ln_w = g.parameter("model.norm.weight", &[hidden]);
x = g.rms_norm(x, final_ln_w, eps);
let logits = if config.tie_word_embeddings {
g.matmul_bt(x, embed_weight)
} else {
let lm_head = g.parameter("lm_head.weight", &[hidden, config.vocab_size]);
g.matmul(x, lm_head)
};
(logits, k_outputs, v_outputs)
}
pub fn build_decode_graph(
g: &mut Graph,
config: &SmolLM2Config,
max_seq_len: usize,
) -> (NodeId, Vec<NodeId>, Vec<NodeId>) {
let hidden = config.hidden_size;
let kv_dim = config.kv_dim();
let ffn = config.intermediate_size;
let eps = config.rms_norm_eps;
let theta = config.rope_theta;
let token_ids = g.input_u32("token_ids", &[1]);
let kv_pos = g.input_u32("kv_pos", &[1]);
let embed_weight = g.parameter("model.embed_tokens.weight", &[config.vocab_size, hidden]);
let mut x = g.embedding(token_ids, embed_weight);
let mut k_cache_params = Vec::new();
let mut v_cache_params = Vec::new();
for i in 0..config.num_hidden_layers {
let prefix = format!("model.layers.{}", i);
let ln1_w = g.parameter(&format!("{}.input_layernorm.weight", prefix), &[hidden]);
let h = g.rms_norm(x, ln1_w, eps);
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[hidden, hidden],
);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[hidden, kv_dim],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[hidden, kv_dim],
);
let q = g.matmul(h, wq); let k = g.matmul(h, wk); let v = g.matmul(h, wv);
let q = g.rope_dynamic_offset(q, theta, kv_pos, config.head_dim());
let k = g.rope_dynamic_offset(k, theta, kv_pos, config.head_dim());
let k_cache = g.parameter(&format!("kv_cache.layer.{}.k", i), &[max_seq_len, kv_dim]);
let v_cache = g.parameter(&format!("kv_cache.layer.{}.v", i), &[max_seq_len, kv_dim]);
k_cache_params.push(k_cache);
v_cache_params.push(v_cache);
let _k_updated = g.cache_write(k, k_cache, kv_pos);
let _v_updated = g.cache_write(v, v_cache, kv_pos);
let attn = g.cached_attention(
q,
k_cache,
v_cache,
kv_pos,
config.num_attention_heads,
config.num_key_value_heads,
config.head_dim(),
);
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[hidden, hidden],
);
let attn_out = g.matmul(attn, wo);
x = g.add(x, attn_out);
let ln2_w = g.parameter(
&format!("{}.post_attention_layernorm.weight", prefix),
&[hidden],
);
let h = g.rms_norm(x, ln2_w, eps);
let w_gate = g.parameter(&format!("{}.mlp.gate_proj.weight", prefix), &[hidden, ffn]);
let w_up = g.parameter(&format!("{}.mlp.up_proj.weight", prefix), &[hidden, ffn]);
let w_down = g.parameter(&format!("{}.mlp.down_proj.weight", prefix), &[ffn, hidden]);
let gate = g.matmul(h, w_gate);
let up = g.matmul(h, w_up);
let ffn_out = g.swiglu(gate, up);
let ffn_out = g.matmul(ffn_out, w_down);
x = g.add(x, ffn_out);
}
let final_ln_w = g.parameter("model.norm.weight", &[hidden]);
x = g.rms_norm(x, final_ln_w, eps);
let logits = if config.tie_word_embeddings {
g.matmul_bt(x, embed_weight)
} else {
let lm_head = g.parameter("lm_head.weight", &[hidden, config.vocab_size]);
g.matmul(x, lm_head) };
(logits, k_cache_params, v_cache_params)
}
pub fn weight_names(config: &SmolLM2Config) -> Vec<String> {
let mut names = Vec::new();
names.push("model.embed_tokens.weight".to_string());
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{}", i);
names.push(format!("{}.input_layernorm.weight", p));
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.o_proj.weight", p));
names.push(format!("{}.post_attention_layernorm.weight", p));
names.push(format!("{}.mlp.gate_proj.weight", p));
names.push(format!("{}.mlp.up_proj.weight", p));
names.push(format!("{}.mlp.down_proj.weight", p));
}
names.push("model.norm.weight".to_string());
if !config.tie_word_embeddings {
names.push("lm_head.weight".to_string());
}
names
}
pub fn transposed_weight_names(config: &SmolLM2Config) -> Vec<String> {
let mut names = Vec::new();
for i in 0..config.num_hidden_layers {
let p = format!("model.layers.{}", i);
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.o_proj.weight", p));
names.push(format!("{}.mlp.gate_proj.weight", p));
names.push(format!("{}.mlp.up_proj.weight", p));
names.push(format!("{}.mlp.down_proj.weight", p));
}
if !config.tie_word_embeddings {
names.push("lm_head.weight".to_string());
}
names
}