Skip to main content

mlx_native/ops/
transpose.rs

1//! GPU-accelerated 2D matrix transpose.
2//!
3//! Transposes a 2D matrix `[rows, cols]` to `[cols, rows]`.
4//! Supports F32 and F16 dtypes.
5
6use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16/// MSL-compatible params struct for 2D transpose.
17///
18/// Must match `TransposeParams` in `elementwise.metal`.
19#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuTransposeParams {
22    rows: u32,
23    cols: u32,
24}
25
26/// Encode a 2D matrix transpose: `output[col, row] = input[row, col]`.
27///
28/// # Buffer expectations
29///
30/// * `input`  — `[rows, cols]` in the given dtype
31/// * `output` — `[cols, rows]` in the given dtype (must be pre-allocated)
32///
33/// # Errors
34///
35/// Returns `MlxError::InvalidArgument` if:
36/// * `rows` or `cols` is zero
37/// * `dtype` is not F32 or F16
38/// * Buffers are too small
39#[allow(clippy::too_many_arguments)]
40pub fn transpose_2d(
41    encoder: &mut CommandEncoder,
42    registry: &mut KernelRegistry,
43    device: &metal::DeviceRef,
44    input: &MlxBuffer,
45    output: &MlxBuffer,
46    rows: usize,
47    cols: usize,
48    dtype: DType,
49) -> Result<()> {
50    if rows == 0 {
51        return Err(MlxError::InvalidArgument(
52            "transpose_2d: rows must be > 0".into(),
53        ));
54    }
55    if cols == 0 {
56        return Err(MlxError::InvalidArgument(
57            "transpose_2d: cols must be > 0".into(),
58        ));
59    }
60
61    let kernel_name = match dtype {
62        DType::F32 => "transpose_2d_f32",
63        DType::F16 => "transpose_2d_f16",
64        _ => {
65            return Err(MlxError::InvalidArgument(format!(
66                "transpose_2d: unsupported dtype {dtype}"
67            )));
68        }
69    };
70
71    let elem_bytes = rows * cols * dtype.size_of();
72    if input.byte_len() < elem_bytes {
73        return Err(MlxError::InvalidArgument(format!(
74            "transpose_2d: input buffer too small: need {} bytes, have {}",
75            elem_bytes,
76            input.byte_len()
77        )));
78    }
79    if output.byte_len() < elem_bytes {
80        return Err(MlxError::InvalidArgument(format!(
81            "transpose_2d: output buffer too small: need {} bytes, have {}",
82            elem_bytes,
83            output.byte_len()
84        )));
85    }
86
87    let pipeline = registry.get_pipeline(kernel_name, device)?;
88
89    let gpu_params = GpuTransposeParams {
90        rows: rows as u32,
91        cols: cols as u32,
92    };
93
94    // 2D grid: (cols, rows)
95    let grid = MTLSize::new(cols as u64, rows as u64, 1);
96    let tg = MTLSize::new(
97        std::cmp::min(16, cols as u64),
98        std::cmp::min(16, rows as u64),
99        1,
100    );
101
102    encode_with_args(
103        encoder,
104        pipeline,
105        &[
106            (0, KernelArg::Buffer(input)),
107            (1, KernelArg::Buffer(output)),
108            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
109        ],
110        grid,
111        tg,
112    );
113
114    Ok(())
115}
116
117/// MSL-compatible params struct for 3D permute [A, B, C] -> [B, A, C].
118///
119/// Must match `Permute021Params` in `elementwise.metal`.
120#[repr(C)]
121#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
122struct GpuPermute021Params {
123    dim_a: u32,
124    dim_b: u32,
125    dim_c: u32,
126}
127
128/// Encode a 3D permutation: `[A, B, C] -> [B, A, C]` (bf16).
129///
130/// This is used to convert between `[seq_len, n_heads, head_dim]` and
131/// `[n_heads, seq_len, head_dim]` layouts.
132///
133/// # Buffer expectations
134///
135/// * `input`  — `[dim_a, dim_b, dim_c]` in bf16
136/// * `output` — `[dim_b, dim_a, dim_c]` in bf16 (must be pre-allocated)
137///
138/// # Errors
139///
140/// Returns `MlxError::InvalidArgument` if any dimension is zero or buffers
141/// are too small.
142pub fn permute_021_f32(
143    encoder: &mut CommandEncoder,
144    registry: &mut KernelRegistry,
145    device: &metal::DeviceRef,
146    input: &MlxBuffer,
147    output: &MlxBuffer,
148    dim_a: usize,
149    dim_b: usize,
150    dim_c: usize,
151) -> Result<()> {
152    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
153        return Err(MlxError::InvalidArgument(
154            "permute_021_f32: all dimensions must be > 0".into(),
155        ));
156    }
157
158    let total_elements = dim_a * dim_b * dim_c;
159    let elem_bytes = total_elements * 4; // f32 = 4 bytes
160    if input.byte_len() < elem_bytes {
161        return Err(MlxError::InvalidArgument(format!(
162            "permute_021_f32: input buffer too small: need {} bytes, have {}",
163            elem_bytes,
164            input.byte_len()
165        )));
166    }
167    if output.byte_len() < elem_bytes {
168        return Err(MlxError::InvalidArgument(format!(
169            "permute_021_f32: output buffer too small: need {} bytes, have {}",
170            elem_bytes,
171            output.byte_len()
172        )));
173    }
174
175    let pipeline = registry.get_pipeline("permute_021_f32", device)?;
176
177    let gpu_params = GpuPermute021Params {
178        dim_a: dim_a as u32,
179        dim_b: dim_b as u32,
180        dim_c: dim_c as u32,
181    };
182
183    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
184    let tg = MTLSize::new(
185        std::cmp::min(64, dim_c as u64),
186        std::cmp::min(4, dim_b as u64),
187        std::cmp::min(4, dim_a as u64),
188    );
189
190    encode_with_args(
191        encoder,
192        pipeline,
193        &[
194            (0, KernelArg::Buffer(input)),
195            (1, KernelArg::Buffer(output)),
196            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
197        ],
198        grid,
199        tg,
200    );
201
202    Ok(())
203}
204
205pub fn permute_021_bf16(
206    encoder: &mut CommandEncoder,
207    registry: &mut KernelRegistry,
208    device: &metal::DeviceRef,
209    input: &MlxBuffer,
210    output: &MlxBuffer,
211    dim_a: usize,
212    dim_b: usize,
213    dim_c: usize,
214) -> Result<()> {
215    if dim_a == 0 || dim_b == 0 || dim_c == 0 {
216        return Err(MlxError::InvalidArgument(
217            "permute_021_bf16: all dimensions must be > 0".into(),
218        ));
219    }
220
221    let total_elements = dim_a * dim_b * dim_c;
222    let elem_bytes = total_elements * 2; // bf16 = 2 bytes
223    if input.byte_len() < elem_bytes {
224        return Err(MlxError::InvalidArgument(format!(
225            "permute_021_bf16: input buffer too small: need {} bytes, have {}",
226            elem_bytes,
227            input.byte_len()
228        )));
229    }
230    if output.byte_len() < elem_bytes {
231        return Err(MlxError::InvalidArgument(format!(
232            "permute_021_bf16: output buffer too small: need {} bytes, have {}",
233            elem_bytes,
234            output.byte_len()
235        )));
236    }
237
238    let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
239
240    let gpu_params = GpuPermute021Params {
241        dim_a: dim_a as u32,
242        dim_b: dim_b as u32,
243        dim_c: dim_c as u32,
244    };
245
246    // 3D grid: (dim_c, dim_b, dim_a), each thread copies one element
247    let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
248    let tg = MTLSize::new(
249        std::cmp::min(64, dim_c as u64),
250        std::cmp::min(4, dim_b as u64),
251        std::cmp::min(4, dim_a as u64),
252    );
253
254    encode_with_args(
255        encoder,
256        pipeline,
257        &[
258            (0, KernelArg::Buffer(input)),
259            (1, KernelArg::Buffer(output)),
260            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
261        ],
262        grid,
263        tg,
264    );
265
266    Ok(())
267}