Skip to main content

mlx_native/ops/
copy.rs

1//! GPU-accelerated strided copy for making tensors contiguous.
2//!
3//! Copies a 2D strided tensor to a contiguous layout:
4//!   `dst[row * cols + col] = src[row * stride_row + col * stride_col]`
5//!
6//! Used after transpose/permute operations to produce contiguous memory.
7
8use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
16
17/// MSL source for the strided copy kernel (embedded at compile time).
18pub static COPY_SHADER_SOURCE: &str = include_str!("../shaders/copy.metal");
19
20/// Register strided copy shader source with the given kernel registry.
21pub fn register(registry: &mut KernelRegistry) {
22    registry.register_source("strided_copy_f32", COPY_SHADER_SOURCE);
23}
24
25/// MSL-compatible params struct for strided copy.
26///
27/// Must match `StridedCopyParams` in `copy.metal`.
28#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct GpuStridedCopyParams {
31    rows: u32,
32    cols: u32,
33    stride_row: u32,
34    stride_col: u32,
35}
36
37/// Parameters for a strided copy operation.
38pub struct StridedCopyParams {
39    /// Number of rows in the output.
40    pub rows: u32,
41    /// Number of columns in the output.
42    pub cols: u32,
43    /// Stride (in elements) between rows in the source.
44    pub stride_row: u32,
45    /// Stride (in elements) between columns in the source.
46    pub stride_col: u32,
47}
48
49/// Dispatch a strided copy operation on the GPU.
50///
51/// Copies a 2D strided tensor to contiguous layout:
52///   `dst[row * cols + col] = src[row * stride_row + col * stride_col]`
53///
54/// # Arguments
55///
56/// * `encoder`  - Command encoder to record the dispatch into.
57/// * `registry` - Kernel registry (must have `strided_copy_f32` registered).
58/// * `device`   - Metal device for pipeline compilation.
59/// * `src`      - Source buffer (f32, strided layout).
60/// * `dst`      - Destination buffer (f32, contiguous output).
61/// * `params`   - Copy parameters (rows, cols, strides).
62///
63/// # Errors
64///
65/// Returns `MlxError::InvalidArgument` if dimensions are 0 or buffers are
66/// too small.
67pub fn dispatch_strided_copy_f32(
68    encoder: &mut CommandEncoder,
69    registry: &mut KernelRegistry,
70    device: &metal::DeviceRef,
71    src: &MlxBuffer,
72    dst: &MlxBuffer,
73    params: &StridedCopyParams,
74) -> Result<()> {
75    if params.rows == 0 || params.cols == 0 {
76        return Err(MlxError::InvalidArgument(
77            "strided_copy_f32: rows and cols must be > 0".into(),
78        ));
79    }
80
81    // Check destination buffer size (contiguous output).
82    let dst_bytes = params.rows as usize * params.cols as usize * 4;
83    if dst.byte_len() < dst_bytes {
84        return Err(MlxError::InvalidArgument(format!(
85            "strided_copy_f32: dst buffer too small: need {} bytes, have {}",
86            dst_bytes,
87            dst.byte_len()
88        )));
89    }
90
91    // Source buffer must be large enough for the maximum strided access.
92    // Max index = (rows-1)*stride_row + (cols-1)*stride_col
93    let max_src_idx = (params.rows as usize - 1) * params.stride_row as usize
94        + (params.cols as usize - 1) * params.stride_col as usize;
95    let src_min_bytes = (max_src_idx + 1) * 4;
96    if src.byte_len() < src_min_bytes {
97        return Err(MlxError::InvalidArgument(format!(
98            "strided_copy_f32: src buffer too small: need at least {} bytes for stride access, have {}",
99            src_min_bytes,
100            src.byte_len()
101        )));
102    }
103
104    let pipeline = registry.get_pipeline("strided_copy_f32", device)?;
105
106    let gpu_params = GpuStridedCopyParams {
107        rows: params.rows,
108        cols: params.cols,
109        stride_row: params.stride_row,
110        stride_col: params.stride_col,
111    };
112
113    let grid = MTLSize::new(params.cols as u64, params.rows as u64, 1);
114    let tg = MTLSize::new(std::cmp::min(256, params.cols as u64), 1, 1);
115
116    encode_with_args(
117        encoder,
118        pipeline,
119        &[
120            (0, KernelArg::Buffer(src)),
121            (1, KernelArg::Buffer(dst)),
122            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
123        ],
124        grid,
125        tg,
126    );
127
128    Ok(())
129}