Skip to main content

mlx_native/ops/
slice_concat_2d.rs

1//! 2-D row-major slice + concat-by-column primitives.
2//!
3//! Used by hf2q's ADR-020 Track 1 multi-head SDPA on GpuTape:
4//! Q/K/V tensors are sliced into per-head views, each head runs the
5//! single-head SDPA chain, and per-head context outputs are
6//! concatenated back into the full attention output.
7//!
8//! Two kernels:
9//! - `slice_2d_cols_f32(input[rows, in_cols], output[rows, out_cols], (in_cols, out_cols, start_col))`
10//!   produces `output[r, c] = input[r, start_col + c]`.
11//! - `copy_2d_cols_into_f32(src[rows, src_cols], dst[rows, dst_cols], (src_cols, dst_cols, start))`
12//!   writes `dst[r, start + c] = src[r, c]` for `c < src_cols`.  Caller
13//!   pre-zeros (or pre-populates) `dst`; this kernel writes the slab only.
14
15use metal::MTLSize;
16
17use crate::buffer::MlxBuffer;
18use crate::dtypes::DType;
19use crate::encoder::CommandEncoder;
20use crate::error::{MlxError, Result};
21use crate::kernel_registry::KernelRegistry;
22
23pub static SLICE_CONCAT_2D_SHADER_SOURCE: &str =
24    include_str!("../shaders/slice_concat_2d.metal");
25
26pub fn register(registry: &mut KernelRegistry) {
27    registry.register_source("slice_2d_cols_f32", SLICE_CONCAT_2D_SHADER_SOURCE);
28    registry.register_source("copy_2d_cols_into_f32", SLICE_CONCAT_2D_SHADER_SOURCE);
29}
30
31/// Slice `output[r, c] = input[r, start_col + c]` for `c < out_cols`.
32///
33/// `params_buf` must be at least 12 bytes (3 × u32: in_cols, out_cols, start_col).
34#[allow(clippy::too_many_arguments)]
35pub fn dispatch_slice_2d_cols_f32(
36    encoder: &mut CommandEncoder,
37    registry: &mut KernelRegistry,
38    device: &metal::DeviceRef,
39    input: &MlxBuffer,
40    output: &MlxBuffer,
41    params_buf: &MlxBuffer,
42    rows: u32,
43    in_cols: u32,
44    out_cols: u32,
45    start_col: u32,
46) -> Result<()> {
47    if rows == 0 || in_cols == 0 || out_cols == 0 {
48        return Err(MlxError::InvalidArgument(
49            "slice_2d_cols: rows/in_cols/out_cols must all be > 0".into(),
50        ));
51    }
52    if start_col + out_cols > in_cols {
53        return Err(MlxError::InvalidArgument(format!(
54            "slice_2d_cols: start_col({start_col}) + out_cols({out_cols}) > in_cols({in_cols})"
55        )));
56    }
57    if input.element_count() != (rows as usize) * (in_cols as usize) {
58        return Err(MlxError::InvalidArgument(format!(
59            "slice_2d_cols: input element count {} != rows({rows}) * in_cols({in_cols})",
60            input.element_count(),
61        )));
62    }
63    if output.element_count() != (rows as usize) * (out_cols as usize) {
64        return Err(MlxError::InvalidArgument(format!(
65            "slice_2d_cols: output element count {} != rows({rows}) * out_cols({out_cols})",
66            output.element_count(),
67        )));
68    }
69    for (label, buf) in [("input", input), ("output", output)] {
70        if buf.dtype() != DType::F32 {
71            return Err(MlxError::InvalidArgument(format!(
72                "slice_2d_cols: {label} dtype {} not f32",
73                buf.dtype()
74            )));
75        }
76    }
77    if params_buf.byte_len() < 12 {
78        return Err(MlxError::InvalidArgument(format!(
79            "slice_2d_cols: params_buf too small (need 12 bytes for 3×u32, got {})",
80            params_buf.byte_len()
81        )));
82    }
83
84    let pipeline = registry.get_pipeline("slice_2d_cols_f32", device)?;
85    encoder.encode(
86        pipeline,
87        &[(0, input), (1, output), (2, params_buf)],
88        MTLSize::new(out_cols as u64, rows as u64, 1),
89        MTLSize::new(
90            std::cmp::min(out_cols as u64, 32),
91            std::cmp::min(rows as u64, 8),
92            1,
93        ),
94    );
95    Ok(())
96}
97
98/// Write `src[rows, src_cols]` into `dst[rows, dst_cols]` at column
99/// offset `start_col`.  Does NOT touch dst columns outside the slab —
100/// caller pre-zeros (or pre-populates) `dst`.
101///
102/// `params_buf` must be at least 12 bytes (3 × u32: src_cols, dst_cols, start_col).
103#[allow(clippy::too_many_arguments)]
104pub fn dispatch_copy_2d_cols_into_f32(
105    encoder: &mut CommandEncoder,
106    registry: &mut KernelRegistry,
107    device: &metal::DeviceRef,
108    src: &MlxBuffer,
109    dst: &MlxBuffer,
110    params_buf: &MlxBuffer,
111    rows: u32,
112    src_cols: u32,
113    dst_cols: u32,
114    start_col: u32,
115) -> Result<()> {
116    if rows == 0 || src_cols == 0 || dst_cols == 0 {
117        return Err(MlxError::InvalidArgument(
118            "copy_2d_cols_into: rows/src_cols/dst_cols must all be > 0".into(),
119        ));
120    }
121    if start_col + src_cols > dst_cols {
122        return Err(MlxError::InvalidArgument(format!(
123            "copy_2d_cols_into: start_col({start_col}) + src_cols({src_cols}) > dst_cols({dst_cols})"
124        )));
125    }
126    if src.element_count() != (rows as usize) * (src_cols as usize) {
127        return Err(MlxError::InvalidArgument(format!(
128            "copy_2d_cols_into: src element count {} != rows({rows}) * src_cols({src_cols})",
129            src.element_count(),
130        )));
131    }
132    if dst.element_count() != (rows as usize) * (dst_cols as usize) {
133        return Err(MlxError::InvalidArgument(format!(
134            "copy_2d_cols_into: dst element count {} != rows({rows}) * dst_cols({dst_cols})",
135            dst.element_count(),
136        )));
137    }
138    for (label, buf) in [("src", src), ("dst", dst)] {
139        if buf.dtype() != DType::F32 {
140            return Err(MlxError::InvalidArgument(format!(
141                "copy_2d_cols_into: {label} dtype {} not f32",
142                buf.dtype()
143            )));
144        }
145    }
146    if params_buf.byte_len() < 12 {
147        return Err(MlxError::InvalidArgument(format!(
148            "copy_2d_cols_into: params_buf too small (need 12 bytes for 3×u32, got {})",
149            params_buf.byte_len()
150        )));
151    }
152
153    let pipeline = registry.get_pipeline("copy_2d_cols_into_f32", device)?;
154    encoder.encode(
155        pipeline,
156        &[(0, src), (1, dst), (2, params_buf)],
157        MTLSize::new(src_cols as u64, rows as u64, 1),
158        MTLSize::new(
159            std::cmp::min(src_cols as u64, 32),
160            std::cmp::min(rows as u64, 8),
161            1,
162        ),
163    );
164    Ok(())
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::device::MlxDevice;
171
172    fn build_device_buf(device: &MlxDevice, data: &[f32], shape: Vec<usize>) -> MlxBuffer {
173        let n_bytes = data.len() * 4;
174        let mut buf = device
175            .alloc_buffer(n_bytes, DType::F32, shape)
176            .expect("alloc");
177        buf.as_mut_slice::<f32>().expect("as_mut").copy_from_slice(data);
178        buf
179    }
180
181    fn write_params_u32(buf: &mut MlxBuffer, vals: &[u32]) {
182        let slice: &mut [u32] = buf.as_mut_slice().expect("params as_mut");
183        slice[..vals.len()].copy_from_slice(vals);
184    }
185
186    #[test]
187    fn slice_2d_cols_byte_identical_to_cpu() {
188        let device = MlxDevice::new().expect("device");
189        let rows = 4u32;
190        let in_cols = 12u32;
191        let out_cols = 4u32;
192        let start_col = 5u32;
193        let input: Vec<f32> = (0..rows * in_cols).map(|i| (i as f32) * 0.5 - 1.0).collect();
194        let in_buf = build_device_buf(&device, &input, vec![rows as usize, in_cols as usize]);
195        let out_buf = build_device_buf(
196            &device,
197            &vec![0.0_f32; (rows * out_cols) as usize],
198            vec![rows as usize, out_cols as usize],
199        );
200        let mut params = device.alloc_buffer(12, DType::F32, vec![3]).expect("params");
201        write_params_u32(&mut params, &[in_cols, out_cols, start_col]);
202
203        let mut registry = KernelRegistry::new();
204        register(&mut registry);
205        let mut encoder = device.command_encoder().expect("encoder");
206        dispatch_slice_2d_cols_f32(
207            &mut encoder,
208            &mut registry,
209            device.metal_device(),
210            &in_buf,
211            &out_buf,
212            &params,
213            rows,
214            in_cols,
215            out_cols,
216            start_col,
217        )
218        .expect("slice dispatch");
219        encoder.commit_and_wait().expect("commit");
220
221        let gpu = out_buf.as_slice::<f32>().unwrap();
222        for r in 0..rows as usize {
223            for c in 0..out_cols as usize {
224                let expected = input[r * in_cols as usize + start_col as usize + c];
225                assert_eq!(
226                    gpu[r * out_cols as usize + c].to_bits(),
227                    expected.to_bits(),
228                    "mismatch at ({r},{c})"
229                );
230            }
231        }
232    }
233
234    #[test]
235    fn copy_2d_cols_into_byte_identical_to_cpu() {
236        // Pre-fill dst with sentinel 999.0; copy src into a slab;
237        // verify slab matches src and surrounding cells are untouched.
238        let device = MlxDevice::new().expect("device");
239        let rows = 3u32;
240        let src_cols = 4u32;
241        let dst_cols = 12u32;
242        let start_col = 5u32;
243        let src: Vec<f32> = (0..rows * src_cols).map(|i| (i as f32) * 0.7 + 1.5).collect();
244        let dst_init: Vec<f32> = vec![999.0; (rows * dst_cols) as usize];
245        let src_buf = build_device_buf(&device, &src, vec![rows as usize, src_cols as usize]);
246        let dst_buf = build_device_buf(
247            &device,
248            &dst_init,
249            vec![rows as usize, dst_cols as usize],
250        );
251        let mut params = device.alloc_buffer(12, DType::F32, vec![3]).expect("params");
252        write_params_u32(&mut params, &[src_cols, dst_cols, start_col]);
253
254        let mut registry = KernelRegistry::new();
255        register(&mut registry);
256        let mut encoder = device.command_encoder().expect("encoder");
257        dispatch_copy_2d_cols_into_f32(
258            &mut encoder,
259            &mut registry,
260            device.metal_device(),
261            &src_buf,
262            &dst_buf,
263            &params,
264            rows,
265            src_cols,
266            dst_cols,
267            start_col,
268        )
269        .expect("copy dispatch");
270        encoder.commit_and_wait().expect("commit");
271
272        let gpu = dst_buf.as_slice::<f32>().unwrap();
273        for r in 0..rows as usize {
274            for c in 0..dst_cols as usize {
275                let expected = if c >= start_col as usize
276                    && c < (start_col + src_cols) as usize
277                {
278                    src[r * src_cols as usize + (c - start_col as usize)]
279                } else {
280                    999.0
281                };
282                assert_eq!(
283                    gpu[r * dst_cols as usize + c].to_bits(),
284                    expected.to_bits(),
285                    "mismatch at ({r},{c})"
286                );
287            }
288        }
289    }
290
291    #[test]
292    fn slice_then_copy_back_round_trips() {
293        // Sanity: slice every column from a tensor, copy each slice back
294        // into a fresh dst at the same column offset; result must equal
295        // the original tensor.
296        let device = MlxDevice::new().expect("device");
297        let rows = 5u32;
298        let cols = 16u32;
299        let chunk = 4u32;
300        let n_chunks = cols / chunk;
301        let input: Vec<f32> = (0..rows * cols).map(|i| (i as f32) * 0.13 - 2.5).collect();
302        let in_buf = build_device_buf(&device, &input, vec![rows as usize, cols as usize]);
303
304        // Build accumulator dst initialized to 0.
305        let dst_buf = build_device_buf(
306            &device,
307            &vec![0.0_f32; (rows * cols) as usize],
308            vec![rows as usize, cols as usize],
309        );
310
311        let mut registry = KernelRegistry::new();
312        register(&mut registry);
313        let mut encoder = device.command_encoder().expect("encoder");
314        for h in 0..n_chunks {
315            let start = h * chunk;
316            // slice → temp
317            let temp_buf = device
318                .alloc_buffer(
319                    (rows * chunk * 4) as usize,
320                    DType::F32,
321                    vec![rows as usize, chunk as usize],
322                )
323                .expect("temp");
324            let mut p_slice = device.alloc_buffer(12, DType::F32, vec![3]).expect("p_slice");
325            write_params_u32(&mut p_slice, &[cols, chunk, start]);
326            dispatch_slice_2d_cols_f32(
327                &mut encoder,
328                &mut registry,
329                device.metal_device(),
330                &in_buf,
331                &temp_buf,
332                &p_slice,
333                rows,
334                cols,
335                chunk,
336                start,
337            )
338            .unwrap();
339            encoder.memory_barrier();
340            // copy temp → dst at start
341            let mut p_copy = device.alloc_buffer(12, DType::F32, vec![3]).expect("p_copy");
342            write_params_u32(&mut p_copy, &[chunk, cols, start]);
343            dispatch_copy_2d_cols_into_f32(
344                &mut encoder,
345                &mut registry,
346                device.metal_device(),
347                &temp_buf,
348                &dst_buf,
349                &p_copy,
350                rows,
351                chunk,
352                cols,
353                start,
354            )
355            .unwrap();
356            encoder.memory_barrier();
357        }
358        encoder.commit_and_wait().expect("commit");
359
360        let gpu = dst_buf.as_slice::<f32>().unwrap();
361        for (i, (g, c)) in gpu.iter().zip(input.iter()).enumerate() {
362            assert_eq!(g.to_bits(), c.to_bits(), "round-trip mismatch at {i}");
363        }
364    }
365
366    #[test]
367    fn slice_rejects_out_of_range() {
368        let device = MlxDevice::new().expect("device");
369        let in_buf = device
370            .alloc_buffer(4 * 12 * 4, DType::F32, vec![4, 12])
371            .expect("in");
372        let out_buf = device
373            .alloc_buffer(4 * 4 * 4, DType::F32, vec![4, 4])
374            .expect("out");
375        let params = device.alloc_buffer(12, DType::F32, vec![3]).expect("params");
376        let mut registry = KernelRegistry::new();
377        register(&mut registry);
378        let mut encoder = device.command_encoder().expect("encoder");
379        let err = dispatch_slice_2d_cols_f32(
380            &mut encoder,
381            &mut registry,
382            device.metal_device(),
383            &in_buf,
384            &out_buf,
385            &params,
386            4,
387            12,
388            4,
389            10, // 10 + 4 = 14 > 12
390        )
391        .expect_err("must reject");
392        assert!(format!("{err}").contains("> in_cols"));
393    }
394}