use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::{CapturedOpKind, CommandEncoder};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static RMS_NORM_SHADER_SOURCE: &str = include_str!("../shaders/rms_norm.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("rms_norm_f32", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_f16", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_bf16", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_no_scale_bf16", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_no_scale_f32", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_mul_f32", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_mul_f16", RMS_NORM_SHADER_SOURCE);
registry.register_source("rms_norm_mul_bf16", RMS_NORM_SHADER_SOURCE);
}
fn fused_rms_norm_mul_kernel_name(dtype: DType) -> Result<&'static str> {
match dtype {
DType::F32 => Ok("rms_norm_mul_f32"),
DType::F16 => Ok("rms_norm_mul_f16"),
DType::BF16 => Ok("rms_norm_mul_bf16"),
_ => Err(MlxError::InvalidArgument(format!(
"Fused RMS norm+mul unsupported dtype: {}",
dtype
))),
}
}
pub fn dispatch_rms_norm(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"RMS norm rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
let kernel_name = match input.dtype() {
DType::F32 => "rms_norm_f32",
DType::F16 => "rms_norm_f16",
DType::BF16 => "rms_norm_bf16",
_ => {
return Err(MlxError::InvalidArgument(format!(
"RMS norm unsupported dtype: {}",
input.dtype()
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.set_op_kind(CapturedOpKind::RmsNorm);
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, weight),
(2, output),
(3, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
pub fn dispatch_rms_norm_no_scale_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"RMS norm no_scale: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm no_scale: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm no_scale: output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
let pipeline = registry.get_pipeline("rms_norm_no_scale_bf16", device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, output),
(2, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
pub fn dispatch_rms_norm_no_scale_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"RMS norm no_scale f32: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm no_scale f32: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"RMS norm no_scale f32: output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
let pipeline = registry.get_pipeline("rms_norm_no_scale_f32", device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, output),
(2, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rms_norm_mul(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
norm_weight: &MlxBuffer,
scale_weight: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"Fused RMS norm+mul: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"Fused RMS norm+mul: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
let kernel_name = fused_rms_norm_mul_kernel_name(input.dtype())?;
let pipeline = registry.get_pipeline(kernel_name, device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, norm_weight),
(2, scale_weight),
(3, output),
(4, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}