Skip to main content

mlx_native/ops/
rms_norm.rs

1//! RMS Normalization GPU dispatch.
2//!
3//! Computes: `x * rsqrt(mean(x^2) + eps) * weight`
4//!
5//! The mean is computed over the last dimension.  eps=1e-6 is the standard
6//! value for Gemma 4.
7
8use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::dtypes::DType;
12use crate::encoder::{CapturedOpKind, CommandEncoder};
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16/// MSL source for the RMS norm kernels (embedded at compile time).
17pub static RMS_NORM_SHADER_SOURCE: &str = include_str!("../shaders/rms_norm.metal");
18
19/// Register RMS norm shader sources with the given kernel registry.
20pub fn register(registry: &mut KernelRegistry) {
21    registry.register_source("rms_norm_f32", RMS_NORM_SHADER_SOURCE);
22    registry.register_source("rms_norm_f16", RMS_NORM_SHADER_SOURCE);
23    registry.register_source("rms_norm_bf16", RMS_NORM_SHADER_SOURCE);
24    registry.register_source("rms_norm_no_scale_bf16", RMS_NORM_SHADER_SOURCE);
25    registry.register_source("rms_norm_no_scale_f32", RMS_NORM_SHADER_SOURCE);
26    // Fused RMS norm + elementwise multiply (Phase 4e.2)
27    registry.register_source("rms_norm_mul_f32", RMS_NORM_SHADER_SOURCE);
28    registry.register_source("rms_norm_mul_f16", RMS_NORM_SHADER_SOURCE);
29    registry.register_source("rms_norm_mul_bf16", RMS_NORM_SHADER_SOURCE);
30}
31
32/// Select the fused RMS norm + multiply kernel name based on input dtype.
33fn fused_rms_norm_mul_kernel_name(dtype: DType) -> Result<&'static str> {
34    match dtype {
35        DType::F32 => Ok("rms_norm_mul_f32"),
36        DType::F16 => Ok("rms_norm_mul_f16"),
37        DType::BF16 => Ok("rms_norm_mul_bf16"),
38        _ => Err(MlxError::InvalidArgument(format!(
39            "Fused RMS norm+mul unsupported dtype: {}",
40            dtype
41        ))),
42    }
43}
44
45/// Dispatch an RMS normalization operation on the GPU.
46///
47/// # Arguments
48///
49/// * `encoder`    - Command encoder to record the dispatch into.
50/// * `registry`   - Kernel registry (must have rms_norm sources registered).
51/// * `device`     - Metal device for pipeline compilation.
52/// * `input`      - Input buffer of shape `[rows, dim]` (f32, f16, or bf16).
53/// * `weight`     - Weight buffer of shape `[dim]` (same dtype as input; f32 for f32/f16 kernels, bf16 for bf16).
54/// * `output`     - Output buffer (same dtype and shape as input).
55/// * `params_buf` - Params buffer containing `[eps, dim]` as f32.
56/// * `rows`       - Number of rows to normalize.
57/// * `dim`        - Dimension of the last axis.
58///
59/// # Errors
60///
61/// Returns `MlxError::InvalidArgument` if:
62/// - Input dtype is not f32, f16, or bf16.
63/// - Input element count does not match rows * dim.
64pub fn dispatch_rms_norm(
65    encoder: &mut CommandEncoder,
66    registry: &mut KernelRegistry,
67    device: &metal::DeviceRef,
68    input: &MlxBuffer,
69    weight: &MlxBuffer,
70    output: &MlxBuffer,
71    params_buf: &MlxBuffer,
72    rows: u32,
73    dim: u32,
74) -> Result<()> {
75    if rows == 0 || dim == 0 {
76        return Err(MlxError::InvalidArgument(
77            "RMS norm rows and dim must be > 0".into(),
78        ));
79    }
80
81    let expected = (rows as usize) * (dim as usize);
82    if input.element_count() != expected {
83        return Err(MlxError::InvalidArgument(format!(
84            "RMS norm input element count {} != rows({}) * dim({})",
85            input.element_count(),
86            rows,
87            dim
88        )));
89    }
90    if output.element_count() != expected {
91        return Err(MlxError::InvalidArgument(format!(
92            "RMS norm output element count {} != rows({}) * dim({})",
93            output.element_count(),
94            rows,
95            dim
96        )));
97    }
98
99    let kernel_name = match input.dtype() {
100        DType::F32 => "rms_norm_f32",
101        DType::F16 => "rms_norm_f16",
102        DType::BF16 => "rms_norm_bf16",
103        _ => {
104            return Err(MlxError::InvalidArgument(format!(
105                "RMS norm unsupported dtype: {}",
106                input.dtype()
107            )));
108        }
109    };
110
111    let pipeline = registry.get_pipeline(kernel_name, device)?;
112
113    // One threadgroup per row.  Threadgroup size must be a power of 2
114    // for the tree reduction to work correctly.
115    let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
116
117    // Threadgroup shared memory: tg_size floats for the reduction.
118    let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
119
120    // Tag for the fusion pass (Phase 4e.2): RMS norm can fuse with a
121    // subsequent elementwise multiply.
122    encoder.set_op_kind(CapturedOpKind::RmsNorm);
123
124    encoder.encode_threadgroups_with_shared(
125        pipeline,
126        &[
127            (0, input),
128            (1, weight),
129            (2, output),
130            (3, params_buf),
131        ],
132        &[(0, shared_mem_bytes)],
133        MTLSize::new(rows as u64, 1, 1),
134        MTLSize::new(tg_size, 1, 1),
135    );
136
137    Ok(())
138}
139
140/// Dispatch an RMS normalization without learned scale (bf16 only).
141///
142/// Computes: `output = x * rsqrt(mean(x^2) + eps)` — no weight multiplication.
143/// Used for per-head V normalization in Gemma 4.
144///
145/// # Arguments
146///
147/// * `encoder`    - Command encoder to record the dispatch into.
148/// * `registry`   - Kernel registry (must have rms_norm_no_scale_bf16 registered).
149/// * `device`     - Metal device for pipeline compilation.
150/// * `input`      - Input buffer of shape `[rows, dim]` (bf16).
151/// * `output`     - Output buffer (same dtype and shape as input).
152/// * `params_buf` - Params buffer containing `[eps, dim]` as f32.
153/// * `rows`       - Number of rows to normalize.
154/// * `dim`        - Dimension of the last axis.
155///
156/// # Errors
157///
158/// Returns `MlxError::InvalidArgument` if parameters are invalid.
159pub fn dispatch_rms_norm_no_scale_bf16(
160    encoder: &mut CommandEncoder,
161    registry: &mut KernelRegistry,
162    device: &metal::DeviceRef,
163    input: &MlxBuffer,
164    output: &MlxBuffer,
165    params_buf: &MlxBuffer,
166    rows: u32,
167    dim: u32,
168) -> Result<()> {
169    if rows == 0 || dim == 0 {
170        return Err(MlxError::InvalidArgument(
171            "RMS norm no_scale: rows and dim must be > 0".into(),
172        ));
173    }
174
175    let expected = (rows as usize) * (dim as usize);
176    if input.element_count() != expected {
177        return Err(MlxError::InvalidArgument(format!(
178            "RMS norm no_scale: input element count {} != rows({}) * dim({})",
179            input.element_count(),
180            rows,
181            dim
182        )));
183    }
184    if output.element_count() != expected {
185        return Err(MlxError::InvalidArgument(format!(
186            "RMS norm no_scale: output element count {} != rows({}) * dim({})",
187            output.element_count(),
188            rows,
189            dim
190        )));
191    }
192
193    let pipeline = registry.get_pipeline("rms_norm_no_scale_bf16", device)?;
194
195    // One threadgroup per row.  Threadgroup size must be a power of 2
196    // for the tree reduction to work correctly.
197    let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
198
199    // Threadgroup shared memory: tg_size floats for the reduction.
200    let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
201
202    encoder.encode_threadgroups_with_shared(
203        pipeline,
204        &[
205            (0, input),
206            (1, output),
207            (2, params_buf),
208        ],
209        &[(0, shared_mem_bytes)],
210        MTLSize::new(rows as u64, 1, 1),
211        MTLSize::new(tg_size, 1, 1),
212    );
213
214    Ok(())
215}
216
217/// Dispatch an RMS normalization without learned scale (f32).
218///
219/// Computes: `output = x * rsqrt(mean(x^2) + eps)` -- no weight multiplication.
220/// Used for per-head V normalization in Gemma 4 when activations are f32.
221///
222/// # Arguments
223///
224/// * `encoder`    - Command encoder to record the dispatch into.
225/// * `registry`   - Kernel registry (must have rms_norm_no_scale_f32 registered).
226/// * `device`     - Metal device for pipeline compilation.
227/// * `input`      - Input buffer of shape `[rows, dim]` (f32).
228/// * `output`     - Output buffer (same dtype and shape as input).
229/// * `params_buf` - Params buffer containing `[eps, dim]` as f32.
230/// * `rows`       - Number of rows to normalize.
231/// * `dim`        - Dimension of the last axis.
232///
233/// # Errors
234///
235/// Returns `MlxError::InvalidArgument` if parameters are invalid.
236pub fn dispatch_rms_norm_no_scale_f32(
237    encoder: &mut CommandEncoder,
238    registry: &mut KernelRegistry,
239    device: &metal::DeviceRef,
240    input: &MlxBuffer,
241    output: &MlxBuffer,
242    params_buf: &MlxBuffer,
243    rows: u32,
244    dim: u32,
245) -> Result<()> {
246    if rows == 0 || dim == 0 {
247        return Err(MlxError::InvalidArgument(
248            "RMS norm no_scale f32: rows and dim must be > 0".into(),
249        ));
250    }
251
252    let expected = (rows as usize) * (dim as usize);
253    if input.element_count() != expected {
254        return Err(MlxError::InvalidArgument(format!(
255            "RMS norm no_scale f32: input element count {} != rows({}) * dim({})",
256            input.element_count(),
257            rows,
258            dim
259        )));
260    }
261    if output.element_count() != expected {
262        return Err(MlxError::InvalidArgument(format!(
263            "RMS norm no_scale f32: output element count {} != rows({}) * dim({})",
264            output.element_count(),
265            rows,
266            dim
267        )));
268    }
269
270    let pipeline = registry.get_pipeline("rms_norm_no_scale_f32", device)?;
271
272    let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
273    let shared_mem_bytes = tg_size * 4;
274
275    encoder.encode_threadgroups_with_shared(
276        pipeline,
277        &[
278            (0, input),
279            (1, output),
280            (2, params_buf),
281        ],
282        &[(0, shared_mem_bytes)],
283        MTLSize::new(rows as u64, 1, 1),
284        MTLSize::new(tg_size, 1, 1),
285    );
286
287    Ok(())
288}
289
290/// Dispatch a fused RMS normalization + elementwise multiply.
291///
292/// Computes: `output = (input * rsqrt(mean(input^2) + eps) * weight) * scale`
293///
294/// This replaces the two-dispatch pattern:
295///   1. `rms_norm(input, weight) -> tmp`
296///   2. `elementwise_mul(tmp, scale) -> output`
297///
298/// with a single kernel pass, eliminating one barrier and one global memory
299/// round-trip for the intermediate `tmp` buffer.
300///
301/// # Arguments
302///
303/// * `encoder`      - Command encoder to record the dispatch into.
304/// * `registry`     - Kernel registry.
305/// * `device`       - Metal device for pipeline compilation.
306/// * `input`        - Input buffer of shape `[rows, dim]`.
307/// * `norm_weight`  - Norm weight buffer of shape `[dim]`.
308/// * `scale_weight` - Scale (MUL operand) buffer of shape `[rows, dim]`.
309/// * `output`       - Output buffer of shape `[rows, dim]`.
310/// * `params_buf`   - Params buffer containing `[eps, dim]` as f32.
311/// * `rows`         - Number of rows.
312/// * `dim`          - Dimension of the last axis.
313#[allow(clippy::too_many_arguments)]
314pub fn dispatch_rms_norm_mul(
315    encoder: &mut CommandEncoder,
316    registry: &mut KernelRegistry,
317    device: &metal::DeviceRef,
318    input: &MlxBuffer,
319    norm_weight: &MlxBuffer,
320    scale_weight: &MlxBuffer,
321    output: &MlxBuffer,
322    params_buf: &MlxBuffer,
323    rows: u32,
324    dim: u32,
325) -> Result<()> {
326    if rows == 0 || dim == 0 {
327        return Err(MlxError::InvalidArgument(
328            "Fused RMS norm+mul: rows and dim must be > 0".into(),
329        ));
330    }
331
332    let expected = (rows as usize) * (dim as usize);
333    if input.element_count() != expected {
334        return Err(MlxError::InvalidArgument(format!(
335            "Fused RMS norm+mul: input element count {} != rows({}) * dim({})",
336            input.element_count(),
337            rows,
338            dim
339        )));
340    }
341
342    let kernel_name = fused_rms_norm_mul_kernel_name(input.dtype())?;
343    let pipeline = registry.get_pipeline(kernel_name, device)?;
344
345    let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
346    let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
347
348    encoder.encode_threadgroups_with_shared(
349        pipeline,
350        &[
351            (0, input),
352            (1, norm_weight),
353            (2, scale_weight),
354            (3, output),
355            (4, params_buf),
356        ],
357        &[(0, shared_mem_bytes)],
358        MTLSize::new(rows as u64, 1, 1),
359        MTLSize::new(tg_size, 1, 1),
360    );
361
362    Ok(())
363}