Skip to main content

mlx_native/ops/
argsort.rs

1//! GPU-accelerated argsort (descending) for MoE top-K routing.
2//!
3//! Sorts indices by value in descending order using a bitonic sort kernel.
4//! For MoE with N <= 128 experts per row, this fits in a single threadgroup.
5
6use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::encoder::CommandEncoder;
10use crate::error::{MlxError, Result};
11use crate::kernel_registry::KernelRegistry;
12
13use super::encode_helpers::{as_bytes, encode_threadgroups_with_args, KernelArg};
14
15/// MSL source for the argsort kernel (embedded at compile time).
16pub static ARGSORT_SHADER_SOURCE: &str = include_str!("../shaders/argsort.metal");
17
18/// Register argsort shader source with the given kernel registry.
19pub fn register(registry: &mut KernelRegistry) {
20    registry.register_source("argsort_desc_f32", ARGSORT_SHADER_SOURCE);
21}
22
23/// MSL-compatible params struct for argsort.
24///
25/// Must match `ArgsortParams` in `argsort.metal`.
26#[repr(C)]
27#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
28struct GpuArgsortParams {
29    row_len: u32,
30    batch_size: u32,
31}
32
33/// Dispatch an argsort (descending) operation on the GPU.
34///
35/// For each row of `input`, produces a permutation of indices `[0..row_len)`
36/// such that `input[row][output[row][0]] >= input[row][output[row][1]] >= ...`.
37///
38/// # Arguments
39///
40/// * `encoder`    - Command encoder to record the dispatch into.
41/// * `registry`   - Kernel registry (must have `argsort_desc_f32` registered).
42/// * `device`     - Metal device for pipeline compilation.
43/// * `input`      - Input buffer of shape `[batch_size, row_len]` (f32).
44/// * `output`     - Output buffer of shape `[batch_size, row_len]` (u32) — sorted indices.
45/// * `batch_size` - Number of rows.
46/// * `row_len`    - Number of elements per row (must be <= 256).
47///
48/// # Errors
49///
50/// Returns `MlxError::InvalidArgument` if:
51/// - `row_len` is 0 or > 256
52/// - `batch_size` is 0
53/// - Buffers are too small
54#[allow(clippy::too_many_arguments)]
55pub fn dispatch_argsort_desc_f32(
56    encoder: &mut CommandEncoder,
57    registry: &mut KernelRegistry,
58    device: &metal::DeviceRef,
59    input: &MlxBuffer,
60    output: &MlxBuffer,
61    batch_size: u32,
62    row_len: u32,
63) -> Result<()> {
64    if row_len == 0 {
65        return Err(MlxError::InvalidArgument(
66            "argsort_desc_f32: row_len must be > 0".into(),
67        ));
68    }
69    if row_len > 256 {
70        return Err(MlxError::InvalidArgument(format!(
71            "argsort_desc_f32: row_len {} exceeds max 256 (shared memory limit)",
72            row_len
73        )));
74    }
75    if batch_size == 0 {
76        return Err(MlxError::InvalidArgument(
77            "argsort_desc_f32: batch_size must be > 0".into(),
78        ));
79    }
80
81    let total = batch_size as usize * row_len as usize;
82    let input_bytes = total * 4; // f32
83    if input.byte_len() < input_bytes {
84        return Err(MlxError::InvalidArgument(format!(
85            "argsort_desc_f32: input buffer too small: need {} bytes, have {}",
86            input_bytes,
87            input.byte_len()
88        )));
89    }
90    let output_bytes = total * 4; // u32
91    if output.byte_len() < output_bytes {
92        return Err(MlxError::InvalidArgument(format!(
93            "argsort_desc_f32: output buffer too small: need {} bytes, have {}",
94            output_bytes,
95            output.byte_len()
96        )));
97    }
98
99    let pipeline = registry.get_pipeline("argsort_desc_f32", device)?;
100
101    let gpu_params = GpuArgsortParams {
102        row_len,
103        batch_size,
104    };
105
106    // One threadgroup per row, threadgroup size = next power of two of row_len.
107    let tg_size = std::cmp::min(256, row_len.next_power_of_two()) as u64;
108
109    encode_threadgroups_with_args(
110        encoder,
111        pipeline,
112        &[
113            (0, KernelArg::Buffer(input)),
114            (1, KernelArg::Buffer(output)),
115            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
116        ],
117        MTLSize::new(batch_size as u64, 1, 1),
118        MTLSize::new(tg_size, 1, 1),
119    );
120
121    Ok(())
122}