1use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::dtypes::DType;
12use crate::encoder::{CapturedOpKind, CommandEncoder};
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16pub static RMS_NORM_SHADER_SOURCE: &str = include_str!("../shaders/rms_norm.metal");
18
19pub fn register(registry: &mut KernelRegistry) {
21 registry.register_source("rms_norm_f32", RMS_NORM_SHADER_SOURCE);
22 registry.register_source("rms_norm_f16", RMS_NORM_SHADER_SOURCE);
23 registry.register_source("rms_norm_bf16", RMS_NORM_SHADER_SOURCE);
24 registry.register_source("rms_norm_no_scale_bf16", RMS_NORM_SHADER_SOURCE);
25 registry.register_source("rms_norm_no_scale_f32", RMS_NORM_SHADER_SOURCE);
26 registry.register_source("rms_norm_mul_f32", RMS_NORM_SHADER_SOURCE);
28 registry.register_source("rms_norm_mul_f16", RMS_NORM_SHADER_SOURCE);
29 registry.register_source("rms_norm_mul_bf16", RMS_NORM_SHADER_SOURCE);
30}
31
32fn fused_rms_norm_mul_kernel_name(dtype: DType) -> Result<&'static str> {
34 match dtype {
35 DType::F32 => Ok("rms_norm_mul_f32"),
36 DType::F16 => Ok("rms_norm_mul_f16"),
37 DType::BF16 => Ok("rms_norm_mul_bf16"),
38 _ => Err(MlxError::InvalidArgument(format!(
39 "Fused RMS norm+mul unsupported dtype: {}",
40 dtype
41 ))),
42 }
43}
44
45pub fn dispatch_rms_norm(
65 encoder: &mut CommandEncoder,
66 registry: &mut KernelRegistry,
67 device: &metal::DeviceRef,
68 input: &MlxBuffer,
69 weight: &MlxBuffer,
70 output: &MlxBuffer,
71 params_buf: &MlxBuffer,
72 rows: u32,
73 dim: u32,
74) -> Result<()> {
75 if rows == 0 || dim == 0 {
76 return Err(MlxError::InvalidArgument(
77 "RMS norm rows and dim must be > 0".into(),
78 ));
79 }
80
81 let expected = (rows as usize) * (dim as usize);
82 if input.element_count() != expected {
83 return Err(MlxError::InvalidArgument(format!(
84 "RMS norm input element count {} != rows({}) * dim({})",
85 input.element_count(),
86 rows,
87 dim
88 )));
89 }
90 if output.element_count() != expected {
91 return Err(MlxError::InvalidArgument(format!(
92 "RMS norm output element count {} != rows({}) * dim({})",
93 output.element_count(),
94 rows,
95 dim
96 )));
97 }
98
99 let kernel_name = match input.dtype() {
100 DType::F32 => "rms_norm_f32",
101 DType::F16 => "rms_norm_f16",
102 DType::BF16 => "rms_norm_bf16",
103 _ => {
104 return Err(MlxError::InvalidArgument(format!(
105 "RMS norm unsupported dtype: {}",
106 input.dtype()
107 )));
108 }
109 };
110
111 let pipeline = registry.get_pipeline(kernel_name, device)?;
112
113 let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
116
117 let shared_mem_bytes = tg_size * 4; encoder.set_op_kind(CapturedOpKind::RmsNorm);
123
124 encoder.encode_threadgroups_with_shared(
125 pipeline,
126 &[
127 (0, input),
128 (1, weight),
129 (2, output),
130 (3, params_buf),
131 ],
132 &[(0, shared_mem_bytes)],
133 MTLSize::new(rows as u64, 1, 1),
134 MTLSize::new(tg_size, 1, 1),
135 );
136
137 Ok(())
138}
139
140pub fn dispatch_rms_norm_no_scale_bf16(
160 encoder: &mut CommandEncoder,
161 registry: &mut KernelRegistry,
162 device: &metal::DeviceRef,
163 input: &MlxBuffer,
164 output: &MlxBuffer,
165 params_buf: &MlxBuffer,
166 rows: u32,
167 dim: u32,
168) -> Result<()> {
169 if rows == 0 || dim == 0 {
170 return Err(MlxError::InvalidArgument(
171 "RMS norm no_scale: rows and dim must be > 0".into(),
172 ));
173 }
174
175 let expected = (rows as usize) * (dim as usize);
176 if input.element_count() != expected {
177 return Err(MlxError::InvalidArgument(format!(
178 "RMS norm no_scale: input element count {} != rows({}) * dim({})",
179 input.element_count(),
180 rows,
181 dim
182 )));
183 }
184 if output.element_count() != expected {
185 return Err(MlxError::InvalidArgument(format!(
186 "RMS norm no_scale: output element count {} != rows({}) * dim({})",
187 output.element_count(),
188 rows,
189 dim
190 )));
191 }
192
193 let pipeline = registry.get_pipeline("rms_norm_no_scale_bf16", device)?;
194
195 let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
198
199 let shared_mem_bytes = tg_size * 4; encoder.encode_threadgroups_with_shared(
203 pipeline,
204 &[
205 (0, input),
206 (1, output),
207 (2, params_buf),
208 ],
209 &[(0, shared_mem_bytes)],
210 MTLSize::new(rows as u64, 1, 1),
211 MTLSize::new(tg_size, 1, 1),
212 );
213
214 Ok(())
215}
216
217pub fn dispatch_rms_norm_no_scale_f32(
237 encoder: &mut CommandEncoder,
238 registry: &mut KernelRegistry,
239 device: &metal::DeviceRef,
240 input: &MlxBuffer,
241 output: &MlxBuffer,
242 params_buf: &MlxBuffer,
243 rows: u32,
244 dim: u32,
245) -> Result<()> {
246 if rows == 0 || dim == 0 {
247 return Err(MlxError::InvalidArgument(
248 "RMS norm no_scale f32: rows and dim must be > 0".into(),
249 ));
250 }
251
252 let expected = (rows as usize) * (dim as usize);
253 if input.element_count() != expected {
254 return Err(MlxError::InvalidArgument(format!(
255 "RMS norm no_scale f32: input element count {} != rows({}) * dim({})",
256 input.element_count(),
257 rows,
258 dim
259 )));
260 }
261 if output.element_count() != expected {
262 return Err(MlxError::InvalidArgument(format!(
263 "RMS norm no_scale f32: output element count {} != rows({}) * dim({})",
264 output.element_count(),
265 rows,
266 dim
267 )));
268 }
269
270 let pipeline = registry.get_pipeline("rms_norm_no_scale_f32", device)?;
271
272 let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
273 let shared_mem_bytes = tg_size * 4;
274
275 encoder.encode_threadgroups_with_shared(
276 pipeline,
277 &[
278 (0, input),
279 (1, output),
280 (2, params_buf),
281 ],
282 &[(0, shared_mem_bytes)],
283 MTLSize::new(rows as u64, 1, 1),
284 MTLSize::new(tg_size, 1, 1),
285 );
286
287 Ok(())
288}
289
290#[allow(clippy::too_many_arguments)]
314pub fn dispatch_rms_norm_mul(
315 encoder: &mut CommandEncoder,
316 registry: &mut KernelRegistry,
317 device: &metal::DeviceRef,
318 input: &MlxBuffer,
319 norm_weight: &MlxBuffer,
320 scale_weight: &MlxBuffer,
321 output: &MlxBuffer,
322 params_buf: &MlxBuffer,
323 rows: u32,
324 dim: u32,
325) -> Result<()> {
326 if rows == 0 || dim == 0 {
327 return Err(MlxError::InvalidArgument(
328 "Fused RMS norm+mul: rows and dim must be > 0".into(),
329 ));
330 }
331
332 let expected = (rows as usize) * (dim as usize);
333 if input.element_count() != expected {
334 return Err(MlxError::InvalidArgument(format!(
335 "Fused RMS norm+mul: input element count {} != rows({}) * dim({})",
336 input.element_count(),
337 rows,
338 dim
339 )));
340 }
341
342 let kernel_name = fused_rms_norm_mul_kernel_name(input.dtype())?;
343 let pipeline = registry.get_pipeline(kernel_name, device)?;
344
345 let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
346 let shared_mem_bytes = tg_size * 4; encoder.encode_threadgroups_with_shared(
349 pipeline,
350 &[
351 (0, input),
352 (1, norm_weight),
353 (2, scale_weight),
354 (3, output),
355 (4, params_buf),
356 ],
357 &[(0, shared_mem_bytes)],
358 MTLSize::new(rows as u64, 1, 1),
359 MTLSize::new(tg_size, 1, 1),
360 );
361
362 Ok(())
363}