use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub static COPY_SHADER_SOURCE: &str = include_str!("../shaders/copy.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("strided_copy_f32", COPY_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuStridedCopyParams {
rows: u32,
cols: u32,
stride_row: u32,
stride_col: u32,
}
pub struct StridedCopyParams {
pub rows: u32,
pub cols: u32,
pub stride_row: u32,
pub stride_col: u32,
}
pub fn dispatch_strided_copy_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
dst: &MlxBuffer,
params: &StridedCopyParams,
) -> Result<()> {
if params.rows == 0 || params.cols == 0 {
return Err(MlxError::InvalidArgument(
"strided_copy_f32: rows and cols must be > 0".into(),
));
}
let dst_bytes = params.rows as usize * params.cols as usize * 4;
if dst.byte_len() < dst_bytes {
return Err(MlxError::InvalidArgument(format!(
"strided_copy_f32: dst buffer too small: need {} bytes, have {}",
dst_bytes,
dst.byte_len()
)));
}
let max_src_idx = (params.rows as usize - 1) * params.stride_row as usize
+ (params.cols as usize - 1) * params.stride_col as usize;
let src_min_bytes = (max_src_idx + 1) * 4;
if src.byte_len() < src_min_bytes {
return Err(MlxError::InvalidArgument(format!(
"strided_copy_f32: src buffer too small: need at least {} bytes for stride access, have {}",
src_min_bytes,
src.byte_len()
)));
}
let pipeline = registry.get_pipeline("strided_copy_f32", device)?;
let gpu_params = GpuStridedCopyParams {
rows: params.rows,
cols: params.cols,
stride_row: params.stride_row,
stride_col: params.stride_col,
};
let grid = MTLSize::new(params.cols as u64, params.rows as u64, 1);
let tg = MTLSize::new(std::cmp::min(256, params.cols as u64), 1, 1);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(dst)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
grid,
tg,
);
Ok(())
}