1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
//! Fused residual addition + RMS normalization GPU dispatch (bf16).
//!
//! Replaces two separate dispatches — `elementwise_add_bf16` followed by
//! `rms_norm_bf16` — with a single kernel launch per transformer sub-layer.
//! Saves approximately 136 kernel dispatches per Gemma 4 forward pass.
//!
//! The Metal kernel reads each row once, computes the elementwise sum,
//! optionally writes the un-normed sum for subsequent residual use,
//! reduces for RMS, then normalizes and writes the output — the intermediate
//! summed buffer is never materialized in GPU memory between the two ops.
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_threadgroups_with_args_and_shared, KernelArg};
/// MSL source embedded at compile time.
pub static FUSED_RESIDUAL_NORM_SHADER_SOURCE: &str =
include_str!("../shaders/fused_residual_norm_bf16.metal");
/// Register the fused residual-norm shader with the given kernel registry.
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"fused_residual_norm_bf16",
FUSED_RESIDUAL_NORM_SHADER_SOURCE,
);
}
/// GPU params struct — must match `FusedResidualNormParams` in the MSL shader exactly.
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedResidualNormParams {
dim: u32,
rows: u32,
eps: f32,
write_sum: u32, // 0 = do not write sum_output, nonzero = write it
}
/// Dispatch a fused residual addition + RMS normalization operation.
///
/// Computes `normed = rms_norm(residual + input, weight, eps)` in a single pass.
/// Optionally also writes the un-normalized sum `residual + input` to `sum_output`
/// so the caller can use it as the residual stream for the next layer without a
/// second elementwise kernel.
///
/// # Arguments
///
/// * `encoder` - Command encoder to record the dispatch into.
/// * `registry` - Kernel registry (must have fused_residual_norm_bf16 registered).
/// * `device` - Metal device for pipeline compilation.
/// * `residual` - bf16 buffer of shape `[rows, dim]` — the residual stream.
/// * `input` - bf16 buffer of shape `[rows, dim]` — the sublayer output to add.
/// * `weight` - bf16 buffer of shape `[dim]` — RMS norm learned scale.
/// * `normed_output` - bf16 output buffer of shape `[rows, dim]` — normalized result.
/// * `sum_output` - Optional bf16 output buffer of shape `[rows, dim]`. When
/// `Some`, the un-normed sum is written here (for use as the next
/// layer's residual). When `None`, that write is skipped.
/// * `rows` - Number of rows (tokens) in the batch.
/// * `dim` - Hidden dimension (last axis size).
/// * `eps` - RMS normalization epsilon (1e-6 for Gemma 4).
///
/// # Errors
///
/// Returns `MlxError::InvalidArgument` if parameters are inconsistent.
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_residual_norm_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
residual: &MlxBuffer,
input: &MlxBuffer,
weight: &MlxBuffer,
normed_output: &MlxBuffer,
sum_output: Option<&MlxBuffer>,
rows: u32,
dim: u32,
eps: f32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_residual_norm: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if residual.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_residual_norm: residual element count {} != rows({}) * dim({})",
residual.element_count(),
rows,
dim,
)));
}
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_residual_norm: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim,
)));
}
if normed_output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_residual_norm: normed_output element count {} != rows({}) * dim({})",
normed_output.element_count(),
rows,
dim,
)));
}
if let Some(sum_buf) = sum_output {
if sum_buf.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_residual_norm: sum_output element count {} != rows({}) * dim({})",
sum_buf.element_count(),
rows,
dim,
)));
}
}
let pipeline = registry.get_pipeline("fused_residual_norm_bf16", device)?;
// One threadgroup per row; size is the smallest power-of-two >= dim,
// capped at 256.
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
// Shared memory: tg_size f32 values — used first to cache element sums,
// then reused for the parallel sum-of-squares reduction.
let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
let write_sum = sum_output.is_some();
let gpu_params = GpuFusedResidualNormParams {
dim,
rows,
eps,
write_sum: u32::from(write_sum),
};
// When sum_output is None we still bind a dummy at buffer 4 to satisfy
// Metal's requirement that all declared buffers are bound. The shader
// checks write_sum before accessing the pointer so no data is written.
let sum_buf = sum_output.unwrap_or(normed_output);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(residual)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(weight)),
(3, KernelArg::Buffer(normed_output)),
(4, KernelArg::Buffer(sum_buf)),
(5, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}