use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_threadgroups_with_args_and_shared, KernelArg};
pub static FUSED_HEAD_NORM_ROPE_SHADER_SOURCE: &str =
include_str!("../shaders/fused_head_norm_rope_bf16.metal");
pub static FUSED_HEAD_NORM_ROPE_F32_SHADER_SOURCE: &str =
include_str!("../shaders/fused_head_norm_rope_f32.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"fused_head_norm_rope_bf16",
FUSED_HEAD_NORM_ROPE_SHADER_SOURCE,
);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedHeadNormRopeParams {
head_dim: u32,
n_heads: u32,
half_rope_dim: u32,
eps: f32,
has_weight: u32, }
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_head_norm_rope_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
norm_weight: Option<&MlxBuffer>,
cos_cache: &MlxBuffer,
sin_cache: &MlxBuffer,
n_heads: u32,
head_dim: u32,
half_rope_dim: u32,
eps: f32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_head_norm_rope: n_heads and head_dim must be > 0".into(),
));
}
if half_rope_dim > head_dim / 2 {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: half_rope_dim ({}) must be <= head_dim/2 ({})",
half_rope_dim,
head_dim / 2,
)));
}
let expected_elements = (n_heads as usize) * (head_dim as usize);
if input.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: input element count {} != n_heads({}) * head_dim({})",
input.element_count(),
n_heads,
head_dim,
)));
}
if output.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: output element count {} != n_heads({}) * head_dim({})",
output.element_count(),
n_heads,
head_dim,
)));
}
if cos_cache.element_count() < half_rope_dim as usize {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: cos_cache element count {} < half_rope_dim ({})",
cos_cache.element_count(),
half_rope_dim,
)));
}
if sin_cache.element_count() < half_rope_dim as usize {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope: sin_cache element count {} < half_rope_dim ({})",
sin_cache.element_count(),
half_rope_dim,
)));
}
let pipeline = registry.get_pipeline("fused_head_norm_rope_bf16", device)?;
let tg_size = std::cmp::min(256, head_dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
let has_weight = norm_weight.is_some();
let gpu_params = GpuFusedHeadNormRopeParams {
head_dim,
n_heads,
half_rope_dim,
eps,
has_weight: u32::from(has_weight),
};
let weight_buf = norm_weight.unwrap_or(input);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(weight_buf)),
(3, KernelArg::Buffer(cos_cache)),
(4, KernelArg::Buffer(sin_cache)),
(5, KernelArg::Bytes(as_bytes(&gpu_params))),
],
&[(0, shared_mem_bytes)],
MTLSize::new(n_heads as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuFusedHeadNormRopeF32Params {
head_dim: u32,
n_heads: u32,
half_rope_dim: u32,
eps: f32,
has_weight: u32,
theta: f32,
has_freq_factors: u32,
has_bf16_output: u32,
bf16_permuted: u32, seq_len: u32, has_f32_perm_output: u32, }
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_head_norm_rope_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
norm_weight: Option<&MlxBuffer>,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
n_heads: u32,
head_dim: u32,
half_rope_dim: u32,
eps: f32,
theta: f32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 {
return Err(MlxError::InvalidArgument(
"fused_head_norm_rope_f32: n_heads and head_dim must be > 0".into(),
));
}
if half_rope_dim > head_dim / 2 {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope_f32: half_rope_dim ({}) must be <= head_dim/2 ({})",
half_rope_dim,
head_dim / 2,
)));
}
static CACHED_FUSED_HEAD_NORM_ROPE_V2: std::sync::atomic::AtomicI8 = std::sync::atomic::AtomicI8::new(-1);
let use_v2 = (head_dim % 4 == 0)
&& crate::env_flags::cached_env_default_true(&CACHED_FUSED_HEAD_NORM_ROPE_V2, "HF2Q_FUSED_HEAD_NORM_ROPE_V2");
let kernel_name = if use_v2 {
"fused_head_norm_rope_f32_v2"
} else {
"fused_head_norm_rope_f32"
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let tg_size = std::cmp::min(256, head_dim.next_power_of_two()) as u64;
let shared_slots = std::cmp::max(tg_size as u32, head_dim);
let shared_mem_bytes = (shared_slots as u64) * 4;
let has_weight = norm_weight.is_some();
let has_ff = freq_factors.is_some();
let gpu_params = GpuFusedHeadNormRopeF32Params {
head_dim,
n_heads,
half_rope_dim,
eps,
has_weight: u32::from(has_weight),
theta,
has_freq_factors: u32::from(has_ff),
has_bf16_output: 0,
bf16_permuted: 0,
seq_len: 1, has_f32_perm_output: 0,
};
let weight_buf = norm_weight.unwrap_or(input);
let ff_buf = freq_factors.unwrap_or(input);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(weight_buf)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
(4, KernelArg::Buffer(positions_buf)),
(5, KernelArg::Buffer(ff_buf)),
(6, KernelArg::Buffer(input)),
(7, KernelArg::Buffer(input)),
],
&[(0, shared_mem_bytes)],
MTLSize::new(n_heads as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_head_norm_rope_batch_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
norm_weight: Option<&MlxBuffer>,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
n_heads: u32,
head_dim: u32,
half_rope_dim: u32,
seq_len: u32,
eps: f32,
theta: f32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 || seq_len == 0 {
return Err(MlxError::InvalidArgument(
"fused_head_norm_rope_batch_bf16: n_heads, head_dim, seq_len must be > 0".into(),
));
}
if half_rope_dim > head_dim / 2 {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope_batch_bf16: half_rope_dim ({}) must be <= head_dim/2 ({})",
half_rope_dim,
head_dim / 2,
)));
}
let pipeline = registry.get_pipeline("fused_head_norm_rope_batch_bf16", device)?;
let tg_size = std::cmp::min(256, head_dim.next_power_of_two()) as u64;
let shared_slots = std::cmp::max(tg_size as u32, head_dim);
let shared_mem_bytes = (shared_slots as u64) * 4;
let has_weight = norm_weight.is_some();
let has_ff = freq_factors.is_some();
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuBatchBf16Params {
head_dim: u32,
n_heads: u32,
half_rope_dim: u32,
eps: f32,
has_weight: u32,
theta: f32,
has_freq_factors: u32,
_pad: u32,
}
let gpu_params = GpuBatchBf16Params {
head_dim,
n_heads,
half_rope_dim,
eps,
has_weight: u32::from(has_weight),
theta,
has_freq_factors: u32::from(has_ff),
_pad: 0,
};
let weight_buf = norm_weight.unwrap_or(input);
let ff_buf = freq_factors.unwrap_or(input);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(weight_buf)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
(4, KernelArg::Buffer(positions_buf)),
(5, KernelArg::Buffer(ff_buf)),
],
&[(0, shared_mem_bytes)],
MTLSize::new((n_heads as u64) * (seq_len as u64), 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_head_norm_rope_batch_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
norm_weight: Option<&MlxBuffer>,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
n_heads: u32,
head_dim: u32,
half_rope_dim: u32,
seq_len: u32,
eps: f32,
theta: f32,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 || seq_len == 0 {
return Err(MlxError::InvalidArgument(
"fused_head_norm_rope_batch_f32: n_heads, head_dim, seq_len must be > 0".into(),
));
}
if half_rope_dim > head_dim / 2 {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope_batch_f32: half_rope_dim ({}) must be <= head_dim/2 ({})",
half_rope_dim,
head_dim / 2,
)));
}
let pipeline = registry.get_pipeline("fused_head_norm_rope_f32", device)?;
let tg_size = std::cmp::min(256, head_dim.next_power_of_two()) as u64;
let shared_slots = std::cmp::max(tg_size as u32, head_dim);
let shared_mem_bytes = (shared_slots as u64) * 4;
let has_weight = norm_weight.is_some();
let has_ff = freq_factors.is_some();
let gpu_params = GpuFusedHeadNormRopeF32Params {
head_dim,
n_heads,
half_rope_dim,
eps,
has_weight: u32::from(has_weight),
theta,
has_freq_factors: u32::from(has_ff),
has_bf16_output: 0,
bf16_permuted: 0,
seq_len,
has_f32_perm_output: 0,
};
let weight_buf = norm_weight.unwrap_or(input);
let ff_buf = freq_factors.unwrap_or(input);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(weight_buf)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
(4, KernelArg::Buffer(positions_buf)),
(5, KernelArg::Buffer(ff_buf)),
(6, KernelArg::Buffer(input)),
(7, KernelArg::Buffer(input)),
],
&[(0, shared_mem_bytes)],
MTLSize::new((n_heads as u64) * (seq_len as u64), 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_head_norm_rope_batch_f32_with_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
output_bf16: Option<&MlxBuffer>,
norm_weight: Option<&MlxBuffer>,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
n_heads: u32,
head_dim: u32,
half_rope_dim: u32,
seq_len: u32,
eps: f32,
theta: f32,
bf16_permuted: bool,
) -> Result<()> {
dispatch_fused_head_norm_rope_batch_f32_with_bf16_f32_perm(
encoder, registry, device, input, output, output_bf16, None,
norm_weight, positions_buf, freq_factors,
n_heads, head_dim, half_rope_dim, seq_len, eps, theta,
bf16_permuted,
)
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_fused_head_norm_rope_batch_f32_with_bf16_f32_perm(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
output_bf16: Option<&MlxBuffer>,
output_f32_perm: Option<&MlxBuffer>,
norm_weight: Option<&MlxBuffer>,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
n_heads: u32,
head_dim: u32,
half_rope_dim: u32,
seq_len: u32,
eps: f32,
theta: f32,
bf16_permuted: bool,
) -> Result<()> {
if n_heads == 0 || head_dim == 0 || seq_len == 0 {
return Err(MlxError::InvalidArgument(
"fused_head_norm_rope_batch_f32_with_bf16: n_heads, head_dim, seq_len must be > 0".into(),
));
}
if half_rope_dim > head_dim / 2 {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope_batch_f32_with_bf16: half_rope_dim ({}) must be <= head_dim/2 ({})",
half_rope_dim,
head_dim / 2,
)));
}
if let Some(buf) = output_bf16 {
let expected = (seq_len as usize) * (n_heads as usize) * (head_dim as usize);
if buf.element_count() < expected {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope_batch_f32_with_bf16: output_bf16 element count {} < expected {}",
buf.element_count(), expected
)));
}
}
if let Some(buf) = output_f32_perm {
let expected = (seq_len as usize) * (n_heads as usize) * (head_dim as usize);
if buf.element_count() < expected {
return Err(MlxError::InvalidArgument(format!(
"fused_head_norm_rope_batch_f32_with_bf16_f32_perm: output_f32_perm element count {} < expected {}",
buf.element_count(), expected
)));
}
}
let pipeline = registry.get_pipeline("fused_head_norm_rope_f32", device)?;
let tg_size = std::cmp::min(256, head_dim.next_power_of_two()) as u64;
let shared_slots = std::cmp::max(tg_size as u32, head_dim);
let shared_mem_bytes = (shared_slots as u64) * 4;
let has_weight = norm_weight.is_some();
let has_ff = freq_factors.is_some();
let has_bf16 = output_bf16.is_some();
let has_f32_perm = output_f32_perm.is_some();
let gpu_params = GpuFusedHeadNormRopeF32Params {
head_dim,
n_heads,
half_rope_dim,
eps,
has_weight: u32::from(has_weight),
theta,
has_freq_factors: u32::from(has_ff),
has_bf16_output: u32::from(has_bf16),
bf16_permuted: u32::from(bf16_permuted && (has_bf16 || has_f32_perm)),
seq_len,
has_f32_perm_output: u32::from(has_f32_perm),
};
let weight_buf = norm_weight.unwrap_or(input);
let ff_buf = freq_factors.unwrap_or(input);
let bf16_buf = output_bf16.unwrap_or(input);
let f32_perm_buf = output_f32_perm.unwrap_or(input);
encode_threadgroups_with_args_and_shared(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(weight_buf)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
(4, KernelArg::Buffer(positions_buf)),
(5, KernelArg::Buffer(ff_buf)),
(6, KernelArg::Buffer(bf16_buf)),
(7, KernelArg::Buffer(f32_perm_buf)),
],
&[(0, shared_mem_bytes)],
MTLSize::new((n_heads as u64) * (seq_len as u64), 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}