Skip to main content

mlx_native/ops/
transpose.rs

1//! GPU-accelerated 2D matrix transpose.
2//!
3//! Transposes a 2D matrix `[rows, cols]` to `[cols, rows]`.
4//! Supports F32 and F16 dtypes.
5
6use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16/// MSL-compatible params struct for 2D transpose.
17///
18/// Must match `TransposeParams` in `elementwise.metal`.
19#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuTransposeParams {
22    rows: u32,
23    cols: u32,
24}
25
26/// Encode a 2D matrix transpose: `output[col, row] = input[row, col]`.
27///
28/// # Buffer expectations
29///
30/// * `input`  — `[rows, cols]` in the given dtype
31/// * `output` — `[cols, rows]` in the given dtype (must be pre-allocated)
32///
33/// # Errors
34///
35/// Returns `MlxError::InvalidArgument` if:
36/// * `rows` or `cols` is zero
37/// * `dtype` is not F32 or F16
38/// * Buffers are too small
39#[allow(clippy::too_many_arguments)]
40pub fn transpose_2d(
41    encoder: &mut CommandEncoder,
42    registry: &mut KernelRegistry,
43    device: &metal::DeviceRef,
44    input: &MlxBuffer,
45    output: &MlxBuffer,
46    rows: usize,
47    cols: usize,
48    dtype: DType,
49) -> Result<()> {
50    if rows == 0 {
51        return Err(MlxError::InvalidArgument(
52            "transpose_2d: rows must be > 0".into(),
53        ));
54    }
55    if cols == 0 {
56        return Err(MlxError::InvalidArgument(
57            "transpose_2d: cols must be > 0".into(),
58        ));
59    }
60
61    let kernel_name = match dtype {
62        DType::F32 => "transpose_2d_f32",
63        DType::F16 => "transpose_2d_f16",
64        _ => {
65            return Err(MlxError::InvalidArgument(format!(
66                "transpose_2d: unsupported dtype {dtype}"
67            )));
68        }
69    };
70
71    let elem_bytes = rows * cols * dtype.size_of();
72    if input.byte_len() < elem_bytes {
73        return Err(MlxError::InvalidArgument(format!(
74            "transpose_2d: input buffer too small: need {} bytes, have {}",
75            elem_bytes,
76            input.byte_len()
77        )));
78    }
79    if output.byte_len() < elem_bytes {
80        return Err(MlxError::InvalidArgument(format!(
81            "transpose_2d: output buffer too small: need {} bytes, have {}",
82            elem_bytes,
83            output.byte_len()
84        )));
85    }
86
87    let pipeline = registry.get_pipeline(kernel_name, device)?;
88
89    let gpu_params = GpuTransposeParams {
90        rows: rows as u32,
91        cols: cols as u32,
92    };
93
94    // 2D grid: (cols, rows)
95    let grid = MTLSize::new(cols as u64, rows as u64, 1);
96    let tg = MTLSize::new(
97        std::cmp::min(16, cols as u64),
98        std::cmp::min(16, rows as u64),
99        1,
100    );
101
102    encode_with_args(
103        encoder,
104        pipeline,
105        &[
106            (0, KernelArg::Buffer(input)),
107            (1, KernelArg::Buffer(output)),
108            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
109        ],
110        grid,
111        tg,
112    );
113
114    Ok(())
115}
116
117/// MSL-compatible params struct for 3D permute [A, B, C] -> [B, A, C].
118///
119/// Must match `Permute021Params` in `elementwise.metal`.
120#[repr(C)]
121#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
122struct GpuPermute021Params {
123    dim_a: u32,
124    dim_b: u32,
125    dim_c: u32,
126}
127
128/// Encode a 3D permutation: `[A, B, C] -> [B, A, C]` (bf16).
129///
130/// This is used to convert between `[seq_len, n_heads, head_dim]` and
131/// `[n_heads, seq_len, head_dim]` layouts.
132///
133/// # Buffer expectations
134///
135/// * `input`  — `[dim_a, dim_b, dim_c]` in bf16
136/// * `output` — `[dim_b, dim_a, dim_c]` in bf16 (must be pre-allocated)
137///
138/// # Errors
139///
140/// Returns `MlxError::InvalidArgument` if any dimension is zero or buffers
141/// are too small.
142pub fn permute_021_f32(
143    encoder: &mut CommandEncoder,
144    registry: &mut KernelRegistry,
145    device: &metal::DeviceRef,
146    input: &MlxBuffer,
147    output: &MlxBuffer,
148    dim_a: usize,
149    dim_b: usize,
150    dim_c: usize,
151) -> Result<()> {
152    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
153        return Err(MlxError::InvalidArgument(
154            "permute_021_f32: all dimensions must be > 0".into(),
155        ));
156    }
157
158    let total_elements = dim_a * dim_b * dim_c;
159    let elem_bytes = total_elements * 4; // f32 = 4 bytes
160    if input.byte_len() < elem_bytes {
161        return Err(MlxError::InvalidArgument(format!(
162            "permute_021_f32: input buffer too small: need {} bytes, have {}",
163            elem_bytes,
164            input.byte_len()
165        )));
166    }
167    if output.byte_len() < elem_bytes {
168        return Err(MlxError::InvalidArgument(format!(
169            "permute_021_f32: output buffer too small: need {} bytes, have {}",
170            elem_bytes,
171            output.byte_len()
172        )));
173    }
174
175    let pipeline = registry.get_pipeline("permute_021_f32", device)?;
176
177    let gpu_params = GpuPermute021Params {
178        dim_a: dim_a as u32,
179        dim_b: dim_b as u32,
180        dim_c: dim_c as u32,
181    };
182
183    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
184    let tg = MTLSize::new(
185        std::cmp::min(64, dim_c as u64),
186        std::cmp::min(4, dim_b as u64),
187        std::cmp::min(4, dim_a as u64),
188    );
189
190    encode_with_args(
191        encoder,
192        pipeline,
193        &[
194            (0, KernelArg::Buffer(input)),
195            (1, KernelArg::Buffer(output)),
196            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
197        ],
198        grid,
199        tg,
200    );
201
202    Ok(())
203}
204
205/// Swap the last two axes of a 3D bf16 tensor: [A, B, C] -> [A, C, B].
206///
207/// Used by hf2q's non-flash-attention prefill path to transpose V from
208/// its natural `[nkv, seq, hd]` post-RoPE layout to the `[nkv, hd, seq]`
209/// layout the `scores @ V` matmul consumes (the contract dim of our
210/// tensor-mm kernel is the inner-most axis of src0).
211///
212/// One dispatch covers all A batches; each thread copies a single bf16.
213pub fn transpose_last2_bf16(
214    encoder: &mut CommandEncoder,
215    registry: &mut KernelRegistry,
216    device: &metal::DeviceRef,
217    input: &MlxBuffer,
218    output: &MlxBuffer,
219    dim_a: usize,
220    dim_b: usize,
221    dim_c: usize,
222) -> Result<()> {
223    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
224        return Err(MlxError::InvalidArgument(
225            "transpose_last2_bf16: all dimensions must be > 0".into(),
226        ));
227    }
228
229    let total_elements = dim_a * dim_b * dim_c;
230    let elem_bytes = total_elements * 2;
231    if input.byte_len() < elem_bytes {
232        return Err(MlxError::InvalidArgument(format!(
233            "transpose_last2_bf16: input buffer too small: need {} bytes, have {}",
234            elem_bytes, input.byte_len()
235        )));
236    }
237    if output.byte_len() < elem_bytes {
238        return Err(MlxError::InvalidArgument(format!(
239            "transpose_last2_bf16: output buffer too small: need {} bytes, have {}",
240            elem_bytes, output.byte_len()
241        )));
242    }
243
244    let pipeline = registry.get_pipeline("transpose_last2_bf16", device)?;
245
246    let gpu_params = GpuPermute021Params {
247        dim_a: dim_a as u32,
248        dim_b: dim_b as u32,
249        dim_c: dim_c as u32,
250    };
251
252    // Grid: (dim_b, dim_c, dim_a).  Kernel maps (gid.x, gid.y, gid.z) →
253    // (b, c, a); see shaders/elementwise.metal::transpose_last2_bf16.
254    let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
255    let tg = MTLSize::new(
256        std::cmp::min(16, dim_b as u64),
257        std::cmp::min(16, dim_c as u64),
258        std::cmp::min(4, dim_a as u64),
259    );
260
261    encode_with_args(
262        encoder,
263        pipeline,
264        &[
265            (0, KernelArg::Buffer(input)),
266            (1, KernelArg::Buffer(output)),
267            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
268        ],
269        grid,
270        tg,
271    );
272
273    Ok(())
274}
275
276/// Swap the last two axes of a 3D f16 tensor: [A, B, C] -> [A, C, B].
277///
278/// F16 sibling of [`transpose_last2_bf16`].  Added for ADR-005 Phase 2c
279/// iter-129 (gemma4v ViT precision parity): peer's `flash_attn_ext`
280/// stages V at F16 (10-bit mantissa) and runs `simdgroup_half8x8` MMA
281/// on V@scores.  Pre-iter-129 hf2q cast V F32→BF16 (7-bit mantissa)
282/// before the transpose, capturing the dominant residual per-block
283/// cascade after iter-128's weight-matmul F16 fix.
284///
285/// Used by `vit_attention_gpu` to materialise V_t (V transposed over
286/// seq↔hd) so the F16-staged `scores @ V` matmul
287/// (`dense_matmul_f16_f32_tensor`) can consume it at the tile geometry
288/// the kernel expects.
289///
290/// One dispatch covers all A batches; each thread copies a single half.
291pub fn transpose_last2_f16(
292    encoder: &mut CommandEncoder,
293    registry: &mut KernelRegistry,
294    device: &metal::DeviceRef,
295    input: &MlxBuffer,
296    output: &MlxBuffer,
297    dim_a: usize,
298    dim_b: usize,
299    dim_c: usize,
300) -> Result<()> {
301    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
302        return Err(MlxError::InvalidArgument(
303            "transpose_last2_f16: all dimensions must be > 0".into(),
304        ));
305    }
306
307    let total_elements = dim_a * dim_b * dim_c;
308    let elem_bytes = total_elements * 2; // f16 = 2 bytes
309    if input.byte_len() < elem_bytes {
310        return Err(MlxError::InvalidArgument(format!(
311            "transpose_last2_f16: input buffer too small: need {} bytes, have {}",
312            elem_bytes, input.byte_len()
313        )));
314    }
315    if output.byte_len() < elem_bytes {
316        return Err(MlxError::InvalidArgument(format!(
317            "transpose_last2_f16: output buffer too small: need {} bytes, have {}",
318            elem_bytes, output.byte_len()
319        )));
320    }
321
322    let pipeline = registry.get_pipeline("transpose_last2_f16", device)?;
323
324    let gpu_params = GpuPermute021Params {
325        dim_a: dim_a as u32,
326        dim_b: dim_b as u32,
327        dim_c: dim_c as u32,
328    };
329
330    // Grid: (dim_b, dim_c, dim_a).  Kernel maps (gid.x, gid.y, gid.z) →
331    // (b, c, a); see shaders/elementwise.metal::transpose_last2_f16.
332    // Same threadgroup geometry as the BF16 sibling — bfloat and half
333    // share the 16-bit storage size, so the dispatch is byte-identical.
334    let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
335    let tg = MTLSize::new(
336        std::cmp::min(16, dim_b as u64),
337        std::cmp::min(16, dim_c as u64),
338        std::cmp::min(4, dim_a as u64),
339    );
340
341    encode_with_args(
342        encoder,
343        pipeline,
344        &[
345            (0, KernelArg::Buffer(input)),
346            (1, KernelArg::Buffer(output)),
347            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
348        ],
349        grid,
350        tg,
351    );
352
353    Ok(())
354}
355
356pub fn permute_021_bf16(
357    encoder: &mut CommandEncoder,
358    registry: &mut KernelRegistry,
359    device: &metal::DeviceRef,
360    input: &MlxBuffer,
361    output: &MlxBuffer,
362    dim_a: usize,
363    dim_b: usize,
364    dim_c: usize,
365) -> Result<()> {
366    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
367        return Err(MlxError::InvalidArgument(
368            "permute_021_bf16: all dimensions must be > 0".into(),
369        ));
370    }
371
372    let total_elements = dim_a * dim_b * dim_c;
373    let elem_bytes = total_elements * 2; // bf16 = 2 bytes
374    if input.byte_len() < elem_bytes {
375        return Err(MlxError::InvalidArgument(format!(
376            "permute_021_bf16: input buffer too small: need {} bytes, have {}",
377            elem_bytes,
378            input.byte_len()
379        )));
380    }
381    if output.byte_len() < elem_bytes {
382        return Err(MlxError::InvalidArgument(format!(
383            "permute_021_bf16: output buffer too small: need {} bytes, have {}",
384            elem_bytes,
385            output.byte_len()
386        )));
387    }
388
389    let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
390
391    let gpu_params = GpuPermute021Params {
392        dim_a: dim_a as u32,
393        dim_b: dim_b as u32,
394        dim_c: dim_c as u32,
395    };
396
397    // 3D grid: (dim_c, dim_b, dim_a), each thread copies one element
398    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
399    let tg = MTLSize::new(
400        std::cmp::min(64, dim_c as u64),
401        std::cmp::min(4, dim_b as u64),
402        std::cmp::min(4, dim_a as u64),
403    );
404
405    encode_with_args(
406        encoder,
407        pipeline,
408        &[
409            (0, KernelArg::Buffer(input)),
410            (1, KernelArg::Buffer(output)),
411            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
412        ],
413        grid,
414        tg,
415    );
416
417    Ok(())
418}
419
420/// Fused permute_021 + bf16→f32 cast.  Replaces the two-pass sequence
421/// `permute_021_bf16(bf16 → bf16) ; cast_bf16_to_f32(bf16 → f32)` with a
422/// single dispatch that reads bf16 in [A, B, C] order and writes f32 in
423/// [B, A, C] order, halving the global-memory traffic on the post-FA SDPA
424/// output buffer.  Wave P4.10.
425pub fn permute_021_bf16_to_f32(
426    encoder: &mut CommandEncoder,
427    registry: &mut KernelRegistry,
428    device: &metal::DeviceRef,
429    input: &MlxBuffer,
430    output: &MlxBuffer,
431    dim_a: usize,
432    dim_b: usize,
433    dim_c: usize,
434) -> Result<()> {
435    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
436        return Err(MlxError::InvalidArgument(
437            "permute_021_bf16_to_f32: all dimensions must be > 0".into(),
438        ));
439    }
440
441    let total_elements = dim_a * dim_b * dim_c;
442    let in_bytes = total_elements * 2; // bf16
443    let out_bytes = total_elements * 4; // f32
444    if input.byte_len() < in_bytes {
445        return Err(MlxError::InvalidArgument(format!(
446            "permute_021_bf16_to_f32: input buffer too small: need {} bytes, have {}",
447            in_bytes, input.byte_len()
448        )));
449    }
450    if output.byte_len() < out_bytes {
451        return Err(MlxError::InvalidArgument(format!(
452            "permute_021_bf16_to_f32: output buffer too small: need {} bytes, have {}",
453            out_bytes, output.byte_len()
454        )));
455    }
456
457    let pipeline = registry.get_pipeline("permute_021_bf16_to_f32", device)?;
458
459    let gpu_params = GpuPermute021Params {
460        dim_a: dim_a as u32,
461        dim_b: dim_b as u32,
462        dim_c: dim_c as u32,
463    };
464
465    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
466    let tg = MTLSize::new(
467        std::cmp::min(64, dim_c as u64),
468        std::cmp::min(4, dim_b as u64),
469        std::cmp::min(4, dim_a as u64),
470    );
471
472    encode_with_args(
473        encoder,
474        pipeline,
475        &[
476            (0, KernelArg::Buffer(input)),
477            (1, KernelArg::Buffer(output)),
478            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
479        ],
480        grid,
481        tg,
482    );
483
484    Ok(())
485}