use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_threadgroups_with_args_and_shared, KernelArg};
pub static FUSED_HEAD_NORM_ROPE_SHADER_SOURCE: &str =
include_str!("../shaders/fused_head_norm_rope_bf16.metal");
pub static FUSED_HEAD_NORM_ROPE_F32_SHADER_SOURCE: &str =
include_str!("../shaders/fused_head_norm_rope_f32.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"fused_head_norm_rope_bf16",
FUSED_HEAD_NORM_ROPE_SHADER_SOURCE,
);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedHeadNormRopeParams {
head_dim: u32,
n_heads: u32,
half_rope_dim: u32,
eps: f32,
has_weight: u32, }
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_head_norm_rope_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
norm_weight: Option<&MlxBuffer>,
cos_cache: &MlxBuffer,
sin_cache: &MlxBuffer,
n_heads: u32,
head_dim: u32,
half_rope_dim: u32,
eps: f32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_head_norm_rope: n_heads and head_dim must be > 0".into(),
));
}
if half_rope_dim > head_dim / 2 {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: half_rope_dim ({}) must be <= head_dim/2 ({})",
half_rope_dim,
head_dim / 2,
)));
}
let expected_elements = (n_heads as usize) * (head_dim as usize);
if input.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: input element count {} != n_heads({}) * head_dim({})",
input.element_count(),
n_heads,
head_dim,
)));
}
if output.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: output element count {} != n_heads({}) * head_dim({})",
output.element_count(),
n_heads,
head_dim,
)));
}
if cos_cache.element_count() < half_rope_dim as usize {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: cos_cache element count {} < half_rope_dim ({})",
cos_cache.element_count(),
half_rope_dim,
)));
}
if sin_cache.element_count() < half_rope_dim as usize {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: sin_cache element count {} < half_rope_dim ({})",
sin_cache.element_count(),
half_rope_dim,
)));
}
let pipeline = registry.get_pipeline("fused_head_norm_rope_bf16", device)?;
let tg_size = std::cmp::min(256, head_dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
let has_weight = norm_weight.is_some();
let gpu_params = GpuFusedHeadNormRopeParams {
head_dim,
n_heads,
half_rope_dim,
eps,
has_weight: u32::from(has_weight),
};
let weight_buf = norm_weight.unwrap_or(input);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(weight_buf)),
(3, KernelArg::Buffer(cos_cache)),
(4, KernelArg::Buffer(sin_cache)),
(5, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(n_heads as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedHeadNormRopeF32Params {
head_dim: u32,
n_heads: u32,
half_rope_dim: u32,
eps: f32,
has_weight: u32,
theta: f32,
has_freq_factors: u32,
_pad: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_head_norm_rope_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
norm_weight: Option<&MlxBuffer>,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
n_heads: u32,
head_dim: u32,
half_rope_dim: u32,
eps: f32,
theta: f32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_head_norm_rope_f32: n_heads and head_dim must be > 0".into(),
));
}
if half_rope_dim > head_dim / 2 {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope_f32: half_rope_dim ({}) must be <= head_dim/2 ({})",
half_rope_dim,
head_dim / 2,
)));
}
let pipeline = registry.get_pipeline("fused_head_norm_rope_f32", device)?;
let tg_size = std::cmp::min(256, head_dim.next_power_of_two()) as u64;
let shared_slots = std::cmp::max(tg_size as u32, head_dim);
let shared_mem_bytes = (shared_slots as u64) * 4;
let has_weight = norm_weight.is_some();
let has_ff = freq_factors.is_some();
let gpu_params = GpuFusedHeadNormRopeF32Params {
head_dim,
n_heads,
half_rope_dim,
eps,
has_weight: u32::from(has_weight),
theta,
has_freq_factors: u32::from(has_ff),
_pad: 0,
};
let weight_buf = norm_weight.unwrap_or(input);
let ff_buf = freq_factors.unwrap_or(input);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(weight_buf)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
(4, KernelArg::Buffer(positions_buf)),
(5, KernelArg::Buffer(ff_buf)),
],
&[(0, shared_mem_bytes)],
MTLSize::new(n_heads as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}