use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_threadgroups_with_args_and_shared, KernelArg};
pub static FUSED_NORM_ADD_SHADER_SOURCE: &str =
include_str!("../shaders/fused_norm_add_bf16.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("fused_norm_add_bf16", FUSED_NORM_ADD_SHADER_SOURCE);
registry.register_source(
"fused_norm_add_no_weight_bf16",
FUSED_NORM_ADD_SHADER_SOURCE,
);
}
#[inline]
fn tg_size_for_dim(dim: u32) -> u64 {
std::cmp::min(256, dim.next_power_of_two()) as u64
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_norm_add_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
residual: &MlxBuffer,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
dim: u32,
rows: u32,
eps: f32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_norm_add: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if residual.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_norm_add: residual element count {} != rows({}) * dim({})",
residual.element_count(),
rows,
dim,
)));
}
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_norm_add: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim,
)));
}
if weight.element_count() != dim as usize {
return Err(MlxError::InvalidArgument(format!(
"fused_norm_add: weight element count {} != dim({})",
weight.element_count(),
dim,
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_norm_add: output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim,
)));
}
let pipeline = registry.get_pipeline("fused_norm_add_bf16", device)?;
let tg_size = tg_size_for_dim(dim);
let shared_mem_bytes = tg_size * 4;
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(residual)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(weight)),
(3, KernelArg::Buffer(output)),
(4, KernelArg::Bytes(as_bytes(&dim))),
(5, KernelArg::Bytes(as_bytes(&rows))),
(6, KernelArg::Bytes(as_bytes(&eps))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_norm_add_no_weight_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
residual: &MlxBuffer,
input: &MlxBuffer,
output: &MlxBuffer,
dim: u32,
rows: u32,
eps: f32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_norm_add_no_weight: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if residual.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_norm_add_no_weight: residual element count {} != rows({}) * dim({})",
residual.element_count(),
rows,
dim,
)));
}
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_norm_add_no_weight: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim,
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_norm_add_no_weight: output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim,
)));
}
let pipeline = registry.get_pipeline("fused_norm_add_no_weight_bf16", device)?;
let tg_size = tg_size_for_dim(dim);
let shared_mem_bytes = tg_size * 4;
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(residual)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&dim))),
(4, KernelArg::Bytes(as_bytes(&rows))),
(5, KernelArg::Bytes(as_bytes(&eps))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_norm_add_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
residual: &MlxBuffer,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
dim: u32,
rows: u32,
eps: f32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_norm_add_f32: rows and dim must be > 0".into(),
));
}
let pipeline = registry.get_pipeline("fused_norm_add_f32", device)?;
let tg_size = tg_size_for_dim(dim);
let shared_mem_bytes = tg_size * 4;
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(residual)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(weight)),
(3, KernelArg::Buffer(output)),
(4, KernelArg::Bytes(as_bytes(&dim))),
(5, KernelArg::Bytes(as_bytes(&rows))),
(6, KernelArg::Bytes(as_bytes(&eps))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedResidualNormF32Params {
dim: u32,
rows: u32,
eps: f32,
write_sum: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_residual_norm_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
residual: &MlxBuffer,
input: &MlxBuffer,
weight: &MlxBuffer,
normed_output: &MlxBuffer,
sum_output: Option<&MlxBuffer>,
rows: u32,
dim: u32,
eps: f32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_residual_norm_f32: rows and dim must be > 0".into(),
));
}
let pipeline = registry.get_pipeline("fused_residual_norm_f32", device)?;
let tg_size = tg_size_for_dim(dim);
let shared_slots = std::cmp::max(tg_size as u32, dim);
let shared_mem_bytes = (shared_slots as u64) * 4;
let write_sum = sum_output.is_some();
let gpu_params = GpuFusedResidualNormF32Params {
dim,
rows,
eps,
write_sum: u32::from(write_sum),
};
let sum_buf = sum_output.unwrap_or(normed_output);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(residual)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(weight)),
(3, KernelArg::Buffer(normed_output)),
(4, KernelArg::Buffer(sum_buf)),
(5, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedResidualNormScalarF32Params {
dim: u32,
rows: u32,
eps: f32,
scalar_is_vector: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_residual_norm_scalar_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
residual: &MlxBuffer,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
scalar: &MlxBuffer,
rows: u32,
dim: u32,
eps: f32,
scalar_is_vector: bool,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_residual_norm_scalar_f32: rows and dim must be > 0".into(),
));
}
let pipeline = registry.get_pipeline("fused_residual_norm_scalar_f32", device)?;
let tg_size = tg_size_for_dim(dim);
let shared_slots = std::cmp::max(tg_size as u32, dim);
let shared_mem_bytes = (shared_slots as u64) * 4;
let gpu_params = GpuFusedResidualNormScalarF32Params {
dim,
rows,
eps,
scalar_is_vector: u32::from(scalar_is_vector),
};
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(residual)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(weight)),
(3, KernelArg::Buffer(output)),
(4, KernelArg::Buffer(scalar)),
(5, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedMoeRoutingParams {
num_experts: u32,
top_k: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_moe_routing_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
logits: &MlxBuffer,
expert_ids: &MlxBuffer,
routing_weights: &MlxBuffer,
per_expert_scale: &MlxBuffer,
num_experts: u32,
top_k: u32,
) -> Result<()> {
if num_experts == 0 || top_k == 0 {
return Err(MlxError::InvalidArgument(
"fused_moe_routing_f32: num_experts and top_k must be > 0".into(),
));
}
let pipeline = registry.get_pipeline("fused_moe_routing_f32", device)?;
let gpu_params = GpuFusedMoeRoutingParams {
num_experts,
top_k,
};
let tg_size = std::cmp::min(64, num_experts.next_power_of_two()) as u64;
let shared_slots = 2 * num_experts + tg_size as u32;
let shared_mem_bytes = (shared_slots as u64) * 4;
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(logits)),
(1, KernelArg::Buffer(expert_ids)),
(2, KernelArg::Buffer(routing_weights)),
(3, KernelArg::Buffer(per_expert_scale)),
(4, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(1, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedNormAddScalarF32Params {
dim: u32,
rows: u32,
eps: f32,
scalar_is_vector: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_norm_add_scalar_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
residual: &MlxBuffer,
input: &MlxBuffer,
weight: &MlxBuffer,
output: &MlxBuffer,
scalar: &MlxBuffer,
rows: u32,
dim: u32,
eps: f32,
scalar_is_vector: bool,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_norm_add_scalar_f32: rows and dim must be > 0".into(),
));
}
let pipeline = registry.get_pipeline("fused_norm_add_scalar_f32", device)?;
let tg_size = tg_size_for_dim(dim);
let shared_mem_bytes = tg_size * 4;
let gpu_params = GpuFusedNormAddScalarF32Params {
dim,
rows,
eps,
scalar_is_vector: u32::from(scalar_is_vector),
};
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(residual)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(weight)),
(3, KernelArg::Buffer(output)),
(4, KernelArg::Buffer(scalar)),
(5, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}