1use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
16
17pub static COPY_SHADER_SOURCE: &str = include_str!("../shaders/copy.metal");
19
20pub fn register(registry: &mut KernelRegistry) {
22 registry.register_source("strided_copy_f32", COPY_SHADER_SOURCE);
23}
24
25#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct GpuStridedCopyParams {
31 rows: u32,
32 cols: u32,
33 stride_row: u32,
34 stride_col: u32,
35}
36
37pub struct StridedCopyParams {
39 pub rows: u32,
41 pub cols: u32,
43 pub stride_row: u32,
45 pub stride_col: u32,
47}
48
49pub fn dispatch_strided_copy_f32(
68 encoder: &mut CommandEncoder,
69 registry: &mut KernelRegistry,
70 device: &metal::DeviceRef,
71 src: &MlxBuffer,
72 dst: &MlxBuffer,
73 params: &StridedCopyParams,
74) -> Result<()> {
75 if params.rows == 0 || params.cols == 0 {
76 return Err(MlxError::InvalidArgument(
77 "strided_copy_f32: rows and cols must be > 0".into(),
78 ));
79 }
80
81 let dst_bytes = params.rows as usize * params.cols as usize * 4;
83 if dst.byte_len() < dst_bytes {
84 return Err(MlxError::InvalidArgument(format!(
85 "strided_copy_f32: dst buffer too small: need {} bytes, have {}",
86 dst_bytes,
87 dst.byte_len()
88 )));
89 }
90
91 let max_src_idx = (params.rows as usize - 1) * params.stride_row as usize
94 + (params.cols as usize - 1) * params.stride_col as usize;
95 let src_min_bytes = (max_src_idx + 1) * 4;
96 if src.byte_len() < src_min_bytes {
97 return Err(MlxError::InvalidArgument(format!(
98 "strided_copy_f32: src buffer too small: need at least {} bytes for stride access, have {}",
99 src_min_bytes,
100 src.byte_len()
101 )));
102 }
103
104 let pipeline = registry.get_pipeline("strided_copy_f32", device)?;
105
106 let gpu_params = GpuStridedCopyParams {
107 rows: params.rows,
108 cols: params.cols,
109 stride_row: params.stride_row,
110 stride_col: params.stride_col,
111 };
112
113 let grid = MTLSize::new(params.cols as u64, params.rows as u64, 1);
114 let tg = MTLSize::new(std::cmp::min(256, params.cols as u64), 1, 1);
115
116 encode_with_args(
117 encoder,
118 pipeline,
119 &[
120 (0, KernelArg::Buffer(src)),
121 (1, KernelArg::Buffer(dst)),
122 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
123 ],
124 grid,
125 tg,
126 );
127
128 Ok(())
129}