mlx_native/ops/top_k.rs
1//! GPU top-K dispatch — returns the K largest elements of a float array.
2//!
3//! Used by the Q8 lm_head rerank path to avoid a full 1 MB logits readback.
4//! After Q8 matmul writes the full vocabulary of logits, this kernel selects
5//! the top-K on GPU; only K * 8 bytes of (index, value) pairs come back to
6//! CPU for exact F32 reranking.
7//!
8//! Output order is NOT guaranteed — callers that need sorted order should
9//! sort themselves. The rerank path sorts implicitly by picking argmax over
10//! the reranked logits.
11
12use metal::MTLSize;
13
14use crate::buffer::MlxBuffer;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19pub static TOP_K_SHADER_SOURCE: &str = include_str!("../shaders/top_k.metal");
20
21pub fn register(registry: &mut KernelRegistry) {
22 registry.register_source("top_k_f32", TOP_K_SHADER_SOURCE);
23}
24
25/// Dispatch a top-K selection on the GPU.
26///
27/// # Arguments
28///
29/// * `encoder` - Command encoder to record the dispatch into.
30/// * `registry` - Kernel registry (must have `top_k_f32` registered).
31/// * `device` - Metal device for pipeline compilation.
32/// * `input` - Input buffer `[n_elements]` (f32).
33/// * `out_indices` - Output buffer `[k]` (u32) — indices of top-K elements.
34/// * `out_values` - Output buffer `[k]` (f32) — values of top-K elements.
35/// * `params_buf` - Params buffer `[2]` (u32) — `[n_elements, k]`.
36/// * `n_elements` - Number of elements in `input`.
37/// * `k` - Number of top elements to return. Must be <= 128.
38///
39/// # Errors
40///
41/// Returns `MlxError::InvalidArgument` if `n_elements == 0`, `k == 0`,
42/// `k > 128`, or buffer sizes don't match.
43pub fn dispatch_top_k_f32(
44 encoder: &mut CommandEncoder,
45 registry: &mut KernelRegistry,
46 device: &metal::DeviceRef,
47 input: &MlxBuffer,
48 out_indices: &MlxBuffer,
49 out_values: &MlxBuffer,
50 params_buf: &MlxBuffer,
51 n_elements: u32,
52 k: u32,
53) -> Result<()> {
54 if n_elements == 0 || k == 0 {
55 return Err(MlxError::InvalidArgument(
56 "top_k_f32: n_elements and k must be > 0".into(),
57 ));
58 }
59 if k > 128 {
60 return Err(MlxError::InvalidArgument(format!(
61 "top_k_f32: k ({}) must be <= 128 (MAX_K in shader)",
62 k
63 )));
64 }
65 if input.element_count() < n_elements as usize {
66 return Err(MlxError::InvalidArgument(format!(
67 "top_k_f32: input element count {} < n_elements {}",
68 input.element_count(),
69 n_elements
70 )));
71 }
72 if out_indices.element_count() < k as usize {
73 return Err(MlxError::InvalidArgument(format!(
74 "top_k_f32: out_indices ({}) < k ({})",
75 out_indices.element_count(),
76 k
77 )));
78 }
79 if out_values.element_count() < k as usize {
80 return Err(MlxError::InvalidArgument(format!(
81 "top_k_f32: out_values ({}) < k ({})",
82 out_values.element_count(),
83 k
84 )));
85 }
86
87 let pipeline = registry.get_pipeline("top_k_f32", device)?;
88
89 // tg_size choice: threadgroup shared memory on Apple Silicon is ~32 KB.
90 // Shared = tg_size * K * (4 + 4) bytes.
91 // K=64 → tg_size <= 64 (32 KB)
92 // K=32 → tg_size <= 128
93 // K=128 → tg_size <= 32
94 // Use tg_size=32 for K up to 128 to be safe across Apple generations.
95 let tg_size: u64 = match k {
96 1..=32 => 128,
97 33..=64 => 64,
98 _ => 32,
99 };
100
101 // Shared memory: tg_size * K each for values (float) and indices (uint).
102 let float_shared = tg_size * (k as u64) * 4;
103 let uint_shared = tg_size * (k as u64) * 4;
104
105 encoder.encode_threadgroups_with_shared(
106 pipeline,
107 &[
108 (0, input),
109 (1, out_indices),
110 (2, out_values),
111 (3, params_buf),
112 ],
113 &[(0, float_shared), (1, uint_shared)],
114 MTLSize::new(1, 1, 1), // single threadgroup
115 MTLSize::new(tg_size, 1, 1),
116 );
117
118 Ok(())
119}