use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
use super::encode_helpers::{as_bytes, encode_threadgroups_with_args, KernelArg};
pub static ARGSORT_SHADER_SOURCE: &str = include_str!("../shaders/argsort.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("argsort_desc_f32", ARGSORT_SHADER_SOURCE);
}
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuArgsortParams {
row_len: u32,
batch_size: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_argsort_desc_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
batch_size: u32,
row_len: u32,
) -> Result<()> {
if row_len == 0 {
return Err(MlxError::InvalidArgument(
"argsort_desc_f32: row_len must be > 0".into(),
));
}
if row_len > 256 {
return Err(MlxError::InvalidArgument(format!(
"argsort_desc_f32: row_len {} exceeds max 256 (shared memory limit)",
row_len
)));
}
if batch_size == 0 {
return Err(MlxError::InvalidArgument(
"argsort_desc_f32: batch_size must be > 0".into(),
));
}
let total = batch_size as usize * row_len as usize;
let input_bytes = total * 4; if input.byte_len() < input_bytes {
return Err(MlxError::InvalidArgument(format!(
"argsort_desc_f32: input buffer too small: need {} bytes, have {}",
input_bytes,
input.byte_len()
)));
}
let output_bytes = total * 4; if output.byte_len() < output_bytes {
return Err(MlxError::InvalidArgument(format!(
"argsort_desc_f32: output buffer too small: need {} bytes, have {}",
output_bytes,
output.byte_len()
)));
}
let pipeline = registry.get_pipeline("argsort_desc_f32", device)?;
let gpu_params = GpuArgsortParams {
row_len,
batch_size,
};
let tg_size = std::cmp::min(256, row_len.next_power_of_two()) as u64;
encode_threadgroups_with_args(
encoder,
pipeline,
&[
(0, KernelArg::Buffer(input)),
(1, KernelArg::Buffer(output)),
(2, KernelArg::Bytes(as_bytes(&gpu_params))),
],
MTLSize::new(batch_size as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}