Skip to main content

mlx_native/ops/
transpose.rs

1//! GPU-accelerated 2D matrix transpose.
2//!
3//! Transposes a 2D matrix `[rows, cols]` to `[cols, rows]`.
4//! Supports F32 and F16 dtypes.
5
6use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16/// MSL-compatible params struct for 2D transpose.
17///
18/// Must match `TransposeParams` in `elementwise.metal`.
19#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuTransposeParams {
22    rows: u32,
23    cols: u32,
24}
25
26/// Encode a 2D matrix transpose: `output[col, row] = input[row, col]`.
27///
28/// # Buffer expectations
29///
30/// * `input`  — `[rows, cols]` in the given dtype
31/// * `output` — `[cols, rows]` in the given dtype (must be pre-allocated)
32///
33/// # Errors
34///
35/// Returns `MlxError::InvalidArgument` if:
36/// * `rows` or `cols` is zero
37/// * `dtype` is not F32 or F16
38/// * Buffers are too small
39#[allow(clippy::too_many_arguments)]
40pub fn transpose_2d(
41    encoder: &mut CommandEncoder,
42    registry: &mut KernelRegistry,
43    device: &metal::DeviceRef,
44    input: &MlxBuffer,
45    output: &MlxBuffer,
46    rows: usize,
47    cols: usize,
48    dtype: DType,
49) -> Result<()> {
50    if rows == 0 {
51        return Err(MlxError::InvalidArgument(
52            "transpose_2d: rows must be > 0".into(),
53        ));
54    }
55    if cols == 0 {
56        return Err(MlxError::InvalidArgument(
57            "transpose_2d: cols must be > 0".into(),
58        ));
59    }
60
61    let kernel_name = match dtype {
62        DType::F32 => "transpose_2d_f32",
63        DType::F16 => "transpose_2d_f16",
64        _ => {
65            return Err(MlxError::InvalidArgument(format!(
66                "transpose_2d: unsupported dtype {dtype}"
67            )));
68        }
69    };
70
71    // hf2q ADR-030 iter-117 — defense-in-depth dtype check.  The caller-
72    // supplied `dtype` parameter selects the kernel pipeline; passing
73    // buffers of a DIFFERENT dtype would silently mis-stride.  Mirrors
74    // the pattern landed at hf2q ADR-030 iter-110→113.
75    if input.dtype() != dtype {
76        return Err(MlxError::InvalidArgument(format!(
77            "transpose_2d: input dtype {} != dtype param {}",
78            input.dtype(), dtype,
79        )));
80    }
81    if output.dtype() != dtype {
82        return Err(MlxError::InvalidArgument(format!(
83            "transpose_2d: output dtype {} != dtype param {}",
84            output.dtype(), dtype,
85        )));
86    }
87
88    let elem_bytes = rows * cols * dtype.size_of();
89    if input.byte_len() < elem_bytes {
90        return Err(MlxError::InvalidArgument(format!(
91            "transpose_2d: input buffer too small: need {} bytes, have {}",
92            elem_bytes,
93            input.byte_len()
94        )));
95    }
96    if output.byte_len() < elem_bytes {
97        return Err(MlxError::InvalidArgument(format!(
98            "transpose_2d: output buffer too small: need {} bytes, have {}",
99            elem_bytes,
100            output.byte_len()
101        )));
102    }
103
104    let pipeline = registry.get_pipeline(kernel_name, device)?;
105
106    let gpu_params = GpuTransposeParams {
107        rows: rows as u32,
108        cols: cols as u32,
109    };
110
111    // 2D grid: (cols, rows)
112    let grid = MTLSize::new(cols as u64, rows as u64, 1);
113    let tg = MTLSize::new(
114        std::cmp::min(16, cols as u64),
115        std::cmp::min(16, rows as u64),
116        1,
117    );
118
119    encode_with_args(
120        encoder,
121        pipeline,
122        &[
123            (0, KernelArg::Buffer(input)),
124            (1, KernelArg::Buffer(output)),
125            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
126        ],
127        grid,
128        tg,
129    );
130
131    Ok(())
132}
133
134/// MSL-compatible params struct for 3D permute [A, B, C] -> [B, A, C].
135///
136/// Must match `Permute021Params` in `elementwise.metal`.
137#[repr(C)]
138#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
139struct GpuPermute021Params {
140    dim_a: u32,
141    dim_b: u32,
142    dim_c: u32,
143}
144
145/// Encode a 3D permutation: `[A, B, C] -> [B, A, C]` (bf16).
146///
147/// This is used to convert between `[seq_len, n_heads, head_dim]` and
148/// `[n_heads, seq_len, head_dim]` layouts.
149///
150/// # Buffer expectations
151///
152/// * `input`  — `[dim_a, dim_b, dim_c]` in bf16
153/// * `output` — `[dim_b, dim_a, dim_c]` in bf16 (must be pre-allocated)
154///
155/// # Errors
156///
157/// Returns `MlxError::InvalidArgument` if any dimension is zero or buffers
158/// are too small.
159pub fn permute_021_f32(
160    encoder: &mut CommandEncoder,
161    registry: &mut KernelRegistry,
162    device: &metal::DeviceRef,
163    input: &MlxBuffer,
164    output: &MlxBuffer,
165    dim_a: usize,
166    dim_b: usize,
167    dim_c: usize,
168) -> Result<()> {
169    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
170        return Err(MlxError::InvalidArgument(
171            "permute_021_f32: all dimensions must be > 0".into(),
172        ));
173    }
174
175    let total_elements = dim_a * dim_b * dim_c;
176    let elem_bytes = total_elements * 4; // f32 = 4 bytes
177    if input.byte_len() < elem_bytes {
178        return Err(MlxError::InvalidArgument(format!(
179            "permute_021_f32: input buffer too small: need {} bytes, have {}",
180            elem_bytes,
181            input.byte_len()
182        )));
183    }
184    if output.byte_len() < elem_bytes {
185        return Err(MlxError::InvalidArgument(format!(
186            "permute_021_f32: output buffer too small: need {} bytes, have {}",
187            elem_bytes,
188            output.byte_len()
189        )));
190    }
191
192    let pipeline = registry.get_pipeline("permute_021_f32", device)?;
193
194    let gpu_params = GpuPermute021Params {
195        dim_a: dim_a as u32,
196        dim_b: dim_b as u32,
197        dim_c: dim_c as u32,
198    };
199
200    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
201    let tg = MTLSize::new(
202        std::cmp::min(64, dim_c as u64),
203        std::cmp::min(4, dim_b as u64),
204        std::cmp::min(4, dim_a as u64),
205    );
206
207    encode_with_args(
208        encoder,
209        pipeline,
210        &[
211            (0, KernelArg::Buffer(input)),
212            (1, KernelArg::Buffer(output)),
213            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
214        ],
215        grid,
216        tg,
217    );
218
219    Ok(())
220}
221
222/// Swap the last two axes of a 3D bf16 tensor: [A, B, C] -> [A, C, B].
223///
224/// Used by hf2q's non-flash-attention prefill path to transpose V from
225/// its natural `[nkv, seq, hd]` post-RoPE layout to the `[nkv, hd, seq]`
226/// layout the `scores @ V` matmul consumes (the contract dim of our
227/// tensor-mm kernel is the inner-most axis of src0).
228///
229/// One dispatch covers all A batches; each thread copies a single bf16.
230pub fn transpose_last2_bf16(
231    encoder: &mut CommandEncoder,
232    registry: &mut KernelRegistry,
233    device: &metal::DeviceRef,
234    input: &MlxBuffer,
235    output: &MlxBuffer,
236    dim_a: usize,
237    dim_b: usize,
238    dim_c: usize,
239) -> Result<()> {
240    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
241        return Err(MlxError::InvalidArgument(
242            "transpose_last2_bf16: all dimensions must be > 0".into(),
243        ));
244    }
245
246    let total_elements = dim_a * dim_b * dim_c;
247    let elem_bytes = total_elements * 2;
248    if input.byte_len() < elem_bytes {
249        return Err(MlxError::InvalidArgument(format!(
250            "transpose_last2_bf16: input buffer too small: need {} bytes, have {}",
251            elem_bytes, input.byte_len()
252        )));
253    }
254    if output.byte_len() < elem_bytes {
255        return Err(MlxError::InvalidArgument(format!(
256            "transpose_last2_bf16: output buffer too small: need {} bytes, have {}",
257            elem_bytes, output.byte_len()
258        )));
259    }
260
261    let pipeline = registry.get_pipeline("transpose_last2_bf16", device)?;
262
263    let gpu_params = GpuPermute021Params {
264        dim_a: dim_a as u32,
265        dim_b: dim_b as u32,
266        dim_c: dim_c as u32,
267    };
268
269    // Grid: (dim_b, dim_c, dim_a).  Kernel maps (gid.x, gid.y, gid.z) →
270    // (b, c, a); see shaders/elementwise.metal::transpose_last2_bf16.
271    let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
272    let tg = MTLSize::new(
273        std::cmp::min(16, dim_b as u64),
274        std::cmp::min(16, dim_c as u64),
275        std::cmp::min(4, dim_a as u64),
276    );
277
278    encode_with_args(
279        encoder,
280        pipeline,
281        &[
282            (0, KernelArg::Buffer(input)),
283            (1, KernelArg::Buffer(output)),
284            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
285        ],
286        grid,
287        tg,
288    );
289
290    Ok(())
291}
292
293/// Swap the last two axes of a 3D f16 tensor: [A, B, C] -> [A, C, B].
294///
295/// F16 sibling of [`transpose_last2_bf16`].  Added for ADR-005 Phase 2c
296/// iter-129 (gemma4v ViT precision parity): peer's `flash_attn_ext`
297/// stages V at F16 (10-bit mantissa) and runs `simdgroup_half8x8` MMA
298/// on V@scores.  Pre-iter-129 hf2q cast V F32→BF16 (7-bit mantissa)
299/// before the transpose, capturing the dominant residual per-block
300/// cascade after iter-128's weight-matmul F16 fix.
301///
302/// Used by `vit_attention_gpu` to materialise V_t (V transposed over
303/// seq↔hd) so the F16-staged `scores @ V` matmul
304/// (`dense_matmul_f16_f32_tensor`) can consume it at the tile geometry
305/// the kernel expects.
306///
307/// One dispatch covers all A batches; each thread copies a single half.
308pub fn transpose_last2_f16(
309    encoder: &mut CommandEncoder,
310    registry: &mut KernelRegistry,
311    device: &metal::DeviceRef,
312    input: &MlxBuffer,
313    output: &MlxBuffer,
314    dim_a: usize,
315    dim_b: usize,
316    dim_c: usize,
317) -> Result<()> {
318    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
319        return Err(MlxError::InvalidArgument(
320            "transpose_last2_f16: all dimensions must be > 0".into(),
321        ));
322    }
323
324    let total_elements = dim_a * dim_b * dim_c;
325    let elem_bytes = total_elements * 2; // f16 = 2 bytes
326    if input.byte_len() < elem_bytes {
327        return Err(MlxError::InvalidArgument(format!(
328            "transpose_last2_f16: input buffer too small: need {} bytes, have {}",
329            elem_bytes, input.byte_len()
330        )));
331    }
332    if output.byte_len() < elem_bytes {
333        return Err(MlxError::InvalidArgument(format!(
334            "transpose_last2_f16: output buffer too small: need {} bytes, have {}",
335            elem_bytes, output.byte_len()
336        )));
337    }
338
339    let pipeline = registry.get_pipeline("transpose_last2_f16", device)?;
340
341    let gpu_params = GpuPermute021Params {
342        dim_a: dim_a as u32,
343        dim_b: dim_b as u32,
344        dim_c: dim_c as u32,
345    };
346
347    // Grid: (dim_b, dim_c, dim_a).  Kernel maps (gid.x, gid.y, gid.z) →
348    // (b, c, a); see shaders/elementwise.metal::transpose_last2_f16.
349    // Same threadgroup geometry as the BF16 sibling — bfloat and half
350    // share the 16-bit storage size, so the dispatch is byte-identical.
351    let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
352    let tg = MTLSize::new(
353        std::cmp::min(16, dim_b as u64),
354        std::cmp::min(16, dim_c as u64),
355        std::cmp::min(4, dim_a as u64),
356    );
357
358    encode_with_args(
359        encoder,
360        pipeline,
361        &[
362            (0, KernelArg::Buffer(input)),
363            (1, KernelArg::Buffer(output)),
364            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
365        ],
366        grid,
367        tg,
368    );
369
370    Ok(())
371}
372
373pub fn permute_021_bf16(
374    encoder: &mut CommandEncoder,
375    registry: &mut KernelRegistry,
376    device: &metal::DeviceRef,
377    input: &MlxBuffer,
378    output: &MlxBuffer,
379    dim_a: usize,
380    dim_b: usize,
381    dim_c: usize,
382) -> Result<()> {
383    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
384        return Err(MlxError::InvalidArgument(
385            "permute_021_bf16: all dimensions must be > 0".into(),
386        ));
387    }
388
389    let total_elements = dim_a * dim_b * dim_c;
390    let elem_bytes = total_elements * 2; // bf16 = 2 bytes
391    if input.byte_len() < elem_bytes {
392        return Err(MlxError::InvalidArgument(format!(
393            "permute_021_bf16: input buffer too small: need {} bytes, have {}",
394            elem_bytes,
395            input.byte_len()
396        )));
397    }
398    if output.byte_len() < elem_bytes {
399        return Err(MlxError::InvalidArgument(format!(
400            "permute_021_bf16: output buffer too small: need {} bytes, have {}",
401            elem_bytes,
402            output.byte_len()
403        )));
404    }
405
406    let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
407
408    let gpu_params = GpuPermute021Params {
409        dim_a: dim_a as u32,
410        dim_b: dim_b as u32,
411        dim_c: dim_c as u32,
412    };
413
414    // 3D grid: (dim_c, dim_b, dim_a), each thread copies one element
415    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
416    let tg = MTLSize::new(
417        std::cmp::min(64, dim_c as u64),
418        std::cmp::min(4, dim_b as u64),
419        std::cmp::min(4, dim_a as u64),
420    );
421
422    encode_with_args(
423        encoder,
424        pipeline,
425        &[
426            (0, KernelArg::Buffer(input)),
427            (1, KernelArg::Buffer(output)),
428            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
429        ],
430        grid,
431        tg,
432    );
433
434    Ok(())
435}
436
437/// Fused permute_021 + bf16→f32 cast.  Replaces the two-pass sequence
438/// `permute_021_bf16(bf16 → bf16) ; cast_bf16_to_f32(bf16 → f32)` with a
439/// single dispatch that reads bf16 in [A, B, C] order and writes f32 in
440/// [B, A, C] order, halving the global-memory traffic on the post-FA SDPA
441/// output buffer.  Wave P4.10.
442pub fn permute_021_bf16_to_f32(
443    encoder: &mut CommandEncoder,
444    registry: &mut KernelRegistry,
445    device: &metal::DeviceRef,
446    input: &MlxBuffer,
447    output: &MlxBuffer,
448    dim_a: usize,
449    dim_b: usize,
450    dim_c: usize,
451) -> Result<()> {
452    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
453        return Err(MlxError::InvalidArgument(
454            "permute_021_bf16_to_f32: all dimensions must be > 0".into(),
455        ));
456    }
457
458    let total_elements = dim_a * dim_b * dim_c;
459    let in_bytes = total_elements * 2; // bf16
460    let out_bytes = total_elements * 4; // f32
461    if input.byte_len() < in_bytes {
462        return Err(MlxError::InvalidArgument(format!(
463            "permute_021_bf16_to_f32: input buffer too small: need {} bytes, have {}",
464            in_bytes, input.byte_len()
465        )));
466    }
467    if output.byte_len() < out_bytes {
468        return Err(MlxError::InvalidArgument(format!(
469            "permute_021_bf16_to_f32: output buffer too small: need {} bytes, have {}",
470            out_bytes, output.byte_len()
471        )));
472    }
473
474    let pipeline = registry.get_pipeline("permute_021_bf16_to_f32", device)?;
475
476    let gpu_params = GpuPermute021Params {
477        dim_a: dim_a as u32,
478        dim_b: dim_b as u32,
479        dim_c: dim_c as u32,
480    };
481
482    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
483    let tg = MTLSize::new(
484        std::cmp::min(64, dim_c as u64),
485        std::cmp::min(4, dim_b as u64),
486        std::cmp::min(4, dim_a as u64),
487    );
488
489    encode_with_args(
490        encoder,
491        pipeline,
492        &[
493            (0, KernelArg::Buffer(input)),
494            (1, KernelArg::Buffer(output)),
495            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
496        ],
497        grid,
498        tg,
499    );
500
501    Ok(())
502}