Skip to main content

baracuda_kernels/random/
topk_topp_sampling.rs

1//! Sort-free top-K / top-P / min-P sampling — Phase 46 (FlashInfer
2//! cherry-pick).
3//!
4//! Faster decode-time alternative to baracuda's existing
5//! `topk + softmax + multinomial` pipeline. The op takes a row-
6//! normalized probability tensor (one row per request) and produces
7//! one sampled index per row.
8//!
9//! ## Variants
10//!
11//! Each variant maps to a dedicated FlashInfer launcher. Pick the
12//! sampler that matches your filter:
13//!
14//! - [`SamplerKind::TopK`] — keep only the top-`K` cells.
15//! - [`SamplerKind::TopP`] — keep the smallest set of largest cells
16//!   whose cumulative mass exceeds `top_p`.
17//! - [`SamplerKind::MinP`] — keep cells whose probability
18//!   `>= min_p * max_prob_in_row`.
19//! - [`SamplerKind::TopKTopP`] — combine top-K and top-P (the
20//!   canonical decode hot path for Llama / Mistral / Gemma).
21//!
22//! ## Determinism
23//!
24//! The sampler is internally rejection-based, drawing a fresh uniform
25//! `u ~ U(0, 1)` per row from a philox stream seeded with
26//! `(seed_val, offset_val)`. With `deterministic == true`, FlashInfer
27//! falls back to a sort-based tiebreaker on the rare ambiguous-cell
28//! case (cells where the cumulative-sum boundary lands exactly on a
29//! cell start).
30//!
31//! Calling the sampler twice with the same `(seed_val, offset_val)`
32//! and identical `probs` is bit-stable.
33//!
34//! ## Caller contract
35//!
36//! - `probs` : `[batch, vocab]` row-major f32. Each row must be
37//!   non-negative and sum to ~1 (you should typically chain after
38//!   softmax + exp).
39//! - `output` : `[batch]` i32, written.
40//! - `valid` : `[batch]` u8 bool, written if non-null. 1 means the
41//!   sample was accepted; 0 means rejection sampling timed out and
42//!   the caller should re-draw with a fresh seed.
43
44
45use baracuda_cutlass::{Error, Result};
46use baracuda_driver::Stream;
47use baracuda_kernels_types::{
48    ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
49    PrecisionGuarantee, RandomKind, TensorMut, TensorRef, Workspace,
50};
51
52
53/// Which sort-free sampler to run.
54///
55/// Note: only `PartialEq` is derived (not `Eq`) because `top_p` /
56/// `min_p` are `f32`, and `f32` doesn't satisfy `Eq` (NaN != NaN).
57#[derive(Copy, Clone, Debug, PartialEq)]
58#[non_exhaustive]
59pub enum SamplerKind {
60    /// Keep only the top-`K` cells (K = `top_k`).
61    TopK {
62        /// Cells kept per row. Must be in `[1, vocab_size]`.
63        top_k: i32,
64    },
65    /// Keep the smallest top-prob set whose cumulative mass > `top_p`.
66    TopP {
67        /// Cumulative-mass cutoff in `(0, 1]`.
68        top_p: f32,
69    },
70    /// Keep cells `prob >= min_p * row_max`.
71    MinP {
72        /// Multiplier of the per-row max prob in `(0, 1]`.
73        min_p: f32,
74    },
75    /// Combined top-K then top-P filter. Canonical decode sampler.
76    TopKTopP {
77        /// Cells kept per row. Must be in `[1, vocab_size]`.
78        top_k: i32,
79        /// Cumulative-mass cutoff in `(0, 1]`.
80        top_p: f32,
81    },
82}
83
84/// Descriptor for a sort-free sampling op.
85#[derive(Copy, Clone, Debug)]
86pub struct TopKTopPSamplingDescriptor {
87    /// Batch size (rows of `probs`).
88    pub batch_size: i32,
89    /// Vocabulary size (columns of `probs`).
90    pub vocab_size: i32,
91    /// Sampler family + filter parameters.
92    pub sampler: SamplerKind,
93    /// If true, FlashInfer falls back to a sort-based tiebreaker on
94    /// ambiguous cells. Documented in the module docstring.
95    pub deterministic: bool,
96}
97
98/// Args bundle for a sort-free sampling launch.
99pub struct TopKTopPSamplingArgs<'a> {
100    /// Row-normalized probabilities `[batch, vocab]` f32.
101    pub probs: TensorRef<'a, f32, 2>,
102    /// Sampled indices `[batch]` i32 (written).
103    pub output: TensorMut<'a, i32, 1>,
104    /// Optional per-row "sample accepted" flags `[batch]` u8 bool.
105    /// `None` to skip emitting them.
106    pub valid: Option<TensorMut<'a, u8, 1>>,
107    /// RNG seed (shared across the batch).
108    pub seed_val: u64,
109    /// RNG philox offset.
110    pub offset_val: u64,
111}
112
113/// Sort-free top-K / top-P / min-P sampling plan.
114///
115/// Routes to FlashInfer's `Top*FromProb` family. Requires the
116/// `flashinfer` cargo feature.
117pub struct TopKTopPSamplingPlan {
118    desc: TopKTopPSamplingDescriptor,
119    sku: KernelSku,
120}
121
122impl TopKTopPSamplingPlan {
123    /// Pick a sampler kernel + validate filter parameters against the
124    /// descriptor. Returns `Error::InvalidProblem` for out-of-range
125    /// `top_k` / `top_p` / `min_p` values.
126    pub fn select(
127        _stream: &Stream,
128        desc: &TopKTopPSamplingDescriptor,
129        _pref: PlanPreference,
130    ) -> Result<Self> {
131        if desc.batch_size <= 0 || desc.vocab_size <= 0 {
132            return Err(Error::InvalidProblem(
133                "TopKTopPSamplingPlan: batch_size / vocab_size must be positive",
134            ));
135        }
136        match desc.sampler {
137            SamplerKind::TopK { top_k } => {
138                if top_k <= 0 || top_k > desc.vocab_size {
139                    return Err(Error::InvalidProblem(
140                        "TopKTopPSamplingPlan: top_k must be in [1, vocab_size]",
141                    ));
142                }
143            }
144            SamplerKind::TopP { top_p } => {
145                if !(top_p > 0.0 && top_p <= 1.0) {
146                    return Err(Error::InvalidProblem(
147                        "TopKTopPSamplingPlan: top_p must be in (0, 1]",
148                    ));
149                }
150            }
151            SamplerKind::MinP { min_p } => {
152                if !(min_p > 0.0 && min_p <= 1.0) {
153                    return Err(Error::InvalidProblem(
154                        "TopKTopPSamplingPlan: min_p must be in (0, 1]",
155                    ));
156                }
157            }
158            SamplerKind::TopKTopP { top_k, top_p } => {
159                if top_k <= 0 || top_k > desc.vocab_size {
160                    return Err(Error::InvalidProblem(
161                        "TopKTopPSamplingPlan: top_k must be in [1, vocab_size]",
162                    ));
163                }
164                if !(top_p > 0.0 && top_p <= 1.0) {
165                    return Err(Error::InvalidProblem(
166                        "TopKTopPSamplingPlan: top_p must be in (0, 1]",
167                    ));
168                }
169            }
170        }
171        let precision_guarantee = PrecisionGuarantee {
172            math_precision: MathPrecision::F32,
173            accumulator: ElementKind::F32,
174            bit_stable_on_same_hardware: true,
175            // Deterministic across same-(seed, offset) repeat. Cell-
176            // selection determinism on tiebreaker cells is controlled
177            // by `desc.deterministic` (FlashInfer's sort-fallback).
178            deterministic: desc.deterministic,
179        };
180        let sku = KernelSku {
181            category: OpCategory::Random,
182            op: RandomKind::Multinomial as u16,
183            element: ElementKind::F32,
184            aux_element: Some(ElementKind::I32),
185            layout: None,
186            epilogue: None,
187            arch: ArchSku::Sm80,
188            backend: BackendKind::FlashInfer,
189            precision_guarantee,
190        };
191        Ok(Self { desc: *desc, sku })
192    }
193
194    /// Validate args against the descriptor (shape + contiguity check).
195    pub fn can_implement(&self, args: &TopKTopPSamplingArgs<'_>) -> Result<()> {
196        if args.probs.shape != [self.desc.batch_size, self.desc.vocab_size] {
197            return Err(Error::InvalidProblem(
198                "TopKTopPSamplingPlan: probs shape must be [batch_size, vocab_size]",
199            ));
200        }
201        if args.output.shape != [self.desc.batch_size] {
202            return Err(Error::InvalidProblem(
203                "TopKTopPSamplingPlan: output shape must be [batch_size]",
204            ));
205        }
206        if let Some(v) = &args.valid {
207            if v.shape != [self.desc.batch_size] {
208                return Err(Error::InvalidProblem(
209                    "TopKTopPSamplingPlan: valid shape must be [batch_size]",
210                ));
211            }
212        }
213        if !args.probs.is_contiguous() || !args.output.is_contiguous() {
214            return Err(Error::Unsupported(
215                "TopKTopPSamplingPlan: probs / output must be contiguous",
216            ));
217        }
218        Ok(())
219    }
220
221    /// Required workspace bytes (always 0 — sampling is workspace-free).
222    #[inline]
223    pub fn workspace_size(&self) -> usize {
224        0
225    }
226
227    /// SKU identity (telemetry / autotuner key).
228    #[inline]
229    pub fn sku(&self) -> KernelSku {
230        self.sku
231    }
232
233    /// Numerical guarantees of this plan.
234    #[inline]
235    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
236        self.sku.precision_guarantee
237    }
238
239    /// Launch the selected sampler on the supplied stream.
240    pub fn run(
241        &self,
242        stream: &Stream,
243        _workspace: Workspace<'_>,
244        args: TopKTopPSamplingArgs<'_>,
245    ) -> Result<()> {
246        self.can_implement(&args)?;
247        #[cfg(not(feature = "flashinfer"))]
248        {
249            let _ = (stream, &args);
250            Err(Error::Unsupported(
251                "TopKTopPSamplingPlan: `flashinfer` cargo feature is not enabled",
252            ))
253        }
254        #[cfg(feature = "flashinfer")]
255        {
256            let stream_ptr = stream.as_raw() as *mut c_void;
257            let probs_ptr = args.probs.data.as_raw().0 as *const c_void;
258            let output_ptr = args.output.data.as_raw().0 as *mut c_void;
259            let valid_ptr = match &args.valid {
260                Some(v) => v.data.as_raw().0 as *mut c_void,
261                None => core::ptr::null_mut::<c_void>(),
262            };
263            let det_flag = if self.desc.deterministic { 1 } else { 0 };
264
265            let status = match self.desc.sampler {
266                SamplerKind::TopK { top_k } => unsafe {
267                    baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_sampling_f32_run(
268                        self.desc.batch_size,
269                        self.desc.vocab_size,
270                        top_k,
271                        det_flag,
272                        args.seed_val,
273                        args.offset_val,
274                        probs_ptr,
275                        output_ptr,
276                        valid_ptr,
277                        stream_ptr,
278                    )
279                },
280                SamplerKind::TopP { top_p } => unsafe {
281                    baracuda_kernels_sys::baracuda_kernels_flashinfer_top_p_sampling_f32_run(
282                        self.desc.batch_size,
283                        self.desc.vocab_size,
284                        top_p,
285                        det_flag,
286                        args.seed_val,
287                        args.offset_val,
288                        probs_ptr,
289                        output_ptr,
290                        valid_ptr,
291                        stream_ptr,
292                    )
293                },
294                SamplerKind::MinP { min_p } => unsafe {
295                    baracuda_kernels_sys::baracuda_kernels_flashinfer_min_p_sampling_f32_run(
296                        self.desc.batch_size,
297                        self.desc.vocab_size,
298                        min_p,
299                        det_flag,
300                        args.seed_val,
301                        args.offset_val,
302                        probs_ptr,
303                        output_ptr,
304                        valid_ptr,
305                        stream_ptr,
306                    )
307                },
308                SamplerKind::TopKTopP { top_k, top_p } => unsafe {
309                    baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_run(
310                        self.desc.batch_size,
311                        self.desc.vocab_size,
312                        top_k,
313                        top_p,
314                        det_flag,
315                        args.seed_val,
316                        args.offset_val,
317                        probs_ptr,
318                        output_ptr,
319                        valid_ptr,
320                        stream_ptr,
321                    )
322                },
323            };
324            map_status(status)
325        }
326    }
327}