use bytemuck::{Pod, Zeroable};
use super::cached_dispatch;
use crate::backend::WgpuCtx;
use crate::backend::pipelines::Pipelines;
use crate::error::{Result, RullamaError};
use crate::gguf::GgmlDtype;
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct DiffusionAttnParams {
n_tokens: u32,
n_heads: u32,
n_kv_heads: u32,
head_dim: u32,
prompt_len: u32,
n_swa: u32,
swa_layer: u32,
_pad: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn diffusion_attention_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
q: &wgpu::Buffer,
k: &wgpu::Buffer,
v: &wgpu::Buffer,
o: &wgpu::Buffer,
n_tokens: usize,
n_heads: usize,
n_kv_heads: usize,
head_dim: usize,
prompt_len: usize,
n_swa: usize,
swa_layer: bool,
) {
let params = DiffusionAttnParams {
n_tokens: n_tokens as u32,
n_heads: n_heads as u32,
n_kv_heads: n_kv_heads as u32,
head_dim: head_dim as u32,
prompt_len: prompt_len as u32,
n_swa: n_swa as u32,
swa_layer: swa_layer as u32,
_pad: 0,
};
cached_dispatch(
ctx,
enc,
&p.diffusion_attention,
"diffusion_attention",
&[q, k, v, o],
¶ms,
(n_tokens as u32, n_heads as u32, 1),
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct MoeRouterParams {
d_model: u32,
n_experts: u32,
top_k: u32,
eps: f32,
has_scale: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_router_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
scale: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
router_w: &wgpu::Buffer,
expert_ids: &wgpu::Buffer,
expert_weights: &wgpu::Buffer,
d_model: usize,
n_experts: usize,
top_k: usize,
eps: f32,
) {
let params = MoeRouterParams {
d_model: d_model as u32,
n_experts: n_experts as u32,
top_k: top_k as u32,
eps,
has_scale: if scale.is_some() { 1 } else { 0 },
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
cached_dispatch(
ctx,
enc,
&p.moe_router,
"moe_router",
&[
x,
scale.unwrap_or(dummy),
router_w,
expert_ids,
expert_weights,
],
¶ms,
(1, 1, 1),
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct MoeExpertMatmulParams {
k: u32,
n: u32,
slot: u32,
slice_blocks: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_expert_matmul_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
ids: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
k: usize,
n: usize,
slot: usize,
dtype: GgmlDtype,
) -> Result<()> {
let pipeline = match dtype {
GgmlDtype::Q4_K => &p.moe_expert_matmul_q4_k,
GgmlDtype::Q5_0 => &p.moe_expert_matmul_q5_0,
GgmlDtype::Q8_0 => &p.moe_expert_matmul_q8_0,
other => {
return Err(RullamaError::Inference(format!(
"moe expert matmul: unsupported quant dtype {other:?} (expected Q4_K, Q5_0, or Q8_0)"
)));
}
};
let blocks_per_row = k / dtype.block_elems();
let params = MoeExpertMatmulParams {
k: k as u32,
n: n as u32,
slot: slot as u32,
slice_blocks: (blocks_per_row * n) as u32,
};
cached_dispatch(
ctx,
enc,
pipeline,
"moe_expert_matmul",
&[w, ids, x, y],
¶ms,
((n as u32).div_ceil(64), 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct MoeGegluBatchedParams {
rows: u32,
n_ff: u32,
_p0: u32,
_p1: u32,
}
pub fn moe_geglu_halves_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
gu: &wgpu::Buffer,
y: &wgpu::Buffer,
rows: usize,
n_ff: usize,
) {
let params = MoeGegluBatchedParams {
rows: rows as u32,
n_ff: n_ff as u32,
_p0: 0,
_p1: 0,
};
cached_dispatch(
ctx,
enc,
&p.moe_geglu_halves_batched,
"moe_geglu_halves_batched",
&[gu, y],
¶ms,
((n_ff as u32).div_ceil(64), rows as u32, 1),
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct MoeCombineBatchedParams {
n_pos: u32,
d_model: u32,
top_k: u32,
has_down_scale: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_combine_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
slots: &wgpu::Buffer,
expert_ids: &wgpu::Buffer,
expert_weights: &wgpu::Buffer,
down_scale: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
y: &wgpu::Buffer,
n_pos: usize,
d_model: usize,
top_k: usize,
) {
let params = MoeCombineBatchedParams {
n_pos: n_pos as u32,
d_model: d_model as u32,
top_k: top_k as u32,
has_down_scale: if down_scale.is_some() { 1 } else { 0 },
};
cached_dispatch(
ctx,
enc,
&p.moe_combine_batched,
"moe_combine_batched",
&[
slots,
expert_ids,
expert_weights,
down_scale.unwrap_or(dummy),
y,
],
¶ms,
((d_model as u32).div_ceil(64), n_pos as u32, 1),
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct MoeExpertBatchedParams {
k: u32,
n: u32,
top_k: u32,
slice_blocks: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_expert_matmul_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
w: &wgpu::Buffer,
ids: &wgpu::Buffer,
x: &wgpu::Buffer,
y: &wgpu::Buffer,
n_pos: usize,
k: usize,
n: usize,
top_k: usize,
dtype: GgmlDtype,
) -> Result<()> {
let pipeline = match dtype {
GgmlDtype::Q4_K => &p.moe_expert_matmul_batched_q4_k,
GgmlDtype::Q5_0 => &p.moe_expert_matmul_batched_q5_0,
GgmlDtype::Q8_0 => &p.moe_expert_matmul_batched_q8_0,
other => {
return Err(RullamaError::Inference(format!(
"moe expert matmul (batched): unsupported dtype {other:?} (expected Q4_K, Q5_0, or Q8_0)"
)));
}
};
let blocks_per_row = k / dtype.block_elems();
let params = MoeExpertBatchedParams {
k: k as u32,
n: n as u32,
top_k: top_k as u32,
slice_blocks: (blocks_per_row * n) as u32,
};
cached_dispatch(
ctx,
enc,
pipeline,
"moe_expert_matmul_batched",
&[w, ids, x, y],
¶ms,
((n as u32).div_ceil(64), (n_pos * top_k) as u32, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct MoeRouterBatchedParams {
n_pos: u32,
d_model: u32,
n_experts: u32,
top_k: u32,
eps: f32,
has_scale: u32,
_pad0: u32,
_pad1: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_router_batched_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
x: &wgpu::Buffer,
scale: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
router_w: &wgpu::Buffer,
expert_ids: &wgpu::Buffer,
expert_weights: &wgpu::Buffer,
n_pos: usize,
d_model: usize,
n_experts: usize,
top_k: usize,
eps: f32,
) {
let params = MoeRouterBatchedParams {
n_pos: n_pos as u32,
d_model: d_model as u32,
n_experts: n_experts as u32,
top_k: top_k as u32,
eps,
has_scale: if scale.is_some() { 1 } else { 0 },
_pad0: 0,
_pad1: 0,
};
cached_dispatch(
ctx,
enc,
&p.moe_router_batched,
"moe_router_batched",
&[
x,
scale.unwrap_or(dummy),
router_w,
expert_ids,
expert_weights,
],
¶ms,
(n_pos as u32, 1, 1),
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct MoeGegluParams {
n_ff: u32,
_pad0: u32,
_pad1: u32,
_pad2: u32,
}
pub fn moe_geglu_halves_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
gu: &wgpu::Buffer,
y: &wgpu::Buffer,
n_ff: usize,
) {
let params = MoeGegluParams {
n_ff: n_ff as u32,
_pad0: 0,
_pad1: 0,
_pad2: 0,
};
cached_dispatch(
ctx,
enc,
&p.moe_geglu_halves,
"moe_geglu_halves",
&[gu, y],
¶ms,
((n_ff as u32).div_ceil(64), 1, 1),
);
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct MoeCombineParams {
d_model: u32,
top_k: u32,
has_down_scale: u32,
_pad0: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn moe_combine_chained(
ctx: &WgpuCtx,
p: &Pipelines,
enc: &mut wgpu::CommandEncoder,
slots: &wgpu::Buffer,
expert_ids: &wgpu::Buffer,
expert_weights: &wgpu::Buffer,
down_scale: Option<&wgpu::Buffer>,
dummy: &wgpu::Buffer,
y: &wgpu::Buffer,
d_model: usize,
top_k: usize,
) {
let params = MoeCombineParams {
d_model: d_model as u32,
top_k: top_k as u32,
has_down_scale: if down_scale.is_some() { 1 } else { 0 },
_pad0: 0,
};
cached_dispatch(
ctx,
enc,
&p.moe_combine,
"moe_combine",
&[
slots,
expert_ids,
expert_weights,
down_scale.unwrap_or(dummy),
y,
],
¶ms,
((d_model as u32).div_ceil(64), 1, 1),
);
}