1use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuTransposeParams {
22 rows: u32,
23 cols: u32,
24}
25
26#[allow(clippy::too_many_arguments)]
40pub fn transpose_2d(
41 encoder: &mut CommandEncoder,
42 registry: &mut KernelRegistry,
43 device: &metal::DeviceRef,
44 input: &MlxBuffer,
45 output: &MlxBuffer,
46 rows: usize,
47 cols: usize,
48 dtype: DType,
49) -> Result<()> {
50 if rows == 0 {
51 return Err(MlxError::InvalidArgument(
52 "transpose_2d: rows must be > 0".into(),
53 ));
54 }
55 if cols == 0 {
56 return Err(MlxError::InvalidArgument(
57 "transpose_2d: cols must be > 0".into(),
58 ));
59 }
60
61 let kernel_name = match dtype {
62 DType::F32 => "transpose_2d_f32",
63 DType::F16 => "transpose_2d_f16",
64 _ => {
65 return Err(MlxError::InvalidArgument(format!(
66 "transpose_2d: unsupported dtype {dtype}"
67 )));
68 }
69 };
70
71 let elem_bytes = rows * cols * dtype.size_of();
72 if input.byte_len() < elem_bytes {
73 return Err(MlxError::InvalidArgument(format!(
74 "transpose_2d: input buffer too small: need {} bytes, have {}",
75 elem_bytes,
76 input.byte_len()
77 )));
78 }
79 if output.byte_len() < elem_bytes {
80 return Err(MlxError::InvalidArgument(format!(
81 "transpose_2d: output buffer too small: need {} bytes, have {}",
82 elem_bytes,
83 output.byte_len()
84 )));
85 }
86
87 let pipeline = registry.get_pipeline(kernel_name, device)?;
88
89 let gpu_params = GpuTransposeParams {
90 rows: rows as u32,
91 cols: cols as u32,
92 };
93
94 let grid = MTLSize::new(cols as u64, rows as u64, 1);
96 let tg = MTLSize::new(
97 std::cmp::min(16, cols as u64),
98 std::cmp::min(16, rows as u64),
99 1,
100 );
101
102 encode_with_args(
103 encoder,
104 pipeline,
105 &[
106 (0, KernelArg::Buffer(input)),
107 (1, KernelArg::Buffer(output)),
108 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
109 ],
110 grid,
111 tg,
112 );
113
114 Ok(())
115}
116
117#[repr(C)]
121#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
122struct GpuPermute021Params {
123 dim_a: u32,
124 dim_b: u32,
125 dim_c: u32,
126}
127
128pub fn permute_021_f32(
143 encoder: &mut CommandEncoder,
144 registry: &mut KernelRegistry,
145 device: &metal::DeviceRef,
146 input: &MlxBuffer,
147 output: &MlxBuffer,
148 dim_a: usize,
149 dim_b: usize,
150 dim_c: usize,
151) -> Result<()> {
152 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
153 return Err(MlxError::InvalidArgument(
154 "permute_021_f32: all dimensions must be > 0".into(),
155 ));
156 }
157
158 let total_elements = dim_a * dim_b * dim_c;
159 let elem_bytes = total_elements * 4; if input.byte_len() < elem_bytes {
161 return Err(MlxError::InvalidArgument(format!(
162 "permute_021_f32: input buffer too small: need {} bytes, have {}",
163 elem_bytes,
164 input.byte_len()
165 )));
166 }
167 if output.byte_len() < elem_bytes {
168 return Err(MlxError::InvalidArgument(format!(
169 "permute_021_f32: output buffer too small: need {} bytes, have {}",
170 elem_bytes,
171 output.byte_len()
172 )));
173 }
174
175 let pipeline = registry.get_pipeline("permute_021_f32", device)?;
176
177 let gpu_params = GpuPermute021Params {
178 dim_a: dim_a as u32,
179 dim_b: dim_b as u32,
180 dim_c: dim_c as u32,
181 };
182
183 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
184 let tg = MTLSize::new(
185 std::cmp::min(64, dim_c as u64),
186 std::cmp::min(4, dim_b as u64),
187 std::cmp::min(4, dim_a as u64),
188 );
189
190 encode_with_args(
191 encoder,
192 pipeline,
193 &[
194 (0, KernelArg::Buffer(input)),
195 (1, KernelArg::Buffer(output)),
196 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
197 ],
198 grid,
199 tg,
200 );
201
202 Ok(())
203}
204
205pub fn transpose_last2_bf16(
214 encoder: &mut CommandEncoder,
215 registry: &mut KernelRegistry,
216 device: &metal::DeviceRef,
217 input: &MlxBuffer,
218 output: &MlxBuffer,
219 dim_a: usize,
220 dim_b: usize,
221 dim_c: usize,
222) -> Result<()> {
223 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
224 return Err(MlxError::InvalidArgument(
225 "transpose_last2_bf16: all dimensions must be > 0".into(),
226 ));
227 }
228
229 let total_elements = dim_a * dim_b * dim_c;
230 let elem_bytes = total_elements * 2;
231 if input.byte_len() < elem_bytes {
232 return Err(MlxError::InvalidArgument(format!(
233 "transpose_last2_bf16: input buffer too small: need {} bytes, have {}",
234 elem_bytes, input.byte_len()
235 )));
236 }
237 if output.byte_len() < elem_bytes {
238 return Err(MlxError::InvalidArgument(format!(
239 "transpose_last2_bf16: output buffer too small: need {} bytes, have {}",
240 elem_bytes, output.byte_len()
241 )));
242 }
243
244 let pipeline = registry.get_pipeline("transpose_last2_bf16", device)?;
245
246 let gpu_params = GpuPermute021Params {
247 dim_a: dim_a as u32,
248 dim_b: dim_b as u32,
249 dim_c: dim_c as u32,
250 };
251
252 let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
255 let tg = MTLSize::new(
256 std::cmp::min(16, dim_b as u64),
257 std::cmp::min(16, dim_c as u64),
258 std::cmp::min(4, dim_a as u64),
259 );
260
261 encode_with_args(
262 encoder,
263 pipeline,
264 &[
265 (0, KernelArg::Buffer(input)),
266 (1, KernelArg::Buffer(output)),
267 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
268 ],
269 grid,
270 tg,
271 );
272
273 Ok(())
274}
275
276pub fn permute_021_bf16(
277 encoder: &mut CommandEncoder,
278 registry: &mut KernelRegistry,
279 device: &metal::DeviceRef,
280 input: &MlxBuffer,
281 output: &MlxBuffer,
282 dim_a: usize,
283 dim_b: usize,
284 dim_c: usize,
285) -> Result<()> {
286 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
287 return Err(MlxError::InvalidArgument(
288 "permute_021_bf16: all dimensions must be > 0".into(),
289 ));
290 }
291
292 let total_elements = dim_a * dim_b * dim_c;
293 let elem_bytes = total_elements * 2; if input.byte_len() < elem_bytes {
295 return Err(MlxError::InvalidArgument(format!(
296 "permute_021_bf16: input buffer too small: need {} bytes, have {}",
297 elem_bytes,
298 input.byte_len()
299 )));
300 }
301 if output.byte_len() < elem_bytes {
302 return Err(MlxError::InvalidArgument(format!(
303 "permute_021_bf16: output buffer too small: need {} bytes, have {}",
304 elem_bytes,
305 output.byte_len()
306 )));
307 }
308
309 let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
310
311 let gpu_params = GpuPermute021Params {
312 dim_a: dim_a as u32,
313 dim_b: dim_b as u32,
314 dim_c: dim_c as u32,
315 };
316
317 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
319 let tg = MTLSize::new(
320 std::cmp::min(64, dim_c as u64),
321 std::cmp::min(4, dim_b as u64),
322 std::cmp::min(4, dim_a as u64),
323 );
324
325 encode_with_args(
326 encoder,
327 pipeline,
328 &[
329 (0, KernelArg::Buffer(input)),
330 (1, KernelArg::Buffer(output)),
331 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
332 ],
333 grid,
334 tg,
335 );
336
337 Ok(())
338}
339
340pub fn permute_021_bf16_to_f32(
346 encoder: &mut CommandEncoder,
347 registry: &mut KernelRegistry,
348 device: &metal::DeviceRef,
349 input: &MlxBuffer,
350 output: &MlxBuffer,
351 dim_a: usize,
352 dim_b: usize,
353 dim_c: usize,
354) -> Result<()> {
355 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
356 return Err(MlxError::InvalidArgument(
357 "permute_021_bf16_to_f32: all dimensions must be > 0".into(),
358 ));
359 }
360
361 let total_elements = dim_a * dim_b * dim_c;
362 let in_bytes = total_elements * 2; let out_bytes = total_elements * 4; if input.byte_len() < in_bytes {
365 return Err(MlxError::InvalidArgument(format!(
366 "permute_021_bf16_to_f32: input buffer too small: need {} bytes, have {}",
367 in_bytes, input.byte_len()
368 )));
369 }
370 if output.byte_len() < out_bytes {
371 return Err(MlxError::InvalidArgument(format!(
372 "permute_021_bf16_to_f32: output buffer too small: need {} bytes, have {}",
373 out_bytes, output.byte_len()
374 )));
375 }
376
377 let pipeline = registry.get_pipeline("permute_021_bf16_to_f32", device)?;
378
379 let gpu_params = GpuPermute021Params {
380 dim_a: dim_a as u32,
381 dim_b: dim_b as u32,
382 dim_c: dim_c as u32,
383 };
384
385 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
386 let tg = MTLSize::new(
387 std::cmp::min(64, dim_c as u64),
388 std::cmp::min(4, dim_b as u64),
389 std::cmp::min(4, dim_a as u64),
390 );
391
392 encode_with_args(
393 encoder,
394 pipeline,
395 &[
396 (0, KernelArg::Buffer(input)),
397 (1, KernelArg::Buffer(output)),
398 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
399 ],
400 grid,
401 tg,
402 );
403
404 Ok(())
405}