use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static VISION_2D_ROPE_SHADER_SOURCE: &str =
include_str!("../shaders/vision_2d_rope.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("vision_2d_rope_f32", VISION_2D_ROPE_SHADER_SOURCE);
registry.register_source("vision_2d_rope_bf16", VISION_2D_ROPE_SHADER_SOURCE);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_vision_2d_rope(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
pos_x: &MlxBuffer,
pos_y: &MlxBuffer,
seq_len: u32,
n_heads: u32,
head_dim: u32,
) -> Result<()> {
if head_dim == 0 || seq_len == 0 || n_heads == 0 {
return Err(MlxError::InvalidArgument(
"vision_2d_rope: head_dim, seq_len, n_heads must all be > 0".into(),
));
}
if head_dim % 4 != 0 {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: head_dim ({}) must be divisible by 4 (need clean d_half/d_quarter split)",
head_dim
)));
}
let n_rows = (seq_len as usize) * (n_heads as usize);
let elements = n_rows * (head_dim as usize);
if input.element_count() != elements {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: input element count {} != seq_len({}) * n_heads({}) * head_dim({}) = {}",
input.element_count(),
seq_len,
n_heads,
head_dim,
elements
)));
}
if output.element_count() != elements {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: output element count {} != {}",
output.element_count(),
elements
)));
}
if input.dtype() != output.dtype() {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: input/output dtype mismatch {} vs {}",
input.dtype(),
output.dtype()
)));
}
let expected_pos = seq_len as usize;
if pos_x.element_count() != expected_pos {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: pos_x length {} != seq_len {}",
pos_x.element_count(),
seq_len
)));
}
if pos_y.element_count() != expected_pos {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: pos_y length {} != seq_len {}",
pos_y.element_count(),
seq_len
)));
}
match pos_x.dtype() {
DType::U32 | DType::I32 => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: pos_x must be u32 or i32 (got {})",
other
)));
}
}
match pos_y.dtype() {
DType::U32 | DType::I32 => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: pos_y must be u32 or i32 (got {})",
other
)));
}
}
let kernel_name = match input.dtype() {
DType::F32 => "vision_2d_rope_f32",
DType::BF16 => "vision_2d_rope_bf16",
other => {
return Err(MlxError::InvalidArgument(format!(
"vision_2d_rope: unsupported dtype {}",
other
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let d_quarter = head_dim / 4;
let n_rows_u32 = n_rows as u32;
let tg_x = std::cmp::min(64, d_quarter as u64).max(1);
let tg_y = std::cmp::min(4, n_rows_u32 as u64).max(1);
encoder.encode(
pipeline,
&[
(0, input),
(1, output),
(2, params_buf),
(3, pos_x),
(4, pos_y),
],
MTLSize::new(d_quarter as u64, n_rows_u32 as u64, 1),
MTLSize::new(tg_x, tg_y, 1),
);
Ok(())
}
pub fn build_vision_2d_rope_params(
device: &crate::MlxDevice,
theta: f32,
head_dim: u32,
n_heads: u32,
) -> Result<MlxBuffer> {
let mut params = device.alloc_buffer(4 * 4, DType::F32, vec![4])?;
{
let s = params.as_mut_slice::<f32>()?;
s[0] = theta;
s[1] = head_dim as f32;
s[2] = n_heads as f32;
s[3] = 0.0;
}
Ok(params)
}