1use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
16
17pub static COPY_SHADER_SOURCE: &str = include_str!("../shaders/copy.metal");
19
20pub fn register(registry: &mut KernelRegistry) {
22 registry.register_source("strided_copy_f32", COPY_SHADER_SOURCE);
23 registry.register_source("offset_copy_f32", COPY_SHADER_SOURCE);
24}
25
26#[repr(C)]
30#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
31struct GpuStridedCopyParams {
32 rows: u32,
33 cols: u32,
34 stride_row: u32,
35 stride_col: u32,
36}
37
38pub struct StridedCopyParams {
40 pub rows: u32,
42 pub cols: u32,
44 pub stride_row: u32,
46 pub stride_col: u32,
48}
49
50pub fn dispatch_strided_copy_f32(
69 encoder: &mut CommandEncoder,
70 registry: &mut KernelRegistry,
71 device: &metal::DeviceRef,
72 src: &MlxBuffer,
73 dst: &MlxBuffer,
74 params: &StridedCopyParams,
75) -> Result<()> {
76 if params.rows == 0 || params.cols == 0 {
77 return Err(MlxError::InvalidArgument(
78 "strided_copy_f32: rows and cols must be > 0".into(),
79 ));
80 }
81
82 let dst_bytes = params.rows as usize * params.cols as usize * 4;
84 if dst.byte_len() < dst_bytes {
85 return Err(MlxError::InvalidArgument(format!(
86 "strided_copy_f32: dst buffer too small: need {} bytes, have {}",
87 dst_bytes,
88 dst.byte_len()
89 )));
90 }
91
92 let max_src_idx = (params.rows as usize - 1) * params.stride_row as usize
95 + (params.cols as usize - 1) * params.stride_col as usize;
96 let src_min_bytes = (max_src_idx + 1) * 4;
97 if src.byte_len() < src_min_bytes {
98 return Err(MlxError::InvalidArgument(format!(
99 "strided_copy_f32: src buffer too small: need at least {} bytes for stride access, have {}",
100 src_min_bytes,
101 src.byte_len()
102 )));
103 }
104
105 let pipeline = registry.get_pipeline("strided_copy_f32", device)?;
106
107 let gpu_params = GpuStridedCopyParams {
108 rows: params.rows,
109 cols: params.cols,
110 stride_row: params.stride_row,
111 stride_col: params.stride_col,
112 };
113
114 let grid = MTLSize::new(params.cols as u64, params.rows as u64, 1);
115 let tg = MTLSize::new(std::cmp::min(256, params.cols as u64), 1, 1);
116
117 encode_with_args(
118 encoder,
119 pipeline,
120 &[
121 (0, KernelArg::Buffer(src)),
122 (1, KernelArg::Buffer(dst)),
123 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
124 ],
125 grid,
126 tg,
127 );
128
129 Ok(())
130}
131
132#[repr(C)]
134#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
135struct GpuOffsetCopyParams {
136 src_offset: u32,
137 dst_offset: u32,
138 count: u32,
139}
140
141pub fn dispatch_copy_f32(
146 encoder: &mut CommandEncoder,
147 registry: &mut KernelRegistry,
148 device: &metal::DeviceRef,
149 src: &MlxBuffer,
150 dst: &MlxBuffer,
151 src_offset: usize,
152 dst_offset: usize,
153 count: usize,
154) -> Result<()> {
155 if count == 0 {
156 return Ok(()); }
158 let src_end_bytes = (src_offset + count) * 4;
159 let dst_end_bytes = (dst_offset + count) * 4;
160 if src.byte_len() < src_end_bytes {
161 return Err(MlxError::InvalidArgument(format!(
162 "offset_copy_f32: src too small: need {} bytes (offset {} + count {}), have {}",
163 src_end_bytes, src_offset, count, src.byte_len()
164 )));
165 }
166 if dst.byte_len() < dst_end_bytes {
167 return Err(MlxError::InvalidArgument(format!(
168 "offset_copy_f32: dst too small: need {} bytes (offset {} + count {}), have {}",
169 dst_end_bytes, dst_offset, count, dst.byte_len()
170 )));
171 }
172
173 let pipeline = registry.get_pipeline("offset_copy_f32", device)?;
174
175 let gpu_params = GpuOffsetCopyParams {
176 src_offset: src_offset as u32,
177 dst_offset: dst_offset as u32,
178 count: count as u32,
179 };
180
181 let grid = MTLSize::new(count as u64, 1, 1);
182 let tg = MTLSize::new(std::cmp::min(256, count as u64), 1, 1);
183
184 encode_with_args(
185 encoder,
186 pipeline,
187 &[
188 (0, KernelArg::Buffer(src)),
189 (1, KernelArg::Buffer(dst)),
190 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
191 ],
192 grid,
193 tg,
194 );
195
196 Ok(())
197}