1use metal::MTLSize;
29
30use crate::buffer::MlxBuffer;
31use crate::dtypes::DType;
32use crate::encoder::CommandEncoder;
33use crate::error::{MlxError, Result};
34use crate::kernel_registry::KernelRegistry;
35
36pub static ROPE_MULTI_SHADER_SOURCE: &str = include_str!("../shaders/rope_multi.metal");
37
38pub fn register(registry: &mut KernelRegistry) {
39 registry.register_source("rope_multi_f32", ROPE_MULTI_SHADER_SOURCE);
40 registry.register_source("rope_multi_bf16", ROPE_MULTI_SHADER_SOURCE);
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45#[repr(u32)]
46pub enum RopeMultiMode {
47 Mrope = 8,
49 Imrope = 40,
52}
53
54#[derive(Debug, Clone, Copy)]
56pub struct RopeMultiParams {
57 pub head_dim: u32,
58 pub rope_dim: u32, pub n_heads: u32,
60 pub seq_len: u32,
61 pub freq_base: f32,
62 pub mode: RopeMultiMode,
63 pub sections: [u32; 4],
66}
67
68fn validate(
69 p: &RopeMultiParams,
70 input: &MlxBuffer,
71 output: &MlxBuffer,
72 positions: &MlxBuffer,
73) -> Result<()> {
74 if p.head_dim == 0 || p.rope_dim == 0 || p.n_heads == 0 || p.seq_len == 0 {
75 return Err(MlxError::InvalidArgument(
76 "rope_multi: head_dim, rope_dim, n_heads, seq_len must all be > 0".into(),
77 ));
78 }
79 if p.head_dim % 2 != 0 || p.rope_dim % 2 != 0 {
80 return Err(MlxError::InvalidArgument(
81 "rope_multi: head_dim and rope_dim must be even".into(),
82 ));
83 }
84 if p.rope_dim > p.head_dim {
85 return Err(MlxError::InvalidArgument(
86 "rope_multi: rope_dim must be <= head_dim".into(),
87 ));
88 }
89 if !p.freq_base.is_finite() || p.freq_base <= 0.0 {
90 return Err(MlxError::InvalidArgument(format!(
91 "rope_multi: freq_base must be finite and positive, got {}",
92 p.freq_base
93 )));
94 }
95
96 let n_rows = (p.seq_len as usize) * (p.n_heads as usize);
97 let elements = n_rows * (p.head_dim as usize);
98 if input.element_count() != elements {
99 return Err(MlxError::InvalidArgument(format!(
100 "rope_multi: input element count {} != seq_len({}) * n_heads({}) * head_dim({}) = {}",
101 input.element_count(),
102 p.seq_len,
103 p.n_heads,
104 p.head_dim,
105 elements
106 )));
107 }
108 if output.element_count() != elements {
109 return Err(MlxError::InvalidArgument(format!(
110 "rope_multi: output element count {} != {}",
111 output.element_count(),
112 elements
113 )));
114 }
115 if input.dtype() != output.dtype() {
116 return Err(MlxError::InvalidArgument(format!(
117 "rope_multi: input/output dtype mismatch {} vs {}",
118 input.dtype(),
119 output.dtype()
120 )));
121 }
122
123 let expected_positions = 4 * (p.seq_len as usize);
124 if positions.element_count() != expected_positions {
125 return Err(MlxError::InvalidArgument(format!(
126 "rope_multi: positions length {} != 4 * seq_len({}) = {}",
127 positions.element_count(),
128 p.seq_len,
129 expected_positions
130 )));
131 }
132 match positions.dtype() {
133 DType::I32 | DType::U32 => {}
134 other => {
135 return Err(MlxError::InvalidArgument(format!(
136 "rope_multi: positions must be i32 or u32 (got {})",
137 other
138 )));
139 }
140 }
141
142 Ok(())
143}
144
145#[allow(clippy::too_many_arguments)]
157pub fn dispatch_rope_multi(
158 encoder: &mut CommandEncoder,
159 registry: &mut KernelRegistry,
160 device: &metal::DeviceRef,
161 input: &MlxBuffer,
162 output: &MlxBuffer,
163 positions: &MlxBuffer,
164 params_buf: &MlxBuffer,
165 rope_params_buf: &MlxBuffer,
166 sections_buf: &MlxBuffer,
167 p: RopeMultiParams,
168) -> Result<()> {
169 validate(&p, input, output, positions)?;
170
171 let kernel_name = match input.dtype() {
172 DType::F32 => "rope_multi_f32",
173 DType::BF16 => "rope_multi_bf16",
174 other => {
175 return Err(MlxError::InvalidArgument(format!(
176 "rope_multi: unsupported dtype {}",
177 other
178 )));
179 }
180 };
181
182 let pipeline = registry.get_pipeline(kernel_name, device)?;
183
184 let half_dim = p.head_dim / 2;
185 let n_rows = p.seq_len * p.n_heads;
186
187 let grid = MTLSize::new(half_dim as u64, n_rows as u64, 1);
189
190 let tg_x = std::cmp::min(half_dim, 256).max(1);
191 let remain = (256u32 / tg_x).max(1);
192 let tg_y = std::cmp::min(n_rows, remain).max(1);
193 let tg = MTLSize::new(tg_x as u64, tg_y as u64, 1);
194
195 encoder.encode(
196 pipeline,
197 &[
198 (0, input),
199 (1, output),
200 (2, params_buf),
201 (3, positions),
202 (4, rope_params_buf),
203 (5, sections_buf),
204 ],
205 grid,
206 tg,
207 );
208
209 Ok(())
210}
211
212pub fn build_rope_multi_buffers(
216 device: &crate::MlxDevice,
217 p: RopeMultiParams,
218) -> Result<(MlxBuffer, MlxBuffer, MlxBuffer)> {
219 let mut params = device.alloc_buffer(4 * 4, DType::F32, vec![4])?;
220 {
221 let s = params.as_mut_slice::<f32>()?;
222 s[0] = p.freq_base;
223 s[1] = p.head_dim as f32;
224 s[2] = p.rope_dim as f32;
225 s[3] = 0.0;
226 }
227 let mut rope_params = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
228 {
229 let s = rope_params.as_mut_slice::<u32>()?;
230 s[0] = p.n_heads;
231 s[1] = p.mode as u32;
232 s[2] = p.seq_len;
233 s[3] = 0;
234 }
235 let mut sections = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
236 {
237 let s = sections.as_mut_slice::<u32>()?;
238 s[0] = p.sections[0];
239 s[1] = p.sections[1];
240 s[2] = p.sections[2];
241 s[3] = p.sections[3];
242 }
243 Ok((params, rope_params, sections))
244}