mlx_native/ops/softmax_sample.rs
1//! Temperature-scaled softmax + categorical sample, entirely on GPU.
2//!
3//! For stochastic (temperature > 0) decoding this replaces the 1MB
4//! GPU→CPU logits readback with an 8-byte readback: the sampled token index
5//! and its log-probability.
6//!
7//! The kernel runs three parallel passes (max, exp-sum, normalize) using
8//! threadgroup reductions, then a sequential CDF scan by thread 0 to draw the
9//! sample.
10
11use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18/// MSL source for the softmax_sample kernel (embedded at compile time).
19pub static SOFTMAX_SAMPLE_SHADER_SOURCE: &str =
20 include_str!("../shaders/softmax_sample.metal");
21
22/// Register the softmax_sample shader source with the given kernel registry.
23pub fn register(registry: &mut KernelRegistry) {
24 registry.register_source("softmax_sample_f32", SOFTMAX_SAMPLE_SHADER_SOURCE);
25}
26
27/// Dispatch a temperature-scaled softmax + categorical sample on the GPU.
28///
29/// Computes `softmax(logits / temperature)` entirely on the GPU, then samples
30/// one token index using the provided uniform random value. Only 8 bytes
31/// (token_id u32 + logprob f32) are transferred back to the CPU.
32///
33/// # Arguments
34///
35/// * `encoder` - Command encoder to record the dispatch into.
36/// * `registry` - Kernel registry (must have `softmax_sample_f32` registered).
37/// * `device` - Metal device for pipeline compilation.
38/// * `logits` - Input logits buffer `[n_elements]` (f32).
39/// * `scratch` - Scratch buffer `[n_elements]` (f32) used for intermediate
40/// probability values. May be a transient allocation; must
41/// not alias `logits`.
42/// * `out_token` - Output buffer `[1]` (u32) — sampled token index.
43/// * `out_logprob` - Output buffer `[1]` (f32) — log-probability of the
44/// sampled token.
45/// * `params_buf` - Params buffer `[3]` (f32) containing:
46/// `[n_elements as f32, temperature, random_val]`
47/// * `n_elements` - Vocabulary size (number of logits).
48/// * `temperature` - Sampling temperature (must be > 0.0).
49/// * `random_val` - Uniform random value in `[0, 1)` for categorical sample.
50///
51/// # Errors
52///
53/// Returns `MlxError::InvalidArgument` if:
54/// - `n_elements` is 0.
55/// - `temperature` is not positive.
56/// - `random_val` is not in `[0, 1)`.
57/// - Buffer sizes are inconsistent.
58pub fn dispatch_softmax_sample_f32(
59 encoder: &mut CommandEncoder,
60 registry: &mut KernelRegistry,
61 device: &metal::DeviceRef,
62 logits: &MlxBuffer,
63 scratch: &MlxBuffer,
64 out_token: &MlxBuffer,
65 out_logprob: &MlxBuffer,
66 params_buf: &MlxBuffer,
67 n_elements: u32,
68 temperature: f32,
69 random_val: f32,
70) -> Result<()> {
71 if n_elements == 0 {
72 return Err(MlxError::InvalidArgument(
73 "softmax_sample_f32: n_elements must be > 0".into(),
74 ));
75 }
76 if temperature <= 0.0 {
77 return Err(MlxError::InvalidArgument(format!(
78 "softmax_sample_f32: temperature must be > 0, got {}",
79 temperature
80 )));
81 }
82 if !(0.0..1.0).contains(&random_val) {
83 return Err(MlxError::InvalidArgument(format!(
84 "softmax_sample_f32: random_val must be in [0, 1), got {}",
85 random_val
86 )));
87 }
88 if logits.element_count() != n_elements as usize {
89 return Err(MlxError::InvalidArgument(format!(
90 "softmax_sample_f32: logits element count {} != n_elements {}",
91 logits.element_count(),
92 n_elements
93 )));
94 }
95 if scratch.element_count() != n_elements as usize {
96 return Err(MlxError::InvalidArgument(format!(
97 "softmax_sample_f32: scratch element count {} != n_elements {}",
98 scratch.element_count(),
99 n_elements
100 )));
101 }
102 if out_token.element_count() < 1 {
103 return Err(MlxError::InvalidArgument(
104 "softmax_sample_f32: out_token must have at least 1 element".into(),
105 ));
106 }
107 if out_logprob.element_count() < 1 {
108 return Err(MlxError::InvalidArgument(
109 "softmax_sample_f32: out_logprob must have at least 1 element".into(),
110 ));
111 }
112
113 let pipeline = registry.get_pipeline("softmax_sample_f32", device)?;
114
115 // Threadgroup size: next power-of-two of n_elements, capped at 1024.
116 let tg_size = std::cmp::min(1024, n_elements.next_power_of_two()) as u64;
117
118 // Shared memory: tg_size floats for the reduction (max pass and sum pass).
119 let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
120
121 encoder.encode_threadgroups_with_shared(
122 pipeline,
123 &[
124 (0, logits),
125 (1, scratch),
126 (2, out_token),
127 (3, out_logprob),
128 (4, params_buf),
129 ],
130 &[(0, shared_mem_bytes)],
131 MTLSize::new(1, 1, 1), // single threadgroup
132 MTLSize::new(tg_size, 1, 1),
133 );
134
135 Ok(())
136}