use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static ARGMAX_SHADER_SOURCE: &str = include_str!("../shaders/argmax.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("argmax_f32", ARGMAX_SHADER_SOURCE);
}
pub fn dispatch_argmax_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
out_index: &MlxBuffer,
out_value: &MlxBuffer,
params_buf: &MlxBuffer,
n_elements: u32,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"argmax_f32: n_elements must be > 0".into(),
));
}
if input.element_count() != n_elements as usize {
return Err(MlxError::InvalidArgument(format!(
"argmax_f32: input element count {} != n_elements {}",
input.element_count(),
n_elements
)));
}
if out_index.element_count() < 1 {
return Err(MlxError::InvalidArgument(
"argmax_f32: out_index must have at least 1 element".into(),
));
}
if out_value.element_count() < 1 {
return Err(MlxError::InvalidArgument(
"argmax_f32: out_value must have at least 1 element".into(),
));
}
let pipeline = registry.get_pipeline("argmax_f32", device)?;
let tg_size = std::cmp::min(1024, n_elements.next_power_of_two()) as u64;
let float_shared = tg_size * 4; let uint_shared = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, out_index),
(2, out_value),
(3, params_buf),
],
&[(0, float_shared), (1, uint_shared)],
MTLSize::new(1, 1, 1), MTLSize::new(tg_size, 1, 1),
);
Ok(())
}