1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
//! Temperature-scaled softmax + categorical sample, entirely on GPU.
//!
//! For stochastic (temperature > 0) decoding this replaces the 1MB
//! GPU→CPU logits readback with an 8-byte readback: the sampled token index
//! and its log-probability.
//!
//! The kernel runs three parallel passes (max, exp-sum, normalize) using
//! threadgroup reductions, then a sequential CDF scan by thread 0 to draw the
//! sample.
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
/// MSL source for the softmax_sample kernel (embedded at compile time).
pub static SOFTMAX_SAMPLE_SHADER_SOURCE: &str =
include_str!("../shaders/softmax_sample.metal");
/// Register the softmax_sample shader source with the given kernel registry.
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("softmax_sample_f32", SOFTMAX_SAMPLE_SHADER_SOURCE);
}
/// Dispatch a temperature-scaled softmax + categorical sample on the GPU.
///
/// Computes `softmax(logits / temperature)` entirely on the GPU, then samples
/// one token index using the provided uniform random value. Only 8 bytes
/// (token_id u32 + logprob f32) are transferred back to the CPU.
///
/// # Arguments
///
/// * `encoder` - Command encoder to record the dispatch into.
/// * `registry` - Kernel registry (must have `softmax_sample_f32` registered).
/// * `device` - Metal device for pipeline compilation.
/// * `logits` - Input logits buffer `[n_elements]` (f32).
/// * `scratch` - Scratch buffer `[n_elements]` (f32) used for intermediate
/// probability values. May be a transient allocation; must
/// not alias `logits`.
/// * `out_token` - Output buffer `[1]` (u32) — sampled token index.
/// * `out_logprob` - Output buffer `[1]` (f32) — log-probability of the
/// sampled token.
/// * `params_buf` - Params buffer `[3]` (f32) containing:
/// `[n_elements as f32, temperature, random_val]`
/// * `n_elements` - Vocabulary size (number of logits).
/// * `temperature` - Sampling temperature (must be > 0.0).
/// * `random_val` - Uniform random value in `[0, 1)` for categorical sample.
///
/// # Errors
///
/// Returns `MlxError::InvalidArgument` if:
/// - `n_elements` is 0.
/// - `temperature` is not positive.
/// - `random_val` is not in `[0, 1)`.
/// - Buffer sizes are inconsistent.
pub fn dispatch_softmax_sample_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
logits: &MlxBuffer,
scratch: &MlxBuffer,
out_token: &MlxBuffer,
out_logprob: &MlxBuffer,
params_buf: &MlxBuffer,
n_elements: u32,
temperature: f32,
random_val: f32,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"softmax_sample_f32: n_elements must be > 0".into(),
));
}
if temperature <= 0.0 {
return Err(MlxError::InvalidArgument(format!(
"softmax_sample_f32: temperature must be > 0, got {}",
temperature
)));
}
if !(0.0..1.0).contains(&random_val) {
return Err(MlxError::InvalidArgument(format!(
"softmax_sample_f32: random_val must be in [0, 1), got {}",
random_val
)));
}
if logits.element_count() != n_elements as usize {
return Err(MlxError::InvalidArgument(format!(
"softmax_sample_f32: logits element count {} != n_elements {}",
logits.element_count(),
n_elements
)));
}
if scratch.element_count() != n_elements as usize {
return Err(MlxError::InvalidArgument(format!(
"softmax_sample_f32: scratch element count {} != n_elements {}",
scratch.element_count(),
n_elements
)));
}
if out_token.element_count() < 1 {
return Err(MlxError::InvalidArgument(
"softmax_sample_f32: out_token must have at least 1 element".into(),
));
}
if out_logprob.element_count() < 1 {
return Err(MlxError::InvalidArgument(
"softmax_sample_f32: out_logprob must have at least 1 element".into(),
));
}
let pipeline = registry.get_pipeline("softmax_sample_f32", device)?;
// Threadgroup size: next power-of-two of n_elements, capped at 1024.
let tg_size = std::cmp::min(1024, n_elements.next_power_of_two()) as u64;
// Shared memory: tg_size floats for the reduction (max pass and sum pass).
let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, logits),
(1, scratch),
(2, out_token),
(3, out_logprob),
(4, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(1, 1, 1), // single threadgroup
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}