mlx_native/ops/
argsort.rs1use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::encoder::CommandEncoder;
10use crate::error::{MlxError, Result};
11use crate::kernel_registry::KernelRegistry;
12
13use super::encode_helpers::{as_bytes, encode_threadgroups_with_args, KernelArg};
14
15pub static ARGSORT_SHADER_SOURCE: &str = include_str!("../shaders/argsort.metal");
17
18pub fn register(registry: &mut KernelRegistry) {
20 registry.register_source("argsort_desc_f32", ARGSORT_SHADER_SOURCE);
21}
22
23#[repr(C)]
27#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
28struct GpuArgsortParams {
29 row_len: u32,
30 batch_size: u32,
31}
32
33#[allow(clippy::too_many_arguments)]
55pub fn dispatch_argsort_desc_f32(
56 encoder: &mut CommandEncoder,
57 registry: &mut KernelRegistry,
58 device: &metal::DeviceRef,
59 input: &MlxBuffer,
60 output: &MlxBuffer,
61 batch_size: u32,
62 row_len: u32,
63) -> Result<()> {
64 if row_len == 0 {
65 return Err(MlxError::InvalidArgument(
66 "argsort_desc_f32: row_len must be > 0".into(),
67 ));
68 }
69 if row_len > 256 {
70 return Err(MlxError::InvalidArgument(format!(
71 "argsort_desc_f32: row_len {} exceeds max 256 (shared memory limit)",
72 row_len
73 )));
74 }
75 if batch_size == 0 {
76 return Err(MlxError::InvalidArgument(
77 "argsort_desc_f32: batch_size must be > 0".into(),
78 ));
79 }
80
81 let total = batch_size as usize * row_len as usize;
82 let input_bytes = total * 4; if input.byte_len() < input_bytes {
84 return Err(MlxError::InvalidArgument(format!(
85 "argsort_desc_f32: input buffer too small: need {} bytes, have {}",
86 input_bytes,
87 input.byte_len()
88 )));
89 }
90 let output_bytes = total * 4; if output.byte_len() < output_bytes {
92 return Err(MlxError::InvalidArgument(format!(
93 "argsort_desc_f32: output buffer too small: need {} bytes, have {}",
94 output_bytes,
95 output.byte_len()
96 )));
97 }
98
99 let pipeline = registry.get_pipeline("argsort_desc_f32", device)?;
100
101 let gpu_params = GpuArgsortParams {
102 row_len,
103 batch_size,
104 };
105
106 let tg_size = std::cmp::min(256, row_len.next_power_of_two()) as u64;
108
109 encode_threadgroups_with_args(
110 encoder,
111 pipeline,
112 &[
113 (0, KernelArg::Buffer(input)),
114 (1, KernelArg::Buffer(output)),
115 (2, KernelArg::Bytes(as_bytes(&gpu_params))),
116 ],
117 MTLSize::new(batch_size as u64, 1, 1),
118 MTLSize::new(tg_size, 1, 1),
119 );
120
121 Ok(())
122}