1use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
16
17pub static GATHER_SHADER_SOURCE: &str = include_str!("../shaders/gather.metal");
19
20pub fn register(registry: &mut KernelRegistry) {
22 registry.register_source("gather_f32", GATHER_SHADER_SOURCE);
23}
24
25#[repr(C)]
29#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
30struct GpuGatherParams {
31 row_width: u32,
32 n_indices: u32,
33 src_rows: u32,
34}
35
36#[allow(clippy::too_many_arguments)]
58pub fn dispatch_gather_f32(
59 encoder: &mut CommandEncoder,
60 registry: &mut KernelRegistry,
61 device: &metal::DeviceRef,
62 src: &MlxBuffer,
63 indices: &MlxBuffer,
64 output: &MlxBuffer,
65 src_rows: u32,
66 row_width: u32,
67 n_indices: u32,
68) -> Result<()> {
69 if src_rows == 0 || row_width == 0 || n_indices == 0 {
70 return Err(MlxError::InvalidArgument(
71 "gather_f32: all dimensions must be > 0".into(),
72 ));
73 }
74
75 let src_bytes = src_rows as usize * row_width as usize * 4;
76 if src.byte_len() < src_bytes {
77 return Err(MlxError::InvalidArgument(format!(
78 "gather_f32: src buffer too small: need {} bytes, have {}",
79 src_bytes,
80 src.byte_len()
81 )));
82 }
83 let idx_bytes = n_indices as usize * 4;
84 if indices.byte_len() < idx_bytes {
85 return Err(MlxError::InvalidArgument(format!(
86 "gather_f32: indices buffer too small: need {} bytes, have {}",
87 idx_bytes,
88 indices.byte_len()
89 )));
90 }
91 let out_bytes = n_indices as usize * row_width as usize * 4;
92 if output.byte_len() < out_bytes {
93 return Err(MlxError::InvalidArgument(format!(
94 "gather_f32: output buffer too small: need {} bytes, have {}",
95 out_bytes,
96 output.byte_len()
97 )));
98 }
99
100 let pipeline = registry.get_pipeline("gather_f32", device)?;
101
102 let gpu_params = GpuGatherParams {
103 row_width,
104 n_indices,
105 src_rows,
106 };
107
108 let grid = MTLSize::new(row_width as u64, n_indices as u64, 1);
109 let tg = MTLSize::new(std::cmp::min(256, row_width as u64), 1, 1);
110
111 encode_with_args(
112 encoder,
113 pipeline,
114 &[
115 (0, KernelArg::Buffer(src)),
116 (1, KernelArg::Buffer(indices)),
117 (2, KernelArg::Buffer(output)),
118 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
119 ],
120 grid,
121 tg,
122 );
123
124 Ok(())
125}