use crate::graph::{Graph, NodeId};
use crate::models::smolvlm2::{SmolVLM2Config, TextConfig, VisionConfig};
pub struct ExpertConfig {
pub hidden_size: usize,
pub num_layers: usize,
pub num_attention_heads: u32,
pub num_key_value_heads: u32,
pub head_dim: u32,
pub intermediate_size: usize,
pub rms_norm_eps: f32,
pub self_attn_every_n_layers: usize,
}
impl ExpertConfig {
pub fn kv_dim(&self) -> usize {
self.num_key_value_heads as usize * self.head_dim as usize
}
}
pub struct SmolVLAConfig {
pub vlm: SmolVLM2Config,
pub expert: ExpertConfig,
pub max_action_dim: usize,
pub max_state_dim: usize,
pub chunk_size: usize,
pub num_steps: usize,
pub num_vlm_layers: usize,
}
impl SmolVLAConfig {
pub fn small_test() -> Self {
Self {
vlm: SmolVLM2Config {
vision: VisionConfig {
image_size: 32,
patch_size: 16,
hidden_size: 64,
num_attention_heads: 2,
num_hidden_layers: 2,
intermediate_size: 128,
layer_norm_eps: 1e-6,
},
text: TextConfig {
vocab_size: 256,
hidden_size: 64,
num_hidden_layers: 2,
num_attention_heads: 2,
num_key_value_heads: 2,
intermediate_size: 128,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
},
scale_factor: 1,
},
expert: ExpertConfig {
hidden_size: 64,
num_layers: 2,
num_attention_heads: 2,
num_key_value_heads: 2,
head_dim: 32,
intermediate_size: 128,
rms_norm_eps: 1e-5,
self_attn_every_n_layers: 2,
},
max_action_dim: 8,
max_state_dim: 8,
chunk_size: 4,
num_steps: 2,
num_vlm_layers: 2,
}
}
pub fn smolvla_base() -> Self {
Self {
vlm: SmolVLM2Config {
vision: VisionConfig {
image_size: 512,
patch_size: 16,
hidden_size: 768,
num_attention_heads: 12,
num_hidden_layers: 12,
intermediate_size: 3072,
layer_norm_eps: 1e-6,
},
text: TextConfig {
vocab_size: 49280,
hidden_size: 960,
num_hidden_layers: 32,
num_attention_heads: 15,
num_key_value_heads: 5,
intermediate_size: 2560,
rms_norm_eps: 1e-5,
rope_theta: 100000.0,
},
scale_factor: 4,
},
expert: ExpertConfig {
hidden_size: 720, num_layers: 16,
num_attention_heads: 15,
num_key_value_heads: 5,
head_dim: 64,
intermediate_size: 2048,
rms_norm_eps: 1e-5,
self_attn_every_n_layers: 2,
},
max_action_dim: 32,
max_state_dim: 32,
chunk_size: 50,
num_steps: 10,
num_vlm_layers: 16,
}
}
}
pub fn build_action_expert(
g: &mut Graph,
config: &SmolVLAConfig,
action_seq_len: usize,
vlm_seq_len: usize,
) -> NodeId {
let expert = &config.expert;
let expert_hidden = expert.hidden_size;
let text_hidden = config.vlm.text.hidden_size;
let kv_dim = expert.kv_dim();
let attn_dim = expert.num_attention_heads as usize * expert.head_dim as usize;
let eps = expert.rms_norm_eps;
let noisy_actions = g.input("noisy_actions", &[action_seq_len, config.max_action_dim]);
let timestep = g.input("timestep", &[1, expert_hidden * 2]);
let action_in_w = g.parameter(
"model.action_in_proj.weight",
&[config.max_action_dim, expert_hidden],
);
let action_in_b = g.parameter("model.action_in_proj.bias", &[expert_hidden]);
let mut x = g.matmul(noisy_actions, action_in_w);
x = g.bias_add(x, action_in_b);
let time_in_w = g.parameter(
"model.action_time_mlp_in.weight",
&[expert_hidden * 2, expert_hidden],
);
let time_in_b = g.parameter("model.action_time_mlp_in.bias", &[expert_hidden]);
let time_out_w = g.parameter(
"model.action_time_mlp_out.weight",
&[expert_hidden, expert_hidden],
);
let time_out_b = g.parameter("model.action_time_mlp_out.bias", &[expert_hidden]);
let time_h = g.matmul(timestep, time_in_w);
let time_h = g.bias_add(time_h, time_in_b);
let time_h = g.silu(time_h);
let time_h = g.matmul(time_h, time_out_w);
let time_embed = g.bias_add(time_h, time_out_b);
x = g.broadcast_add(x, time_embed);
let _vlm_hidden = g.input("vlm_hidden", &[vlm_seq_len, text_hidden]);
for i in 0..expert.num_layers {
let prefix = format!("model.vlm_with_expert.lm_expert.layers.{}", i);
let is_cross_attn = i % expert.self_attn_every_n_layers != 0;
let ln1_w = g.parameter(
&format!("{}.input_layernorm.weight", prefix),
&[expert_hidden],
);
let h = g.rms_norm(x, ln1_w, eps);
if is_cross_attn {
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[expert_hidden, attn_dim],
);
let q = g.matmul(h, wq);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[kv_dim, kv_dim],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[kv_dim, kv_dim],
);
let vlm_kv = g.input(&format!("vlm_kv_layer_{}", i), &[vlm_seq_len, kv_dim]);
let k = g.matmul(vlm_kv, wk);
let v = g.matmul(vlm_kv, wv);
let attn = g.cross_attention(
q,
k,
v,
expert.num_attention_heads,
expert.num_key_value_heads,
expert.head_dim,
);
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[attn_dim, expert_hidden],
);
let attn_out = g.matmul(attn, wo);
x = g.add(x, attn_out);
} else {
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[expert_hidden, attn_dim],
);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[expert_hidden, kv_dim],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[expert_hidden, kv_dim],
);
let q = g.matmul(h, wq);
let k = g.matmul(h, wk);
let v = g.matmul(h, wv);
let attn = g.causal_attention(
q,
k,
v,
expert.num_attention_heads,
expert.num_key_value_heads,
expert.head_dim,
);
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[attn_dim, expert_hidden],
);
let attn_out = g.matmul(attn, wo);
x = g.add(x, attn_out);
}
let ln2_w = g.parameter(
&format!("{}.post_attention_layernorm.weight", prefix),
&[expert_hidden],
);
let h = g.rms_norm(x, ln2_w, eps);
let w_gate = g.parameter(
&format!("{}.mlp.gate_proj.weight", prefix),
&[expert_hidden, expert.intermediate_size],
);
let w_up = g.parameter(
&format!("{}.mlp.up_proj.weight", prefix),
&[expert_hidden, expert.intermediate_size],
);
let w_down = g.parameter(
&format!("{}.mlp.down_proj.weight", prefix),
&[expert.intermediate_size, expert_hidden],
);
let gate = g.matmul(h, w_gate);
let up = g.matmul(h, w_up);
let gate_up = g.swiglu(gate, up);
let ffn_out = g.matmul(gate_up, w_down);
x = g.add(x, ffn_out);
}
let action_out_w = g.parameter(
"model.action_out_proj.weight",
&[expert_hidden, config.max_action_dim],
);
let action_out_b = g.parameter("model.action_out_proj.bias", &[config.max_action_dim]);
let out = g.matmul(x, action_out_w);
g.bias_add(out, action_out_b)
}
pub fn build_state_projection(g: &mut Graph, config: &SmolVLAConfig) -> NodeId {
let state_input = g.input("observation_state", &[1, config.max_state_dim]);
let w = g.parameter(
"model.state_proj.weight",
&[config.max_state_dim, config.vlm.text.hidden_size],
);
let b = g.parameter("model.state_proj.bias", &[config.vlm.text.hidden_size]);
let proj = g.matmul(state_input, w);
g.bias_add(proj, b)
}
pub fn expert_weight_names(config: &SmolVLAConfig) -> Vec<String> {
let expert = &config.expert;
let mut names = vec![
"model.state_proj.weight".into(),
"model.state_proj.bias".into(),
"model.action_in_proj.weight".into(),
"model.action_in_proj.bias".into(),
"model.action_out_proj.weight".into(),
"model.action_out_proj.bias".into(),
"model.action_time_mlp_in.weight".into(),
"model.action_time_mlp_in.bias".into(),
"model.action_time_mlp_out.weight".into(),
"model.action_time_mlp_out.bias".into(),
];
for i in 0..expert.num_layers {
let p = format!("model.vlm_with_expert.lm_expert.layers.{}", i);
names.push(format!("{}.input_layernorm.weight", p));
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.o_proj.weight", p));
names.push(format!("{}.post_attention_layernorm.weight", p));
names.push(format!("{}.mlp.gate_proj.weight", p));
names.push(format!("{}.mlp.up_proj.weight", p));
names.push(format!("{}.mlp.down_proj.weight", p));
}
names
}
pub fn build_action_expert_training(
config: &SmolVLAConfig,
action_seq_len: usize,
vlm_seq_len: usize,
) -> Graph {
let mut g = Graph::new();
let expert = &config.expert;
let expert_hidden = expert.hidden_size;
let kv_dim = expert.kv_dim();
let eps = expert.rms_norm_eps;
let num_heads = expert.num_attention_heads;
let num_kv_heads = expert.num_key_value_heads;
let hd = expert.head_dim;
let q_dim = (num_heads * hd) as usize;
let kv_dim_full = (num_kv_heads * hd) as usize;
let noisy_actions = g.input("noisy_actions", &[action_seq_len, config.max_action_dim]);
let timestep = g.input("timestep", &[1, expert_hidden * 2]);
let action_in_w = g.parameter(
"model.action_in_proj.weight",
&[config.max_action_dim, expert_hidden],
);
let action_in_b = g.parameter("model.action_in_proj.bias", &[expert_hidden]);
let mut x = g.matmul(noisy_actions, action_in_w);
x = g.bias_add(x, action_in_b);
let time_in_w = g.parameter(
"model.action_time_mlp_in.weight",
&[expert_hidden * 2, expert_hidden],
);
let time_in_b = g.parameter("model.action_time_mlp_in.bias", &[expert_hidden]);
let time_out_w = g.parameter(
"model.action_time_mlp_out.weight",
&[expert_hidden, expert_hidden],
);
let time_out_b = g.parameter("model.action_time_mlp_out.bias", &[expert_hidden]);
let time_h = g.matmul(timestep, time_in_w);
let time_h = g.bias_add(time_h, time_in_b);
let time_h = g.silu(time_h);
let time_h = g.matmul(time_h, time_out_w);
let time_embed = g.bias_add(time_h, time_out_b);
x = g.broadcast_add(x, time_embed);
for i in 0..expert.num_layers {
let prefix = format!("model.vlm_with_expert.lm_expert.layers.{}", i);
let is_cross_attn = i % expert.self_attn_every_n_layers != 0;
let ln1_w = g.parameter(
&format!("{}.input_layernorm.weight", prefix),
&[expert_hidden],
);
let h = g.rms_norm(x, ln1_w, eps);
if is_cross_attn {
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[expert_hidden, q_dim],
);
let q = g.matmul(h, wq);
let vlm_kv = g.input(&format!("vlm_kv_layer_{}", i), &[vlm_seq_len, kv_dim]);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[kv_dim, kv_dim_full],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[kv_dim, kv_dim_full],
);
let k = g.matmul(vlm_kv, wk); let v = g.matmul(vlm_kv, wv);
let attn = g.multi_head_attn(q, k, v, num_heads, num_kv_heads, hd, true);
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[q_dim, expert_hidden],
);
let attn_out = g.matmul(attn, wo); x = g.add(x, attn_out);
} else {
let wq = g.parameter(
&format!("{}.self_attn.q_proj.weight", prefix),
&[expert_hidden, q_dim],
);
let wk = g.parameter(
&format!("{}.self_attn.k_proj.weight", prefix),
&[expert_hidden, kv_dim_full],
);
let wv = g.parameter(
&format!("{}.self_attn.v_proj.weight", prefix),
&[expert_hidden, kv_dim_full],
);
let q = g.matmul(h, wq); let k = g.matmul(h, wk); let v = g.matmul(h, wv);
let attn = g.multi_head_attn(q, k, v, num_heads, num_kv_heads, hd, false);
let wo = g.parameter(
&format!("{}.self_attn.o_proj.weight", prefix),
&[q_dim, expert_hidden],
);
let attn_out = g.matmul(attn, wo); x = g.add(x, attn_out);
}
let ln2_w = g.parameter(
&format!("{}.post_attention_layernorm.weight", prefix),
&[expert_hidden],
);
let h = g.rms_norm(x, ln2_w, eps);
let w_gate = g.parameter(
&format!("{}.mlp.gate_proj.weight", prefix),
&[expert_hidden, expert.intermediate_size],
);
let w_up = g.parameter(
&format!("{}.mlp.up_proj.weight", prefix),
&[expert_hidden, expert.intermediate_size],
);
let w_down = g.parameter(
&format!("{}.mlp.down_proj.weight", prefix),
&[expert.intermediate_size, expert_hidden],
);
let gate = g.matmul(h, w_gate);
let up = g.matmul(h, w_up);
let gate_up = g.swiglu(gate, up);
let ffn_out = g.matmul(gate_up, w_down);
x = g.add(x, ffn_out);
}
let action_out_w = g.parameter(
"model.action_out_proj.weight",
&[expert_hidden, config.max_action_dim],
);
let action_out_b = g.parameter("model.action_out_proj.bias", &[config.max_action_dim]);
let out = g.matmul(x, action_out_w);
let out = g.bias_add(out, action_out_b);
let target = g.input("target_actions", &[action_seq_len, config.max_action_dim]);
let neg_target = g.neg(target);
let diff = g.add(out, neg_target);
let sq_diff = g.mul(diff, diff);
let loss = g.mean_all(sq_diff);
g.set_outputs(vec![loss]);
g
}
pub fn expert_transposed_weight_names(config: &SmolVLAConfig) -> Vec<String> {
let expert = &config.expert;
let mut names = vec![
"model.state_proj.weight".into(),
"model.action_in_proj.weight".into(),
"model.action_out_proj.weight".into(),
"model.action_time_mlp_in.weight".into(),
"model.action_time_mlp_out.weight".into(),
];
for i in 0..expert.num_layers {
let p = format!("model.vlm_with_expert.lm_expert.layers.{}", i);
names.push(format!("{}.self_attn.q_proj.weight", p));
names.push(format!("{}.self_attn.k_proj.weight", p));
names.push(format!("{}.self_attn.v_proj.weight", p));
names.push(format!("{}.self_attn.o_proj.weight", p));
names.push(format!("{}.mlp.gate_up_proj.weight", p));
names.push(format!("{}.mlp.down_proj.weight", p));
}
names
}