use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{encode_with_args, KernelArg};
pub static GATHER_BENCH_SHADER_SOURCE: &str = include_str!("../shaders/gather_bench.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("gather_bench_nibble", GATHER_BENCH_SHADER_SOURCE);
registry.register_source("gather_bench_f16_seq", GATHER_BENCH_SHADER_SOURCE);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_gather_nibble(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
packed: &MlxBuffer,
centroids: &MlxBuffer,
out: &MlxBuffer,
capacity: u32,
head_dim: u32,
) -> Result<()> {
if capacity == 0 || head_dim == 0 {
return Ok(());
}
if head_dim % 2 != 0 {
return Err(MlxError::InvalidArgument(format!(
"gather_bench_nibble: head_dim must be even, got {}",
head_dim
)));
}
let pipeline = registry.get_pipeline("gather_bench_nibble", device)?;
let capacity_bytes = capacity.to_ne_bytes();
let head_dim_bytes = head_dim.to_ne_bytes();
let tg_x = std::cmp::min(256, head_dim as u64);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(packed)),
(1, KernelArg::Buffer(centroids)),
(2, KernelArg::Bytes(&capacity_bytes)),
(3, KernelArg::Bytes(&head_dim_bytes)),
(4, KernelArg::Buffer(out)),
],
MTLSize::new(head_dim as u64, capacity as u64, 1),
MTLSize::new(tg_x, 1, 1),
);
Ok(())
}
pub fn dispatch_gather_f16_seq(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
cache: &MlxBuffer,
out: &MlxBuffer,
capacity: u32,
head_dim: u32,
) -> Result<()> {
if capacity == 0 || head_dim == 0 {
return Ok(());
}
let pipeline = registry.get_pipeline("gather_bench_f16_seq", device)?;
let capacity_bytes = capacity.to_ne_bytes();
let head_dim_bytes = head_dim.to_ne_bytes();
let tg_x = std::cmp::min(256, head_dim as u64);
encode_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(cache)),
(1, KernelArg::Bytes(&capacity_bytes)),
(2, KernelArg::Bytes(&head_dim_bytes)),
(3, KernelArg::Buffer(out)),
],
MTLSize::new(head_dim as u64, capacity as u64, 1),
MTLSize::new(tg_x, 1, 1),
);
Ok(())
}