1use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::CommandEncoder;
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuTransposeParams {
22 rows: u32,
23 cols: u32,
24}
25
26#[allow(clippy::too_many_arguments)]
40pub fn transpose_2d(
41 encoder: &mut CommandEncoder,
42 registry: &mut KernelRegistry,
43 device: &metal::DeviceRef,
44 input: &MlxBuffer,
45 output: &MlxBuffer,
46 rows: usize,
47 cols: usize,
48 dtype: DType,
49) -> Result<()> {
50 if rows == 0 {
51 return Err(MlxError::InvalidArgument(
52 "transpose_2d: rows must be > 0".into(),
53 ));
54 }
55 if cols == 0 {
56 return Err(MlxError::InvalidArgument(
57 "transpose_2d: cols must be > 0".into(),
58 ));
59 }
60
61 let kernel_name = match dtype {
62 DType::F32 => "transpose_2d_f32",
63 DType::F16 => "transpose_2d_f16",
64 _ => {
65 return Err(MlxError::InvalidArgument(format!(
66 "transpose_2d: unsupported dtype {dtype}"
67 )));
68 }
69 };
70
71 let elem_bytes = rows * cols * dtype.size_of();
72 if input.byte_len() < elem_bytes {
73 return Err(MlxError::InvalidArgument(format!(
74 "transpose_2d: input buffer too small: need {} bytes, have {}",
75 elem_bytes,
76 input.byte_len()
77 )));
78 }
79 if output.byte_len() < elem_bytes {
80 return Err(MlxError::InvalidArgument(format!(
81 "transpose_2d: output buffer too small: need {} bytes, have {}",
82 elem_bytes,
83 output.byte_len()
84 )));
85 }
86
87 let pipeline = registry.get_pipeline(kernel_name, device)?;
88
89 let gpu_params = GpuTransposeParams {
90 rows: rows as u32,
91 cols: cols as u32,
92 };
93
94 let grid = MTLSize::new(cols as u64, rows as u64, 1);
96 let tg = MTLSize::new(
97 std::cmp::min(16, cols as u64),
98 std::cmp::min(16, rows as u64),
99 1,
100 );
101
102 encode_with_args(
103 encoder,
104 pipeline,
105 &[
106 (0, KernelArg::Buffer(input)),
107 (1, KernelArg::Buffer(output)),
108 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
109 ],
110 grid,
111 tg,
112 );
113
114 Ok(())
115}
116
117#[repr(C)]
121#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
122struct GpuPermute021Params {
123 dim_a: u32,
124 dim_b: u32,
125 dim_c: u32,
126}
127
128pub fn permute_021_f32(
143 encoder: &mut CommandEncoder,
144 registry: &mut KernelRegistry,
145 device: &metal::DeviceRef,
146 input: &MlxBuffer,
147 output: &MlxBuffer,
148 dim_a: usize,
149 dim_b: usize,
150 dim_c: usize,
151) -> Result<()> {
152 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
153 return Err(MlxError::InvalidArgument(
154 "permute_021_f32: all dimensions must be > 0".into(),
155 ));
156 }
157
158 let total_elements = dim_a * dim_b * dim_c;
159 let elem_bytes = total_elements * 4; if input.byte_len() < elem_bytes {
161 return Err(MlxError::InvalidArgument(format!(
162 "permute_021_f32: input buffer too small: need {} bytes, have {}",
163 elem_bytes,
164 input.byte_len()
165 )));
166 }
167 if output.byte_len() < elem_bytes {
168 return Err(MlxError::InvalidArgument(format!(
169 "permute_021_f32: output buffer too small: need {} bytes, have {}",
170 elem_bytes,
171 output.byte_len()
172 )));
173 }
174
175 let pipeline = registry.get_pipeline("permute_021_f32", device)?;
176
177 let gpu_params = GpuPermute021Params {
178 dim_a: dim_a as u32,
179 dim_b: dim_b as u32,
180 dim_c: dim_c as u32,
181 };
182
183 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
184 let tg = MTLSize::new(
185 std::cmp::min(64, dim_c as u64),
186 std::cmp::min(4, dim_b as u64),
187 std::cmp::min(4, dim_a as u64),
188 );
189
190 encode_with_args(
191 encoder,
192 pipeline,
193 &[
194 (0, KernelArg::Buffer(input)),
195 (1, KernelArg::Buffer(output)),
196 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
197 ],
198 grid,
199 tg,
200 );
201
202 Ok(())
203}
204
205pub fn permute_021_bf16(
206 encoder: &mut CommandEncoder,
207 registry: &mut KernelRegistry,
208 device: &metal::DeviceRef,
209 input: &MlxBuffer,
210 output: &MlxBuffer,
211 dim_a: usize,
212 dim_b: usize,
213 dim_c: usize,
214) -> Result<()> {
215 if dim_a == 0 || dim_b == 0 || dim_c == 0 {
216 return Err(MlxError::InvalidArgument(
217 "permute_021_bf16: all dimensions must be > 0".into(),
218 ));
219 }
220
221 let total_elements = dim_a * dim_b * dim_c;
222 let elem_bytes = total_elements * 2; if input.byte_len() < elem_bytes {
224 return Err(MlxError::InvalidArgument(format!(
225 "permute_021_bf16: input buffer too small: need {} bytes, have {}",
226 elem_bytes,
227 input.byte_len()
228 )));
229 }
230 if output.byte_len() < elem_bytes {
231 return Err(MlxError::InvalidArgument(format!(
232 "permute_021_bf16: output buffer too small: need {} bytes, have {}",
233 elem_bytes,
234 output.byte_len()
235 )));
236 }
237
238 let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
239
240 let gpu_params = GpuPermute021Params {
241 dim_a: dim_a as u32,
242 dim_b: dim_b as u32,
243 dim_c: dim_c as u32,
244 };
245
246 let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
248 let tg = MTLSize::new(
249 std::cmp::min(64, dim_c as u64),
250 std::cmp::min(4, dim_b as u64),
251 std::cmp::min(4, dim_a as u64),
252 );
253
254 encode_with_args(
255 encoder,
256 pipeline,
257 &[
258 (0, KernelArg::Buffer(input)),
259 (1, KernelArg::Buffer(output)),
260 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
261 ],
262 grid,
263 tg,
264 );
265
266 Ok(())
267}