use ferrum_kernels::backend::{BackendMoeFused, QuantLlmBackend};
use ferrum_types::Result;
use crate::models::qwen3_moe::{Qwen3MoeLayerState, Qwen3MoeScratch};
use crate::models::qwen3_moe_profile::{
DEC_DOWN_US, DEC_GATE_US, DEC_ROUTE_US, DEC_SILU_US, DEC_UP_US, DEC_WSUM_US,
MOE_BATCHED_DECODE_DOWN_US, MOE_BATCHED_DECODE_GATE_US, MOE_BATCHED_DECODE_ROUTE_US,
MOE_BATCHED_DECODE_SILU_US, MOE_BATCHED_DECODE_UP_US, MOE_BATCHED_DECODE_WSUM_US,
MOE_PREFILL_DOWN_CALLS, MOE_PREFILL_DOWN_US, MOE_PREFILL_GATE_CALLS, MOE_PREFILL_GATE_US,
MOE_PREFILL_HOST_TOPK_CALLS, MOE_PREFILL_HOST_TOPK_US, MOE_PREFILL_SILU_CALLS,
MOE_PREFILL_SILU_US, MOE_PREFILL_UP_CALLS, MOE_PREFILL_UP_US, MOE_PREFILL_WSUM_CALLS,
MOE_PREFILL_WSUM_US,
};
use std::sync::atomic::AtomicU64;
use std::sync::OnceLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MoeForwardRuntimeConfig {
decode_op_profile: bool,
fused_gate_up_silu_disabled: bool,
moe_host_topk: bool,
moe_direct_dispatch: bool,
}
impl MoeForwardRuntimeConfig {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut config = Self {
decode_op_profile: false,
fused_gate_up_silu_disabled: false,
moe_host_topk: false,
moe_direct_dispatch: false,
};
for (name, value) in vars {
let value = value.as_ref();
match name.as_ref() {
"FERRUM_DECODE_OP_PROFILE" => config.decode_op_profile = true,
"FERRUM_MOE_FUSED_GATE_UP_SILU" => {
config.fused_gate_up_silu_disabled = value == "0";
}
"FERRUM_MOE_HOST_TOPK" => config.moe_host_topk = value == "1",
"FERRUM_MOE_DIRECT_DISPATCH" => config.moe_direct_dispatch = value == "1",
_ => {}
}
}
config
}
}
fn moe_forward_runtime_config() -> &'static MoeForwardRuntimeConfig {
static CONFIG: OnceLock<MoeForwardRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(MoeForwardRuntimeConfig::from_env)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn moe_forward_stacked_decode_impl<B: QuantLlmBackend + BackendMoeFused>(
ctx: &mut B::Context,
moe_layer: &Qwen3MoeLayerState<B>,
scratch: &mut Qwen3MoeScratch<B>,
h: usize,
inter: usize,
top_k: usize,
n_exp: usize,
norm_topk_prob: bool,
tokens: usize,
residual: &mut B::Buffer,
next_norm_w: Option<&B::Buffer>,
eps: f32,
) -> Result<()> {
use ferrum_kernels::backend::timer::{finish_probe_timer_traced, start_probe_timer_if};
use ferrum_kernels::backend::Backend;
let decode_op_profile = moe_forward_runtime_config().decode_op_profile;
let stage_t0 = |ctx: &mut B::Context| -> Option<<B as Backend>::Timer> {
start_probe_timer_if::<B>(decode_op_profile, ctx)
};
let stage_end =
|t: Option<<B as Backend>::Timer>, ctx: &mut B::Context, name: &str, c: &AtomicU64| {
if let Some(us) = finish_probe_timer_traced::<B>(t, ctx, name, "moe", 0) {
c.fetch_add(us, std::sync::atomic::Ordering::Relaxed);
}
};
let t0 = stage_t0(ctx);
B::route_topk_softmax(
ctx,
&scratch.router_logits,
&mut scratch.ids_buf,
&mut scratch.weights_buf,
tokens,
n_exp,
top_k,
norm_topk_prob,
)?;
stage_end(t0, ctx, "route", &DEC_ROUTE_US);
debug_assert_eq!(
tokens, 1,
"moe_forward_stacked_decode_impl expects tokens=1 (prefill goes through moe_forward_batched_prefill_impl)"
);
let _ = tokens;
{
let fused_disabled = moe_forward_runtime_config().fused_gate_up_silu_disabled;
let use_fused = B::supports_fused_moe_gate_up_silu() && !fused_disabled;
if use_fused {
let t0 = stage_t0(ctx);
moe_layer.experts.gemv_gate_up_silu_fused(
ctx,
&scratch.norm_out,
&scratch.ids_buf,
&mut scratch.silu_stacked,
top_k,
)?;
stage_end(t0, ctx, "gate_up_silu_fused", &DEC_SILU_US);
} else {
let t0 = stage_t0(ctx);
moe_layer.experts.gemv_gate(
ctx,
&scratch.norm_out,
&scratch.ids_buf,
&mut scratch.gate_out_stacked,
top_k,
)?;
stage_end(t0, ctx, "gate", &DEC_GATE_US);
let t0 = stage_t0(ctx);
moe_layer.experts.gemv_up(
ctx,
&scratch.norm_out,
&scratch.ids_buf,
&mut scratch.up_out_stacked,
top_k,
)?;
stage_end(t0, ctx, "up", &DEC_UP_US);
let t0 = stage_t0(ctx);
B::silu_mul_stacked(
ctx,
&scratch.gate_out_stacked,
&scratch.up_out_stacked,
&mut scratch.silu_stacked,
top_k,
inter,
)?;
stage_end(t0, ctx, "silu", &DEC_SILU_US);
}
let t0 = stage_t0(ctx);
moe_layer.experts.gemv_down(
ctx,
&scratch.silu_stacked,
&scratch.ids_buf,
&mut scratch.down_out_stacked,
top_k,
inter,
)?;
stage_end(t0, ctx, "down", &DEC_DOWN_US);
let t0 = stage_t0(ctx);
if let Some(nnw) = next_norm_w {
B::weighted_sum_residual_norm_stacked(
ctx,
&scratch.down_out_stacked,
&scratch.weights_buf,
residual,
nnw,
&mut scratch.norm_out,
top_k,
h,
eps,
)?;
} else {
B::weighted_sum_residual_stacked(
ctx,
&scratch.down_out_stacked,
&scratch.weights_buf,
residual,
top_k,
h,
)?;
}
stage_end(t0, ctx, "wsum", &DEC_WSUM_US);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn moe_forward_batched_prefill_impl<B: QuantLlmBackend + BackendMoeFused>(
ctx: &mut B::Context,
moe_layer: &Qwen3MoeLayerState<B>,
scratch: &mut Qwen3MoeScratch<B>,
h: usize,
inter: usize,
top_k: usize,
n_exp: usize,
norm_topk_prob: bool,
tokens: usize,
) -> Result<()> {
use ferrum_kernels::backend::timer::{finish_probe_timer_traced, start_probe_timer_if};
use ferrum_kernels::backend::Backend;
let decode_op_profile = moe_forward_runtime_config().decode_op_profile;
let stage_t0 = |ctx: &mut B::Context| -> Option<<B as Backend>::Timer> {
start_probe_timer_if::<B>(decode_op_profile, ctx)
};
let stage_end = |t: Option<<B as Backend>::Timer>,
ctx: &mut B::Context,
name: &str,
us: &AtomicU64,
n: &AtomicU64| {
if let Some(elapsed) = finish_probe_timer_traced::<B>(t, ctx, name, "moe_prefill", 0) {
us.fetch_add(elapsed, std::sync::atomic::Ordering::Relaxed);
n.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
};
let runtime_config = moe_forward_runtime_config();
let use_gpu_topk = !runtime_config.moe_host_topk;
let use_indirect_dispatch = use_gpu_topk && !runtime_config.moe_direct_dispatch;
let max_per_expert = if use_gpu_topk {
let t0 = stage_t0(ctx);
B::route_topk_softmax(
ctx,
&scratch.router_logits,
&mut scratch.selected_ids_buf,
&mut scratch.weights_2d,
tokens,
n_exp,
top_k,
norm_topk_prob,
)?;
B::compute_ids_tpe_gpu(
ctx,
&scratch.selected_ids_buf,
&mut scratch.tpe_buf,
&mut scratch.ids_2d,
&mut scratch.gate_up_args_buf,
&mut scratch.down_args_buf,
tokens,
n_exp,
top_k,
inter,
h,
)?;
stage_end(
t0,
ctx,
"host_topk",
&MOE_PREFILL_HOST_TOPK_US,
&MOE_PREFILL_HOST_TOPK_CALLS,
);
tokens * top_k
} else {
use ferrum_kernels::moe_host::compute_ids_tpe;
let t0 = stage_t0(ctx);
B::sync(ctx);
let logits_host = B::to_vec(&scratch.router_logits, tokens * n_exp);
let route = crate::moe::router::route(&logits_host, tokens, n_exp, top_k, norm_topk_prob);
let (tpe_host, ids_host, max_per_expert) =
compute_ids_tpe(&route.expert_ids, n_exp, tokens, top_k);
B::write_typed::<i32>(ctx, &mut scratch.tpe_buf, &tpe_host);
B::write_typed::<i32>(ctx, &mut scratch.ids_2d, &ids_host);
B::write_typed::<f32>(ctx, &mut scratch.weights_2d, &route.expert_weights);
stage_end(
t0,
ctx,
"host_topk",
&MOE_PREFILL_HOST_TOPK_US,
&MOE_PREFILL_HOST_TOPK_CALLS,
);
max_per_expert
};
let gate_up_args = use_indirect_dispatch.then_some(&scratch.gate_up_args_buf);
let down_args = use_indirect_dispatch.then_some(&scratch.down_args_buf);
let t0 = stage_t0(ctx);
moe_layer.experts.gemm_gate(
ctx,
&scratch.norm_out,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.gate_out_stacked,
gate_up_args,
top_k,
max_per_expert,
tokens,
)?;
stage_end(
t0,
ctx,
"gate",
&MOE_PREFILL_GATE_US,
&MOE_PREFILL_GATE_CALLS,
);
let t0 = stage_t0(ctx);
moe_layer.experts.gemm_up(
ctx,
&scratch.norm_out,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.up_out_stacked,
gate_up_args,
top_k,
max_per_expert,
tokens,
)?;
stage_end(t0, ctx, "up", &MOE_PREFILL_UP_US, &MOE_PREFILL_UP_CALLS);
let total_pairs = tokens * top_k;
let t0 = stage_t0(ctx);
B::silu_mul_batched(
ctx,
&scratch.gate_out_stacked,
&scratch.up_out_stacked,
&mut scratch.silu_stacked,
total_pairs,
inter,
)?;
stage_end(
t0,
ctx,
"silu",
&MOE_PREFILL_SILU_US,
&MOE_PREFILL_SILU_CALLS,
);
let t0 = stage_t0(ctx);
moe_layer.experts.gemm_down(
ctx,
&scratch.silu_stacked,
&scratch.ids_2d,
&scratch.tpe_buf,
&mut scratch.down_out_stacked,
down_args,
top_k,
max_per_expert,
tokens,
)?;
stage_end(
t0,
ctx,
"down",
&MOE_PREFILL_DOWN_US,
&MOE_PREFILL_DOWN_CALLS,
);
let t0 = stage_t0(ctx);
B::weighted_sum_batched(
ctx,
&scratch.down_out_stacked,
&scratch.weights_2d,
&mut scratch.moe_out,
tokens,
top_k,
h,
)?;
stage_end(
t0,
ctx,
"wsum",
&MOE_PREFILL_WSUM_US,
&MOE_PREFILL_WSUM_CALLS,
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn moe_forward_batched_decode_impl<B: QuantLlmBackend + BackendMoeFused>(
ctx: &mut B::Context,
moe_layer: &Qwen3MoeLayerState<B>,
scratch: &mut Qwen3MoeScratch<B>,
h: usize,
inter: usize,
top_k: usize,
n_exp: usize,
norm_topk_prob: bool,
tokens: usize,
) -> Result<()> {
use ferrum_kernels::backend::timer::{finish_probe_timer_traced, start_probe_timer_if};
use ferrum_kernels::backend::Backend;
let decode_op_profile = moe_forward_runtime_config().decode_op_profile;
let stage_t0 = |ctx: &mut B::Context| -> Option<<B as Backend>::Timer> {
start_probe_timer_if::<B>(decode_op_profile, ctx)
};
let stage_end =
|t: Option<<B as Backend>::Timer>, ctx: &mut B::Context, name: &str, c: &AtomicU64| {
if let Some(us) = finish_probe_timer_traced::<B>(t, ctx, name, "moe_batched", 0) {
c.fetch_add(us, std::sync::atomic::Ordering::Relaxed);
}
};
let total_pairs = tokens * top_k;
let t0 = stage_t0(ctx);
B::route_topk_softmax(
ctx,
&scratch.router_logits,
&mut scratch.selected_ids_buf,
&mut scratch.weights_2d,
tokens,
n_exp,
top_k,
norm_topk_prob,
)?;
stage_end(t0, ctx, "route", &MOE_BATCHED_DECODE_ROUTE_US);
if B::supports_batched_moe_gate_up_silu() {
let t0 = stage_t0(ctx);
moe_layer.experts.gemv_gate_up_silu_batched_fused(
ctx,
&scratch.norm_out,
&scratch.selected_ids_buf,
&mut scratch.silu_stacked,
tokens,
top_k,
h, 0, )?;
stage_end(t0, ctx, "silu", &MOE_BATCHED_DECODE_SILU_US);
} else {
let t0 = stage_t0(ctx);
moe_layer.experts.gemv_gate_batched(
ctx,
&scratch.norm_out,
&scratch.selected_ids_buf,
&mut scratch.gate_out_stacked,
tokens,
top_k,
h,
0,
)?;
stage_end(t0, ctx, "gate", &MOE_BATCHED_DECODE_GATE_US);
let t0 = stage_t0(ctx);
moe_layer.experts.gemv_up_batched(
ctx,
&scratch.norm_out,
&scratch.selected_ids_buf,
&mut scratch.up_out_stacked,
tokens,
top_k,
h,
0,
)?;
stage_end(t0, ctx, "up", &MOE_BATCHED_DECODE_UP_US);
let t0 = stage_t0(ctx);
B::silu_mul_batched(
ctx,
&scratch.gate_out_stacked,
&scratch.up_out_stacked,
&mut scratch.silu_stacked,
total_pairs,
inter,
)?;
stage_end(t0, ctx, "silu", &MOE_BATCHED_DECODE_SILU_US);
}
let t0 = stage_t0(ctx);
moe_layer.experts.gemv_down_batched(
ctx,
&scratch.silu_stacked,
&scratch.selected_ids_buf,
&mut scratch.down_out_stacked,
tokens,
top_k,
top_k * inter, inter, )?;
stage_end(t0, ctx, "down", &MOE_BATCHED_DECODE_DOWN_US);
let t0 = stage_t0(ctx);
B::weighted_sum_batched(
ctx,
&scratch.down_out_stacked,
&scratch.weights_2d,
&mut scratch.moe_out,
tokens,
top_k,
h,
)?;
stage_end(t0, ctx, "wsum", &MOE_BATCHED_DECODE_WSUM_US);
Ok(())
}
#[cfg(test)]
mod tests {
use super::MoeForwardRuntimeConfig;
#[test]
fn moe_forward_runtime_config_parses_startup_knobs() {
let config = MoeForwardRuntimeConfig::from_env_vars([
("FERRUM_MOE_FUSED_GATE_UP_SILU", "0"),
("FERRUM_MOE_HOST_TOPK", "1"),
("FERRUM_MOE_DIRECT_DISPATCH", "1"),
]);
assert!(config.fused_gate_up_silu_disabled);
assert!(config.moe_host_topk);
assert!(config.moe_direct_dispatch);
}
#[test]
fn moe_forward_runtime_config_keeps_default_fast_paths() {
let config = MoeForwardRuntimeConfig::from_env_vars([
("FERRUM_MOE_FUSED_GATE_UP_SILU", "1"),
("FERRUM_MOE_HOST_TOPK", "0"),
("FERRUM_MOE_DIRECT_DISPATCH", "0"),
]);
assert!(!config.fused_gate_up_silu_disabled);
assert!(!config.moe_host_topk);
assert!(!config.moe_direct_dispatch);
}
}