1use std::cell::RefCell;
29use std::collections::HashMap;
30
31use metal::MTLSize;
32
33use crate::buffer::MlxBuffer;
34use crate::dtypes::DType;
35use crate::encoder::CommandEncoder;
36use crate::error::{MlxError, Result};
37use crate::kernel_registry::KernelRegistry;
38
39pub static ROPE_MULTI_SHADER_SOURCE: &str = include_str!("../shaders/rope_multi.metal");
40
41pub fn register(registry: &mut KernelRegistry) {
42 registry.register_source("rope_multi_f32", ROPE_MULTI_SHADER_SOURCE);
43 registry.register_source("rope_multi_bf16", ROPE_MULTI_SHADER_SOURCE);
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48#[repr(u32)]
49pub enum RopeMultiMode {
50 Mrope = 8,
52 Imrope = 40,
55}
56
57#[derive(Debug, Clone, Copy)]
59pub struct RopeMultiParams {
60 pub head_dim: u32,
61 pub rope_dim: u32, pub n_heads: u32,
63 pub seq_len: u32,
64 pub freq_base: f32,
65 pub mode: RopeMultiMode,
66 pub sections: [u32; 4],
69}
70
71fn validate(
72 p: &RopeMultiParams,
73 input: &MlxBuffer,
74 output: &MlxBuffer,
75 positions: &MlxBuffer,
76) -> Result<()> {
77 if p.head_dim == 0 || p.rope_dim == 0 || p.n_heads == 0 || p.seq_len == 0 {
78 return Err(MlxError::InvalidArgument(
79 "rope_multi: head_dim, rope_dim, n_heads, seq_len must all be > 0".into(),
80 ));
81 }
82 if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
83 return Err(MlxError::InvalidArgument(
84 "rope_multi: head_dim and rope_dim must be even".into(),
85 ));
86 }
87 if p.rope_dim > p.head_dim {
88 return Err(MlxError::InvalidArgument(
89 "rope_multi: rope_dim must be <= head_dim".into(),
90 ));
91 }
92 if !p.freq_base.is_finite() || p.freq_base <= 0.0 {
93 return Err(MlxError::InvalidArgument(format!(
94 "rope_multi: freq_base must be finite and positive, got {}",
95 p.freq_base
96 )));
97 }
98
99 let n_rows = (p.seq_len as usize) * (p.n_heads as usize);
100 let elements = n_rows * (p.head_dim as usize);
101 if input.element_count() != elements {
102 return Err(MlxError::InvalidArgument(format!(
103 "rope_multi: input element count {} != seq_len({}) * n_heads({}) * head_dim({}) = {}",
104 input.element_count(),
105 p.seq_len,
106 p.n_heads,
107 p.head_dim,
108 elements
109 )));
110 }
111 if output.element_count() != elements {
112 return Err(MlxError::InvalidArgument(format!(
113 "rope_multi: output element count {} != {}",
114 output.element_count(),
115 elements
116 )));
117 }
118 if input.dtype() != output.dtype() {
119 return Err(MlxError::InvalidArgument(format!(
120 "rope_multi: input/output dtype mismatch {} vs {}",
121 input.dtype(),
122 output.dtype()
123 )));
124 }
125
126 let expected_positions = 4 * (p.seq_len as usize);
127 if positions.element_count() != expected_positions {
128 return Err(MlxError::InvalidArgument(format!(
129 "rope_multi: positions length {} != 4 * seq_len({}) = {}",
130 positions.element_count(),
131 p.seq_len,
132 expected_positions
133 )));
134 }
135 match positions.dtype() {
136 DType::I32 | DType::U32 => {}
137 other => {
138 return Err(MlxError::InvalidArgument(format!(
139 "rope_multi: positions must be i32 or u32 (got {})",
140 other
141 )));
142 }
143 }
144
145 Ok(())
146}
147
148#[allow(clippy::too_many_arguments)]
160pub fn dispatch_rope_multi(
161 encoder: &mut CommandEncoder,
162 registry: &mut KernelRegistry,
163 device: &metal::DeviceRef,
164 input: &MlxBuffer,
165 output: &MlxBuffer,
166 positions: &MlxBuffer,
167 params_buf: &MlxBuffer,
168 rope_params_buf: &MlxBuffer,
169 sections_buf: &MlxBuffer,
170 p: RopeMultiParams,
171) -> Result<()> {
172 validate(&p, input, output, positions)?;
173
174 let kernel_name = match input.dtype() {
175 DType::F32 => "rope_multi_f32",
176 DType::BF16 => "rope_multi_bf16",
177 other => {
178 return Err(MlxError::InvalidArgument(format!(
179 "rope_multi: unsupported dtype {}",
180 other
181 )));
182 }
183 };
184
185 let pipeline = registry.get_pipeline(kernel_name, device)?;
186
187 let half_dim = p.head_dim / 2;
188 let n_rows = p.seq_len * p.n_heads;
189
190 let grid = MTLSize::new(half_dim as u64, n_rows as u64, 1);
192
193 let tg_x = std::cmp::min(half_dim, 256).max(1);
194 let remain = (256u32 / tg_x).max(1);
195 let tg_y = std::cmp::min(n_rows, remain).max(1);
196 let tg = MTLSize::new(tg_x as u64, tg_y as u64, 1);
197
198 encoder.encode(
199 pipeline,
200 &[
201 (0, input),
202 (1, output),
203 (2, params_buf),
204 (3, positions),
205 (4, rope_params_buf),
206 (5, sections_buf),
207 ],
208 grid,
209 tg,
210 );
211
212 Ok(())
213}
214
215pub struct RopeMultiBufferPack {
226 pub params_buf: MlxBuffer,
227 pub rope_params_buf: MlxBuffer,
228 pub sections_buf: MlxBuffer,
229}
230
231#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
235struct RopeMultiCacheKey {
236 device_ptr: usize,
237 head_dim: u32,
238 rope_dim: u32,
239 n_heads: u32,
240 seq_len: u32,
241 freq_base_bits: u32,
242 mode: u32,
243 sections: [u32; 4],
244}
245
246impl RopeMultiCacheKey {
247 fn from_params(device: &crate::MlxDevice, p: &RopeMultiParams) -> Self {
248 Self {
249 device_ptr: device as *const _ as usize,
250 head_dim: p.head_dim,
251 rope_dim: p.rope_dim,
252 n_heads: p.n_heads,
253 seq_len: p.seq_len,
254 freq_base_bits: p.freq_base.to_bits(),
255 mode: p.mode as u32,
256 sections: p.sections,
257 }
258 }
259}
260
261thread_local! {
262 static ROPE_PACK_CACHE: RefCell<HashMap<RopeMultiCacheKey, RopeMultiBufferPack>> =
268 RefCell::new(HashMap::new());
269}
270
271pub fn clear_rope_pack_cache() {
277 ROPE_PACK_CACHE.with(|cell| cell.borrow_mut().clear());
278}
279
280pub fn rope_pack_cache_len() -> usize {
282 ROPE_PACK_CACHE.with(|cell| cell.borrow().len())
283}
284
285#[allow(clippy::too_many_arguments)]
303pub fn dispatch_rope_multi_cached(
304 encoder: &mut CommandEncoder,
305 registry: &mut KernelRegistry,
306 device: &crate::MlxDevice,
307 input: &MlxBuffer,
308 output: &MlxBuffer,
309 positions: &MlxBuffer,
310 p: RopeMultiParams,
311) -> Result<()> {
312 let key = RopeMultiCacheKey::from_params(device, &p);
313 ROPE_PACK_CACHE.with(|cell| {
314 let mut map = cell.borrow_mut();
315 if !map.contains_key(&key) {
316 let (params_buf, rope_params_buf, sections_buf) =
317 build_rope_multi_buffers(device, p)?;
318 map.insert(
319 key,
320 RopeMultiBufferPack {
321 params_buf,
322 rope_params_buf,
323 sections_buf,
324 },
325 );
326 }
327 let pack = map
328 .get(&key)
329 .expect("inserted above if missing; cache is single-threaded");
330 dispatch_rope_multi(
331 encoder,
332 registry,
333 device.metal_device(),
334 input,
335 output,
336 positions,
337 &pack.params_buf,
338 &pack.rope_params_buf,
339 &pack.sections_buf,
340 p,
341 )
342 })
343}
344
345pub fn build_rope_multi_buffers(
349 device: &crate::MlxDevice,
350 p: RopeMultiParams,
351) -> Result<(MlxBuffer, MlxBuffer, MlxBuffer)> {
352 let mut params = device.alloc_buffer(4 * 4, DType::F32, vec![4])?;
353 {
354 let s = params.as_mut_slice::<f32>()?;
355 s[0] = p.freq_base;
356 s[1] = p.head_dim as f32;
357 s[2] = p.rope_dim as f32;
358 s[3] = 0.0;
359 }
360 let mut rope_params = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
361 {
362 let s = rope_params.as_mut_slice::<u32>()?;
363 s[0] = p.n_heads;
364 s[1] = p.mode as u32;
365 s[2] = p.seq_len;
366 s[3] = 0;
367 }
368 let mut sections = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
369 {
370 let s = sections.as_mut_slice::<u32>()?;
371 s[0] = p.sections[0];
372 s[1] = p.sections[1];
373 s[2] = p.sections[2];
374 s[3] = p.sections[3];
375 }
376 Ok((params, rope_params, sections))
377}