Skip to main content

mlx_native/ops/
l2_norm.rs

1//! L2 Normalization GPU dispatch.
2//!
3//! Computes: `x / sqrt(sum(x^2) + eps)` over the last dimension.
4//!
5//! Used by Gated DeltaNet to normalize Q and K after the conv1d state update
6//! (ADR-013 Decision 3; spec derived from the mathematical definition of
7//! L2 norm, not from llama.cpp source).
8//!
9//! Reduction is always performed in f32 for numerical stability regardless
10//! of input dtype.
11//!
12//! # Invariants
13//!
14//! * Input and output share the same shape `[rows, dim]` and dtype.
15//! * `params_buf` must hold exactly `[eps, dim as f32]` as two contiguous f32.
16//! * `rows > 0`, `dim > 0`, `input.elements() == rows * dim`.
17//!
18//! # Threadgroup shape
19//!
20//! One threadgroup per row; threadgroup size = `min(256, next_power_of_two(dim))`.
21//! Shared memory of `tg_size` floats is used for the tree reduction.
22use metal::MTLSize;
23
24use crate::buffer::MlxBuffer;
25use crate::dtypes::DType;
26use crate::encoder::CommandEncoder;
27use crate::error::{MlxError, Result};
28use crate::kernel_registry::KernelRegistry;
29
30/// MSL source for the L2 norm kernels (embedded at compile time).
31pub static L2_NORM_SHADER_SOURCE: &str = include_str!("../shaders/l2_norm.metal");
32
33/// Register L2 norm shader sources with the given kernel registry.
34///
35/// Currently registered via `KernelRegistry::new()`'s static table;
36/// this free function exists for symmetry with other ops' registration
37/// helpers and may be used by tests that construct an empty registry.
38pub fn register(registry: &mut KernelRegistry) {
39    registry.register_source("l2_norm_f32", L2_NORM_SHADER_SOURCE);
40    registry.register_source("l2_norm_f16", L2_NORM_SHADER_SOURCE);
41    registry.register_source("l2_norm_bf16", L2_NORM_SHADER_SOURCE);
42    // ADR-015 iter59a — fused L2 norm + scalar multiply.
43    registry.register_source("l2_norm_scale_f32", L2_NORM_SHADER_SOURCE);
44}
45
46/// Dispatch an L2 normalization operation on the GPU.
47///
48/// # Arguments
49///
50/// * `encoder`    - Command encoder to record the dispatch into.
51/// * `registry`   - Kernel registry (must have l2_norm sources registered).
52/// * `device`     - Metal device for pipeline compilation.
53/// * `input`      - Input buffer of shape `[rows, dim]` (f32, f16, or bf16).
54/// * `output`     - Output buffer (same dtype and shape as input).
55/// * `params_buf` - Params buffer containing `[eps, dim]` as two f32 values.
56/// * `rows`       - Number of rows to normalize.
57/// * `dim`        - Dimension of the last axis.
58///
59/// # Errors
60///
61/// Returns `MlxError::InvalidArgument` if:
62/// - `rows == 0` or `dim == 0`.
63/// - `input.element_count() != rows * dim`.
64/// - `output.element_count() != rows * dim`.
65/// - Input and output dtypes differ.
66/// - Input dtype is not f32, f16, or bf16.
67pub fn dispatch_l2_norm(
68    encoder: &mut CommandEncoder,
69    registry: &mut KernelRegistry,
70    device: &metal::DeviceRef,
71    input: &MlxBuffer,
72    output: &MlxBuffer,
73    params_buf: &MlxBuffer,
74    rows: u32,
75    dim: u32,
76) -> Result<()> {
77    if rows == 0 || dim == 0 {
78        return Err(MlxError::InvalidArgument(
79            "L2 norm rows and dim must be > 0".into(),
80        ));
81    }
82
83    let expected = (rows as usize) * (dim as usize);
84    if input.element_count() != expected {
85        return Err(MlxError::InvalidArgument(format!(
86            "L2 norm input element count {} != rows({}) * dim({})",
87            input.element_count(),
88            rows,
89            dim
90        )));
91    }
92    if output.element_count() != expected {
93        return Err(MlxError::InvalidArgument(format!(
94            "L2 norm output element count {} != rows({}) * dim({})",
95            output.element_count(),
96            rows,
97            dim
98        )));
99    }
100    if input.dtype() != output.dtype() {
101        return Err(MlxError::InvalidArgument(format!(
102            "L2 norm input/output dtype mismatch: {} vs {}",
103            input.dtype(),
104            output.dtype()
105        )));
106    }
107
108    let kernel_name = match input.dtype() {
109        DType::F32 => "l2_norm_f32",
110        DType::F16 => "l2_norm_f16",
111        DType::BF16 => "l2_norm_bf16",
112        _ => {
113            return Err(MlxError::InvalidArgument(format!(
114                "L2 norm unsupported dtype: {}",
115                input.dtype()
116            )));
117        }
118    };
119
120    let pipeline = registry.get_pipeline(kernel_name, device)?;
121
122    let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
123    let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
124
125    encoder.encode_threadgroups_with_shared(
126        pipeline,
127        &[(0, input), (1, output), (2, params_buf)],
128        &[(0, shared_mem_bytes)],
129        MTLSize::new(rows as u64, 1, 1),
130        MTLSize::new(tg_size, 1, 1),
131    );
132
133    Ok(())
134}
135
136/// Dispatch a fused L2 normalization + scalar multiply on the GPU.
137///
138/// Computes `output = (input / sqrt(sum(input^2) + eps)) * scale` over the
139/// last dimension, in a single kernel pass.
140///
141/// ADR-015 iter59a — fuses the post-conv1d Q-path's `dispatch_l2_norm` +
142/// `scalar_mul_f32` pair on Qwen3.5/3.6 DeltaNet into one dispatch,
143/// eliminating one dispatch per DN layer per prefill chunk and per decode
144/// token.  Bit-equivalent to the unfused sequence (same f32 accumulation,
145/// same `rsqrt`; the scalar is multiplied into `rsqrt(...)` once per row
146/// instead of into every element after the fact, which is associative-real
147/// equivalent).
148///
149/// # Arguments
150///
151/// * `encoder`    - Command encoder to record the dispatch into.
152/// * `registry`   - Kernel registry (must have `l2_norm_scale_f32` registered).
153/// * `device`     - Metal device for pipeline compilation.
154/// * `input`      - Input buffer of shape `[rows, dim]` (f32).
155/// * `output`     - Output buffer (f32, same shape as input).
156/// * `params_buf` - Params buffer holding `[eps, dim as f32, scale]` (3 × f32).
157/// * `rows`       - Number of rows to normalize.
158/// * `dim`        - Dimension of the last axis.
159///
160/// # Errors
161///
162/// Returns `MlxError::InvalidArgument` if:
163/// - `rows == 0` or `dim == 0`.
164/// - `input.element_count() != rows * dim`.
165/// - `output.element_count() != rows * dim`.
166/// - Input and output dtypes differ or are not f32.
167pub fn dispatch_l2_norm_scale_f32(
168    encoder: &mut CommandEncoder,
169    registry: &mut KernelRegistry,
170    device: &metal::DeviceRef,
171    input: &MlxBuffer,
172    output: &MlxBuffer,
173    params_buf: &MlxBuffer,
174    rows: u32,
175    dim: u32,
176) -> Result<()> {
177    if rows == 0 || dim == 0 {
178        return Err(MlxError::InvalidArgument(
179            "L2 norm scale rows and dim must be > 0".into(),
180        ));
181    }
182
183    let expected = (rows as usize) * (dim as usize);
184    if input.element_count() != expected {
185        return Err(MlxError::InvalidArgument(format!(
186            "L2 norm scale input element count {} != rows({}) * dim({})",
187            input.element_count(),
188            rows,
189            dim
190        )));
191    }
192    if output.element_count() != expected {
193        return Err(MlxError::InvalidArgument(format!(
194            "L2 norm scale output element count {} != rows({}) * dim({})",
195            output.element_count(),
196            rows,
197            dim
198        )));
199    }
200    if input.dtype() != output.dtype() {
201        return Err(MlxError::InvalidArgument(format!(
202            "L2 norm scale input/output dtype mismatch: {} vs {}",
203            input.dtype(),
204            output.dtype()
205        )));
206    }
207    if input.dtype() != DType::F32 {
208        return Err(MlxError::InvalidArgument(format!(
209            "L2 norm scale only supports f32 (got {})",
210            input.dtype()
211        )));
212    }
213
214    let pipeline = registry.get_pipeline("l2_norm_scale_f32", device)?;
215
216    let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
217    let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
218
219    encoder.encode_threadgroups_with_shared(
220        pipeline,
221        &[(0, input), (1, output), (2, params_buf)],
222        &[(0, shared_mem_bytes)],
223        MTLSize::new(rows as u64, 1, 1),
224        MTLSize::new(tg_size, 1, 1),
225    );
226
227    Ok(())
228}