use cudarc::driver::{CudaStream, LaunchConfig, PushKernelArg};
use ferrum_bench_core::{global_profile, profile_fields_from_json};
use ferrum_types::{FerrumError, Result};
use half::f16;
use std::sync::{Arc, OnceLock};
use super::CudaBackend;
use crate::backend::{Backend, BackendMoeFused, CudaBuf};
use crate::ptx;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct MoeDumpRuntimeConfig {
enabled: bool,
batch_x_topk_filter: Option<usize>,
}
fn moe_dump_runtime_config() -> &'static MoeDumpRuntimeConfig {
static CONFIG: OnceLock<MoeDumpRuntimeConfig> = OnceLock::new();
CONFIG.get_or_init(|| {
let mut config = MoeDumpRuntimeConfig {
enabled: false,
batch_x_topk_filter: None,
};
for (name, value) in std::env::vars() {
match name.as_str() {
"FERRUM_MOE_DUMP" => config.enabled = true,
"FERRUM_MOE_DUMP_BATCH_X_TOPK" => {
config.batch_x_topk_filter = value.parse::<usize>().ok();
}
_ => {}
}
}
config
})
}
fn maybe_dump_moe_routing(
kind: &str,
stream: &Arc<CudaStream>,
sorted_token_ids: &CudaBuf,
block_ids: &CudaBuf,
total_tokens_post_pad: &CudaBuf,
batch_x_topk: usize,
num_experts: usize,
block_size: usize,
) {
let config = moe_dump_runtime_config();
if !config.enabled {
return;
}
if config
.batch_x_topk_filter
.is_some_and(|target| target != batch_x_topk)
{
return;
}
use std::sync::atomic::{AtomicBool, Ordering};
static DUMPED: AtomicBool = AtomicBool::new(false);
if DUMPED.swap(true, Ordering::Relaxed) {
return;
}
let read_i32 = |buf: &CudaBuf, len: usize| -> Vec<i32> {
let n = len.min(buf.len());
if n == 0 {
return Vec::new();
}
let view = buf.as_i32().slice(0..n);
let mut host = vec![0i32; n];
if stream.memcpy_dtoh(&view, host.as_mut_slice()).is_err() {
return Vec::new();
}
if stream.synchronize().is_err() {
return Vec::new();
}
host
};
let st = read_i32(sorted_token_ids, sorted_token_ids.len());
let bi = read_i32(block_ids, block_ids.len());
let tp = read_i32(total_tokens_post_pad, 1);
let total_post_pad = tp.first().copied().unwrap_or(-1);
let total_blocks = if total_post_pad > 0 {
((total_post_pad as usize) / block_size).min(bi.len())
} else {
0
};
let mut seen = vec![false; num_experts];
let mut unique_experts = 0usize;
for &expert_id in bi.iter().take(total_blocks) {
if expert_id >= 0 {
let expert_idx = expert_id as usize;
if expert_idx < seen.len() && !seen[expert_idx] {
seen[expert_idx] = true;
unique_experts += 1;
}
}
}
let n_show = 48.min(st.len());
let n_bi = 32.min(bi.len());
eprintln!(
"[MOE_DUMP:{kind}] batch_x_topk={batch_x_topk} block_size={block_size} \
num_experts={num_experts} total_post_pad={total_post_pad} \
active_blocks={total_blocks} unique_experts={unique_experts}",
);
eprintln!(
"[MOE_DUMP:{kind}] sorted_token_ids[0..{n_show}] = {:?}",
&st[..n_show]
);
eprintln!("[MOE_DUMP:{kind}] block_ids[0..{n_bi}] = {:?}", &bi[..n_bi]);
let profile = global_profile();
if profile.is_enabled() {
let _ = profile.push_event(
"moe_dump",
profile_fields_from_json(serde_json::json!({
"kind": kind,
"batch_x_topk": batch_x_topk,
"block_size": block_size,
"num_experts": num_experts,
"total_post_pad": total_post_pad,
"active_blocks": total_blocks,
"unique_experts": unique_experts,
"sorted_token_ids_preview": &st[..n_show],
"block_ids_preview": &bi[..n_bi],
})),
profile_fields_from_json(serde_json::json!({})),
false,
);
}
}
impl BackendMoeFused for CudaBackend {
fn route_topk_softmax(
ctx: &mut Self::Context,
logits: &Self::Buffer,
out_ids: &mut Self::Buffer,
out_weights: &mut Self::Buffer,
batch: usize,
num_experts: usize,
top_k: usize,
norm_topk_prob: bool,
) -> Result<()> {
let func = ctx.func(
"moe_router_topk_softmax",
ptx::MOE_ROUTER,
"moe_router_topk_softmax_f16",
);
let batch_i32 = batch as i32;
let n_exp_i32 = num_experts as i32;
let top_k_i32 = top_k as i32;
let norm_i32 = if norm_topk_prob { 1i32 } else { 0i32 };
let smem_bytes = (num_experts as u32) * 4;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(logits);
b.arg(out_ids);
b.arg(out_weights);
b.arg(&batch_i32);
b.arg(&n_exp_i32);
b.arg(&top_k_i32);
b.arg(&norm_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (batch as u32, 1, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: smem_bytes,
})
}
.map_err(|e| FerrumError::model(format!("moe_router launch: {e}")))?;
Ok(())
}
fn try_gpu_route_topk_into_host(
ctx: &mut Self::Context,
logits_dev: &Self::Buffer,
out_ids_host: &mut Vec<u32>,
out_weights_host: &mut Vec<f32>,
batch: usize,
num_experts: usize,
top_k: usize,
norm_topk_prob: bool,
) -> Result<()> {
let total_pairs = batch * top_k;
if ctx.moe_route_capacity < total_pairs {
let stream = ctx.stream.clone();
let nf16 = 2 * total_pairs;
ctx.moe_route_ids = Some(
stream
.alloc_zeros::<f16>(nf16)
.map_err(|e| FerrumError::model(format!("alloc moe_route_ids: {e}")))?,
);
ctx.moe_route_weights = Some(
stream
.alloc_zeros::<f16>(nf16)
.map_err(|e| FerrumError::model(format!("alloc moe_route_weights: {e}")))?,
);
ctx.moe_route_capacity = total_pairs;
}
let func = ctx.func(
"moe_router_topk_softmax",
ptx::MOE_ROUTER,
"moe_router_topk_softmax_f16",
);
let batch_i32 = batch as i32;
let n_exp_i32 = num_experts as i32;
let top_k_i32 = top_k as i32;
let norm_i32 = if norm_topk_prob { 1i32 } else { 0i32 };
let smem_bytes = (num_experts as u32) * 4;
let stream = ctx.stream.clone();
{
let ids_dev = ctx
.moe_route_ids
.as_mut()
.expect("moe_route_ids should be allocated");
let weights_dev = ctx
.moe_route_weights
.as_mut()
.expect("moe_route_weights should be allocated");
let mut b = stream.launch_builder(&func);
b.arg(logits_dev);
b.arg(ids_dev);
b.arg(weights_dev);
b.arg(&batch_i32);
b.arg(&n_exp_i32);
b.arg(&top_k_i32);
b.arg(&norm_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (batch as u32, 1, 1),
block_dim: (32, 1, 1),
shared_mem_bytes: smem_bytes,
})
}
.map_err(|e| FerrumError::model(format!("moe_router launch: {e}")))?;
}
out_ids_host.clear();
out_ids_host.resize(total_pairs, 0u32);
out_weights_host.clear();
out_weights_host.resize(total_pairs, 0.0f32);
let ids_dev = ctx
.moe_route_ids
.as_ref()
.expect("moe_route_ids should be allocated");
let weights_dev = ctx
.moe_route_weights
.as_ref()
.expect("moe_route_weights should be allocated");
let ids_view = unsafe {
ids_dev
.transmute::<i32>(total_pairs)
.ok_or_else(|| FerrumError::model("ids transmute size mismatch"))?
};
let weights_view = unsafe {
weights_dev
.transmute::<f32>(total_pairs)
.ok_or_else(|| FerrumError::model("weights transmute size mismatch"))?
};
let out_ids_i32: &mut [i32] = unsafe {
std::slice::from_raw_parts_mut(out_ids_host.as_mut_ptr() as *mut i32, total_pairs)
};
stream
.memcpy_dtoh(&ids_view, out_ids_i32)
.map_err(|e| FerrumError::model(format!("dtoh route ids: {e}")))?;
stream
.memcpy_dtoh(&weights_view, out_weights_host.as_mut_slice())
.map_err(|e| FerrumError::model(format!("dtoh route weights: {e}")))?;
stream
.synchronize()
.map_err(|e| FerrumError::model(format!("dtoh sync: {e}")))?;
Ok(())
}
fn moe_build_pairs_by_token(
ctx: &mut Self::Context,
expert_ids: &Self::Buffer,
pairs_by_token: &mut Self::Buffer,
packed_token_idx: &mut Self::Buffer,
expert_offsets: &mut Self::Buffer,
batch_x_topk: usize,
num_experts: usize,
top_k: usize,
) -> Result<()> {
if num_experts > 256 {
return Err(FerrumError::model(format!(
"moe_build_pairs_by_token: num_experts={num_experts} > MAX 256 (shmem limit)"
)));
}
let func = ctx.func(
"moe_build_pairs_by_token",
ptx::MOE_BUILD_PAIRS,
"moe_build_pairs_by_token",
);
let n = batch_x_topk as i32;
let ne = num_experts as i32;
let tk = top_k as i32;
let smem = (num_experts as u32) * 4; let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(expert_ids);
b.arg(pairs_by_token);
b.arg(packed_token_idx);
b.arg(expert_offsets);
b.arg(&n);
b.arg(&ne);
b.arg(&tk);
unsafe {
b.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: smem,
})
}
.map_err(|e| FerrumError::model(format!("moe_build_pairs_by_token launch: {e}")))?;
Ok(())
}
fn moe_align_block_size(
ctx: &mut Self::Context,
expert_ids_per_pair: &Self::Buffer,
sorted_token_ids: &mut Self::Buffer,
block_ids: &mut Self::Buffer,
total_tokens_post_pad: &mut Self::Buffer,
batch_x_topk: usize,
num_experts: usize,
block_size: usize,
sorted_max_size: usize,
) -> Result<()> {
if num_experts > 256 {
return Err(FerrumError::model(format!(
"moe_align_block_size: num_experts={num_experts} exceeds compile-time MAX_NUM_EXPERTS=256"
)));
}
let func = ctx.func(
"moe_align_block_size",
ptx::MOE_ALIGN_BLOCK_SIZE,
"moe_align_block_size_f32",
);
let n = batch_x_topk as i32;
let ne = num_experts as i32;
let bs = block_size as i32;
let smax = sorted_max_size as i32;
let stream = ctx.stream.clone();
{
let mut b = stream.launch_builder(&func);
b.arg(&*expert_ids_per_pair);
b.arg(&mut *sorted_token_ids);
b.arg(&mut *block_ids);
b.arg(&mut *total_tokens_post_pad);
b.arg(&n);
b.arg(&ne);
b.arg(&bs);
b.arg(&smax);
unsafe {
b.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::model(format!("moe_align_block_size launch: {e}")))?;
}
maybe_dump_moe_routing(
"packed",
&stream,
sorted_token_ids,
block_ids,
total_tokens_post_pad,
batch_x_topk,
num_experts,
block_size,
);
Ok(())
}
fn moe_align_block_size_pair_ids(
ctx: &mut Self::Context,
expert_ids_per_pair: &Self::Buffer,
sorted_token_ids: &mut Self::Buffer,
block_ids: &mut Self::Buffer,
total_tokens_post_pad: &mut Self::Buffer,
batch_x_topk: usize,
num_experts: usize,
block_size: usize,
sorted_max_size: usize,
) -> Result<()> {
if num_experts > 256 {
return Err(FerrumError::model(format!(
"moe_align_block_size_pair_ids: num_experts={num_experts} exceeds compile-time MAX_NUM_EXPERTS=256"
)));
}
let func = ctx.func(
"moe_align_block_size_pair_ids",
ptx::MOE_ALIGN_BLOCK_SIZE_PAIR_IDS,
"moe_align_block_size_pair_ids_f32",
);
let n = batch_x_topk as i32;
let ne = num_experts as i32;
let bs = block_size as i32;
let smax = sorted_max_size as i32;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(&*expert_ids_per_pair);
b.arg(&mut *sorted_token_ids);
b.arg(&mut *block_ids);
b.arg(&mut *total_tokens_post_pad);
b.arg(&n);
b.arg(&ne);
b.arg(&bs);
b.arg(&smax);
unsafe {
b.launch(LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (256, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::model(format!("moe_align_block_size_pair_ids launch: {e}")))?;
maybe_dump_moe_routing(
"pair_ids",
&stream,
sorted_token_ids,
block_ids,
total_tokens_post_pad,
batch_x_topk,
num_experts,
block_size,
);
Ok(())
}
fn moe_combine(
ctx: &mut Self::Context,
packed_down: &Self::Buffer,
pairs_by_token: &Self::Buffer,
pair_weights: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
hidden: usize,
top_k: usize,
_total_pairs: usize,
) {
let func = ctx.func("moe_combine", ptx::MOE_COMBINE, "moe_combine_f16");
let batch_i32 = batch as i32;
let hidden_i32 = hidden as i32;
let top_k_i32 = top_k as i32;
let block = 256u32;
let grid_x = ((hidden as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(packed_down);
b.arg(pairs_by_token);
b.arg(pair_weights);
b.arg(out);
b.arg(&batch_i32);
b.arg(&hidden_i32);
b.arg(&top_k_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid_x, batch as u32, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.expect("moe_combine launch");
}
fn weighted_sum_batched(
ctx: &mut Self::Context,
slots: &Self::Buffer,
weights: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
top_k: usize,
hidden: usize,
) -> Result<()> {
let func = ctx.func(
"weighted_sum_batched",
ptx::MOE_COMBINE,
"weighted_sum_batched_f16",
);
let batch_i32 = batch as i32;
let top_k_i32 = top_k as i32;
let hidden_i32 = hidden as i32;
let block = 256u32;
let grid_x = ((hidden as u32) + block - 1) / block;
let stream = ctx.stream.clone();
let mut b = stream.launch_builder(&func);
b.arg(slots);
b.arg(weights);
b.arg(out);
b.arg(&batch_i32);
b.arg(&top_k_i32);
b.arg(&hidden_i32);
unsafe {
b.launch(LaunchConfig {
grid_dim: (grid_x, batch as u32, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
})
}
.map_err(|e| FerrumError::model(format!("weighted_sum_batched launch: {e}")))?;
Ok(())
}
#[cfg(feature = "vllm-moe-marlin")]
fn upload_moe_routing(
ctx: &mut Self::Context,
sorted_token_ids: &[i32],
expert_ids: &[i32],
num_tokens_past_padded: &[i32],
) -> Result<crate::backend::traits::MoeRouting<Self>> {
use cudarc::driver::CudaSlice;
let stream = ctx.stream.clone();
let st: CudaSlice<i32> = stream
.clone_htod(sorted_token_ids)
.map_err(|e| FerrumError::model(format!("htod sorted_token_ids: {e}")))?;
let eid: CudaSlice<i32> = stream
.clone_htod(expert_ids)
.map_err(|e| FerrumError::model(format!("htod expert_ids: {e}")))?;
let npp: CudaSlice<i32> = stream
.clone_htod(num_tokens_past_padded)
.map_err(|e| FerrumError::model(format!("htod num_tokens_past_padded: {e}")))?;
Ok(crate::backend::traits::MoeRouting {
sorted_token_ids: crate::backend::CudaBuf::from_i32(st),
expert_ids: crate::backend::CudaBuf::from_i32(eid),
num_tokens_past_padded: crate::backend::CudaBuf::from_i32(npp),
})
}
}