mlx_native/ops/argmax.rs
1//! Greedy argmax GPU dispatch — finds the index of the maximum value in a
2//! float array entirely on the GPU.
3//!
4//! For greedy (temperature=0) decoding with vocab_size=262144, this replaces
5//! a 1MB GPU→CPU logits readback with an 8-byte readback: the (index, value)
6//! pair. The kernel uses a single threadgroup with shared-memory tree
7//! reduction.
8
9use metal::MTLSize;
10
11use crate::buffer::MlxBuffer;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16/// MSL source for the argmax kernel (embedded at compile time).
17pub static ARGMAX_SHADER_SOURCE: &str = include_str!("../shaders/argmax.metal");
18
19/// Register argmax shader source with the given kernel registry.
20pub fn register(registry: &mut KernelRegistry) {
21 registry.register_source("argmax_f32", ARGMAX_SHADER_SOURCE);
22}
23
24/// Dispatch an argmax operation on the GPU.
25///
26/// Finds the index of the maximum element in `input` and writes the result to
27/// `out_index` and `out_value`. The entire reduction runs in a single Metal
28/// threadgroup, returning 8 bytes instead of the full vocab-size logits array.
29///
30/// # Arguments
31///
32/// * `encoder` - Command encoder to record the dispatch into.
33/// * `registry` - Kernel registry (must have `argmax_f32` registered).
34/// * `device` - Metal device for pipeline compilation.
35/// * `input` - Input buffer of shape `[n_elements]` (f32).
36/// * `out_index` - Output buffer `[1]` (u32) — index of the maximum element.
37/// * `out_value` - Output buffer `[1]` (f32) — value of the maximum element.
38/// * `params_buf` - Params buffer `[1]` (u32) — contains `n_elements`.
39/// * `n_elements` - Number of elements in `input`.
40///
41/// # Errors
42///
43/// Returns `MlxError::InvalidArgument` if:
44/// - `n_elements` is 0.
45/// - `input` element count does not match `n_elements`.
46/// - `out_index` or `out_value` element count is not 1.
47pub fn dispatch_argmax_f32(
48 encoder: &mut CommandEncoder,
49 registry: &mut KernelRegistry,
50 device: &metal::DeviceRef,
51 input: &MlxBuffer,
52 out_index: &MlxBuffer,
53 out_value: &MlxBuffer,
54 params_buf: &MlxBuffer,
55 n_elements: u32,
56) -> Result<()> {
57 if n_elements == 0 {
58 return Err(MlxError::InvalidArgument(
59 "argmax_f32: n_elements must be > 0".into(),
60 ));
61 }
62 if input.element_count() != n_elements as usize {
63 return Err(MlxError::InvalidArgument(format!(
64 "argmax_f32: input element count {} != n_elements {}",
65 input.element_count(),
66 n_elements
67 )));
68 }
69 if out_index.element_count() < 1 {
70 return Err(MlxError::InvalidArgument(
71 "argmax_f32: out_index must have at least 1 element".into(),
72 ));
73 }
74 if out_value.element_count() < 1 {
75 return Err(MlxError::InvalidArgument(
76 "argmax_f32: out_value must have at least 1 element".into(),
77 ));
78 }
79
80 let pipeline = registry.get_pipeline("argmax_f32", device)?;
81
82 // Threadgroup size: next power-of-two of n_elements, capped at 1024.
83 // Must be a power of 2 for the tree reduction to be correct.
84 let tg_size = std::cmp::min(1024, n_elements.next_power_of_two()) as u64;
85
86 // Shared memory:
87 // index 0 — tg_size floats for value reduction
88 // index 1 — tg_size uints for index tracking
89 let float_shared = tg_size * 4; // sizeof(float) = 4
90 let uint_shared = tg_size * 4; // sizeof(uint) = 4
91
92 encoder.encode_threadgroups_with_shared(
93 pipeline,
94 &[
95 (0, input),
96 (1, out_index),
97 (2, out_value),
98 (3, params_buf),
99 ],
100 &[(0, float_shared), (1, uint_shared)],
101 MTLSize::new(1, 1, 1), // single threadgroup
102 MTLSize::new(tg_size, 1, 1),
103 );
104
105 Ok(())
106}