1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
//! GPU MoE weighted accumulate + shared expert add + optional residual.
//!
//! Replaces the CPU weighted accumulate loop in `build_moe_ffn_layer_gpu_q`
//! (Step 3e) and folds in the shared expert gated addition (Step 5+sh_gate add)
//! and optionally the post-FFN residual, all in one GPU kernel.
//!
//! This eliminates:
//! - CPU download of `y_all` (n_tokens * top_k * h floats)
//! - CPU download of `y_s` (n_tokens * h floats)
//! - CPU download of `sh_logit` (n_tokens floats)
//! - CPU weighted accumulate loop
//! - CPU sigmoid(sh_logit) * y_s addition
//! - The `residual_add_gpu` commit (for MoeQ path)
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::{CommandEncoder, KernelArg, as_bytes};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
/// Metal shader source.
pub static MOE_WEIGHTED_REDUCE_SHADER_SOURCE: &str =
include_str!("../shaders/moe_weighted_reduce.metal");
/// Register the `moe_weighted_reduce_f32` pipeline.
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("moe_weighted_reduce_f32", MOE_WEIGHTED_REDUCE_SHADER_SOURCE);
}
/// GPU-side params struct.
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct MoeWeightedReduceGpuParams {
n_tokens: u32,
top_k: u32,
h: u32,
add_residual: u32,
}
const TG_SIZE: u64 = 256;
/// Dispatch fused MoE weighted accumulate + shared expert + optional residual.
///
/// Computes:
/// `output[t, h] = sum_{k} expert_w[t*top_k+k] * y_expert[(t*top_k+k), h]
/// + sh_gate[t] * y_shared[t, h]
/// [+ residual[t, h]]`
///
/// # Arguments
///
/// * `expert_w` — F32 `[n_tokens * top_k]` routing weights (post-renorm).
/// * `y_expert` — F32 `[n_tokens * top_k, h]` expert down-projection outputs.
/// * `sh_gate` — F32 `[n_tokens]` shared gate scalars (sigmoid of sh_logit).
/// * `y_shared` — F32 `[n_tokens, h]` shared expert output.
/// * `residual` — F32 `[n_tokens, h]` (can be null if `add_residual = false`).
/// * `output` — F32 `[n_tokens, h]` output buffer.
/// * `add_residual` — If true, add the residual buffer to the output.
#[allow(clippy::too_many_arguments)]
pub fn dispatch_moe_weighted_reduce(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &MlxDevice,
expert_w: &MlxBuffer,
y_expert: &MlxBuffer,
sh_gate: &MlxBuffer,
y_shared: &MlxBuffer,
residual: &MlxBuffer, // dummy/unused buffer when add_residual=false
output: &mut MlxBuffer,
n_tokens: u32,
top_k: u32,
h: u32,
add_residual: bool,
) -> Result<()> {
let f32_sz = DType::F32.size_of();
if n_tokens == 0 || top_k == 0 || h == 0 {
return Err(MlxError::InvalidArgument(
"moe_weighted_reduce: n_tokens, top_k, h must be > 0".into(),
));
}
let total_slots = (n_tokens as usize) * (top_k as usize);
let expect_ew = total_slots * f32_sz;
let expect_ye = total_slots * (h as usize) * f32_sz;
let expect_sg = (n_tokens as usize) * f32_sz;
let expect_ys = (n_tokens as usize) * (h as usize) * f32_sz;
let expect_out = (n_tokens as usize) * (h as usize) * f32_sz;
macro_rules! check_buf {
($buf:expr, $expected:expr, $name:literal) => {
if $buf.byte_len() < $expected {
return Err(MlxError::InvalidArgument(format!(
"moe_weighted_reduce: {} too small (expected {}, got {})",
$name, $expected, $buf.byte_len()
)));
}
};
}
check_buf!(expert_w, expect_ew, "expert_w");
check_buf!(y_expert, expect_ye, "y_expert");
check_buf!(sh_gate, expect_sg, "sh_gate");
check_buf!(y_shared, expect_ys, "y_shared");
check_buf!(output, expect_out, "output");
let gpu_params = MoeWeightedReduceGpuParams {
n_tokens,
top_k,
h,
add_residual: add_residual as u32,
};
let pipeline = registry.get_pipeline("moe_weighted_reduce_f32", device.metal_device())?;
// Grid: (ceil(h/256), n_tokens, 1).
let threadgroups = MTLSize::new(
(h as u64 + TG_SIZE - 1) / TG_SIZE,
n_tokens as u64,
1,
);
let threadgroup_size = MTLSize::new(TG_SIZE, 1, 1);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Bytes(as_bytes(&gpu_params))),
(1, KernelArg::Buffer(expert_w)),
(2, KernelArg::Buffer(y_expert)),
(3, KernelArg::Buffer(sh_gate)),
(4, KernelArg::Buffer(y_shared)),
(5, KernelArg::Buffer(residual)),
(6, KernelArg::Buffer(output)),
],
&[], // no threadgroup memory
threadgroups,
threadgroup_size,
);
Ok(())
}