mlx_native/ops/
moe_gate.rs1use metal::MTLSize;
10
11use crate::buffer::MlxBuffer;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16pub struct MoeGateParams {
18 pub hidden_dim: usize,
20 pub n_experts: usize,
22 pub top_k: usize,
24 pub seq_len: usize,
26 pub rms_eps: f32,
28}
29
30#[allow(clippy::too_many_arguments)]
51pub fn moe_gate(
52 encoder: &mut CommandEncoder,
53 registry: &mut KernelRegistry,
54 device: &metal::DeviceRef,
55 hidden_state: &MlxBuffer,
56 router_weights: &MlxBuffer,
57 norm_weight: &MlxBuffer,
58 per_expert_scale: &MlxBuffer,
59 out_expert_ids: &MlxBuffer,
60 out_weights: &MlxBuffer,
61 params: &MoeGateParams,
62) -> Result<()> {
63 if params.hidden_dim == 0 {
65 return Err(MlxError::InvalidArgument(
66 "moe_gate: hidden_dim must be > 0".into(),
67 ));
68 }
69 if params.n_experts == 0 {
70 return Err(MlxError::InvalidArgument(
71 "moe_gate: n_experts must be > 0".into(),
72 ));
73 }
74 if params.top_k == 0 {
75 return Err(MlxError::InvalidArgument(
76 "moe_gate: top_k must be > 0".into(),
77 ));
78 }
79 if params.seq_len == 0 {
80 return Err(MlxError::InvalidArgument(
81 "moe_gate: seq_len must be > 0".into(),
82 ));
83 }
84 if params.top_k > params.n_experts {
85 return Err(MlxError::InvalidArgument(format!(
86 "moe_gate: top_k ({}) must be <= n_experts ({})",
87 params.top_k, params.n_experts
88 )));
89 }
90 if params.n_experts > 128 {
91 return Err(MlxError::InvalidArgument(format!(
92 "moe_gate: n_experts ({}) exceeds max 128 (shader fixed-size array limit)",
93 params.n_experts
94 )));
95 }
96
97 let bf16_size = 2usize;
99 let f32_size = std::mem::size_of::<f32>();
100 let u32_size = std::mem::size_of::<u32>();
101
102 let expected_hidden_bytes = params.seq_len * params.hidden_dim * bf16_size;
103 if hidden_state.byte_len() < expected_hidden_bytes {
104 return Err(MlxError::InvalidArgument(format!(
105 "moe_gate: hidden_state buffer too small: need {} bytes, have {}",
106 expected_hidden_bytes,
107 hidden_state.byte_len()
108 )));
109 }
110
111 let expected_router_bytes = params.n_experts * params.hidden_dim * f32_size;
112 if router_weights.byte_len() < expected_router_bytes {
113 return Err(MlxError::InvalidArgument(format!(
114 "moe_gate: router_weights buffer too small: need {} bytes, have {}",
115 expected_router_bytes,
116 router_weights.byte_len()
117 )));
118 }
119
120 let expected_norm_bytes = params.hidden_dim * f32_size;
121 if norm_weight.byte_len() < expected_norm_bytes {
122 return Err(MlxError::InvalidArgument(format!(
123 "moe_gate: norm_weight buffer too small: need {} bytes, have {}",
124 expected_norm_bytes,
125 norm_weight.byte_len()
126 )));
127 }
128
129 let expected_scale_bytes = params.n_experts * f32_size;
130 if per_expert_scale.byte_len() < expected_scale_bytes {
131 return Err(MlxError::InvalidArgument(format!(
132 "moe_gate: per_expert_scale buffer too small: need {} bytes, have {}",
133 expected_scale_bytes,
134 per_expert_scale.byte_len()
135 )));
136 }
137
138 let expected_ids_bytes = params.seq_len * params.top_k * u32_size;
139 if out_expert_ids.byte_len() < expected_ids_bytes {
140 return Err(MlxError::InvalidArgument(format!(
141 "moe_gate: out_expert_ids buffer too small: need {} bytes, have {}",
142 expected_ids_bytes,
143 out_expert_ids.byte_len()
144 )));
145 }
146
147 let expected_weights_bytes = params.seq_len * params.top_k * f32_size;
148 if out_weights.byte_len() < expected_weights_bytes {
149 return Err(MlxError::InvalidArgument(format!(
150 "moe_gate: out_weights buffer too small: need {} bytes, have {}",
151 expected_weights_bytes,
152 out_weights.byte_len()
153 )));
154 }
155
156 let pipeline = registry.get_pipeline("moe_gate", device)?;
158
159 let tg_threads: u64 = 128;
162
163 let threadgroups = MTLSize::new(params.seq_len as u64, 1, 1);
165 let threadgroup_size = MTLSize::new(tg_threads, 1, 1);
166
167 let shared_bytes =
174 ((params.hidden_dim + params.n_experts) * std::mem::size_of::<f32>()) as u64;
175
176 let hidden_dim_u32 = params.hidden_dim as u32;
178 let n_experts_u32 = params.n_experts as u32;
179 let top_k_u32 = params.top_k as u32;
180 let rms_eps_f32 = params.rms_eps;
181
182 use crate::encoder::{KernelArg, as_bytes};
183
184 encoder.encode_threadgroups_with_args_and_shared(
185 pipeline,
186 &[
187 (0, KernelArg::Buffer(hidden_state)),
188 (1, KernelArg::Buffer(router_weights)),
189 (2, KernelArg::Buffer(norm_weight)),
190 (3, KernelArg::Buffer(per_expert_scale)),
191 (4, KernelArg::Buffer(out_expert_ids)),
192 (5, KernelArg::Buffer(out_weights)),
193 (6, KernelArg::Bytes(as_bytes(&hidden_dim_u32))),
194 (7, KernelArg::Bytes(as_bytes(&n_experts_u32))),
195 (8, KernelArg::Bytes(as_bytes(&top_k_u32))),
196 (9, KernelArg::Bytes(as_bytes(&rms_eps_f32))),
197 ],
198 &[(0, shared_bytes)],
199 threadgroups,
200 threadgroup_size,
201 );
202
203 Ok(())
204}