Skip to main content

mlx_native/ops/
gather.rs

1//! GPU-accelerated gather / index_select along dim=0.
2//!
3//! Gathers rows from a 2D source tensor using an index array:
4//! `output[i, :] = src[indices[i], :]`
5//!
6//! Used for MoE scale factor gathering by expert index.
7
8use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
16
17/// MSL source for the gather kernel (embedded at compile time).
18pub static GATHER_SHADER_SOURCE: &str = include_str!("../shaders/gather.metal");
19
20/// Register gather shader source with the given kernel registry.
21pub fn register(registry: &mut KernelRegistry) {
22    registry.register_source("gather_f32", GATHER_SHADER_SOURCE);
23}
24
25/// MSL-compatible params struct for gather.
26///
27/// Must match `GatherParams` in `gather.metal`.
28#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct GpuGatherParams {
31    row_width: u32,
32    n_indices: u32,
33    src_rows: u32,
34}
35
36/// Dispatch a gather / index_select operation on the GPU.
37///
38/// Gathers rows from `src` using `indices`:
39///   `output[i, j] = src[indices[i], j]`
40///
41/// # Arguments
42///
43/// * `encoder`   - Command encoder to record the dispatch into.
44/// * `registry`  - Kernel registry (must have `gather_f32` registered).
45/// * `device`    - Metal device for pipeline compilation.
46/// * `src`       - Source buffer of shape `[src_rows, row_width]` (f32).
47/// * `indices`   - Index buffer of shape `[n_indices]` (u32).
48/// * `output`    - Output buffer of shape `[n_indices, row_width]` (f32).
49/// * `src_rows`  - Number of rows in the source tensor.
50/// * `row_width` - Number of columns (elements per row).
51/// * `n_indices` - Number of indices to gather.
52///
53/// # Errors
54///
55/// Returns `MlxError::InvalidArgument` if any dimension is 0 or buffers are
56/// too small.
57#[allow(clippy::too_many_arguments)]
58pub fn dispatch_gather_f32(
59    encoder: &mut CommandEncoder,
60    registry: &mut KernelRegistry,
61    device: &metal::DeviceRef,
62    src: &MlxBuffer,
63    indices: &MlxBuffer,
64    output: &MlxBuffer,
65    src_rows: u32,
66    row_width: u32,
67    n_indices: u32,
68) -> Result<()> {
69    if src_rows == 0 || row_width == 0 || n_indices == 0 {
70        return Err(MlxError::InvalidArgument(
71            "gather_f32: all dimensions must be > 0".into(),
72        ));
73    }
74
75    let src_bytes = src_rows as usize * row_width as usize * 4;
76    if src.byte_len() < src_bytes {
77        return Err(MlxError::InvalidArgument(format!(
78            "gather_f32: src buffer too small: need {} bytes, have {}",
79            src_bytes,
80            src.byte_len()
81        )));
82    }
83    let idx_bytes = n_indices as usize * 4;
84    if indices.byte_len() < idx_bytes {
85        return Err(MlxError::InvalidArgument(format!(
86            "gather_f32: indices buffer too small: need {} bytes, have {}",
87            idx_bytes,
88            indices.byte_len()
89        )));
90    }
91    let out_bytes = n_indices as usize * row_width as usize * 4;
92    if output.byte_len() < out_bytes {
93        return Err(MlxError::InvalidArgument(format!(
94            "gather_f32: output buffer too small: need {} bytes, have {}",
95            out_bytes,
96            output.byte_len()
97        )));
98    }
99
100    let pipeline = registry.get_pipeline("gather_f32", device)?;
101
102    let gpu_params = GpuGatherParams {
103        row_width,
104        n_indices,
105        src_rows,
106    };
107
108    let grid = MTLSize::new(row_width as u64, n_indices as u64, 1);
109    let tg = MTLSize::new(std::cmp::min(256, row_width as u64), 1, 1);
110
111    encode_with_args(
112        encoder,
113        pipeline,
114        &[
115            (0, KernelArg::Buffer(src)),
116            (1, KernelArg::Buffer(indices)),
117            (2, KernelArg::Buffer(output)),
118            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
119        ],
120        grid,
121        tg,
122    );
123
124    Ok(())
125}