use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
pub static GATHER_SHADER_SOURCE: &str = include_str!("../shaders/gather.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("gather_f32", GATHER_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuGatherParams {
row_width: u32,
n_indices: u32,
src_rows: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gather_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
src: &MlxBuffer,
indices: &MlxBuffer,
output: &MlxBuffer,
src_rows: u32,
row_width: u32,
n_indices: u32,
) -> Result<()> {
if src_rows == 0 || row_width == 0 || n_indices == 0 {
return Err(MlxError::InvalidArgument(
"gather_f32: all dimensions must be > 0".into(),
));
}
let src_bytes = src_rows as usize * row_width as usize * 4;
if src.byte_len() < src_bytes {
return Err(MlxError::InvalidArgument(format!(
"gather_f32: src buffer too small: need {} bytes, have {}",
src_bytes,
src.byte_len()
)));
}
let idx_bytes = n_indices as usize * 4;
if indices.byte_len() < idx_bytes {
return Err(MlxError::InvalidArgument(format!(
"gather_f32: indices buffer too small: need {} bytes, have {}",
idx_bytes,
indices.byte_len()
)));
}
let out_bytes = n_indices as usize * row_width as usize * 4;
if output.byte_len() < out_bytes {
return Err(MlxError::InvalidArgument(format!(
"gather_f32: output buffer too small: need {} bytes, have {}",
out_bytes,
output.byte_len()
)));
}
let pipeline = registry.get_pipeline("gather_f32", device)?;
let gpu_params = GpuGatherParams {
row_width,
n_indices,
src_rows,
};
let grid = MTLSize::new(row_width as u64, n_indices as u64, 1);
let tg = MTLSize::new(std::cmp::min(256, row_width as u64), 1, 1);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(src)),
(1, KernelArg::Buffer(indices)),
(2, KernelArg::Buffer(output)),
(3, KernelArg::Bytes(as_bytes(&gpu_params))),
],
grid,
tg,
);
Ok(())
}