1use metal::MTLSize;
16
17use crate::buffer::MlxBuffer;
18use crate::dtypes::DType;
19use crate::encoder::CommandEncoder;
20use crate::error::{MlxError, Result};
21use crate::kernel_registry::KernelRegistry;
22
23pub static SLICE_CONCAT_2D_SHADER_SOURCE: &str =
24 include_str!("../shaders/slice_concat_2d.metal");
25
26pub fn register(registry: &mut KernelRegistry) {
27 registry.register_source("slice_2d_cols_f32", SLICE_CONCAT_2D_SHADER_SOURCE);
28 registry.register_source("copy_2d_cols_into_f32", SLICE_CONCAT_2D_SHADER_SOURCE);
29}
30
31#[allow(clippy::too_many_arguments)]
35pub fn dispatch_slice_2d_cols_f32(
36 encoder: &mut CommandEncoder,
37 registry: &mut KernelRegistry,
38 device: &metal::DeviceRef,
39 input: &MlxBuffer,
40 output: &MlxBuffer,
41 params_buf: &MlxBuffer,
42 rows: u32,
43 in_cols: u32,
44 out_cols: u32,
45 start_col: u32,
46) -> Result<()> {
47 if rows == 0 || in_cols == 0 || out_cols == 0 {
48 return Err(MlxError::InvalidArgument(
49 "slice_2d_cols: rows/in_cols/out_cols must all be > 0".into(),
50 ));
51 }
52 if start_col + out_cols > in_cols {
53 return Err(MlxError::InvalidArgument(format!(
54 "slice_2d_cols: start_col({start_col}) + out_cols({out_cols}) > in_cols({in_cols})"
55 )));
56 }
57 if input.element_count() != (rows as usize) * (in_cols as usize) {
58 return Err(MlxError::InvalidArgument(format!(
59 "slice_2d_cols: input element count {} != rows({rows}) * in_cols({in_cols})",
60 input.element_count(),
61 )));
62 }
63 if output.element_count() != (rows as usize) * (out_cols as usize) {
64 return Err(MlxError::InvalidArgument(format!(
65 "slice_2d_cols: output element count {} != rows({rows}) * out_cols({out_cols})",
66 output.element_count(),
67 )));
68 }
69 for (label, buf) in [("input", input), ("output", output)] {
70 if buf.dtype() != DType::F32 {
71 return Err(MlxError::InvalidArgument(format!(
72 "slice_2d_cols: {label} dtype {} not f32",
73 buf.dtype()
74 )));
75 }
76 }
77 if params_buf.byte_len() < 12 {
78 return Err(MlxError::InvalidArgument(format!(
79 "slice_2d_cols: params_buf too small (need 12 bytes for 3×u32, got {})",
80 params_buf.byte_len()
81 )));
82 }
83
84 let pipeline = registry.get_pipeline("slice_2d_cols_f32", device)?;
85 encoder.encode(
86 pipeline,
87 &[(0, input), (1, output), (2, params_buf)],
88 MTLSize::new(out_cols as u64, rows as u64, 1),
89 MTLSize::new(
90 std::cmp::min(out_cols as u64, 32),
91 std::cmp::min(rows as u64, 8),
92 1,
93 ),
94 );
95 Ok(())
96}
97
98#[allow(clippy::too_many_arguments)]
104pub fn dispatch_copy_2d_cols_into_f32(
105 encoder: &mut CommandEncoder,
106 registry: &mut KernelRegistry,
107 device: &metal::DeviceRef,
108 src: &MlxBuffer,
109 dst: &MlxBuffer,
110 params_buf: &MlxBuffer,
111 rows: u32,
112 src_cols: u32,
113 dst_cols: u32,
114 start_col: u32,
115) -> Result<()> {
116 if rows == 0 || src_cols == 0 || dst_cols == 0 {
117 return Err(MlxError::InvalidArgument(
118 "copy_2d_cols_into: rows/src_cols/dst_cols must all be > 0".into(),
119 ));
120 }
121 if start_col + src_cols > dst_cols {
122 return Err(MlxError::InvalidArgument(format!(
123 "copy_2d_cols_into: start_col({start_col}) + src_cols({src_cols}) > dst_cols({dst_cols})"
124 )));
125 }
126 if src.element_count() != (rows as usize) * (src_cols as usize) {
127 return Err(MlxError::InvalidArgument(format!(
128 "copy_2d_cols_into: src element count {} != rows({rows}) * src_cols({src_cols})",
129 src.element_count(),
130 )));
131 }
132 if dst.element_count() != (rows as usize) * (dst_cols as usize) {
133 return Err(MlxError::InvalidArgument(format!(
134 "copy_2d_cols_into: dst element count {} != rows({rows}) * dst_cols({dst_cols})",
135 dst.element_count(),
136 )));
137 }
138 for (label, buf) in [("src", src), ("dst", dst)] {
139 if buf.dtype() != DType::F32 {
140 return Err(MlxError::InvalidArgument(format!(
141 "copy_2d_cols_into: {label} dtype {} not f32",
142 buf.dtype()
143 )));
144 }
145 }
146 if params_buf.byte_len() < 12 {
147 return Err(MlxError::InvalidArgument(format!(
148 "copy_2d_cols_into: params_buf too small (need 12 bytes for 3×u32, got {})",
149 params_buf.byte_len()
150 )));
151 }
152
153 let pipeline = registry.get_pipeline("copy_2d_cols_into_f32", device)?;
154 encoder.encode(
155 pipeline,
156 &[(0, src), (1, dst), (2, params_buf)],
157 MTLSize::new(src_cols as u64, rows as u64, 1),
158 MTLSize::new(
159 std::cmp::min(src_cols as u64, 32),
160 std::cmp::min(rows as u64, 8),
161 1,
162 ),
163 );
164 Ok(())
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::device::MlxDevice;
171
172 fn build_device_buf(device: &MlxDevice, data: &[f32], shape: Vec<usize>) -> MlxBuffer {
173 let n_bytes = data.len() * 4;
174 let mut buf = device
175 .alloc_buffer(n_bytes, DType::F32, shape)
176 .expect("alloc");
177 buf.as_mut_slice::<f32>().expect("as_mut").copy_from_slice(data);
178 buf
179 }
180
181 fn write_params_u32(buf: &mut MlxBuffer, vals: &[u32]) {
182 let slice: &mut [u32] = buf.as_mut_slice().expect("params as_mut");
183 slice[..vals.len()].copy_from_slice(vals);
184 }
185
186 #[test]
187 fn slice_2d_cols_byte_identical_to_cpu() {
188 let device = MlxDevice::new().expect("device");
189 let rows = 4u32;
190 let in_cols = 12u32;
191 let out_cols = 4u32;
192 let start_col = 5u32;
193 let input: Vec<f32> = (0..rows * in_cols).map(|i| (i as f32) * 0.5 - 1.0).collect();
194 let in_buf = build_device_buf(&device, &input, vec![rows as usize, in_cols as usize]);
195 let out_buf = build_device_buf(
196 &device,
197 &vec![0.0_f32; (rows * out_cols) as usize],
198 vec![rows as usize, out_cols as usize],
199 );
200 let mut params = device.alloc_buffer(12, DType::F32, vec![3]).expect("params");
201 write_params_u32(&mut params, &[in_cols, out_cols, start_col]);
202
203 let mut registry = KernelRegistry::new();
204 register(&mut registry);
205 let mut encoder = device.command_encoder().expect("encoder");
206 dispatch_slice_2d_cols_f32(
207 &mut encoder,
208 &mut registry,
209 device.metal_device(),
210 &in_buf,
211 &out_buf,
212 ¶ms,
213 rows,
214 in_cols,
215 out_cols,
216 start_col,
217 )
218 .expect("slice dispatch");
219 encoder.commit_and_wait().expect("commit");
220
221 let gpu = out_buf.as_slice::<f32>().unwrap();
222 for r in 0..rows as usize {
223 for c in 0..out_cols as usize {
224 let expected = input[r * in_cols as usize + start_col as usize + c];
225 assert_eq!(
226 gpu[r * out_cols as usize + c].to_bits(),
227 expected.to_bits(),
228 "mismatch at ({r},{c})"
229 );
230 }
231 }
232 }
233
234 #[test]
235 fn copy_2d_cols_into_byte_identical_to_cpu() {
236 let device = MlxDevice::new().expect("device");
239 let rows = 3u32;
240 let src_cols = 4u32;
241 let dst_cols = 12u32;
242 let start_col = 5u32;
243 let src: Vec<f32> = (0..rows * src_cols).map(|i| (i as f32) * 0.7 + 1.5).collect();
244 let dst_init: Vec<f32> = vec![999.0; (rows * dst_cols) as usize];
245 let src_buf = build_device_buf(&device, &src, vec![rows as usize, src_cols as usize]);
246 let dst_buf = build_device_buf(
247 &device,
248 &dst_init,
249 vec![rows as usize, dst_cols as usize],
250 );
251 let mut params = device.alloc_buffer(12, DType::F32, vec![3]).expect("params");
252 write_params_u32(&mut params, &[src_cols, dst_cols, start_col]);
253
254 let mut registry = KernelRegistry::new();
255 register(&mut registry);
256 let mut encoder = device.command_encoder().expect("encoder");
257 dispatch_copy_2d_cols_into_f32(
258 &mut encoder,
259 &mut registry,
260 device.metal_device(),
261 &src_buf,
262 &dst_buf,
263 ¶ms,
264 rows,
265 src_cols,
266 dst_cols,
267 start_col,
268 )
269 .expect("copy dispatch");
270 encoder.commit_and_wait().expect("commit");
271
272 let gpu = dst_buf.as_slice::<f32>().unwrap();
273 for r in 0..rows as usize {
274 for c in 0..dst_cols as usize {
275 let expected = if c >= start_col as usize
276 && c < (start_col + src_cols) as usize
277 {
278 src[r * src_cols as usize + (c - start_col as usize)]
279 } else {
280 999.0
281 };
282 assert_eq!(
283 gpu[r * dst_cols as usize + c].to_bits(),
284 expected.to_bits(),
285 "mismatch at ({r},{c})"
286 );
287 }
288 }
289 }
290
291 #[test]
292 fn slice_then_copy_back_round_trips() {
293 let device = MlxDevice::new().expect("device");
297 let rows = 5u32;
298 let cols = 16u32;
299 let chunk = 4u32;
300 let n_chunks = cols / chunk;
301 let input: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.13 - 2.5).collect();
302 let in_buf = build_device_buf(&device, &input, vec![rows as usize, cols as usize]);
303
304 let dst_buf = build_device_buf(
306 &device,
307 &vec![0.0_f32; (rows * cols) as usize],
308 vec![rows as usize, cols as usize],
309 );
310
311 let mut registry = KernelRegistry::new();
312 register(&mut registry);
313 let mut encoder = device.command_encoder().expect("encoder");
314 for h in 0..n_chunks {
315 let start = h * chunk;
316 let temp_buf = device
318 .alloc_buffer(
319 (rows * chunk * 4) as usize,
320 DType::F32,
321 vec![rows as usize, chunk as usize],
322 )
323 .expect("temp");
324 let mut p_slice = device.alloc_buffer(12, DType::F32, vec![3]).expect("p_slice");
325 write_params_u32(&mut p_slice, &[cols, chunk, start]);
326 dispatch_slice_2d_cols_f32(
327 &mut encoder,
328 &mut registry,
329 device.metal_device(),
330 &in_buf,
331 &temp_buf,
332 &p_slice,
333 rows,
334 cols,
335 chunk,
336 start,
337 )
338 .unwrap();
339 encoder.memory_barrier();
340 let mut p_copy = device.alloc_buffer(12, DType::F32, vec![3]).expect("p_copy");
342 write_params_u32(&mut p_copy, &[chunk, cols, start]);
343 dispatch_copy_2d_cols_into_f32(
344 &mut encoder,
345 &mut registry,
346 device.metal_device(),
347 &temp_buf,
348 &dst_buf,
349 &p_copy,
350 rows,
351 chunk,
352 cols,
353 start,
354 )
355 .unwrap();
356 encoder.memory_barrier();
357 }
358 encoder.commit_and_wait().expect("commit");
359
360 let gpu = dst_buf.as_slice::<f32>().unwrap();
361 for (i, (g, c)) in gpu.iter().zip(input.iter()).enumerate() {
362 assert_eq!(g.to_bits(), c.to_bits(), "round-trip mismatch at {i}");
363 }
364 }
365
366 #[test]
367 fn slice_rejects_out_of_range() {
368 let device = MlxDevice::new().expect("device");
369 let in_buf = device
370 .alloc_buffer(4 * 12 * 4, DType::F32, vec![4, 12])
371 .expect("in");
372 let out_buf = device
373 .alloc_buffer(4 * 4 * 4, DType::F32, vec![4, 4])
374 .expect("out");
375 let params = device.alloc_buffer(12, DType::F32, vec![3]).expect("params");
376 let mut registry = KernelRegistry::new();
377 register(&mut registry);
378 let mut encoder = device.command_encoder().expect("encoder");
379 let err = dispatch_slice_2d_cols_f32(
380 &mut encoder,
381 &mut registry,
382 device.metal_device(),
383 &in_buf,
384 &out_buf,
385 ¶ms,
386 4,
387 12,
388 4,
389 10, )
391 .expect_err("must reject");
392 assert!(format!("{err}").contains("> in_cols"));
393 }
394}