baracuda_kernels/random/topk_topp_sampling.rs
1//! Sort-free top-K / top-P / min-P sampling — Phase 46 (FlashInfer
2//! cherry-pick).
3//!
4//! Faster decode-time alternative to baracuda's existing
5//! `topk + softmax + multinomial` pipeline. The op takes a row-
6//! normalized probability tensor (one row per request) and produces
7//! one sampled index per row.
8//!
9//! ## Variants
10//!
11//! Each variant maps to a dedicated FlashInfer launcher. Pick the
12//! sampler that matches your filter:
13//!
14//! - [`SamplerKind::TopK`] — keep only the top-`K` cells.
15//! - [`SamplerKind::TopP`] — keep the smallest set of largest cells
16//! whose cumulative mass exceeds `top_p`.
17//! - [`SamplerKind::MinP`] — keep cells whose probability
18//! `>= min_p * max_prob_in_row`.
19//! - [`SamplerKind::TopKTopP`] — combine top-K and top-P (the
20//! canonical decode hot path for Llama / Mistral / Gemma).
21//!
22//! ## Determinism
23//!
24//! The sampler is internally rejection-based, drawing a fresh uniform
25//! `u ~ U(0, 1)` per row from a philox stream seeded with
26//! `(seed_val, offset_val)`. With `deterministic == true`, FlashInfer
27//! falls back to a sort-based tiebreaker on the rare ambiguous-cell
28//! case (cells where the cumulative-sum boundary lands exactly on a
29//! cell start).
30//!
31//! Calling the sampler twice with the same `(seed_val, offset_val)`
32//! and identical `probs` is bit-stable.
33//!
34//! ## Caller contract
35//!
36//! - `probs` : `[batch, vocab]` row-major f32. Each row must be
37//! non-negative and sum to ~1 (you should typically chain after
38//! softmax + exp).
39//! - `output` : `[batch]` i32, written.
40//! - `valid` : `[batch]` u8 bool, written if non-null. 1 means the
41//! sample was accepted; 0 means rejection sampling timed out and
42//! the caller should re-draw with a fresh seed.
43
44
45use baracuda_cutlass::{Error, Result};
46use baracuda_driver::Stream;
47use baracuda_kernels_types::{
48 ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
49 PrecisionGuarantee, RandomKind, TensorMut, TensorRef, Workspace,
50};
51
52
53/// Which sort-free sampler to run.
54///
55/// Note: only `PartialEq` is derived (not `Eq`) because `top_p` /
56/// `min_p` are `f32`, and `f32` doesn't satisfy `Eq` (NaN != NaN).
57#[derive(Copy, Clone, Debug, PartialEq)]
58#[non_exhaustive]
59pub enum SamplerKind {
60 /// Keep only the top-`K` cells (K = `top_k`).
61 TopK {
62 /// Cells kept per row. Must be in `[1, vocab_size]`.
63 top_k: i32,
64 },
65 /// Keep the smallest top-prob set whose cumulative mass > `top_p`.
66 TopP {
67 /// Cumulative-mass cutoff in `(0, 1]`.
68 top_p: f32,
69 },
70 /// Keep cells `prob >= min_p * row_max`.
71 MinP {
72 /// Multiplier of the per-row max prob in `(0, 1]`.
73 min_p: f32,
74 },
75 /// Combined top-K then top-P filter. Canonical decode sampler.
76 TopKTopP {
77 /// Cells kept per row. Must be in `[1, vocab_size]`.
78 top_k: i32,
79 /// Cumulative-mass cutoff in `(0, 1]`.
80 top_p: f32,
81 },
82}
83
84/// Descriptor for a sort-free sampling op.
85#[derive(Copy, Clone, Debug)]
86pub struct TopKTopPSamplingDescriptor {
87 /// Batch size (rows of `probs`).
88 pub batch_size: i32,
89 /// Vocabulary size (columns of `probs`).
90 pub vocab_size: i32,
91 /// Sampler family + filter parameters.
92 pub sampler: SamplerKind,
93 /// If true, FlashInfer falls back to a sort-based tiebreaker on
94 /// ambiguous cells. Documented in the module docstring.
95 pub deterministic: bool,
96}
97
98/// Args bundle for a sort-free sampling launch.
99pub struct TopKTopPSamplingArgs<'a> {
100 /// Row-normalized probabilities `[batch, vocab]` f32.
101 pub probs: TensorRef<'a, f32, 2>,
102 /// Sampled indices `[batch]` i32 (written).
103 pub output: TensorMut<'a, i32, 1>,
104 /// Optional per-row "sample accepted" flags `[batch]` u8 bool.
105 /// `None` to skip emitting them.
106 pub valid: Option<TensorMut<'a, u8, 1>>,
107 /// RNG seed (shared across the batch).
108 pub seed_val: u64,
109 /// RNG philox offset.
110 pub offset_val: u64,
111}
112
113/// Sort-free top-K / top-P / min-P sampling plan.
114///
115/// Routes to FlashInfer's `Top*FromProb` family. Requires the
116/// `flashinfer` cargo feature.
117pub struct TopKTopPSamplingPlan {
118 desc: TopKTopPSamplingDescriptor,
119 sku: KernelSku,
120}
121
122impl TopKTopPSamplingPlan {
123 /// Pick a sampler kernel + validate filter parameters against the
124 /// descriptor. Returns `Error::InvalidProblem` for out-of-range
125 /// `top_k` / `top_p` / `min_p` values.
126 pub fn select(
127 _stream: &Stream,
128 desc: &TopKTopPSamplingDescriptor,
129 _pref: PlanPreference,
130 ) -> Result<Self> {
131 if desc.batch_size <= 0 || desc.vocab_size <= 0 {
132 return Err(Error::InvalidProblem(
133 "TopKTopPSamplingPlan: batch_size / vocab_size must be positive",
134 ));
135 }
136 match desc.sampler {
137 SamplerKind::TopK { top_k } => {
138 if top_k <= 0 || top_k > desc.vocab_size {
139 return Err(Error::InvalidProblem(
140 "TopKTopPSamplingPlan: top_k must be in [1, vocab_size]",
141 ));
142 }
143 }
144 SamplerKind::TopP { top_p } => {
145 if !(top_p > 0.0 && top_p <= 1.0) {
146 return Err(Error::InvalidProblem(
147 "TopKTopPSamplingPlan: top_p must be in (0, 1]",
148 ));
149 }
150 }
151 SamplerKind::MinP { min_p } => {
152 if !(min_p > 0.0 && min_p <= 1.0) {
153 return Err(Error::InvalidProblem(
154 "TopKTopPSamplingPlan: min_p must be in (0, 1]",
155 ));
156 }
157 }
158 SamplerKind::TopKTopP { top_k, top_p } => {
159 if top_k <= 0 || top_k > desc.vocab_size {
160 return Err(Error::InvalidProblem(
161 "TopKTopPSamplingPlan: top_k must be in [1, vocab_size]",
162 ));
163 }
164 if !(top_p > 0.0 && top_p <= 1.0) {
165 return Err(Error::InvalidProblem(
166 "TopKTopPSamplingPlan: top_p must be in (0, 1]",
167 ));
168 }
169 }
170 }
171 let precision_guarantee = PrecisionGuarantee {
172 math_precision: MathPrecision::F32,
173 accumulator: ElementKind::F32,
174 bit_stable_on_same_hardware: true,
175 // Deterministic across same-(seed, offset) repeat. Cell-
176 // selection determinism on tiebreaker cells is controlled
177 // by `desc.deterministic` (FlashInfer's sort-fallback).
178 deterministic: desc.deterministic,
179 };
180 let sku = KernelSku {
181 category: OpCategory::Random,
182 op: RandomKind::Multinomial as u16,
183 element: ElementKind::F32,
184 aux_element: Some(ElementKind::I32),
185 layout: None,
186 epilogue: None,
187 arch: ArchSku::Sm80,
188 backend: BackendKind::FlashInfer,
189 precision_guarantee,
190 };
191 Ok(Self { desc: *desc, sku })
192 }
193
194 /// Validate args against the descriptor (shape + contiguity check).
195 pub fn can_implement(&self, args: &TopKTopPSamplingArgs<'_>) -> Result<()> {
196 if args.probs.shape != [self.desc.batch_size, self.desc.vocab_size] {
197 return Err(Error::InvalidProblem(
198 "TopKTopPSamplingPlan: probs shape must be [batch_size, vocab_size]",
199 ));
200 }
201 if args.output.shape != [self.desc.batch_size] {
202 return Err(Error::InvalidProblem(
203 "TopKTopPSamplingPlan: output shape must be [batch_size]",
204 ));
205 }
206 if let Some(v) = &args.valid {
207 if v.shape != [self.desc.batch_size] {
208 return Err(Error::InvalidProblem(
209 "TopKTopPSamplingPlan: valid shape must be [batch_size]",
210 ));
211 }
212 }
213 if !args.probs.is_contiguous() || !args.output.is_contiguous() {
214 return Err(Error::Unsupported(
215 "TopKTopPSamplingPlan: probs / output must be contiguous",
216 ));
217 }
218 Ok(())
219 }
220
221 /// Required workspace bytes (always 0 — sampling is workspace-free).
222 #[inline]
223 pub fn workspace_size(&self) -> usize {
224 0
225 }
226
227 /// SKU identity (telemetry / autotuner key).
228 #[inline]
229 pub fn sku(&self) -> KernelSku {
230 self.sku
231 }
232
233 /// Numerical guarantees of this plan.
234 #[inline]
235 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
236 self.sku.precision_guarantee
237 }
238
239 /// Launch the selected sampler on the supplied stream.
240 pub fn run(
241 &self,
242 stream: &Stream,
243 _workspace: Workspace<'_>,
244 args: TopKTopPSamplingArgs<'_>,
245 ) -> Result<()> {
246 self.can_implement(&args)?;
247 #[cfg(not(feature = "flashinfer"))]
248 {
249 let _ = (stream, &args);
250 Err(Error::Unsupported(
251 "TopKTopPSamplingPlan: `flashinfer` cargo feature is not enabled",
252 ))
253 }
254 #[cfg(feature = "flashinfer")]
255 {
256 let stream_ptr = stream.as_raw() as *mut c_void;
257 let probs_ptr = args.probs.data.as_raw().0 as *const c_void;
258 let output_ptr = args.output.data.as_raw().0 as *mut c_void;
259 let valid_ptr = match &args.valid {
260 Some(v) => v.data.as_raw().0 as *mut c_void,
261 None => core::ptr::null_mut::<c_void>(),
262 };
263 let det_flag = if self.desc.deterministic { 1 } else { 0 };
264
265 let status = match self.desc.sampler {
266 SamplerKind::TopK { top_k } => unsafe {
267 baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_sampling_f32_run(
268 self.desc.batch_size,
269 self.desc.vocab_size,
270 top_k,
271 det_flag,
272 args.seed_val,
273 args.offset_val,
274 probs_ptr,
275 output_ptr,
276 valid_ptr,
277 stream_ptr,
278 )
279 },
280 SamplerKind::TopP { top_p } => unsafe {
281 baracuda_kernels_sys::baracuda_kernels_flashinfer_top_p_sampling_f32_run(
282 self.desc.batch_size,
283 self.desc.vocab_size,
284 top_p,
285 det_flag,
286 args.seed_val,
287 args.offset_val,
288 probs_ptr,
289 output_ptr,
290 valid_ptr,
291 stream_ptr,
292 )
293 },
294 SamplerKind::MinP { min_p } => unsafe {
295 baracuda_kernels_sys::baracuda_kernels_flashinfer_min_p_sampling_f32_run(
296 self.desc.batch_size,
297 self.desc.vocab_size,
298 min_p,
299 det_flag,
300 args.seed_val,
301 args.offset_val,
302 probs_ptr,
303 output_ptr,
304 valid_ptr,
305 stream_ptr,
306 )
307 },
308 SamplerKind::TopKTopP { top_k, top_p } => unsafe {
309 baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_run(
310 self.desc.batch_size,
311 self.desc.vocab_size,
312 top_k,
313 top_p,
314 det_flag,
315 args.seed_val,
316 args.offset_val,
317 probs_ptr,
318 output_ptr,
319 valid_ptr,
320 stream_ptr,
321 )
322 },
323 };
324 map_status(status)
325 }
326 }
327}