Skip to main content

baracuda_kernels/random/
perrow_spec_sampling.rs

1//! Per-row sampling + speculative-decode verification — Phase 66 Tier 2.
2//!
3//! Companions to [`TopKTopPSamplingPlan`](super::TopKTopPSamplingPlan):
4//!
5//! - [`PerRowSamplingPlan`] — the same sort-free samplers, but the filter
6//!   threshold is a device array `[batch]` (one value per request), the
7//!   canonical serving case. Routes to FlashInfer's `*_arr` entry points.
8//! - [`SpeculativeSamplingPlan`] — speculative-decode accept/reject
9//!   verification (`ChainSpeculativeSampling`).
10//!
11//! Both require the `flashinfer` cargo feature for `run`.
12
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17    ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
18    PrecisionGuarantee, RandomKind, TensorMut, TensorRef, Workspace,
19};
20
21
22/// Which per-row sampler to run (thresholds supplied as device arrays in
23/// the args, not the descriptor).
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25#[non_exhaustive]
26pub enum PerRowSampler {
27    /// Per-row top-K (`top_k_arr` i32 required).
28    TopK,
29    /// Per-row top-P (`top_p_arr` required).
30    TopP,
31    /// Per-row min-P (`min_p_arr` required).
32    MinP,
33    /// Per-row top-K + top-P (`top_k_arr` i32 + `top_p_arr` f32 required).
34    TopKTopP,
35}
36
37/// Descriptor for a per-row sort-free sampling op.
38#[derive(Copy, Clone, Debug)]
39pub struct PerRowSamplingDescriptor {
40    /// Batch size (rows of `probs`).
41    pub batch_size: i32,
42    /// Vocabulary size (columns of `probs`).
43    pub vocab_size: i32,
44    /// Sampler family.
45    pub sampler: PerRowSampler,
46    /// Sort-based tiebreak on ambiguous cells.
47    pub deterministic: bool,
48}
49
50/// Args for a per-row sampling launch. Supply the array(s) the chosen
51/// [`PerRowSampler`] needs; leave the others `None`.
52pub struct PerRowSamplingArgs<'a> {
53    /// Row-normalized probabilities `[batch, vocab]` f32.
54    pub probs: TensorRef<'a, f32, 2>,
55    /// Per-row top-K cells `[batch]` i32.
56    pub top_k_arr: Option<TensorRef<'a, i32, 1>>,
57    /// Per-row top-P cutoff `[batch]` f32.
58    pub top_p_arr: Option<TensorRef<'a, f32, 1>>,
59    /// Per-row min-P multiplier `[batch]` f32.
60    pub min_p_arr: Option<TensorRef<'a, f32, 1>>,
61    /// Sampled indices `[batch]` i32 (written).
62    pub output: TensorMut<'a, i32, 1>,
63    /// Optional per-row "sample accepted" flags `[batch]` u8.
64    pub valid: Option<TensorMut<'a, u8, 1>>,
65    /// RNG seed.
66    pub seed_val: u64,
67    /// RNG philox offset.
68    pub offset_val: u64,
69}
70
71/// Per-row sort-free sampling plan.
72pub struct PerRowSamplingPlan {
73    desc: PerRowSamplingDescriptor,
74    sku: KernelSku,
75}
76
77impl PerRowSamplingPlan {
78    /// Validate the descriptor.
79    pub fn select(
80        _stream: &Stream,
81        desc: &PerRowSamplingDescriptor,
82        _pref: PlanPreference,
83    ) -> Result<Self> {
84        if desc.batch_size <= 0 || desc.vocab_size <= 0 {
85            return Err(Error::InvalidProblem(
86                "PerRowSamplingPlan: batch_size / vocab_size must be positive",
87            ));
88        }
89        let precision_guarantee = PrecisionGuarantee {
90            math_precision: MathPrecision::F32,
91            accumulator: ElementKind::F32,
92            bit_stable_on_same_hardware: true,
93            deterministic: desc.deterministic,
94        };
95        let sku = KernelSku {
96            category: OpCategory::Random,
97            op: RandomKind::Multinomial as u16,
98            element: ElementKind::F32,
99            aux_element: Some(ElementKind::I32),
100            layout: None,
101            epilogue: None,
102            arch: ArchSku::Sm80,
103            backend: BackendKind::FlashInfer,
104            precision_guarantee,
105        };
106        Ok(Self { desc: *desc, sku })
107    }
108
109    /// Validate args (shapes + presence of the required array).
110    pub fn can_implement(&self, args: &PerRowSamplingArgs<'_>) -> Result<()> {
111        let b = self.desc.batch_size;
112        if args.probs.shape != [b, self.desc.vocab_size] {
113            return Err(Error::InvalidProblem(
114                "PerRowSamplingPlan: probs shape must be [batch, vocab]",
115            ));
116        }
117        if args.output.shape != [b] {
118            return Err(Error::InvalidProblem("PerRowSamplingPlan: output shape must be [batch]"));
119        }
120        let need_k = matches!(self.desc.sampler, PerRowSampler::TopK | PerRowSampler::TopKTopP);
121        let need_p = matches!(self.desc.sampler, PerRowSampler::TopP | PerRowSampler::TopKTopP);
122        let need_minp = matches!(self.desc.sampler, PerRowSampler::MinP);
123        if need_k && args.top_k_arr.is_none() {
124            return Err(Error::InvalidProblem("PerRowSamplingPlan: top_k_arr required"));
125        }
126        if need_p && args.top_p_arr.is_none() {
127            return Err(Error::InvalidProblem("PerRowSamplingPlan: top_p_arr required"));
128        }
129        if need_minp && args.min_p_arr.is_none() {
130            return Err(Error::InvalidProblem("PerRowSamplingPlan: min_p_arr required"));
131        }
132        if let Some(t) = &args.top_k_arr {
133            if t.shape != [b] {
134                return Err(Error::InvalidProblem("PerRowSamplingPlan: top_k_arr must be [batch]"));
135            }
136        }
137        if let Some(t) = &args.top_p_arr {
138            if t.shape != [b] {
139                return Err(Error::InvalidProblem("PerRowSamplingPlan: top_p_arr must be [batch]"));
140            }
141        }
142        if let Some(t) = &args.min_p_arr {
143            if t.shape != [b] {
144                return Err(Error::InvalidProblem("PerRowSamplingPlan: min_p_arr must be [batch]"));
145            }
146        }
147        if let Some(v) = &args.valid {
148            if v.shape != [b] {
149                return Err(Error::InvalidProblem("PerRowSamplingPlan: valid must be [batch]"));
150            }
151        }
152        if !args.probs.is_contiguous() || !args.output.is_contiguous() {
153            return Err(Error::Unsupported(
154                "PerRowSamplingPlan: probs / output must be contiguous",
155            ));
156        }
157        Ok(())
158    }
159
160    /// Workspace bytes — always 0.
161    #[inline]
162    pub fn workspace_size(&self) -> usize {
163        0
164    }
165
166    /// SKU identity.
167    #[inline]
168    pub fn sku(&self) -> KernelSku {
169        self.sku
170    }
171
172    /// Numerical guarantees.
173    #[inline]
174    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
175        self.sku.precision_guarantee
176    }
177
178    /// Launch the selected per-row sampler.
179    pub fn run(
180        &self,
181        stream: &Stream,
182        _workspace: Workspace<'_>,
183        args: PerRowSamplingArgs<'_>,
184    ) -> Result<()> {
185        self.can_implement(&args)?;
186        #[cfg(not(feature = "flashinfer"))]
187        {
188            let _ = (stream, &args);
189            Err(Error::Unsupported(
190                "PerRowSamplingPlan: `flashinfer` cargo feature is not enabled",
191            ))
192        }
193        #[cfg(feature = "flashinfer")]
194        {
195            let stream_ptr = stream.as_raw() as *mut c_void;
196            let probs_ptr = args.probs.data.as_raw().0 as *const c_void;
197            let output_ptr = args.output.data.as_raw().0 as *mut c_void;
198            let valid_ptr = match &args.valid {
199                Some(v) => v.data.as_raw().0 as *mut c_void,
200                None => core::ptr::null_mut::<c_void>(),
201            };
202            let det = if self.desc.deterministic { 1 } else { 0 };
203            let k_ptr = args
204                .top_k_arr
205                .as_ref()
206                .map_or(core::ptr::null::<c_void>(), |t| t.data.as_raw().0 as *const c_void);
207            let p_ptr = args
208                .top_p_arr
209                .as_ref()
210                .map_or(core::ptr::null::<c_void>(), |t| t.data.as_raw().0 as *const c_void);
211            let mp_ptr = args
212                .min_p_arr
213                .as_ref()
214                .map_or(core::ptr::null::<c_void>(), |t| t.data.as_raw().0 as *const c_void);
215
216            let status = match self.desc.sampler {
217                PerRowSampler::TopK => unsafe {
218                    baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_sampling_f32_arr_run(
219                        self.desc.batch_size, self.desc.vocab_size, k_ptr, det,
220                        args.seed_val, args.offset_val, probs_ptr, output_ptr, valid_ptr, stream_ptr,
221                    )
222                },
223                PerRowSampler::TopP => unsafe {
224                    baracuda_kernels_sys::baracuda_kernels_flashinfer_top_p_sampling_f32_arr_run(
225                        self.desc.batch_size, self.desc.vocab_size, p_ptr, det,
226                        args.seed_val, args.offset_val, probs_ptr, output_ptr, valid_ptr, stream_ptr,
227                    )
228                },
229                PerRowSampler::MinP => unsafe {
230                    baracuda_kernels_sys::baracuda_kernels_flashinfer_min_p_sampling_f32_arr_run(
231                        self.desc.batch_size, self.desc.vocab_size, mp_ptr, det,
232                        args.seed_val, args.offset_val, probs_ptr, output_ptr, valid_ptr, stream_ptr,
233                    )
234                },
235                PerRowSampler::TopKTopP => unsafe {
236                    baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_arr_run(
237                        self.desc.batch_size, self.desc.vocab_size, k_ptr, p_ptr, det,
238                        args.seed_val, args.offset_val, probs_ptr, output_ptr, valid_ptr, stream_ptr,
239                    )
240                },
241            };
242            map_status(status)
243        }
244    }
245}
246
247// =========================================================================
248// Speculative-decode verification.
249// =========================================================================
250
251/// Descriptor for a speculative-decode verification op.
252#[derive(Copy, Clone, Debug)]
253pub struct SpeculativeSamplingDescriptor {
254    /// Number of requests.
255    pub batch_size: i32,
256    /// Draft tokens proposed per request.
257    pub num_speculative_tokens: i32,
258    /// Vocabulary size.
259    pub vocab_size: i32,
260    /// Sort-based tiebreak on ambiguous cells.
261    pub deterministic: bool,
262}
263
264/// Args bundle for speculative verification.
265pub struct SpeculativeSamplingArgs<'a> {
266    /// Draft probabilities `[batch, num_spec, vocab]` f32.
267    pub draft_probs: TensorRef<'a, f32, 3>,
268    /// Draft sampled token ids `[batch, num_spec]` i32.
269    pub draft_token_ids: TensorRef<'a, i32, 2>,
270    /// Target probabilities `[batch, num_spec + 1, vocab]` f32.
271    pub target_probs: TensorRef<'a, f32, 3>,
272    /// Accepted/corrected token ids `[batch, num_spec + 1]` i32 (written).
273    pub output_token_ids: TensorMut<'a, i32, 2>,
274    /// Per-request accepted-token count `[batch]` i32 (written).
275    pub output_accepted_token_num: TensorMut<'a, i32, 1>,
276    /// Per-request emitted-draft count `[batch]` i32 (written).
277    pub output_emitted_draft_token_num: TensorMut<'a, i32, 1>,
278    /// RNG seed.
279    pub seed_val: u64,
280    /// RNG philox offset.
281    pub offset_val: u64,
282}
283
284/// Speculative-decode verification plan (FlashInfer `ChainSpeculativeSampling`).
285pub struct SpeculativeSamplingPlan {
286    desc: SpeculativeSamplingDescriptor,
287    sku: KernelSku,
288}
289
290impl SpeculativeSamplingPlan {
291    /// Validate the descriptor.
292    pub fn select(
293        _stream: &Stream,
294        desc: &SpeculativeSamplingDescriptor,
295        _pref: PlanPreference,
296    ) -> Result<Self> {
297        if desc.batch_size <= 0 || desc.num_speculative_tokens <= 0 || desc.vocab_size <= 0 {
298            return Err(Error::InvalidProblem(
299                "SpeculativeSamplingPlan: extents must be positive",
300            ));
301        }
302        let precision_guarantee = PrecisionGuarantee {
303            math_precision: MathPrecision::F32,
304            accumulator: ElementKind::F32,
305            bit_stable_on_same_hardware: true,
306            deterministic: desc.deterministic,
307        };
308        let sku = KernelSku {
309            category: OpCategory::Random,
310            op: RandomKind::Multinomial as u16,
311            element: ElementKind::F32,
312            aux_element: Some(ElementKind::I32),
313            layout: None,
314            epilogue: None,
315            arch: ArchSku::Sm80,
316            backend: BackendKind::FlashInfer,
317            precision_guarantee,
318        };
319        Ok(Self { desc: *desc, sku })
320    }
321
322    /// Validate args against the descriptor.
323    pub fn can_implement(&self, args: &SpeculativeSamplingArgs<'_>) -> Result<()> {
324        let d = &self.desc;
325        if args.draft_probs.shape != [d.batch_size, d.num_speculative_tokens, d.vocab_size] {
326            return Err(Error::InvalidProblem("SpeculativeSamplingPlan: draft_probs shape"));
327        }
328        if args.draft_token_ids.shape != [d.batch_size, d.num_speculative_tokens] {
329            return Err(Error::InvalidProblem("SpeculativeSamplingPlan: draft_token_ids shape"));
330        }
331        if args.target_probs.shape != [d.batch_size, d.num_speculative_tokens + 1, d.vocab_size] {
332            return Err(Error::InvalidProblem("SpeculativeSamplingPlan: target_probs shape"));
333        }
334        if args.output_token_ids.shape != [d.batch_size, d.num_speculative_tokens + 1] {
335            return Err(Error::InvalidProblem("SpeculativeSamplingPlan: output_token_ids shape"));
336        }
337        if args.output_accepted_token_num.shape != [d.batch_size]
338            || args.output_emitted_draft_token_num.shape != [d.batch_size]
339        {
340            return Err(Error::InvalidProblem(
341                "SpeculativeSamplingPlan: output count arrays must be [batch]",
342            ));
343        }
344        if !args.draft_probs.is_contiguous()
345            || !args.draft_token_ids.is_contiguous()
346            || !args.target_probs.is_contiguous()
347            || !args.output_token_ids.is_contiguous()
348        {
349            return Err(Error::Unsupported("SpeculativeSamplingPlan: tensors must be contiguous"));
350        }
351        Ok(())
352    }
353
354    /// Workspace bytes — always 0.
355    #[inline]
356    pub fn workspace_size(&self) -> usize {
357        0
358    }
359
360    /// SKU identity.
361    #[inline]
362    pub fn sku(&self) -> KernelSku {
363        self.sku
364    }
365
366    /// Numerical guarantees.
367    #[inline]
368    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
369        self.sku.precision_guarantee
370    }
371
372    /// Run the accept/reject verification.
373    pub fn run(
374        &self,
375        stream: &Stream,
376        _workspace: Workspace<'_>,
377        args: SpeculativeSamplingArgs<'_>,
378    ) -> Result<()> {
379        self.can_implement(&args)?;
380        #[cfg(not(feature = "flashinfer"))]
381        {
382            let _ = (stream, &args);
383            Err(Error::Unsupported(
384                "SpeculativeSamplingPlan: `flashinfer` cargo feature is not enabled",
385            ))
386        }
387        #[cfg(feature = "flashinfer")]
388        {
389            let stream_ptr = stream.as_raw() as *mut c_void;
390            let status = unsafe {
391                baracuda_kernels_sys::baracuda_kernels_flashinfer_chain_speculative_sampling_f32_run(
392                    self.desc.batch_size,
393                    self.desc.num_speculative_tokens,
394                    self.desc.vocab_size,
395                    if self.desc.deterministic { 1 } else { 0 },
396                    args.seed_val,
397                    args.offset_val,
398                    args.draft_probs.data.as_raw().0 as *const c_void,
399                    args.draft_token_ids.data.as_raw().0 as *const c_void,
400                    args.target_probs.data.as_raw().0 as *const c_void,
401                    args.output_token_ids.data.as_raw().0 as *mut c_void,
402                    args.output_accepted_token_num.data.as_raw().0 as *mut c_void,
403                    args.output_emitted_draft_token_num.data.as_raw().0 as *mut c_void,
404                    stream_ptr,
405                )
406            };
407            map_status(status)
408        }
409    }
410}