use crate::graph::{Graph, NodeId};
pub struct VisionConfig {
pub image_size: usize,
pub patch_size: usize,
pub hidden_size: usize,
pub num_attention_heads: u32,
pub num_hidden_layers: usize,
pub intermediate_size: usize,
pub layer_norm_eps: f32,
}
impl VisionConfig {
pub fn num_patches(&self) -> usize {
let p = self.image_size / self.patch_size;
p * p
}
pub fn patch_dim(&self) -> usize {
3 * self.patch_size * self.patch_size
}
pub fn head_dim(&self) -> u32 {
self.hidden_size as u32 / self.num_attention_heads
}
}
pub struct TextConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: u32,
pub num_key_value_heads: u32,
pub intermediate_size: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
}
impl TextConfig {
pub fn head_dim(&self) -> u32 {
self.hidden_size as u32 / self.num_attention_heads
}
pub fn kv_dim(&self) -> usize {
self.num_key_value_heads as usize * self.head_dim() as usize
}
}
pub struct SmolVLM2Config {
pub vision: VisionConfig,
pub text: TextConfig,
pub scale_factor: usize,
}
impl SmolVLM2Config {
pub fn smolvlm2_500m() -> Self {
Self {
vision: VisionConfig {
image_size: 512,
patch_size: 16,
hidden_size: 768,
num_attention_heads: 12,
num_hidden_layers: 12,
intermediate_size: 3072,
layer_norm_eps: 1e-6,
},
text: TextConfig {
vocab_size: 49280,
hidden_size: 960,
num_hidden_layers: 32,
num_attention_heads: 15,
num_key_value_heads: 5,
intermediate_size: 2560,
rms_norm_eps: 1e-5,
rope_theta: 100000.0,
},
scale_factor: 4,
}
}
pub fn num_vision_tokens(&self) -> usize {
self.vision.num_patches() / (self.scale_factor * self.scale_factor)
}
pub fn connector_input_dim(&self) -> usize {
self.vision.hidden_size * self.scale_factor * self.scale_factor
}
}
pub fn build_vision_encoder(g: &mut Graph, config: &VisionConfig, num_patches: usize) -> NodeId {
let hidden = config.hidden_size;
let eps = config.layer_norm_eps;
let num_heads = config.num_attention_heads;
let head_dim = config.head_dim();
let patches = g.input("image_patches", &[num_patches, config.patch_dim()]);
let patch_weight = g.parameter(
"model.vision_model.embeddings.patch_embedding.weight",
&[config.patch_dim(), hidden],
);
let patch_bias = g.parameter(
"model.vision_model.embeddings.patch_embedding.bias",
&[hidden],
);
let mut x = g.matmul(patches, patch_weight);
x = g.bias_add(x, patch_bias);
let pos_embed = g.parameter(
"model.vision_model.embeddings.position_embedding.weight",
&[num_patches, hidden],
);
x = g.add(x, pos_embed);
for i in 0..config.num_hidden_layers {
let prefix = format!("model.vision_model.encoder.layers.{}", i);
let ln1_w = g.parameter(&format!("{}.layer_norm1.weight", prefix), &[hidden]);
let ln1_b = g.parameter(&format!("{}.layer_norm1.bias", prefix), &[hidden]);
let h = g.layer_norm(x, ln1_w, ln1_b, eps);
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[hidden, hidden],
);
let bq = g.parameter(&format!("{}.self_attn.q_proj.bias", prefix), &[hidden]);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[hidden, hidden],
);
let bk = g.parameter(&format!("{}.self_attn.k_proj.bias", prefix), &[hidden]);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[hidden, hidden],
);
let bv = g.parameter(&format!("{}.self_attn.v_proj.bias", prefix), &[hidden]);
let q = g.matmul(h, wq);
let q = g.bias_add(q, bq);
let k = g.matmul(h, wk);
let k = g.bias_add(k, bk);
let v = g.matmul(h, wv);
let v = g.bias_add(v, bv);
let attn = g.full_attention(q, k, v, num_heads, num_heads, head_dim);
let wo = g.parameter(
&format!("{}.self_attn.out_proj.weight", prefix),
&[hidden, hidden],
);
let bo = g.parameter(&format!("{}.self_attn.out_proj.bias", prefix), &[hidden]);
let attn_out = g.matmul(attn, wo);
let attn_out = g.bias_add(attn_out, bo);
x = g.add(x, attn_out);
let ln2_w = g.parameter(&format!("{}.layer_norm2.weight", prefix), &[hidden]);
let ln2_b = g.parameter(&format!("{}.layer_norm2.bias", prefix), &[hidden]);
let h = g.layer_norm(x, ln2_w, ln2_b, eps);
let w1 = g.parameter(
&format!("{}.mlp.fc1.weight", prefix),
&[hidden, config.intermediate_size],
);
let b1 = g.parameter(
&format!("{}.mlp.fc1.bias", prefix),
&[config.intermediate_size],
);
let w2 = g.parameter(
&format!("{}.mlp.fc2.weight", prefix),
&[config.intermediate_size, hidden],
);
let b2 = g.parameter(&format!("{}.mlp.fc2.bias", prefix), &[hidden]);
let mlp = g.matmul(h, w1);
let mlp = g.bias_add(mlp, b1);
let mlp = g.gelu(mlp);
let mlp = g.matmul(mlp, w2);
let mlp = g.bias_add(mlp, b2);
x = g.add(x, mlp);
}
let post_ln_w = g.parameter("model.vision_model.post_layernorm.weight", &[hidden]);
let post_ln_b = g.parameter("model.vision_model.post_layernorm.bias", &[hidden]);
g.layer_norm(x, post_ln_w, post_ln_b, eps)
}
pub fn build_text_decoder(
g: &mut Graph,
config: &TextConfig,
mut x: NodeId,
_seq_len: usize,
) -> NodeId {
let hidden = config.hidden_size;
let kv_dim = config.kv_dim();
let ffn = config.intermediate_size;
let eps = config.rms_norm_eps;
let theta = config.rope_theta;
for i in 0..config.num_hidden_layers {
let prefix = format!("model.text_model.layers.{}", i);
let ln1_w = g.parameter(&format!("{}.input_layernorm.weight", prefix), &[hidden]);
let h = g.rms_norm(x, ln1_w, eps);
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[hidden, hidden],
);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[hidden, kv_dim],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[hidden, kv_dim],
);
let q = g.matmul(h, wq);
let k = g.matmul(h, wk);
let v = g.matmul(h, wv);
let q = g.rope(q, theta, config.head_dim());
let k = g.rope(k, theta, config.head_dim());
let attn = g.causal_attention(
q,
k,
v,
config.num_attention_heads,
config.num_key_value_heads,
config.head_dim(),
);
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[hidden, hidden],
);
let attn_out = g.matmul(attn, wo);
x = g.add(x, attn_out);
let ln2_w = g.parameter(
&format!("{}.post_attention_layernorm.weight", prefix),
&[hidden],
);
let h = g.rms_norm(x, ln2_w, eps);
let w_gate = g.parameter(&format!("{}.mlp.gate_proj.weight", prefix), &[hidden, ffn]);
let w_up = g.parameter(&format!("{}.mlp.up_proj.weight", prefix), &[hidden, ffn]);
let w_down = g.parameter(&format!("{}.mlp.down_proj.weight", prefix), &[ffn, hidden]);
let gate = g.matmul(h, w_gate);
let up = g.matmul(h, w_up);
let gate_up = g.swiglu(gate, up);
let ffn_out = g.matmul(gate_up, w_down);
x = g.add(x, ffn_out);
}
let final_ln_w = g.parameter("model.text_model.norm.weight", &[hidden]);
x = g.rms_norm(x, final_ln_w, eps);
let lm_head = g.parameter("lm_head.weight", &[hidden, config.vocab_size]);
g.matmul(x, lm_head)
}
pub fn build_graph(g: &mut Graph, config: &SmolVLM2Config, text_seq_len: usize) -> NodeId {
let num_patches = config.vision.num_patches();
let num_vision_tokens = config.num_vision_tokens();
let total_seq_len = num_vision_tokens + text_seq_len;
let vision_features = build_vision_encoder(g, &config.vision, num_patches);
let connector_input_dim = config.connector_input_dim();
let connector_weight = g.parameter(
"model.connector.modality_projection.proj.weight",
&[connector_input_dim, config.text.hidden_size],
);
let shuffled_features = g.input(
"vision_features_shuffled",
&[num_vision_tokens, connector_input_dim],
);
let vision_projected = g.matmul(shuffled_features, connector_weight);
let token_ids = g.input_u32("token_ids", &[text_seq_len]);
let embed_weight = g.parameter(
"model.text_model.embed_tokens.weight",
&[config.text.vocab_size, config.text.hidden_size],
);
let _text_embeds = g.embedding(token_ids, embed_weight);
let combined_embeds = g.input("combined_embeds", &[total_seq_len, config.text.hidden_size]);
let logits = build_text_decoder(g, &config.text, combined_embeds, total_seq_len);
let _ = vision_features;
let _ = vision_projected;
logits
}
pub fn weight_names(config: &SmolVLM2Config) -> Vec<String> {
let mut names = Vec::new();
names.push("model.vision_model.embeddings.patch_embedding.weight".into());
names.push("model.vision_model.embeddings.patch_embedding.bias".into());
names.push("model.vision_model.embeddings.position_embedding.weight".into());
for i in 0..config.vision.num_hidden_layers {
let p = format!("model.vision_model.encoder.layers.{}", i);
names.push(format!("{}.layer_norm1.weight", p));
names.push(format!("{}.layer_norm1.bias", p));
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.q_proj.bias", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.bias", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.bias", p));
names.push(format!("{}.self_attn.out_proj.weight", p));
names.push(format!("{}.self_attn.out_proj.bias", p));
names.push(format!("{}.layer_norm2.weight", p));
names.push(format!("{}.layer_norm2.bias", p));
names.push(format!("{}.mlp.fc1.weight", p));
names.push(format!("{}.mlp.fc1.bias", p));
names.push(format!("{}.mlp.fc2.weight", p));
names.push(format!("{}.mlp.fc2.bias", p));
}
names.push("model.vision_model.post_layernorm.weight".into());
names.push("model.vision_model.post_layernorm.bias".into());
names.push("model.connector.modality_projection.proj.weight".into());
names.push("model.text_model.embed_tokens.weight".into());
for i in 0..config.text.num_hidden_layers {
let p = format!("model.text_model.layers.{}", i);
names.push(format!("{}.input_layernorm.weight", p));
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.o_proj.weight", p));
names.push(format!("{}.post_attention_layernorm.weight", p));
names.push(format!("{}.mlp.gate_proj.weight", p));
names.push(format!("{}.mlp.up_proj.weight", p));
names.push(format!("{}.mlp.down_proj.weight", p));
}
names.push("model.text_model.norm.weight".into());
names.push("lm_head.weight".into());
names
}
pub fn transposed_weight_names(config: &SmolVLM2Config) -> Vec<String> {
let mut names = Vec::new();
for i in 0..config.vision.num_hidden_layers {
let p = format!("model.vision_model.encoder.layers.{}", i);
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.out_proj.weight", p));
names.push(format!("{}.mlp.fc1.weight", p));
names.push(format!("{}.mlp.fc2.weight", p));
}
names.push("model.connector.modality_projection.proj.weight".into());
for i in 0..config.text.num_hidden_layers {
let p = format!("model.text_model.layers.{}", i);
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.o_proj.weight", p));
names.push(format!("{}.mlp.gate_proj.weight", p));
names.push(format!("{}.mlp.up_proj.weight", p));
names.push(format!("{}.mlp.down_proj.weight", p));
}
names
}