use vyre::ir::Program;
use super::{
activation::embedding,
attention::mla::mla_decode,
linear::linear,
moe::{expert_mlp, moe_layer::moe_layer_route_and_accumulate},
norm::rms_norm,
};
#[derive(Debug, Clone)]
pub struct Ds4FlashConfig {
pub vocab_size: u32,
pub hidden_dim: u32,
pub num_layers: u32,
pub num_heads: u32,
pub head_dim: u32,
pub kv_lora_rank: u32,
pub qk_rope_head_dim: u32,
pub num_experts: u32,
pub moe_top_k: u32,
pub shared_expert_hidden_dim: u32,
pub rms_norm_eps: f32,
pub max_seq_len: u32,
}
impl Default for Ds4FlashConfig {
fn default() -> Self {
Self {
vocab_size: 129_280,
hidden_dim: 7_168,
num_layers: 61,
num_heads: 128,
head_dim: 128,
kv_lora_rank: 512,
qk_rope_head_dim: 64,
num_experts: 256,
moe_top_k: 8,
shared_expert_hidden_dim: 18_432,
rms_norm_eps: 1e-6,
max_seq_len: 4_096,
}
}
}
pub fn build_forward_graph(config: &Ds4FlashConfig) -> Vec<Program> {
let Ds4FlashConfig {
vocab_size,
hidden_dim,
num_layers: _,
num_heads,
head_dim,
kv_lora_rank,
qk_rope_head_dim,
num_experts,
moe_top_k,
shared_expert_hidden_dim,
rms_norm_eps,
max_seq_len,
} = *config;
let embed_program = embedding(
"embed_table",
"tokens",
"embed_out",
max_seq_len,
hidden_dim,
);
let mla_prefill_program = mla_decode(
"q",
"kv_cache",
"kr_cache",
"w_uk",
"w_uv",
"mla_prefill_out",
max_seq_len,
num_heads,
head_dim,
kv_lora_rank,
qk_rope_head_dim,
)
.unwrap_or_else(|e| {
crate::invalid_program(
"vyre-libs::nn::mla_prefill",
format!("Fix: mla_prefill build failed: {e}"),
)
});
let mla_decode_program = mla_decode(
"q",
"kv_cache",
"kr_cache",
"w_uk",
"w_uv",
"mla_decode_out",
1,
num_heads,
head_dim,
kv_lora_rank,
qk_rope_head_dim,
)
.unwrap_or_else(|e| {
crate::invalid_program(
"vyre-libs::nn::mla_decode",
format!("Fix: mla_decode build failed: {e}"),
)
});
let moe_layer_program = moe_layer_route_and_accumulate(
"moe_x",
"w_router",
"b_router",
"expert_indices",
"expert_weights",
"expert_outputs",
"moe_out",
hidden_dim,
num_experts,
hidden_dim,
moe_top_k,
)
.unwrap_or_else(|e| {
crate::invalid_program(
"vyre-libs::nn::moe_layer",
format!("Fix: moe_layer build failed: {e}"),
)
});
let shared_expert_program = expert_mlp(
"shared_x",
"shared_w_gate",
"shared_b_gate",
"shared_w_up",
"shared_b_up",
"shared_w_down",
"shared_b_down",
"shared_out",
hidden_dim,
shared_expert_hidden_dim,
hidden_dim,
)
.unwrap_or_else(|e| {
crate::invalid_program(
"vyre-libs::nn::shared_expert",
format!("Fix: shared_expert build failed: {e}"),
)
});
let rms_norm_program = rms_norm("rms_in", "rms_out", hidden_dim, rms_norm_eps);
let lm_head_program = linear(
"lm_head_x",
"lm_head_w",
"lm_head_b",
"lm_head_out",
hidden_dim,
vocab_size,
)
.unwrap_or_else(|e| {
crate::invalid_program(
"vyre-libs::nn::lm_head",
format!("Fix: lm_head build failed: {e}"),
)
});
vec![
embed_program,
mla_prefill_program,
mla_decode_program,
moe_layer_program,
shared_expert_program,
rms_norm_program,
lm_head_program,
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn forward_graph_default_config_builds() {
let config = Ds4FlashConfig::default();
let programs = build_forward_graph(&config);
assert_eq!(
programs.len(),
7,
"expected 7 programs (one per layer type)"
);
let expected_names = [
"embed",
"mla_prefill",
"mla_decode",
"moe_layer",
"shared_expert",
"rms_norm",
"lm_head",
];
for (i, program) in programs.iter().enumerate() {
assert!(
!program.buffers().is_empty(),
"{} program should declare at least one buffer",
expected_names[i]
);
}
}
#[test]
fn forward_graph_small_config_builds() {
let config = Ds4FlashConfig {
vocab_size: 1_024,
hidden_dim: 256,
num_layers: 2,
num_heads: 4,
head_dim: 64,
kv_lora_rank: 32,
qk_rope_head_dim: 16,
num_experts: 8,
moe_top_k: 2,
shared_expert_hidden_dim: 512,
rms_norm_eps: 1e-5,
max_seq_len: 128,
};
let programs = build_forward_graph(&config);
assert_eq!(programs.len(), 7);
for program in &programs {
assert!(!program.buffers().is_empty());
}
}
#[test]
fn embed_program_has_correct_buffer_count() {
let config = Ds4FlashConfig::default();
let programs = build_forward_graph(&config);
let embed = &programs[0];
assert_eq!(embed.buffers().len(), 3);
}
#[test]
fn mla_prefill_and_decode_are_distinct() {
let config = Ds4FlashConfig::default();
let programs = build_forward_graph(&config);
let prefill = &programs[1];
let decode = &programs[2];
assert_eq!(prefill.buffers().len(), decode.buffers().len());
assert!(
prefill.workgroup_size() == decode.workgroup_size(),
"prefill and decode use the same workgroup dispatch shape"
);
}
#[test]
fn rms_norm_program_is_f32() {
let config = Ds4FlashConfig::default();
let programs = build_forward_graph(&config);
let rms = &programs[5];
for buf in rms.buffers() {
assert_eq!(
buf.element,
vyre::ir::DataType::F32,
"rms_norm uses F32 buffers"
);
}
}
#[test]
fn lm_head_program_has_expected_buffers() {
let config = Ds4FlashConfig::default();
let programs = build_forward_graph(&config);
let lm_head = &programs[6];
assert!(lm_head.buffers().len() >= 4);
}
}