use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::{CapturedOpKind, CommandEncoder};
use crate::env_flags::env_default_true;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static RMS_NORM_SHADER_SOURCE: &str = include_str!("../shaders/rms_norm.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("rms_norm_f32", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_f16", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_bf16", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_no_scale_bf16", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_no_scale_f32", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_mul_f32", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_mul_f16", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_mul_bf16", RMS_NORM_SHADER_SOURCE);
}
fn fused_rms_norm_mul_kernel_name(dtype: DType) -> Result<&'static str> {
match dtype {
DType::F32 => Ok("rms_norm_mul_f32"),
DType::F16 => Ok("rms_norm_mul_f16"),
DType::BF16 => Ok("rms_norm_mul_bf16"),
_ => Err(MlxError::InvalidArgument(format!(
"Fused RMS norm+mul unsupported dtype: {}",
dtype
))),
}
}
pub fn dispatch_rms_norm(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"RMS norm rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
let use_v2 = matches!(input.dtype(), DType::F32)
&& (dim % 4 == 0)
&& env_default_true("HF2Q_RMS_NORM_V2");
let kernel_name = if use_v2 {
"rms_norm_f32_v2"
} else {
match input.dtype() {
DType::F32 => "rms_norm_f32",
DType::F16 => "rms_norm_f16",
DType::BF16 => "rms_norm_bf16",
_ => {
return Err(MlxError::InvalidArgument(format!(
"RMS norm unsupported dtype: {}",
input.dtype()
)));
}
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let mut tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
if use_v2 && tg_size < 32 {
tg_size = 32;
}
let shared_mem_bytes = if use_v2 {
(tg_size / 32).max(1) * 4
} else {
tg_size * 4
};
encoder.set_op_kind(CapturedOpKind::RmsNorm);
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, weight),
(2, output),
(3, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
pub fn dispatch_rms_norm_f32_triple(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
weight_a: &MlxBuffer,
weight_b: &MlxBuffer,
weight_c: &MlxBuffer,
output_a: &MlxBuffer,
output_b: &MlxBuffer,
output_c: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"RMS norm triple: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
for (name, buf) in [
("input", input),
("output_a", output_a),
("output_b", output_b),
("output_c", output_c),
] {
if buf.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm triple: {} element count {} != rows({}) * dim({})",
name, buf.element_count(), rows, dim
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"RMS norm triple: {} must be f32, got {}",
name, buf.dtype()
)));
}
}
for (name, buf) in [
("weight_a", weight_a),
("weight_b", weight_b),
("weight_c", weight_c),
] {
if buf.element_count() != dim as usize {
return Err(MlxError::InvalidArgument(format!(
"RMS norm triple: {} element count {} != dim({})",
name, buf.element_count(), dim
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"RMS norm triple: {} must be f32, got {}",
name, buf.dtype()
)));
}
}
let pipeline = registry.get_pipeline("rms_norm_f32_triple", device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, weight_a),
(2, weight_b),
(3, weight_c),
(4, output_a),
(5, output_b),
(6, output_c),
(7, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_post_attn_triple_norm_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
hidden: &MlxBuffer,
attn_out: &MlxBuffer,
post_attn_w: &MlxBuffer,
weight_a: &MlxBuffer,
weight_b: &MlxBuffer,
weight_c: &MlxBuffer,
residual_out: &MlxBuffer,
output_a: &MlxBuffer,
output_b: &MlxBuffer,
output_c: &MlxBuffer,
eps: f32,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_post_attn_triple_norm_f32: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
for (name, buf) in [
("hidden", hidden),
("attn_out", attn_out),
("residual_out", residual_out),
("output_a", output_a),
("output_b", output_b),
("output_c", output_c),
] {
if buf.element_count() < expected {
return Err(MlxError::InvalidArgument(format!(
"fused_post_attn_triple_norm_f32: {} size {} < rows({}) * dim({})",
name, buf.element_count(), rows, dim
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"fused_post_attn_triple_norm_f32: {} must be f32, got {}",
name, buf.dtype()
)));
}
}
for (name, buf) in [
("post_attn_w", post_attn_w),
("weight_a", weight_a),
("weight_b", weight_b),
("weight_c", weight_c),
] {
if buf.element_count() != dim as usize {
return Err(MlxError::InvalidArgument(format!(
"fused_post_attn_triple_norm_f32: {} size {} != dim({})",
name, buf.element_count(), dim
)));
}
}
let pipeline = registry.get_pipeline("fused_post_attn_triple_norm_f32", device)?;
let tg_size = std::cmp::min(256u32, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Params { eps: f32, dim: u32 }
let params = Params { eps, dim };
use super::encode_helpers::{as_bytes, KernelArg};
encoder.set_op_kind(CapturedOpKind::Other);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Buffer(hidden)),
(1, KernelArg::Buffer(attn_out)),
(2, KernelArg::Buffer(post_attn_w)),
(3, KernelArg::Buffer(weight_a)),
(4, KernelArg::Buffer(weight_b)),
(5, KernelArg::Buffer(weight_c)),
(6, KernelArg::Buffer(residual_out)),
(7, KernelArg::Buffer(output_a)),
(8, KernelArg::Buffer(output_b)),
(9, KernelArg::Buffer(output_c)),
(10, KernelArg::Bytes(as_bytes(¶ms))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_post_ff_norm2_endlayer_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
attn_out: &MlxBuffer,
moe_accum: &MlxBuffer,
residual: &MlxBuffer,
w2: &MlxBuffer,
w3: &MlxBuffer,
layer_scalar: &MlxBuffer,
mlp_down: &MlxBuffer,
hidden: &MlxBuffer,
eps: f32,
rows: u32,
dim: u32,
scalar_is_vector: bool,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_post_ff_norm2_endlayer_f32: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
for (name, buf) in [
("attn_out", attn_out),
("moe_accum", moe_accum),
("residual", residual),
("mlp_down", mlp_down),
("hidden", hidden),
] {
if buf.element_count() < expected {
return Err(MlxError::InvalidArgument(format!(
"fused_post_ff_norm2_endlayer_f32: {} size {} < rows({}) * dim({})",
name, buf.element_count(), rows, dim
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"fused_post_ff_norm2_endlayer_f32: {} must be f32, got {}",
name, buf.dtype()
)));
}
}
for (name, buf) in [("w2", w2), ("w3", w3)] {
if buf.element_count() != dim as usize {
return Err(MlxError::InvalidArgument(format!(
"fused_post_ff_norm2_endlayer_f32: {} size {} != dim({})",
name, buf.element_count(), dim
)));
}
}
let expected_scalar = if scalar_is_vector { dim as usize } else { 1 };
if layer_scalar.element_count() < expected_scalar {
return Err(MlxError::InvalidArgument(format!(
"fused_post_ff_norm2_endlayer_f32: layer_scalar size {} < expected {}",
layer_scalar.element_count(), expected_scalar
)));
}
let pipeline = registry.get_pipeline("fused_post_ff_norm2_endlayer_f32", device)?;
let tg_size = std::cmp::min(256u32, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Params {
eps: f32,
dim: u32,
scalar_is_vector: u32,
}
let params = Params {
eps,
dim,
scalar_is_vector: if scalar_is_vector { 1 } else { 0 },
};
use super::encode_helpers::{as_bytes, KernelArg};
encoder.set_op_kind(CapturedOpKind::Other);
encoder.encode_threadgroups_with_args_and_shared(
pipeline,
&[
(0, KernelArg::Buffer(attn_out)),
(1, KernelArg::Buffer(moe_accum)),
(2, KernelArg::Buffer(residual)),
(3, KernelArg::Buffer(w2)),
(4, KernelArg::Buffer(w3)),
(5, KernelArg::Buffer(layer_scalar)),
(6, KernelArg::Buffer(mlp_down)),
(7, KernelArg::Buffer(hidden)),
(8, KernelArg::Bytes(as_bytes(¶ms))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
pub fn dispatch_rms_norm_no_scale_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"RMS norm no_scale: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm no_scale: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm no_scale: output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
let pipeline = registry.get_pipeline("rms_norm_no_scale_bf16", device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, output),
(2, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
pub fn dispatch_rms_norm_no_scale_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"RMS norm no_scale f32: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm no_scale f32: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm no_scale f32: output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
let use_v2 = (dim % 4 == 0) && env_default_true("HF2Q_RMS_NORM_V2");
let kernel_name = if use_v2 {
"rms_norm_no_scale_f32_v2"
} else {
"rms_norm_no_scale_f32"
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let mut tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
if use_v2 && tg_size < 32 {
tg_size = 32;
}
let shared_mem_bytes = if use_v2 {
(tg_size / 32).max(1) * 4
} else {
tg_size * 4
};
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, output),
(2, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rms_norm_mul(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
norm_weight: &MlxBuffer,
scale_weight: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"Fused RMS norm+mul: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"Fused RMS norm+mul: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
let kernel_name = fused_rms_norm_mul_kernel_name(input.dtype())?;
let pipeline = registry.get_pipeline(kernel_name, device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, norm_weight),
(2, scale_weight),
(3, output),
(4, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}