1use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuTransposeParams {
22 rows: u32,
23 cols: u32,
24}
25
26#[allow(clippy::too_many_arguments)]
40pub fn transpose_2d(
41 encoder: &mut CommandEncoder,
42 registry: &mut KernelRegistry,
43 device: &metal::DeviceRef,
44 input: &MlxBuffer,
45 output: &MlxBuffer,
46 rows: usize,
47 cols: usize,
48 dtype: DType,
49) -> Result<()> {
50 if rows == 0 {
51 return Err(MlxError::InvalidArgument(
52 "transpose_2d: rows must be > 0".into(),
53 ));
54 }
55 if cols == 0 {
56 return Err(MlxError::InvalidArgument(
57 "transpose_2d: cols must be > 0".into(),
58 ));
59 }
60
61 let kernel_name = match dtype {
62 DType::F32 => "transpose_2d_f32",
63 DType::F16 => "transpose_2d_f16",
64 _ => {
65 return Err(MlxError::InvalidArgument(format!(
66 "transpose_2d: unsupported dtype {dtype}"
67 )));
68 }
69 };
70
71 if input.dtype() != dtype {
76 return Err(MlxError::InvalidArgument(format!(
77 "transpose_2d: input dtype {} != dtype param {}",
78 input.dtype(), dtype,
79 )));
80 }
81 if output.dtype() != dtype {
82 return Err(MlxError::InvalidArgument(format!(
83 "transpose_2d: output dtype {} != dtype param {}",
84 output.dtype(), dtype,
85 )));
86 }
87
88 let elem_bytes = rows * cols * dtype.size_of();
89 if input.byte_len() < elem_bytes {
90 return Err(MlxError::InvalidArgument(format!(
91 "transpose_2d: input buffer too small: need {} bytes, have {}",
92 elem_bytes,
93 input.byte_len()
94 )));
95 }
96 if output.byte_len() < elem_bytes {
97 return Err(MlxError::InvalidArgument(format!(
98 "transpose_2d: output buffer too small: need {} bytes, have {}",
99 elem_bytes,
100 output.byte_len()
101 )));
102 }
103
104 let pipeline = registry.get_pipeline(kernel_name, device)?;
105
106 let gpu_params = GpuTransposeParams {
107 rows: rows as u32,
108 cols: cols as u32,
109 };
110
111 let grid = MTLSize::new(cols as u64, rows as u64, 1);
113 let tg = MTLSize::new(
114 std::cmp::min(16, cols as u64),
115 std::cmp::min(16, rows as u64),
116 1,
117 );
118
119 encode_with_args(
120 encoder,
121 pipeline,
122 &[
123 (0, KernelArg::Buffer(input)),
124 (1, KernelArg::Buffer(output)),
125 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
126 ],
127 grid,
128 tg,
129 );
130
131 Ok(())
132}
133
134#[repr(C)]
138#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
139struct GpuPermute021Params {
140 dim_a: u32,
141 dim_b: u32,
142 dim_c: u32,
143}
144
145pub fn permute_021_f32(
160 encoder: &mut CommandEncoder,
161 registry: &mut KernelRegistry,
162 device: &metal::DeviceRef,
163 input: &MlxBuffer,
164 output: &MlxBuffer,
165 dim_a: usize,
166 dim_b: usize,
167 dim_c: usize,
168) -> Result<()> {
169 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
170 return Err(MlxError::InvalidArgument(
171 "permute_021_f32: all dimensions must be > 0".into(),
172 ));
173 }
174
175 let total_elements = dim_a * dim_b * dim_c;
176 let elem_bytes = total_elements * 4; if input.byte_len() < elem_bytes {
178 return Err(MlxError::InvalidArgument(format!(
179 "permute_021_f32: input buffer too small: need {} bytes, have {}",
180 elem_bytes,
181 input.byte_len()
182 )));
183 }
184 if output.byte_len() < elem_bytes {
185 return Err(MlxError::InvalidArgument(format!(
186 "permute_021_f32: output buffer too small: need {} bytes, have {}",
187 elem_bytes,
188 output.byte_len()
189 )));
190 }
191
192 let pipeline = registry.get_pipeline("permute_021_f32", device)?;
193
194 let gpu_params = GpuPermute021Params {
195 dim_a: dim_a as u32,
196 dim_b: dim_b as u32,
197 dim_c: dim_c as u32,
198 };
199
200 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
201 let tg = MTLSize::new(
202 std::cmp::min(64, dim_c as u64),
203 std::cmp::min(4, dim_b as u64),
204 std::cmp::min(4, dim_a as u64),
205 );
206
207 encode_with_args(
208 encoder,
209 pipeline,
210 &[
211 (0, KernelArg::Buffer(input)),
212 (1, KernelArg::Buffer(output)),
213 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
214 ],
215 grid,
216 tg,
217 );
218
219 Ok(())
220}
221
222pub fn transpose_last2_bf16(
231 encoder: &mut CommandEncoder,
232 registry: &mut KernelRegistry,
233 device: &metal::DeviceRef,
234 input: &MlxBuffer,
235 output: &MlxBuffer,
236 dim_a: usize,
237 dim_b: usize,
238 dim_c: usize,
239) -> Result<()> {
240 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
241 return Err(MlxError::InvalidArgument(
242 "transpose_last2_bf16: all dimensions must be > 0".into(),
243 ));
244 }
245
246 let total_elements = dim_a * dim_b * dim_c;
247 let elem_bytes = total_elements * 2;
248 if input.byte_len() < elem_bytes {
249 return Err(MlxError::InvalidArgument(format!(
250 "transpose_last2_bf16: input buffer too small: need {} bytes, have {}",
251 elem_bytes, input.byte_len()
252 )));
253 }
254 if output.byte_len() < elem_bytes {
255 return Err(MlxError::InvalidArgument(format!(
256 "transpose_last2_bf16: output buffer too small: need {} bytes, have {}",
257 elem_bytes, output.byte_len()
258 )));
259 }
260
261 let pipeline = registry.get_pipeline("transpose_last2_bf16", device)?;
262
263 let gpu_params = GpuPermute021Params {
264 dim_a: dim_a as u32,
265 dim_b: dim_b as u32,
266 dim_c: dim_c as u32,
267 };
268
269 let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
272 let tg = MTLSize::new(
273 std::cmp::min(16, dim_b as u64),
274 std::cmp::min(16, dim_c as u64),
275 std::cmp::min(4, dim_a as u64),
276 );
277
278 encode_with_args(
279 encoder,
280 pipeline,
281 &[
282 (0, KernelArg::Buffer(input)),
283 (1, KernelArg::Buffer(output)),
284 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
285 ],
286 grid,
287 tg,
288 );
289
290 Ok(())
291}
292
293pub fn transpose_last2_f16(
309 encoder: &mut CommandEncoder,
310 registry: &mut KernelRegistry,
311 device: &metal::DeviceRef,
312 input: &MlxBuffer,
313 output: &MlxBuffer,
314 dim_a: usize,
315 dim_b: usize,
316 dim_c: usize,
317) -> Result<()> {
318 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
319 return Err(MlxError::InvalidArgument(
320 "transpose_last2_f16: all dimensions must be > 0".into(),
321 ));
322 }
323
324 let total_elements = dim_a * dim_b * dim_c;
325 let elem_bytes = total_elements * 2; if input.byte_len() < elem_bytes {
327 return Err(MlxError::InvalidArgument(format!(
328 "transpose_last2_f16: input buffer too small: need {} bytes, have {}",
329 elem_bytes, input.byte_len()
330 )));
331 }
332 if output.byte_len() < elem_bytes {
333 return Err(MlxError::InvalidArgument(format!(
334 "transpose_last2_f16: output buffer too small: need {} bytes, have {}",
335 elem_bytes, output.byte_len()
336 )));
337 }
338
339 let pipeline = registry.get_pipeline("transpose_last2_f16", device)?;
340
341 let gpu_params = GpuPermute021Params {
342 dim_a: dim_a as u32,
343 dim_b: dim_b as u32,
344 dim_c: dim_c as u32,
345 };
346
347 let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
352 let tg = MTLSize::new(
353 std::cmp::min(16, dim_b as u64),
354 std::cmp::min(16, dim_c as u64),
355 std::cmp::min(4, dim_a as u64),
356 );
357
358 encode_with_args(
359 encoder,
360 pipeline,
361 &[
362 (0, KernelArg::Buffer(input)),
363 (1, KernelArg::Buffer(output)),
364 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
365 ],
366 grid,
367 tg,
368 );
369
370 Ok(())
371}
372
373pub fn permute_021_bf16(
374 encoder: &mut CommandEncoder,
375 registry: &mut KernelRegistry,
376 device: &metal::DeviceRef,
377 input: &MlxBuffer,
378 output: &MlxBuffer,
379 dim_a: usize,
380 dim_b: usize,
381 dim_c: usize,
382) -> Result<()> {
383 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
384 return Err(MlxError::InvalidArgument(
385 "permute_021_bf16: all dimensions must be > 0".into(),
386 ));
387 }
388
389 let total_elements = dim_a * dim_b * dim_c;
390 let elem_bytes = total_elements * 2; if input.byte_len() < elem_bytes {
392 return Err(MlxError::InvalidArgument(format!(
393 "permute_021_bf16: input buffer too small: need {} bytes, have {}",
394 elem_bytes,
395 input.byte_len()
396 )));
397 }
398 if output.byte_len() < elem_bytes {
399 return Err(MlxError::InvalidArgument(format!(
400 "permute_021_bf16: output buffer too small: need {} bytes, have {}",
401 elem_bytes,
402 output.byte_len()
403 )));
404 }
405
406 let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
407
408 let gpu_params = GpuPermute021Params {
409 dim_a: dim_a as u32,
410 dim_b: dim_b as u32,
411 dim_c: dim_c as u32,
412 };
413
414 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
416 let tg = MTLSize::new(
417 std::cmp::min(64, dim_c as u64),
418 std::cmp::min(4, dim_b as u64),
419 std::cmp::min(4, dim_a as u64),
420 );
421
422 encode_with_args(
423 encoder,
424 pipeline,
425 &[
426 (0, KernelArg::Buffer(input)),
427 (1, KernelArg::Buffer(output)),
428 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
429 ],
430 grid,
431 tg,
432 );
433
434 Ok(())
435}
436
437pub fn permute_021_bf16_to_f32(
443 encoder: &mut CommandEncoder,
444 registry: &mut KernelRegistry,
445 device: &metal::DeviceRef,
446 input: &MlxBuffer,
447 output: &MlxBuffer,
448 dim_a: usize,
449 dim_b: usize,
450 dim_c: usize,
451) -> Result<()> {
452 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
453 return Err(MlxError::InvalidArgument(
454 "permute_021_bf16_to_f32: all dimensions must be > 0".into(),
455 ));
456 }
457
458 let total_elements = dim_a * dim_b * dim_c;
459 let in_bytes = total_elements * 2; let out_bytes = total_elements * 4; if input.byte_len() < in_bytes {
462 return Err(MlxError::InvalidArgument(format!(
463 "permute_021_bf16_to_f32: input buffer too small: need {} bytes, have {}",
464 in_bytes, input.byte_len()
465 )));
466 }
467 if output.byte_len() < out_bytes {
468 return Err(MlxError::InvalidArgument(format!(
469 "permute_021_bf16_to_f32: output buffer too small: need {} bytes, have {}",
470 out_bytes, output.byte_len()
471 )));
472 }
473
474 let pipeline = registry.get_pipeline("permute_021_bf16_to_f32", device)?;
475
476 let gpu_params = GpuPermute021Params {
477 dim_a: dim_a as u32,
478 dim_b: dim_b as u32,
479 dim_c: dim_c as u32,
480 };
481
482 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
483 let tg = MTLSize::new(
484 std::cmp::min(64, dim_c as u64),
485 std::cmp::min(4, dim_b as u64),
486 std::cmp::min(4, dim_a as u64),
487 );
488
489 encode_with_args(
490 encoder,
491 pipeline,
492 &[
493 (0, KernelArg::Buffer(input)),
494 (1, KernelArg::Buffer(output)),
495 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
496 ],
497 grid,
498 tg,
499 );
500
501 Ok(())
502}