use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static TOP_K_SHADER_SOURCE: &str = include_str!("../shaders/top_k.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("top_k_f32", TOP_K_SHADER_SOURCE);
}
pub fn dispatch_top_k_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
out_indices: &MlxBuffer,
out_values: &MlxBuffer,
params_buf: &MlxBuffer,
n_elements: u32,
k: u32,
) -> Result<()> {
if n_elements == 0 || k == 0 {
return Err(MlxError::InvalidArgument(
"top_k_f32: n_elements and k must be > 0".into(),
));
}
if k > 128 {
return Err(MlxError::InvalidArgument(format!(
"top_k_f32: k ({}) must be <= 128 (MAX_K in shader)",
k
)));
}
if input.element_count() < n_elements as usize {
return Err(MlxError::InvalidArgument(format!(
"top_k_f32: input element count {} < n_elements {}",
input.element_count(),
n_elements
)));
}
if out_indices.element_count() < k as usize {
return Err(MlxError::InvalidArgument(format!(
"top_k_f32: out_indices ({}) < k ({})",
out_indices.element_count(),
k
)));
}
if out_values.element_count() < k as usize {
return Err(MlxError::InvalidArgument(format!(
"top_k_f32: out_values ({}) < k ({})",
out_values.element_count(),
k
)));
}
let pipeline = registry.get_pipeline("top_k_f32", device)?;
let tg_size: u64 = match k {
1..=32 => 128,
33..=64 => 64,
_ => 32,
};
let float_shared = tg_size * (k as u64) * 4;
let uint_shared = tg_size * (k as u64) * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, input),
(1, out_indices),
(2, out_values),
(3, params_buf),
],
&[(0, float_shared), (1, uint_shared)],
MTLSize::new(1, 1, 1), MTLSize::new(tg_size, 1, 1),
);
Ok(())
}