use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuTransposeParams {
rows: u32,
cols: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn transpose_2d(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
rows: usize,
cols: usize,
dtype: DType,
) -> Result<()> {
if rows == 0 {
return Err(MlxError::InvalidArgument(
"transpose_2d: rows must be > 0".into(),
));
}
if cols == 0 {
return Err(MlxError::InvalidArgument(
"transpose_2d: cols must be > 0".into(),
));
}
let kernel_name = match dtype {
DType::F32 => "transpose_2d_f32",
DType::F16 => "transpose_2d_f16",
_ => {
return Err(MlxError::InvalidArgument(format!(
"transpose_2d: unsupported dtype {dtype}"
)));
}
};
let elem_bytes = rows * cols * dtype.size_of();
if input.byte_len() < elem_bytes {
return Err(MlxError::InvalidArgument(format!(
"transpose_2d: input buffer too small: need {} bytes, have {}",
elem_bytes,
input.byte_len()
)));
}
if output.byte_len() < elem_bytes {
return Err(MlxError::InvalidArgument(format!(
"transpose_2d: output buffer too small: need {} bytes, have {}",
elem_bytes,
output.byte_len()
)));
}
let pipeline = registry.get_pipeline(kernel_name, device)?;
let gpu_params = GpuTransposeParams {
rows: rows as u32,
cols: cols as u32,
};
let grid = MTLSize::new(cols as u64, rows as u64, 1);
let tg = MTLSize::new(
std::cmp::min(16, cols as u64),
std::cmp::min(16, rows as u64),
1,
);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
grid,
tg,
);
Ok(())
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuPermute021Params {
dim_a: u32,
dim_b: u32,
dim_c: u32,
}
pub fn permute_021_bf16(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
dim_a: usize,
dim_b: usize,
dim_c: usize,
) -> Result<()> {
if dim_a == 0 || dim_b == 0 || dim_c == 0 {
return Err(MlxError::InvalidArgument(
"permute_021_bf16: all dimensions must be > 0".into(),
));
}
let total_elements = dim_a * dim_b * dim_c;
let elem_bytes = total_elements * 2; if input.byte_len() < elem_bytes {
return Err(MlxError::InvalidArgument(format!(
"permute_021_bf16: input buffer too small: need {} bytes, have {}",
elem_bytes,
input.byte_len()
)));
}
if output.byte_len() < elem_bytes {
return Err(MlxError::InvalidArgument(format!(
"permute_021_bf16: output buffer too small: need {} bytes, have {}",
elem_bytes,
output.byte_len()
)));
}
let pipeline = registry.get_pipeline("permute_021_bf16", device)?;
let gpu_params = GpuPermute021Params {
dim_a: dim_a as u32,
dim_b: dim_b as u32,
dim_c: dim_c as u32,
};
let grid = MTLSize::new(dim_c as u64, dim_b as u64, dim_a as u64);
let tg = MTLSize::new(
std::cmp::min(64, dim_c as u64),
std::cmp::min(4, dim_b as u64),
std::cmp::min(4, dim_a as u64),
);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
grid,
tg,
);
Ok(())
}