mlx_native/ops/l2_norm.rs
1//! L2 Normalization GPU dispatch.
2//!
3//! Computes: `x / sqrt(sum(x^2) + eps)` over the last dimension.
4//!
5//! Used by Gated DeltaNet to normalize Q and K after the conv1d state update
6//! (ADR-013 Decision 3; spec derived from the mathematical definition of
7//! L2 norm, not from llama.cpp source).
8//!
9//! Reduction is always performed in f32 for numerical stability regardless
10//! of input dtype.
11//!
12//! # Invariants
13//!
14//! * Input and output share the same shape `[rows, dim]` and dtype.
15//! * `params_buf` must hold exactly `[eps, dim as f32]` as two contiguous f32.
16//! * `rows > 0`, `dim > 0`, `input.elements() == rows * dim`.
17//!
18//! # Threadgroup shape
19//!
20//! One threadgroup per row; threadgroup size = `min(256, next_power_of_two(dim))`.
21//! Shared memory of `tg_size` floats is used for the tree reduction.
22use metal::MTLSize;
23
24use crate::buffer::MlxBuffer;
25use crate::dtypes::DType;
26use crate::encoder::CommandEncoder;
27use crate::error::{MlxError, Result};
28use crate::kernel_registry::KernelRegistry;
29
30/// MSL source for the L2 norm kernels (embedded at compile time).
31pub static L2_NORM_SHADER_SOURCE: &str = include_str!("../shaders/l2_norm.metal");
32
33/// Register L2 norm shader sources with the given kernel registry.
34///
35/// Currently registered via `KernelRegistry::new()`'s static table;
36/// this free function exists for symmetry with other ops' registration
37/// helpers and may be used by tests that construct an empty registry.
38pub fn register(registry: &mut KernelRegistry) {
39 registry.register_source("l2_norm_f32", L2_NORM_SHADER_SOURCE);
40 registry.register_source("l2_norm_f16", L2_NORM_SHADER_SOURCE);
41 registry.register_source("l2_norm_bf16", L2_NORM_SHADER_SOURCE);
42 // ADR-015 iter59a — fused L2 norm + scalar multiply.
43 registry.register_source("l2_norm_scale_f32", L2_NORM_SHADER_SOURCE);
44}
45
46/// Dispatch an L2 normalization operation on the GPU.
47///
48/// # Arguments
49///
50/// * `encoder` - Command encoder to record the dispatch into.
51/// * `registry` - Kernel registry (must have l2_norm sources registered).
52/// * `device` - Metal device for pipeline compilation.
53/// * `input` - Input buffer of shape `[rows, dim]` (f32, f16, or bf16).
54/// * `output` - Output buffer (same dtype and shape as input).
55/// * `params_buf` - Params buffer containing `[eps, dim]` as two f32 values.
56/// * `rows` - Number of rows to normalize.
57/// * `dim` - Dimension of the last axis.
58///
59/// # Errors
60///
61/// Returns `MlxError::InvalidArgument` if:
62/// - `rows == 0` or `dim == 0`.
63/// - `input.element_count() != rows * dim`.
64/// - `output.element_count() != rows * dim`.
65/// - Input and output dtypes differ.
66/// - Input dtype is not f32, f16, or bf16.
67pub fn dispatch_l2_norm(
68 encoder: &mut CommandEncoder,
69 registry: &mut KernelRegistry,
70 device: &metal::DeviceRef,
71 input: &MlxBuffer,
72 output: &MlxBuffer,
73 params_buf: &MlxBuffer,
74 rows: u32,
75 dim: u32,
76) -> Result<()> {
77 if rows == 0 || dim == 0 {
78 return Err(MlxError::InvalidArgument(
79 "L2 norm rows and dim must be > 0".into(),
80 ));
81 }
82
83 let expected = (rows as usize) * (dim as usize);
84 if input.element_count() != expected {
85 return Err(MlxError::InvalidArgument(format!(
86 "L2 norm input element count {} != rows({}) * dim({})",
87 input.element_count(),
88 rows,
89 dim
90 )));
91 }
92 if output.element_count() != expected {
93 return Err(MlxError::InvalidArgument(format!(
94 "L2 norm output element count {} != rows({}) * dim({})",
95 output.element_count(),
96 rows,
97 dim
98 )));
99 }
100 if input.dtype() != output.dtype() {
101 return Err(MlxError::InvalidArgument(format!(
102 "L2 norm input/output dtype mismatch: {} vs {}",
103 input.dtype(),
104 output.dtype()
105 )));
106 }
107
108 let kernel_name = match input.dtype() {
109 DType::F32 => "l2_norm_f32",
110 DType::F16 => "l2_norm_f16",
111 DType::BF16 => "l2_norm_bf16",
112 _ => {
113 return Err(MlxError::InvalidArgument(format!(
114 "L2 norm unsupported dtype: {}",
115 input.dtype()
116 )));
117 }
118 };
119
120 let pipeline = registry.get_pipeline(kernel_name, device)?;
121
122 let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
123 let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
124
125 encoder.encode_threadgroups_with_shared(
126 pipeline,
127 &[(0, input), (1, output), (2, params_buf)],
128 &[(0, shared_mem_bytes)],
129 MTLSize::new(rows as u64, 1, 1),
130 MTLSize::new(tg_size, 1, 1),
131 );
132
133 Ok(())
134}
135
136/// Dispatch a fused L2 normalization + scalar multiply on the GPU.
137///
138/// Computes `output = (input / sqrt(sum(input^2) + eps)) * scale` over the
139/// last dimension, in a single kernel pass.
140///
141/// ADR-015 iter59a — fuses the post-conv1d Q-path's `dispatch_l2_norm` +
142/// `scalar_mul_f32` pair on Qwen3.5/3.6 DeltaNet into one dispatch,
143/// eliminating one dispatch per DN layer per prefill chunk and per decode
144/// token. Bit-equivalent to the unfused sequence (same f32 accumulation,
145/// same `rsqrt`; the scalar is multiplied into `rsqrt(...)` once per row
146/// instead of into every element after the fact, which is associative-real
147/// equivalent).
148///
149/// # Arguments
150///
151/// * `encoder` - Command encoder to record the dispatch into.
152/// * `registry` - Kernel registry (must have `l2_norm_scale_f32` registered).
153/// * `device` - Metal device for pipeline compilation.
154/// * `input` - Input buffer of shape `[rows, dim]` (f32).
155/// * `output` - Output buffer (f32, same shape as input).
156/// * `params_buf` - Params buffer holding `[eps, dim as f32, scale]` (3 × f32).
157/// * `rows` - Number of rows to normalize.
158/// * `dim` - Dimension of the last axis.
159///
160/// # Errors
161///
162/// Returns `MlxError::InvalidArgument` if:
163/// - `rows == 0` or `dim == 0`.
164/// - `input.element_count() != rows * dim`.
165/// - `output.element_count() != rows * dim`.
166/// - Input and output dtypes differ or are not f32.
167pub fn dispatch_l2_norm_scale_f32(
168 encoder: &mut CommandEncoder,
169 registry: &mut KernelRegistry,
170 device: &metal::DeviceRef,
171 input: &MlxBuffer,
172 output: &MlxBuffer,
173 params_buf: &MlxBuffer,
174 rows: u32,
175 dim: u32,
176) -> Result<()> {
177 if rows == 0 || dim == 0 {
178 return Err(MlxError::InvalidArgument(
179 "L2 norm scale rows and dim must be > 0".into(),
180 ));
181 }
182
183 let expected = (rows as usize) * (dim as usize);
184 if input.element_count() != expected {
185 return Err(MlxError::InvalidArgument(format!(
186 "L2 norm scale input element count {} != rows({}) * dim({})",
187 input.element_count(),
188 rows,
189 dim
190 )));
191 }
192 if output.element_count() != expected {
193 return Err(MlxError::InvalidArgument(format!(
194 "L2 norm scale output element count {} != rows({}) * dim({})",
195 output.element_count(),
196 rows,
197 dim
198 )));
199 }
200 if input.dtype() != output.dtype() {
201 return Err(MlxError::InvalidArgument(format!(
202 "L2 norm scale input/output dtype mismatch: {} vs {}",
203 input.dtype(),
204 output.dtype()
205 )));
206 }
207 if input.dtype() != DType::F32 {
208 return Err(MlxError::InvalidArgument(format!(
209 "L2 norm scale only supports f32 (got {})",
210 input.dtype()
211 )));
212 }
213
214 let pipeline = registry.get_pipeline("l2_norm_scale_f32", device)?;
215
216 let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
217 let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
218
219 encoder.encode_threadgroups_with_shared(
220 pipeline,
221 &[(0, input), (1, output), (2, params_buf)],
222 &[(0, shared_mem_bytes)],
223 MTLSize::new(rows as u64, 1, 1),
224 MTLSize::new(tg_size, 1, 1),
225 );
226
227 Ok(())
228}