Skip to main content

mlx_native/ops/
top_k.rs

1//! GPU top-K dispatch — returns the K largest elements of a float array.
2//!
3//! Used by the Q8 lm_head rerank path to avoid a full 1 MB logits readback.
4//! After Q8 matmul writes the full vocabulary of logits, this kernel selects
5//! the top-K on GPU; only K * 8 bytes of (index, value) pairs come back to
6//! CPU for exact F32 reranking.
7//!
8//! Output order is NOT guaranteed — callers that need sorted order should
9//! sort themselves. The rerank path sorts implicitly by picking argmax over
10//! the reranked logits.
11
12use metal::MTLSize;
13
14use crate::buffer::MlxBuffer;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19pub static TOP_K_SHADER_SOURCE: &str = include_str!("../shaders/top_k.metal");
20
21pub fn register(registry: &mut KernelRegistry) {
22    registry.register_source("top_k_f32", TOP_K_SHADER_SOURCE);
23}
24
25/// Dispatch a top-K selection on the GPU.
26///
27/// # Arguments
28///
29/// * `encoder`     - Command encoder to record the dispatch into.
30/// * `registry`    - Kernel registry (must have `top_k_f32` registered).
31/// * `device`      - Metal device for pipeline compilation.
32/// * `input`       - Input buffer `[n_elements]` (f32).
33/// * `out_indices` - Output buffer `[k]` (u32) — indices of top-K elements.
34/// * `out_values`  - Output buffer `[k]` (f32) — values of top-K elements.
35/// * `params_buf`  - Params buffer `[2]` (u32) — `[n_elements, k]`.
36/// * `n_elements`  - Number of elements in `input`.
37/// * `k`           - Number of top elements to return. Must be <= 128.
38///
39/// # Errors
40///
41/// Returns `MlxError::InvalidArgument` if `n_elements == 0`, `k == 0`,
42/// `k > 128`, or buffer sizes don't match.
43pub fn dispatch_top_k_f32(
44    encoder: &mut CommandEncoder,
45    registry: &mut KernelRegistry,
46    device: &metal::DeviceRef,
47    input: &MlxBuffer,
48    out_indices: &MlxBuffer,
49    out_values: &MlxBuffer,
50    params_buf: &MlxBuffer,
51    n_elements: u32,
52    k: u32,
53) -> Result<()> {
54    if n_elements == 0 || k == 0 {
55        return Err(MlxError::InvalidArgument(
56            "top_k_f32: n_elements and k must be > 0".into(),
57        ));
58    }
59    if k > 128 {
60        return Err(MlxError::InvalidArgument(format!(
61            "top_k_f32: k ({}) must be <= 128 (MAX_K in shader)",
62            k
63        )));
64    }
65    if input.element_count() < n_elements as usize {
66        return Err(MlxError::InvalidArgument(format!(
67            "top_k_f32: input element count {} < n_elements {}",
68            input.element_count(),
69            n_elements
70        )));
71    }
72    if out_indices.element_count() < k as usize {
73        return Err(MlxError::InvalidArgument(format!(
74            "top_k_f32: out_indices ({}) < k ({})",
75            out_indices.element_count(),
76            k
77        )));
78    }
79    if out_values.element_count() < k as usize {
80        return Err(MlxError::InvalidArgument(format!(
81            "top_k_f32: out_values ({}) < k ({})",
82            out_values.element_count(),
83            k
84        )));
85    }
86
87    let pipeline = registry.get_pipeline("top_k_f32", device)?;
88
89    // tg_size choice: threadgroup shared memory on Apple Silicon is ~32 KB.
90    // Shared = tg_size * K * (4 + 4) bytes.
91    //   K=64  → tg_size <= 64  (32 KB)
92    //   K=32  → tg_size <= 128
93    //   K=128 → tg_size <= 32
94    // Use tg_size=32 for K up to 128 to be safe across Apple generations.
95    let tg_size: u64 = match k {
96        1..=32 => 128,
97        33..=64 => 64,
98        _ => 32,
99    };
100
101    // Shared memory: tg_size * K each for values (float) and indices (uint).
102    let float_shared = tg_size * (k as u64) * 4;
103    let uint_shared  = tg_size * (k as u64) * 4;
104
105    encoder.encode_threadgroups_with_shared(
106        pipeline,
107        &[
108            (0, input),
109            (1, out_indices),
110            (2, out_values),
111            (3, params_buf),
112        ],
113        &[(0, float_shared), (1, uint_shared)],
114        MTLSize::new(1, 1, 1),          // single threadgroup
115        MTLSize::new(tg_size, 1, 1),
116    );
117
118    Ok(())
119}