Skip to main content

mlx_native/ops/
transpose.rs

1//! GPU-accelerated 2D matrix transpose.
2//!
3//! Transposes a 2D matrix `[rows, cols]` to `[cols, rows]`.
4//! Supports F32 and F16 dtypes.
5
6use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16/// MSL-compatible params struct for 2D transpose.
17///
18/// Must match `TransposeParams` in `elementwise.metal`.
19#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuTransposeParams {
22    rows: u32,
23    cols: u32,
24}
25
26/// Encode a 2D matrix transpose: `output[col, row] = input[row, col]`.
27///
28/// # Buffer expectations
29///
30/// * `input`  — `[rows, cols]` in the given dtype
31/// * `output` — `[cols, rows]` in the given dtype (must be pre-allocated)
32///
33/// # Errors
34///
35/// Returns `MlxError::InvalidArgument` if:
36/// * `rows` or `cols` is zero
37/// * `dtype` is not F32 or F16
38/// * Buffers are too small
39#[allow(clippy::too_many_arguments)]
40pub fn transpose_2d(
41    encoder: &mut CommandEncoder,
42    registry: &mut KernelRegistry,
43    device: &metal::DeviceRef,
44    input: &MlxBuffer,
45    output: &MlxBuffer,
46    rows: usize,
47    cols: usize,
48    dtype: DType,
49) -> Result<()> {
50    if rows == 0 {
51        return Err(MlxError::InvalidArgument(
52            "transpose_2d: rows must be > 0".into(),
53        ));
54    }
55    if cols == 0 {
56        return Err(MlxError::InvalidArgument(
57            "transpose_2d: cols must be > 0".into(),
58        ));
59    }
60
61    let kernel_name = match dtype {
62        DType::F32 => "transpose_2d_f32",
63        DType::F16 => "transpose_2d_f16",
64        _ => {
65            return Err(MlxError::InvalidArgument(format!(
66                "transpose_2d: unsupported dtype {dtype}"
67            )));
68        }
69    };
70
71    let elem_bytes = rows * cols * dtype.size_of();
72    if input.byte_len() < elem_bytes {
73        return Err(MlxError::InvalidArgument(format!(
74            "transpose_2d: input buffer too small: need {} bytes, have {}",
75            elem_bytes,
76            input.byte_len()
77        )));
78    }
79    if output.byte_len() < elem_bytes {
80        return Err(MlxError::InvalidArgument(format!(
81            "transpose_2d: output buffer too small: need {} bytes, have {}",
82            elem_bytes,
83            output.byte_len()
84        )));
85    }
86
87    let pipeline = registry.get_pipeline(kernel_name, device)?;
88
89    let gpu_params = GpuTransposeParams {
90        rows: rows as u32,
91        cols: cols as u32,
92    };
93
94    // 2D grid: (cols, rows)
95    let grid = MTLSize::new(cols as u64, rows as u64, 1);
96    let tg = MTLSize::new(
97        std::cmp::min(16, cols as u64),
98        std::cmp::min(16, rows as u64),
99        1,
100    );
101
102    encode_with_args(
103        encoder,
104        pipeline,
105        &[
106            (0, KernelArg::Buffer(input)),
107            (1, KernelArg::Buffer(output)),
108            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
109        ],
110        grid,
111        tg,
112    );
113
114    Ok(())
115}
116
117/// MSL-compatible params struct for 3D permute [A, B, C] -> [B, A, C].
118///
119/// Must match `Permute021Params` in `elementwise.metal`.
120#[repr(C)]
121#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
122struct GpuPermute021Params {
123    dim_a: u32,
124    dim_b: u32,
125    dim_c: u32,
126}
127
128/// Encode a 3D permutation: `[A, B, C] -> [B, A, C]` (bf16).
129///
130/// This is used to convert between `[seq_len, n_heads, head_dim]` and
131/// `[n_heads, seq_len, head_dim]` layouts.
132///
133/// # Buffer expectations
134///
135/// * `input`  — `[dim_a, dim_b, dim_c]` in bf16
136/// * `output` — `[dim_b, dim_a, dim_c]` in bf16 (must be pre-allocated)
137///
138/// # Errors
139///
140/// Returns `MlxError::InvalidArgument` if any dimension is zero or buffers
141/// are too small.
142pub fn permute_021_f32(
143    encoder: &mut CommandEncoder,
144    registry: &mut KernelRegistry,
145    device: &metal::DeviceRef,
146    input: &MlxBuffer,
147    output: &MlxBuffer,
148    dim_a: usize,
149    dim_b: usize,
150    dim_c: usize,
151) -> Result<()> {
152    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
153        return Err(MlxError::InvalidArgument(
154            "permute_021_f32: all dimensions must be > 0".into(),
155        ));
156    }
157
158    let total_elements = dim_a * dim_b * dim_c;
159    let elem_bytes = total_elements * 4; // f32 = 4 bytes
160    if input.byte_len() < elem_bytes {
161        return Err(MlxError::InvalidArgument(format!(
162            "permute_021_f32: input buffer too small: need {} bytes, have {}",
163            elem_bytes,
164            input.byte_len()
165        )));
166    }
167    if output.byte_len() < elem_bytes {
168        return Err(MlxError::InvalidArgument(format!(
169            "permute_021_f32: output buffer too small: need {} bytes, have {}",
170            elem_bytes,
171            output.byte_len()
172        )));
173    }
174
175    let pipeline = registry.get_pipeline("permute_021_f32", device)?;
176
177    let gpu_params = GpuPermute021Params {
178        dim_a: dim_a as u32,
179        dim_b: dim_b as u32,
180        dim_c: dim_c as u32,
181    };
182
183    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
184    let tg = MTLSize::new(
185        std::cmp::min(64, dim_c as u64),
186        std::cmp::min(4, dim_b as u64),
187        std::cmp::min(4, dim_a as u64),
188    );
189
190    encode_with_args(
191        encoder,
192        pipeline,
193        &[
194            (0, KernelArg::Buffer(input)),
195            (1, KernelArg::Buffer(output)),
196            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
197        ],
198        grid,
199        tg,
200    );
201
202    Ok(())
203}
204
205/// Swap the last two axes of a 3D bf16 tensor: [A, B, C] -> [A, C, B].
206///
207/// Used by hf2q's non-flash-attention prefill path to transpose V from
208/// its natural `[nkv, seq, hd]` post-RoPE layout to the `[nkv, hd, seq]`
209/// layout the `scores @ V` matmul consumes (the contract dim of our
210/// tensor-mm kernel is the inner-most axis of src0).
211///
212/// One dispatch covers all A batches; each thread copies a single bf16.
213pub fn transpose_last2_bf16(
214    encoder: &mut CommandEncoder,
215    registry: &mut KernelRegistry,
216    device: &metal::DeviceRef,
217    input: &MlxBuffer,
218    output: &MlxBuffer,
219    dim_a: usize,
220    dim_b: usize,
221    dim_c: usize,
222) -> Result<()> {
223    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
224        return Err(MlxError::InvalidArgument(
225            "transpose_last2_bf16: all dimensions must be > 0".into(),
226        ));
227    }
228
229    let total_elements = dim_a * dim_b * dim_c;
230    let elem_bytes = total_elements * 2;
231    if input.byte_len() < elem_bytes {
232        return Err(MlxError::InvalidArgument(format!(
233            "transpose_last2_bf16: input buffer too small: need {} bytes, have {}",
234            elem_bytes, input.byte_len()
235        )));
236    }
237    if output.byte_len() < elem_bytes {
238        return Err(MlxError::InvalidArgument(format!(
239            "transpose_last2_bf16: output buffer too small: need {} bytes, have {}",
240            elem_bytes, output.byte_len()
241        )));
242    }
243
244    let pipeline = registry.get_pipeline("transpose_last2_bf16", device)?;
245
246    let gpu_params = GpuPermute021Params {
247        dim_a: dim_a as u32,
248        dim_b: dim_b as u32,
249        dim_c: dim_c as u32,
250    };
251
252    // Grid: (dim_b, dim_c, dim_a).  Kernel maps (gid.x, gid.y, gid.z) →
253    // (b, c, a); see shaders/elementwise.metal::transpose_last2_bf16.
254    let grid = MTLSize::new(dim_b as u64, dim_c as u64, dim_a as u64);
255    let tg = MTLSize::new(
256        std::cmp::min(16, dim_b as u64),
257        std::cmp::min(16, dim_c as u64),
258        std::cmp::min(4, dim_a as u64),
259    );
260
261    encode_with_args(
262        encoder,
263        pipeline,
264        &[
265            (0, KernelArg::Buffer(input)),
266            (1, KernelArg::Buffer(output)),
267            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
268        ],
269        grid,
270        tg,
271    );
272
273    Ok(())
274}
275
276pub fn permute_021_bf16(
277    encoder: &mut CommandEncoder,
278    registry: &mut KernelRegistry,
279    device: &metal::DeviceRef,
280    input: &MlxBuffer,
281    output: &MlxBuffer,
282    dim_a: usize,
283    dim_b: usize,
284    dim_c: usize,
285) -> Result<()> {
286    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
287        return Err(MlxError::InvalidArgument(
288            "permute_021_bf16: all dimensions must be > 0".into(),
289        ));
290    }
291
292    let total_elements = dim_a * dim_b * dim_c;
293    let elem_bytes = total_elements * 2; // bf16 = 2 bytes
294    if input.byte_len() < elem_bytes {
295        return Err(MlxError::InvalidArgument(format!(
296            "permute_021_bf16: input buffer too small: need {} bytes, have {}",
297            elem_bytes,
298            input.byte_len()
299        )));
300    }
301    if output.byte_len() < elem_bytes {
302        return Err(MlxError::InvalidArgument(format!(
303            "permute_021_bf16: output buffer too small: need {} bytes, have {}",
304            elem_bytes,
305            output.byte_len()
306        )));
307    }
308
309    let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
310
311    let gpu_params = GpuPermute021Params {
312        dim_a: dim_a as u32,
313        dim_b: dim_b as u32,
314        dim_c: dim_c as u32,
315    };
316
317    // 3D grid: (dim_c, dim_b, dim_a), each thread copies one element
318    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
319    let tg = MTLSize::new(
320        std::cmp::min(64, dim_c as u64),
321        std::cmp::min(4, dim_b as u64),
322        std::cmp::min(4, dim_a as u64),
323    );
324
325    encode_with_args(
326        encoder,
327        pipeline,
328        &[
329            (0, KernelArg::Buffer(input)),
330            (1, KernelArg::Buffer(output)),
331            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
332        ],
333        grid,
334        tg,
335    );
336
337    Ok(())
338}
339
340/// Fused permute_021 + bf16→f32 cast.  Replaces the two-pass sequence
341/// `permute_021_bf16(bf16 → bf16) ; cast_bf16_to_f32(bf16 → f32)` with a
342/// single dispatch that reads bf16 in [A, B, C] order and writes f32 in
343/// [B, A, C] order, halving the global-memory traffic on the post-FA SDPA
344/// output buffer.  Wave P4.10.
345pub fn permute_021_bf16_to_f32(
346    encoder: &mut CommandEncoder,
347    registry: &mut KernelRegistry,
348    device: &metal::DeviceRef,
349    input: &MlxBuffer,
350    output: &MlxBuffer,
351    dim_a: usize,
352    dim_b: usize,
353    dim_c: usize,
354) -> Result<()> {
355    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
356        return Err(MlxError::InvalidArgument(
357            "permute_021_bf16_to_f32: all dimensions must be > 0".into(),
358        ));
359    }
360
361    let total_elements = dim_a * dim_b * dim_c;
362    let in_bytes = total_elements * 2; // bf16
363    let out_bytes = total_elements * 4; // f32
364    if input.byte_len() < in_bytes {
365        return Err(MlxError::InvalidArgument(format!(
366            "permute_021_bf16_to_f32: input buffer too small: need {} bytes, have {}",
367            in_bytes, input.byte_len()
368        )));
369    }
370    if output.byte_len() < out_bytes {
371        return Err(MlxError::InvalidArgument(format!(
372            "permute_021_bf16_to_f32: output buffer too small: need {} bytes, have {}",
373            out_bytes, output.byte_len()
374        )));
375    }
376
377    let pipeline = registry.get_pipeline("permute_021_bf16_to_f32", device)?;
378
379    let gpu_params = GpuPermute021Params {
380        dim_a: dim_a as u32,
381        dim_b: dim_b as u32,
382        dim_c: dim_c as u32,
383    };
384
385    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
386    let tg = MTLSize::new(
387        std::cmp::min(64, dim_c as u64),
388        std::cmp::min(4, dim_b as u64),
389        std::cmp::min(4, dim_a as u64),
390    );
391
392    encode_with_args(
393        encoder,
394        pipeline,
395        &[
396            (0, KernelArg::Buffer(input)),
397            (1, KernelArg::Buffer(output)),
398            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
399        ],
400        grid,
401        tg,
402    );
403
404    Ok(())
405}