use std::cell::RefCell;
use std::collections::HashMap;
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static ROPE_MULTI_SHADER_SOURCE: &str = include_str!("../shaders/rope_multi.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("rope_multi_f32", ROPE_MULTI_SHADER_SOURCE);
registry.register_source("rope_multi_bf16", ROPE_MULTI_SHADER_SOURCE);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum RopeMultiMode {
Mrope = 8,
Imrope = 40,
}
#[derive(Debug, Clone, Copy)]
pub struct RopeMultiParams {
pub head_dim: u32,
pub rope_dim: u32, pub n_heads: u32,
pub seq_len: u32,
pub freq_base: f32,
pub mode: RopeMultiMode,
pub sections: [u32; 4],
}
fn validate(
p: &RopeMultiParams,
input: &MlxBuffer,
output: &MlxBuffer,
positions: &MlxBuffer,
) -> Result<()> {
if p.head_dim == 0 || p.rope_dim == 0 || p.n_heads == 0 || p.seq_len == 0 {
return Err(MlxError::InvalidArgument(
"rope_multi: head_dim, rope_dim, n_heads, seq_len must all be > 0".into(),
));
}
if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(
"rope_multi: head_dim and rope_dim must be even".into(),
));
}
if p.rope_dim > p.head_dim {
return Err(MlxError::InvalidArgument(
"rope_multi: rope_dim must be <= head_dim".into(),
));
}
if !p.freq_base.is_finite() || p.freq_base <= 0.0 {
return Err(MlxError::InvalidArgument(format!(
"rope_multi: freq_base must be finite and positive, got {}",
p.freq_base
)));
}
let n_rows = (p.seq_len as usize) * (p.n_heads as usize);
let elements = n_rows * (p.head_dim as usize);
if input.element_count() != elements {
return Err(MlxError::InvalidArgument(format!(
"rope_multi: input element count {} != seq_len({}) * n_heads({}) * head_dim({}) = {}",
input.element_count(),
p.seq_len,
p.n_heads,
p.head_dim,
elements
)));
}
if output.element_count() != elements {
return Err(MlxError::InvalidArgument(format!(
"rope_multi: output element count {} != {}",
output.element_count(),
elements
)));
}
if input.dtype() != output.dtype() {
return Err(MlxError::InvalidArgument(format!(
"rope_multi: input/output dtype mismatch {} vs {}",
input.dtype(),
output.dtype()
)));
}
let expected_positions = 4 * (p.seq_len as usize);
if positions.element_count() != expected_positions {
return Err(MlxError::InvalidArgument(format!(
"rope_multi: positions length {} != 4 * seq_len({}) = {}",
positions.element_count(),
p.seq_len,
expected_positions
)));
}
match positions.dtype() {
DType::I32 | DType::U32 => {}
other => {
return Err(MlxError::InvalidArgument(format!(
"rope_multi: positions must be i32 or u32 (got {})",
other
)));
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_multi(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
positions: &MlxBuffer,
params_buf: &MlxBuffer,
rope_params_buf: &MlxBuffer,
sections_buf: &MlxBuffer,
p: RopeMultiParams,
) -> Result<()> {
validate(&p, input, output, positions)?;
let kernel_name = match input.dtype() {
DType::F32 => "rope_multi_f32",
DType::BF16 => "rope_multi_bf16",
other => {
return Err(MlxError::InvalidArgument(format!(
"rope_multi: unsupported dtype {}",
other
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let half_dim = p.head_dim / 2;
let n_rows = p.seq_len * p.n_heads;
let grid = MTLSize::new(half_dim as u64, n_rows as u64, 1);
let tg_x = std::cmp::min(half_dim, 256).max(1);
let remain = (256u32 / tg_x).max(1);
let tg_y = std::cmp::min(n_rows, remain).max(1);
let tg = MTLSize::new(tg_x as u64, tg_y as u64, 1);
encoder.encode(
pipeline,
&[
(0, input),
(1, output),
(2, params_buf),
(3, positions),
(4, rope_params_buf),
(5, sections_buf),
],
grid,
tg,
);
Ok(())
}
pub struct RopeMultiBufferPack {
pub params_buf: MlxBuffer,
pub rope_params_buf: MlxBuffer,
pub sections_buf: MlxBuffer,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
struct RopeMultiCacheKey {
device_ptr: usize,
head_dim: u32,
rope_dim: u32,
n_heads: u32,
seq_len: u32,
freq_base_bits: u32,
mode: u32,
sections: [u32; 4],
}
impl RopeMultiCacheKey {
fn from_params(device: &crate::MlxDevice, p: &RopeMultiParams) -> Self {
Self {
device_ptr: device as *const _ as usize,
head_dim: p.head_dim,
rope_dim: p.rope_dim,
n_heads: p.n_heads,
seq_len: p.seq_len,
freq_base_bits: p.freq_base.to_bits(),
mode: p.mode as u32,
sections: p.sections,
}
}
}
thread_local! {
static ROPE_PACK_CACHE: RefCell<HashMap<RopeMultiCacheKey, RopeMultiBufferPack>> =
RefCell::new(HashMap::new());
}
pub fn clear_rope_pack_cache() {
ROPE_PACK_CACHE.with(|cell| cell.borrow_mut().clear());
}
pub fn rope_pack_cache_len() -> usize {
ROPE_PACK_CACHE.with(|cell| cell.borrow().len())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_rope_multi_cached(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &crate::MlxDevice,
input: &MlxBuffer,
output: &MlxBuffer,
positions: &MlxBuffer,
p: RopeMultiParams,
) -> Result<()> {
let key = RopeMultiCacheKey::from_params(device, &p);
ROPE_PACK_CACHE.with(|cell| {
let mut map = cell.borrow_mut();
if !map.contains_key(&key) {
let (params_buf, rope_params_buf, sections_buf) =
build_rope_multi_buffers(device, p)?;
map.insert(
key,
RopeMultiBufferPack {
params_buf,
rope_params_buf,
sections_buf,
},
);
}
let pack = map
.get(&key)
.expect("inserted above if missing; cache is single-threaded");
dispatch_rope_multi(
encoder,
registry,
device.metal_device(),
input,
output,
positions,
&pack.params_buf,
&pack.rope_params_buf,
&pack.sections_buf,
p,
)
})
}
pub fn build_rope_multi_buffers(
device: &crate::MlxDevice,
p: RopeMultiParams,
) -> Result<(MlxBuffer, MlxBuffer, MlxBuffer)> {
let mut params = device.alloc_buffer(4 * 4, DType::F32, vec![4])?;
{
let s = params.as_mut_slice::<f32>()?;
s[0] = p.freq_base;
s[1] = p.head_dim as f32;
s[2] = p.rope_dim as f32;
s[3] = 0.0;
}
let mut rope_params = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
{
let s = rope_params.as_mut_slice::<u32>()?;
s[0] = p.n_heads;
s[1] = p.mode as u32;
s[2] = p.seq_len;
s[3] = 0;
}
let mut sections = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
{
let s = sections.as_mut_slice::<u32>()?;
s[0] = p.sections[0];
s[1] = p.sections[1];
s[2] = p.sections[2];
s[3] = p.sections[3];
}
Ok((params, rope_params, sections))
}