use crate::error::{Error, Result};
use crate::ops::impl_generic::architecture::moe::{
moe_permute_tokens_impl, moe_unpermute_tokens_impl,
};
use crate::ops::traits::architecture::moe::{MoEActivation, MoEOps};
use numr::dtype::DType;
use numr::runtime::wgpu::{WgpuClient, WgpuRuntime, get_buffer};
use numr::tensor::Tensor;
use wgpu::BufferUsages;
const MOE_ROUTING_SHADER: &str = include_str!("../shaders/architecture/moe_routing.wgsl");
const MOE_GROUPED_GEMM_SHADER: &str = include_str!("../shaders/architecture/moe_grouped_gemm.wgsl");
fn grouped_gemm_entry_point(activation: MoEActivation) -> &'static str {
match activation {
MoEActivation::None => "moe_grouped_gemm_f32",
MoEActivation::SiLU => "moe_grouped_gemm_silu_f32",
MoEActivation::GeLU => "moe_grouped_gemm_gelu_f32",
}
}
fn validate_f32(t: &numr::tensor::Tensor<WgpuRuntime>, op: &str) -> Result<()> {
if t.dtype() != DType::F32 {
return Err(Error::InvalidArgument {
arg: "dtype",
reason: format!("{}: WebGPU MoE requires F32, got {:?}", op, t.dtype()),
});
}
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MoERoutingParams {
num_tokens: u32,
num_experts: u32,
k: u32,
_pad: u32,
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MoEGemmParams {
in_dim: u32,
out_dim: u32,
num_experts: u32,
_pad: u32,
}
impl MoEOps<WgpuRuntime> for WgpuClient {
fn moe_top_k_routing(
&self,
logits: &Tensor<WgpuRuntime>,
k: usize,
) -> Result<(Tensor<WgpuRuntime>, Tensor<WgpuRuntime>)> {
validate_f32(logits, "moe_top_k_routing")?;
let shape = logits.shape();
if shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "logits",
reason: format!(
"expected 2D [num_tokens, num_experts], got {}D",
shape.len()
),
});
}
let num_tokens = shape[0];
let num_experts = shape[1];
if k == 0 || k > num_experts {
return Err(Error::InvalidArgument {
arg: "k",
reason: format!("k={} must be in [1, num_experts={}]", k, num_experts),
});
}
let out_indices =
Tensor::<WgpuRuntime>::empty(&[num_tokens, k], DType::I32, logits.device());
let out_weights =
Tensor::<WgpuRuntime>::empty(&[num_tokens, k], DType::F32, logits.device());
let logits_buf = get_buffer(logits.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "logits buffer not found".into(),
})?;
let indices_buf =
get_buffer(out_indices.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "indices buffer not found".into(),
})?;
let weights_buf =
get_buffer(out_weights.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "weights buffer not found".into(),
})?;
let params = MoERoutingParams {
num_tokens: num_tokens as u32,
num_experts: num_experts as u32,
k: k as u32,
_pad: 0,
};
let params_buf = self.wgpu_device().create_buffer(&wgpu::BufferDescriptor {
label: Some("moe_routing_params"),
size: std::mem::size_of::<MoERoutingParams>() as u64,
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.wgpu_queue()
.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms));
let cache = self.pipeline_cache();
let module = cache.get_or_create_module("moe_routing_f32", MOE_ROUTING_SHADER);
let layout = cache.get_or_create_layout(numr::runtime::wgpu::shaders::LayoutKey {
num_storage_buffers: 3,
num_uniform_buffers: 1,
num_readonly_storage: 1,
});
let pipeline =
cache.get_or_create_pipeline("moe_routing_f32", "moe_routing_f32", &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[&logits_buf, &indices_buf, &weights_buf, ¶ms_buf],
);
let mut encoder =
self.wgpu_device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("moe_routing"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("moe_routing"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(num_tokens as u32, 1, 1);
}
self.wgpu_queue().submit(std::iter::once(encoder.finish()));
Ok((out_indices, out_weights))
}
fn moe_permute_tokens(
&self,
tokens: &Tensor<WgpuRuntime>,
indices: &Tensor<WgpuRuntime>,
num_experts: usize,
) -> Result<(
Tensor<WgpuRuntime>,
Tensor<WgpuRuntime>,
Tensor<WgpuRuntime>,
)> {
moe_permute_tokens_impl(self, tokens, indices, num_experts)
}
fn moe_unpermute_tokens(
&self,
expert_output: &Tensor<WgpuRuntime>,
sort_indices: &Tensor<WgpuRuntime>,
weights: &Tensor<WgpuRuntime>,
num_tokens: usize,
) -> Result<Tensor<WgpuRuntime>> {
moe_unpermute_tokens_impl(self, expert_output, sort_indices, weights, num_tokens)
}
fn moe_grouped_gemm(
&self,
permuted_tokens: &Tensor<WgpuRuntime>,
expert_weights: &Tensor<WgpuRuntime>,
expert_offsets: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
launch_grouped_gemm_wgpu(
self,
permuted_tokens,
expert_weights,
expert_offsets,
grouped_gemm_entry_point(MoEActivation::None),
)
}
fn moe_grouped_gemm_fused(
&self,
permuted_tokens: &Tensor<WgpuRuntime>,
expert_weights: &Tensor<WgpuRuntime>,
expert_offsets: &Tensor<WgpuRuntime>,
activation: MoEActivation,
) -> Result<Tensor<WgpuRuntime>> {
launch_grouped_gemm_wgpu(
self,
permuted_tokens,
expert_weights,
expert_offsets,
grouped_gemm_entry_point(activation),
)
}
}
fn launch_grouped_gemm_wgpu(
client: &WgpuClient,
permuted_tokens: &Tensor<WgpuRuntime>,
expert_weights: &Tensor<WgpuRuntime>,
expert_offsets: &Tensor<WgpuRuntime>,
entry_point: &'static str,
) -> Result<Tensor<WgpuRuntime>> {
validate_f32(permuted_tokens, "moe_grouped_gemm")?;
validate_f32(expert_weights, "moe_grouped_gemm")?;
let pt_shape = permuted_tokens.shape();
let ew_shape = expert_weights.shape();
if pt_shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "permuted_tokens",
reason: format!("expected 2D, got {}D", pt_shape.len()),
});
}
if ew_shape.len() != 3 {
return Err(Error::InvalidArgument {
arg: "expert_weights",
reason: format!("expected 3D, got {}D", ew_shape.len()),
});
}
if pt_shape[1] != ew_shape[1] {
return Err(Error::InvalidArgument {
arg: "expert_weights",
reason: format!(
"in_dim mismatch: tokens {}, weights {}",
pt_shape[1], ew_shape[1]
),
});
}
let total_tokens = pt_shape[0];
let in_dim = pt_shape[1];
let num_experts = ew_shape[0];
let out_dim = ew_shape[2];
let device = permuted_tokens.device();
let output = Tensor::<WgpuRuntime>::empty(&[total_tokens, out_dim], DType::F32, device);
if total_tokens == 0 {
return Ok(output);
}
let tokens_buf =
get_buffer(permuted_tokens.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "tokens buffer not found".into(),
})?;
let weights_buf =
get_buffer(expert_weights.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "weights buffer not found".into(),
})?;
let offsets_buf =
get_buffer(expert_offsets.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "offsets buffer not found".into(),
})?;
let output_buf = get_buffer(output.storage().ptr()).ok_or_else(|| Error::KernelError {
reason: "output buffer not found".into(),
})?;
let params = MoEGemmParams {
in_dim: in_dim as u32,
out_dim: out_dim as u32,
num_experts: num_experts as u32,
_pad: 0,
};
let params_buf = client.wgpu_device().create_buffer(&wgpu::BufferDescriptor {
label: Some("moe_gemm_params"),
size: std::mem::size_of::<MoEGemmParams>() as u64,
usage: BufferUsages::UNIFORM | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
client
.wgpu_queue()
.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms));
let cache = client.pipeline_cache();
let module = cache.get_or_create_module("moe_grouped_gemm", MOE_GROUPED_GEMM_SHADER);
let layout = cache.get_or_create_layout(numr::runtime::wgpu::shaders::LayoutKey {
num_storage_buffers: 4,
num_uniform_buffers: 1,
num_readonly_storage: 3,
});
let pipeline = cache.get_or_create_pipeline(entry_point, entry_point, &module, &layout);
let bind_group = cache.create_bind_group(
&layout,
&[
&tokens_buf,
&weights_buf,
&offsets_buf,
&output_buf,
¶ms_buf,
],
);
const TILE: u32 = 16;
let grid_x = (out_dim as u32).div_ceil(TILE);
let grid_y = (total_tokens as u32).div_ceil(TILE);
let grid_z = num_experts as u32;
let mut encoder =
client
.wgpu_device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("moe_grouped_gemm"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("moe_grouped_gemm"),
timestamp_writes: None,
});
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, Some(&bind_group), &[]);
pass.dispatch_workgroups(grid_x, grid_y, grid_z);
}
client
.wgpu_queue()
.submit(std::iter::once(encoder.finish()));
Ok(output)
}