Skip to main content

mlx_native/ops/
copy.rs

1//! GPU-accelerated strided copy for making tensors contiguous.
2//!
3//! Copies a 2D strided tensor to a contiguous layout:
4//!   `dst[row * cols + col] = src[row * stride_row + col * stride_col]`
5//!
6//! Used after transpose/permute operations to produce contiguous memory.
7
8use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
16
17/// MSL source for the strided copy kernel (embedded at compile time).
18pub static COPY_SHADER_SOURCE: &str = include_str!("../shaders/copy.metal");
19
20/// Register strided copy shader source with the given kernel registry.
21pub fn register(registry: &mut KernelRegistry) {
22    registry.register_source("strided_copy_f32", COPY_SHADER_SOURCE);
23    registry.register_source("offset_copy_f32", COPY_SHADER_SOURCE);
24}
25
26/// MSL-compatible params struct for strided copy.
27///
28/// Must match `StridedCopyParams` in `copy.metal`.
29#[repr(C)]
30#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
31struct GpuStridedCopyParams {
32    rows: u32,
33    cols: u32,
34    stride_row: u32,
35    stride_col: u32,
36}
37
38/// Parameters for a strided copy operation.
39pub struct StridedCopyParams {
40    /// Number of rows in the output.
41    pub rows: u32,
42    /// Number of columns in the output.
43    pub cols: u32,
44    /// Stride (in elements) between rows in the source.
45    pub stride_row: u32,
46    /// Stride (in elements) between columns in the source.
47    pub stride_col: u32,
48}
49
50/// Dispatch a strided copy operation on the GPU.
51///
52/// Copies a 2D strided tensor to contiguous layout:
53///   `dst[row * cols + col] = src[row * stride_row + col * stride_col]`
54///
55/// # Arguments
56///
57/// * `encoder`  - Command encoder to record the dispatch into.
58/// * `registry` - Kernel registry (must have `strided_copy_f32` registered).
59/// * `device`   - Metal device for pipeline compilation.
60/// * `src`      - Source buffer (f32, strided layout).
61/// * `dst`      - Destination buffer (f32, contiguous output).
62/// * `params`   - Copy parameters (rows, cols, strides).
63///
64/// # Errors
65///
66/// Returns `MlxError::InvalidArgument` if dimensions are 0 or buffers are
67/// too small.
68pub fn dispatch_strided_copy_f32(
69    encoder: &mut CommandEncoder,
70    registry: &mut KernelRegistry,
71    device: &metal::DeviceRef,
72    src: &MlxBuffer,
73    dst: &MlxBuffer,
74    params: &StridedCopyParams,
75) -> Result<()> {
76    if params.rows == 0 || params.cols == 0 {
77        return Err(MlxError::InvalidArgument(
78            "strided_copy_f32: rows and cols must be > 0".into(),
79        ));
80    }
81
82    // Check destination buffer size (contiguous output).
83    let dst_bytes = params.rows as usize * params.cols as usize * 4;
84    if dst.byte_len() < dst_bytes {
85        return Err(MlxError::InvalidArgument(format!(
86            "strided_copy_f32: dst buffer too small: need {} bytes, have {}",
87            dst_bytes,
88            dst.byte_len()
89        )));
90    }
91
92    // Source buffer must be large enough for the maximum strided access.
93    // Max index = (rows-1)*stride_row + (cols-1)*stride_col
94    let max_src_idx = (params.rows as usize - 1) * params.stride_row as usize
95        + (params.cols as usize - 1) * params.stride_col as usize;
96    let src_min_bytes = (max_src_idx + 1) * 4;
97    if src.byte_len() < src_min_bytes {
98        return Err(MlxError::InvalidArgument(format!(
99            "strided_copy_f32: src buffer too small: need at least {} bytes for stride access, have {}",
100            src_min_bytes,
101            src.byte_len()
102        )));
103    }
104
105    let pipeline = registry.get_pipeline("strided_copy_f32", device)?;
106
107    let gpu_params = GpuStridedCopyParams {
108        rows: params.rows,
109        cols: params.cols,
110        stride_row: params.stride_row,
111        stride_col: params.stride_col,
112    };
113
114    let grid = MTLSize::new(params.cols as u64, params.rows as u64, 1);
115    let tg = MTLSize::new(std::cmp::min(256, params.cols as u64), 1, 1);
116
117    encode_with_args(
118        encoder,
119        pipeline,
120        &[
121            (0, KernelArg::Buffer(src)),
122            (1, KernelArg::Buffer(dst)),
123            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
124        ],
125        grid,
126        tg,
127    );
128
129    Ok(())
130}
131
132/// GPU-side params for offset copy.
133#[repr(C)]
134#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
135struct GpuOffsetCopyParams {
136    src_offset: u32,
137    dst_offset: u32,
138    count: u32,
139}
140
141/// Copy `count` f32 elements from `src[src_offset..]` to `dst[dst_offset..]`.
142///
143/// Used during prefill to scatter/gather rows between large prefill buffers
144/// and single-token activation buffers.
145pub fn dispatch_copy_f32(
146    encoder: &mut CommandEncoder,
147    registry: &mut KernelRegistry,
148    device: &metal::DeviceRef,
149    src: &MlxBuffer,
150    dst: &MlxBuffer,
151    src_offset: usize,
152    dst_offset: usize,
153    count: usize,
154) -> Result<()> {
155    if count == 0 {
156        return Ok(()); // no-op
157    }
158    let src_end_bytes = (src_offset + count) * 4;
159    let dst_end_bytes = (dst_offset + count) * 4;
160    if src.byte_len() < src_end_bytes {
161        return Err(MlxError::InvalidArgument(format!(
162            "offset_copy_f32: src too small: need {} bytes (offset {} + count {}), have {}",
163            src_end_bytes, src_offset, count, src.byte_len()
164        )));
165    }
166    if dst.byte_len() < dst_end_bytes {
167        return Err(MlxError::InvalidArgument(format!(
168            "offset_copy_f32: dst too small: need {} bytes (offset {} + count {}), have {}",
169            dst_end_bytes, dst_offset, count, dst.byte_len()
170        )));
171    }
172
173    let pipeline = registry.get_pipeline("offset_copy_f32", device)?;
174
175    let gpu_params = GpuOffsetCopyParams {
176        src_offset: src_offset as u32,
177        dst_offset: dst_offset as u32,
178        count: count as u32,
179    };
180
181    let grid = MTLSize::new(count as u64, 1, 1);
182    let tg = MTLSize::new(std::cmp::min(256, count as u64), 1, 1);
183
184    encode_with_args(
185        encoder,
186        pipeline,
187        &[
188            (0, KernelArg::Buffer(src)),
189            (1, KernelArg::Buffer(dst)),
190            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
191        ],
192        grid,
193        tg,
194    );
195
196    Ok(())
197}