use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static ROPE_SHADER_SOURCE: &str = include_str!("../shaders/rope.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("rope_f32", ROPE_SHADER_SOURCE);
registry.register_source("rope_f16", ROPE_SHADER_SOURCE);
registry.register_source("rope_bf16", ROPE_SHADER_SOURCE);
registry.register_source("rope_neox_bf16", ROPE_SHADER_SOURCE);
registry.register_source("rope_neox_f32", ROPE_SHADER_SOURCE);
}
pub fn dispatch_rope(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
positions_buf: &MlxBuffer,
seq_len: u32,
head_dim: u32,
) -> Result<()> {
if head_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(format!(
"RoPE head_dim must be even, got {}",
head_dim
)));
}
if head_dim == 0 || seq_len == 0 {
return Err(MlxError::InvalidArgument(
"RoPE head_dim and seq_len must be > 0".into(),
));
}
let expected_elements = (seq_len as usize) * (head_dim as usize);
if input.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"RoPE input element count {} != seq_len({}) * head_dim({})",
input.element_count(),
seq_len,
head_dim
)));
}
if output.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"RoPE output element count {} != seq_len({}) * head_dim({})",
output.element_count(),
seq_len,
head_dim
)));
}
let kernel_name = match input.dtype() {
DType::F32 => "rope_f32",
DType::F16 => "rope_f16",
DType::BF16 => "rope_bf16",
_ => {
return Err(MlxError::InvalidArgument(format!(
"RoPE unsupported dtype: {}",
input.dtype()
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let half_dim = head_dim / 2;
let tg_x = std::cmp::min(64, half_dim as u64);
let tg_y = std::cmp::min(4, seq_len as u64);
encoder.encode(
pipeline,
&[
(0, input),
(1, output),
(2, params_buf),
(3, positions_buf),
],
MTLSize::new(half_dim as u64, seq_len as u64, 1),
MTLSize::new(tg_x, tg_y, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuRopeNeoxParams {
n_heads: u32,
_pad: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_neox_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
positions_buf: &MlxBuffer,
seq_len: u32,
n_heads: u32,
head_dim: u32,
rope_dim: u32,
) -> Result<()> {
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
if rope_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(format!(
"RoPE neox rope_dim must be even, got {}",
rope_dim
)));
}
if rope_dim > head_dim {
return Err(MlxError::InvalidArgument(format!(
"RoPE neox rope_dim ({}) must be <= head_dim ({})",
rope_dim, head_dim
)));
}
if head_dim == 0 || seq_len == 0 || n_heads == 0 {
return Err(MlxError::InvalidArgument(
"RoPE neox head_dim, seq_len, and n_heads must be > 0".into(),
));
}
let n_rows = (seq_len as usize) * (n_heads as usize);
let expected_elements = n_rows * (head_dim as usize);
if input.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"RoPE neox input element count {} != seq_len({}) * n_heads({}) * head_dim({})",
input.element_count(),
seq_len,
n_heads,
head_dim
)));
}
if output.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"RoPE neox output element count {} != seq_len({}) * n_heads({}) * head_dim({})",
output.element_count(),
seq_len,
n_heads,
head_dim
)));
}
let pipeline = registry.get_pipeline("rope_neox_bf16", device)?;
let half_rope = rope_dim / 2;
let gpu_rope_params = GpuRopeNeoxParams {
n_heads,
_pad: 0,
};
let tg_x = std::cmp::min(64, half_rope as u64);
let tg_y = std::cmp::min(4, n_rows as u64);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(params_buf)),
(3, KernelArg::Buffer(positions_buf)),
(4, KernelArg::Bytes(as_bytes(&gpu_rope_params))),
],
MTLSize::new(half_rope as u64, n_rows as u64, 1),
MTLSize::new(tg_x, tg_y, 1),
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuRopeNeoxF32Params {
n_heads: u32,
has_freq_factors: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_neox_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
positions_buf: &MlxBuffer,
freq_factors: Option<&MlxBuffer>,
seq_len: u32,
n_heads: u32,
head_dim: u32,
rope_dim: u32,
) -> Result<()> {
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
if rope_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(format!(
"RoPE neox f32 rope_dim must be even, got {}",
rope_dim
)));
}
if rope_dim > head_dim {
return Err(MlxError::InvalidArgument(format!(
"RoPE neox f32 rope_dim ({}) must be <= head_dim ({})",
rope_dim, head_dim
)));
}
if head_dim == 0 || seq_len == 0 || n_heads == 0 {
return Err(MlxError::InvalidArgument(
"RoPE neox f32 head_dim, seq_len, and n_heads must be > 0".into(),
));
}
let n_rows = (seq_len as usize) * (n_heads as usize);
let expected_elements = n_rows * (head_dim as usize);
if input.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"RoPE neox f32 input element count {} != seq_len({}) * n_heads({}) * head_dim({})",
input.element_count(),
seq_len,
n_heads,
head_dim
)));
}
if output.element_count() != expected_elements {
return Err(MlxError::InvalidArgument(format!(
"RoPE neox f32 output element count {} != seq_len({}) * n_heads({}) * head_dim({})",
output.element_count(),
seq_len,
n_heads,
head_dim
)));
}
let pipeline = registry.get_pipeline("rope_neox_f32", device)?;
let half_rope = rope_dim / 2;
let has_ff = freq_factors.is_some();
let gpu_rope_params = GpuRopeNeoxF32Params {
n_heads,
has_freq_factors: u32::from(has_ff),
};
let ff_buf = freq_factors.unwrap_or(input);
let tg_x = std::cmp::min(64, half_rope as u64);
let tg_y = std::cmp::min(4, n_rows as u64);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Buffer(params_buf)),
(3, KernelArg::Buffer(positions_buf)),
(4, KernelArg::Bytes(as_bytes(&gpu_rope_params))),
(5, KernelArg::Buffer(ff_buf)),
],
MTLSize::new(half_rope as u64, n_rows as u64, 1),
MTLSize::new(tg_x, tg_y, 1),
);
Ok(())
}