Skip to main content

mlx_native/ops/
softmax_sample.rs

1//! Temperature-scaled softmax + categorical sample, entirely on GPU.
2//!
3//! For stochastic (temperature > 0) decoding this replaces the 1MB
4//! GPU→CPU logits readback with an 8-byte readback: the sampled token index
5//! and its log-probability.
6//!
7//! The kernel runs three parallel passes (max, exp-sum, normalize) using
8//! threadgroup reductions, then a sequential CDF scan by thread 0 to draw the
9//! sample.
10
11use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::encoder::CommandEncoder;
15use crate::error::{MlxError, Result};
16use crate::kernel_registry::KernelRegistry;
17
18/// MSL source for the softmax_sample kernel (embedded at compile time).
19pub static SOFTMAX_SAMPLE_SHADER_SOURCE: &str =
20    include_str!("../shaders/softmax_sample.metal");
21
22/// Register the softmax_sample shader source with the given kernel registry.
23pub fn register(registry: &mut KernelRegistry) {
24    registry.register_source("softmax_sample_f32", SOFTMAX_SAMPLE_SHADER_SOURCE);
25}
26
27/// Dispatch a temperature-scaled softmax + categorical sample on the GPU.
28///
29/// Computes `softmax(logits / temperature)` entirely on the GPU, then samples
30/// one token index using the provided uniform random value.  Only 8 bytes
31/// (token_id u32 + logprob f32) are transferred back to the CPU.
32///
33/// # Arguments
34///
35/// * `encoder`      - Command encoder to record the dispatch into.
36/// * `registry`     - Kernel registry (must have `softmax_sample_f32` registered).
37/// * `device`       - Metal device for pipeline compilation.
38/// * `logits`       - Input logits buffer `[n_elements]` (f32).
39/// * `scratch`      - Scratch buffer `[n_elements]` (f32) used for intermediate
40///                    probability values.  May be a transient allocation; must
41///                    not alias `logits`.
42/// * `out_token`    - Output buffer `[1]` (u32) — sampled token index.
43/// * `out_logprob`  - Output buffer `[1]` (f32) — log-probability of the
44///                    sampled token.
45/// * `params_buf`   - Params buffer `[3]` (f32) containing:
46///                    `[n_elements as f32, temperature, random_val]`
47/// * `n_elements`   - Vocabulary size (number of logits).
48/// * `temperature`  - Sampling temperature (must be > 0.0).
49/// * `random_val`   - Uniform random value in `[0, 1)` for categorical sample.
50///
51/// # Errors
52///
53/// Returns `MlxError::InvalidArgument` if:
54/// - `n_elements` is 0.
55/// - `temperature` is not positive.
56/// - `random_val` is not in `[0, 1)`.
57/// - Buffer sizes are inconsistent.
58pub fn dispatch_softmax_sample_f32(
59    encoder: &mut CommandEncoder,
60    registry: &mut KernelRegistry,
61    device: &metal::DeviceRef,
62    logits: &MlxBuffer,
63    scratch: &MlxBuffer,
64    out_token: &MlxBuffer,
65    out_logprob: &MlxBuffer,
66    params_buf: &MlxBuffer,
67    n_elements: u32,
68    temperature: f32,
69    random_val: f32,
70) -> Result<()> {
71    if n_elements == 0 {
72        return Err(MlxError::InvalidArgument(
73            "softmax_sample_f32: n_elements must be > 0".into(),
74        ));
75    }
76    if temperature <= 0.0 {
77        return Err(MlxError::InvalidArgument(format!(
78            "softmax_sample_f32: temperature must be > 0, got {}",
79            temperature
80        )));
81    }
82    if !(0.0..1.0).contains(&random_val) {
83        return Err(MlxError::InvalidArgument(format!(
84            "softmax_sample_f32: random_val must be in [0, 1), got {}",
85            random_val
86        )));
87    }
88    if logits.element_count() != n_elements as usize {
89        return Err(MlxError::InvalidArgument(format!(
90            "softmax_sample_f32: logits element count {} != n_elements {}",
91            logits.element_count(),
92            n_elements
93        )));
94    }
95    if scratch.element_count() != n_elements as usize {
96        return Err(MlxError::InvalidArgument(format!(
97            "softmax_sample_f32: scratch element count {} != n_elements {}",
98            scratch.element_count(),
99            n_elements
100        )));
101    }
102    if out_token.element_count() < 1 {
103        return Err(MlxError::InvalidArgument(
104            "softmax_sample_f32: out_token must have at least 1 element".into(),
105        ));
106    }
107    if out_logprob.element_count() < 1 {
108        return Err(MlxError::InvalidArgument(
109            "softmax_sample_f32: out_logprob must have at least 1 element".into(),
110        ));
111    }
112
113    let pipeline = registry.get_pipeline("softmax_sample_f32", device)?;
114
115    // Threadgroup size: next power-of-two of n_elements, capped at 1024.
116    let tg_size = std::cmp::min(1024, n_elements.next_power_of_two()) as u64;
117
118    // Shared memory: tg_size floats for the reduction (max pass and sum pass).
119    let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
120
121    encoder.encode_threadgroups_with_shared(
122        pipeline,
123        &[
124            (0, logits),
125            (1, scratch),
126            (2, out_token),
127            (3, out_logprob),
128            (4, params_buf),
129        ],
130        &[(0, shared_mem_bytes)],
131        MTLSize::new(1, 1, 1),       // single threadgroup
132        MTLSize::new(tg_size, 1, 1),
133    );
134
135    Ok(())
136}