1use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuTransposeParams {
22 rows: u32,
23 cols: u32,
24}
25
26#[allow(clippy::too_many_arguments)]
40pub fn transpose_2d(
41 encoder: &mut CommandEncoder,
42 registry: &mut KernelRegistry,
43 device: &metal::DeviceRef,
44 input: &MlxBuffer,
45 output: &MlxBuffer,
46 rows: usize,
47 cols: usize,
48 dtype: DType,
49) -> Result<()> {
50 if rows == 0 {
51 return Err(MlxError::InvalidArgument(
52 "transpose_2d: rows must be > 0".into(),
53 ));
54 }
55 if cols == 0 {
56 return Err(MlxError::InvalidArgument(
57 "transpose_2d: cols must be > 0".into(),
58 ));
59 }
60
61 let kernel_name = match dtype {
62 DType::F32 => "transpose_2d_f32",
63 DType::F16 => "transpose_2d_f16",
64 _ => {
65 return Err(MlxError::InvalidArgument(format!(
66 "transpose_2d: unsupported dtype {dtype}"
67 )));
68 }
69 };
70
71 let elem_bytes = rows * cols * dtype.size_of();
72 if input.byte_len() < elem_bytes {
73 return Err(MlxError::InvalidArgument(format!(
74 "transpose_2d: input buffer too small: need {} bytes, have {}",
75 elem_bytes,
76 input.byte_len()
77 )));
78 }
79 if output.byte_len() < elem_bytes {
80 return Err(MlxError::InvalidArgument(format!(
81 "transpose_2d: output buffer too small: need {} bytes, have {}",
82 elem_bytes,
83 output.byte_len()
84 )));
85 }
86
87 let pipeline = registry.get_pipeline(kernel_name, device)?;
88
89 let gpu_params = GpuTransposeParams {
90 rows: rows as u32,
91 cols: cols as u32,
92 };
93
94 let grid = MTLSize::new(cols as u64, rows as u64, 1);
96 let tg = MTLSize::new(
97 std::cmp::min(16, cols as u64),
98 std::cmp::min(16, rows as u64),
99 1,
100 );
101
102 encode_with_args(
103 encoder,
104 pipeline,
105 &[
106 (0, KernelArg::Buffer(input)),
107 (1, KernelArg::Buffer(output)),
108 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
109 ],
110 grid,
111 tg,
112 );
113
114 Ok(())
115}
116
117#[repr(C)]
121#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
122struct GpuPermute021Params {
123 dim_a: u32,
124 dim_b: u32,
125 dim_c: u32,
126}
127
128pub fn permute_021_f32(
143 encoder: &mut CommandEncoder,
144 registry: &mut KernelRegistry,
145 device: &metal::DeviceRef,
146 input: &MlxBuffer,
147 output: &MlxBuffer,
148 dim_a: usize,
149 dim_b: usize,
150 dim_c: usize,
151) -> Result<()> {
152 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
153 return Err(MlxError::InvalidArgument(
154 "permute_021_f32: all dimensions must be > 0".into(),
155 ));
156 }
157
158 let total_elements = dim_a * dim_b * dim_c;
159 let elem_bytes = total_elements * 4; if input.byte_len() < elem_bytes {
161 return Err(MlxError::InvalidArgument(format!(
162 "permute_021_f32: input buffer too small: need {} bytes, have {}",
163 elem_bytes,
164 input.byte_len()
165 )));
166 }
167 if output.byte_len() < elem_bytes {
168 return Err(MlxError::InvalidArgument(format!(
169 "permute_021_f32: output buffer too small: need {} bytes, have {}",
170 elem_bytes,
171 output.byte_len()
172 )));
173 }
174
175 let pipeline = registry.get_pipeline("permute_021_f32", device)?;
176
177 let gpu_params = GpuPermute021Params {
178 dim_a: dim_a as u32,
179 dim_b: dim_b as u32,
180 dim_c: dim_c as u32,
181 };
182
183 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
184 let tg = MTLSize::new(
185 std::cmp::min(64, dim_c as u64),
186 std::cmp::min(4, dim_b as u64),
187 std::cmp::min(4, dim_a as u64),
188 );
189
190 encode_with_args(
191 encoder,
192 pipeline,
193 &[
194 (0, KernelArg::Buffer(input)),
195 (1, KernelArg::Buffer(output)),
196 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
197 ],
198 grid,
199 tg,
200 );
201
202 Ok(())
203}
204
205pub fn transpose_last2_bf16(
214 encoder: &mut CommandEncoder,
215 registry: &mut KernelRegistry,
216 device: &metal::DeviceRef,
217 input: &MlxBuffer,
218 output: &MlxBuffer,
219 dim_a: usize,
220 dim_b: usize,
221 dim_c: usize,
222) -> Result<()> {
223 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
224 return Err(MlxError::InvalidArgument(
225 "transpose_last2_bf16: all dimensions must be > 0".into(),
226 ));
227 }
228
229 let total_elements = dim_a * dim_b * dim_c;
230 let elem_bytes = total_elements * 2;
231 if input.byte_len() < elem_bytes {
232 return Err(MlxError::InvalidArgument(format!(
233 "transpose_last2_bf16: input buffer too small: need {} bytes, have {}",
234 elem_bytes, input.byte_len()
235 )));
236 }
237 if output.byte_len() < elem_bytes {
238 return Err(MlxError::InvalidArgument(format!(
239 "transpose_last2_bf16: output buffer too small: need {} bytes, have {}",
240 elem_bytes, output.byte_len()
241 )));
242 }
243
244 let pipeline = registry.get_pipeline("transpose_last2_bf16", device)?;
245
246 let gpu_params = GpuPermute021Params {
247 dim_a: dim_a as u32,
248 dim_b: dim_b as u32,
249 dim_c: dim_c as u32,
250 };
251
252 let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
255 let tg = MTLSize::new(
256 std::cmp::min(16, dim_b as u64),
257 std::cmp::min(16, dim_c as u64),
258 std::cmp::min(4, dim_a as u64),
259 );
260
261 encode_with_args(
262 encoder,
263 pipeline,
264 &[
265 (0, KernelArg::Buffer(input)),
266 (1, KernelArg::Buffer(output)),
267 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
268 ],
269 grid,
270 tg,
271 );
272
273 Ok(())
274}
275
276pub fn transpose_last2_f16(
292 encoder: &mut CommandEncoder,
293 registry: &mut KernelRegistry,
294 device: &metal::DeviceRef,
295 input: &MlxBuffer,
296 output: &MlxBuffer,
297 dim_a: usize,
298 dim_b: usize,
299 dim_c: usize,
300) -> Result<()> {
301 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
302 return Err(MlxError::InvalidArgument(
303 "transpose_last2_f16: all dimensions must be > 0".into(),
304 ));
305 }
306
307 let total_elements = dim_a * dim_b * dim_c;
308 let elem_bytes = total_elements * 2; if input.byte_len() < elem_bytes {
310 return Err(MlxError::InvalidArgument(format!(
311 "transpose_last2_f16: input buffer too small: need {} bytes, have {}",
312 elem_bytes, input.byte_len()
313 )));
314 }
315 if output.byte_len() < elem_bytes {
316 return Err(MlxError::InvalidArgument(format!(
317 "transpose_last2_f16: output buffer too small: need {} bytes, have {}",
318 elem_bytes, output.byte_len()
319 )));
320 }
321
322 let pipeline = registry.get_pipeline("transpose_last2_f16", device)?;
323
324 let gpu_params = GpuPermute021Params {
325 dim_a: dim_a as u32,
326 dim_b: dim_b as u32,
327 dim_c: dim_c as u32,
328 };
329
330 let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
335 let tg = MTLSize::new(
336 std::cmp::min(16, dim_b as u64),
337 std::cmp::min(16, dim_c as u64),
338 std::cmp::min(4, dim_a as u64),
339 );
340
341 encode_with_args(
342 encoder,
343 pipeline,
344 &[
345 (0, KernelArg::Buffer(input)),
346 (1, KernelArg::Buffer(output)),
347 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
348 ],
349 grid,
350 tg,
351 );
352
353 Ok(())
354}
355
356pub fn permute_021_bf16(
357 encoder: &mut CommandEncoder,
358 registry: &mut KernelRegistry,
359 device: &metal::DeviceRef,
360 input: &MlxBuffer,
361 output: &MlxBuffer,
362 dim_a: usize,
363 dim_b: usize,
364 dim_c: usize,
365) -> Result<()> {
366 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
367 return Err(MlxError::InvalidArgument(
368 "permute_021_bf16: all dimensions must be > 0".into(),
369 ));
370 }
371
372 let total_elements = dim_a * dim_b * dim_c;
373 let elem_bytes = total_elements * 2; if input.byte_len() < elem_bytes {
375 return Err(MlxError::InvalidArgument(format!(
376 "permute_021_bf16: input buffer too small: need {} bytes, have {}",
377 elem_bytes,
378 input.byte_len()
379 )));
380 }
381 if output.byte_len() < elem_bytes {
382 return Err(MlxError::InvalidArgument(format!(
383 "permute_021_bf16: output buffer too small: need {} bytes, have {}",
384 elem_bytes,
385 output.byte_len()
386 )));
387 }
388
389 let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
390
391 let gpu_params = GpuPermute021Params {
392 dim_a: dim_a as u32,
393 dim_b: dim_b as u32,
394 dim_c: dim_c as u32,
395 };
396
397 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
399 let tg = MTLSize::new(
400 std::cmp::min(64, dim_c as u64),
401 std::cmp::min(4, dim_b as u64),
402 std::cmp::min(4, dim_a as u64),
403 );
404
405 encode_with_args(
406 encoder,
407 pipeline,
408 &[
409 (0, KernelArg::Buffer(input)),
410 (1, KernelArg::Buffer(output)),
411 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
412 ],
413 grid,
414 tg,
415 );
416
417 Ok(())
418}
419
420pub fn permute_021_bf16_to_f32(
426 encoder: &mut CommandEncoder,
427 registry: &mut KernelRegistry,
428 device: &metal::DeviceRef,
429 input: &MlxBuffer,
430 output: &MlxBuffer,
431 dim_a: usize,
432 dim_b: usize,
433 dim_c: usize,
434) -> Result<()> {
435 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
436 return Err(MlxError::InvalidArgument(
437 "permute_021_bf16_to_f32: all dimensions must be > 0".into(),
438 ));
439 }
440
441 let total_elements = dim_a * dim_b * dim_c;
442 let in_bytes = total_elements * 2; let out_bytes = total_elements * 4; if input.byte_len() < in_bytes {
445 return Err(MlxError::InvalidArgument(format!(
446 "permute_021_bf16_to_f32: input buffer too small: need {} bytes, have {}",
447 in_bytes, input.byte_len()
448 )));
449 }
450 if output.byte_len() < out_bytes {
451 return Err(MlxError::InvalidArgument(format!(
452 "permute_021_bf16_to_f32: output buffer too small: need {} bytes, have {}",
453 out_bytes, output.byte_len()
454 )));
455 }
456
457 let pipeline = registry.get_pipeline("permute_021_bf16_to_f32", device)?;
458
459 let gpu_params = GpuPermute021Params {
460 dim_a: dim_a as u32,
461 dim_b: dim_b as u32,
462 dim_c: dim_c as u32,
463 };
464
465 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
466 let tg = MTLSize::new(
467 std::cmp::min(64, dim_c as u64),
468 std::cmp::min(4, dim_b as u64),
469 std::cmp::min(4, dim_a as u64),
470 );
471
472 encode_with_args(
473 encoder,
474 pipeline,
475 &[
476 (0, KernelArg::Buffer(input)),
477 (1, KernelArg::Buffer(output)),
478 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
479 ],
480 grid,
481 tg,
482 );
483
484 Ok(())
485}