use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_threadgroups_with_args_and_shared, KernelArg};
pub static FUSED_RESIDUAL_NORM_SHADER_SOURCE: &str =
include_str!("../shaders/fused_residual_norm_bf16.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"fused_residual_norm_bf16",
FUSED_RESIDUAL_NORM_SHADER_SOURCE,
);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedResidualNormParams {
dim: u32,
rows: u32,
eps: f32,
write_sum: u32, }
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_residual_norm_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
residual: &MlxBuffer,
input: &MlxBuffer,
weight: &MlxBuffer,
normed_output: &MlxBuffer,
sum_output: Option<&MlxBuffer>,
rows: u32,
dim: u32,
eps: f32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_residual_norm: rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if residual.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_residual_norm: residual element count {} != rows({}) * dim({})",
residual.element_count(),
rows,
dim,
)));
}
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_residual_norm: input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim,
)));
}
if normed_output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_residual_norm: normed_output element count {} != rows({}) * dim({})",
normed_output.element_count(),
rows,
dim,
)));
}
if let Some(sum_buf) = sum_output {
if sum_buf.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"fused_residual_norm: sum_output element count {} != rows({}) * dim({})",
sum_buf.element_count(),
rows,
dim,
)));
}
}
let pipeline = registry.get_pipeline("fused_residual_norm_bf16", device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
let write_sum = sum_output.is_some();
let gpu_params = GpuFusedResidualNormParams {
dim,
rows,
eps,
write_sum: u32::from(write_sum),
};
let sum_buf = sum_output.unwrap_or(normed_output);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(residual)),
(1, KernelArg::Buffer(input)),
(2, KernelArg::Buffer(weight)),
(3, KernelArg::Buffer(normed_output)),
(4, KernelArg::Buffer(sum_buf)),
(5, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}