1use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17 ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
18 PrecisionGuarantee, RandomKind, TensorMut, TensorRef, Workspace,
19};
20
21
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25#[non_exhaustive]
26pub enum PerRowSampler {
27 TopK,
29 TopP,
31 MinP,
33 TopKTopP,
35}
36
37#[derive(Copy, Clone, Debug)]
39pub struct PerRowSamplingDescriptor {
40 pub batch_size: i32,
42 pub vocab_size: i32,
44 pub sampler: PerRowSampler,
46 pub deterministic: bool,
48}
49
50pub struct PerRowSamplingArgs<'a> {
53 pub probs: TensorRef<'a, f32, 2>,
55 pub top_k_arr: Option<TensorRef<'a, i32, 1>>,
57 pub top_p_arr: Option<TensorRef<'a, f32, 1>>,
59 pub min_p_arr: Option<TensorRef<'a, f32, 1>>,
61 pub output: TensorMut<'a, i32, 1>,
63 pub valid: Option<TensorMut<'a, u8, 1>>,
65 pub seed_val: u64,
67 pub offset_val: u64,
69}
70
71pub struct PerRowSamplingPlan {
73 desc: PerRowSamplingDescriptor,
74 sku: KernelSku,
75}
76
77impl PerRowSamplingPlan {
78 pub fn select(
80 _stream: &Stream,
81 desc: &PerRowSamplingDescriptor,
82 _pref: PlanPreference,
83 ) -> Result<Self> {
84 if desc.batch_size <= 0 || desc.vocab_size <= 0 {
85 return Err(Error::InvalidProblem(
86 "PerRowSamplingPlan: batch_size / vocab_size must be positive",
87 ));
88 }
89 let precision_guarantee = PrecisionGuarantee {
90 math_precision: MathPrecision::F32,
91 accumulator: ElementKind::F32,
92 bit_stable_on_same_hardware: true,
93 deterministic: desc.deterministic,
94 };
95 let sku = KernelSku {
96 category: OpCategory::Random,
97 op: RandomKind::Multinomial as u16,
98 element: ElementKind::F32,
99 aux_element: Some(ElementKind::I32),
100 layout: None,
101 epilogue: None,
102 arch: ArchSku::Sm80,
103 backend: BackendKind::FlashInfer,
104 precision_guarantee,
105 };
106 Ok(Self { desc: *desc, sku })
107 }
108
109 pub fn can_implement(&self, args: &PerRowSamplingArgs<'_>) -> Result<()> {
111 let b = self.desc.batch_size;
112 if args.probs.shape != [b, self.desc.vocab_size] {
113 return Err(Error::InvalidProblem(
114 "PerRowSamplingPlan: probs shape must be [batch, vocab]",
115 ));
116 }
117 if args.output.shape != [b] {
118 return Err(Error::InvalidProblem("PerRowSamplingPlan: output shape must be [batch]"));
119 }
120 let need_k = matches!(self.desc.sampler, PerRowSampler::TopK | PerRowSampler::TopKTopP);
121 let need_p = matches!(self.desc.sampler, PerRowSampler::TopP | PerRowSampler::TopKTopP);
122 let need_minp = matches!(self.desc.sampler, PerRowSampler::MinP);
123 if need_k && args.top_k_arr.is_none() {
124 return Err(Error::InvalidProblem("PerRowSamplingPlan: top_k_arr required"));
125 }
126 if need_p && args.top_p_arr.is_none() {
127 return Err(Error::InvalidProblem("PerRowSamplingPlan: top_p_arr required"));
128 }
129 if need_minp && args.min_p_arr.is_none() {
130 return Err(Error::InvalidProblem("PerRowSamplingPlan: min_p_arr required"));
131 }
132 if let Some(t) = &args.top_k_arr {
133 if t.shape != [b] {
134 return Err(Error::InvalidProblem("PerRowSamplingPlan: top_k_arr must be [batch]"));
135 }
136 }
137 if let Some(t) = &args.top_p_arr {
138 if t.shape != [b] {
139 return Err(Error::InvalidProblem("PerRowSamplingPlan: top_p_arr must be [batch]"));
140 }
141 }
142 if let Some(t) = &args.min_p_arr {
143 if t.shape != [b] {
144 return Err(Error::InvalidProblem("PerRowSamplingPlan: min_p_arr must be [batch]"));
145 }
146 }
147 if let Some(v) = &args.valid {
148 if v.shape != [b] {
149 return Err(Error::InvalidProblem("PerRowSamplingPlan: valid must be [batch]"));
150 }
151 }
152 if !args.probs.is_contiguous() || !args.output.is_contiguous() {
153 return Err(Error::Unsupported(
154 "PerRowSamplingPlan: probs / output must be contiguous",
155 ));
156 }
157 Ok(())
158 }
159
160 #[inline]
162 pub fn workspace_size(&self) -> usize {
163 0
164 }
165
166 #[inline]
168 pub fn sku(&self) -> KernelSku {
169 self.sku
170 }
171
172 #[inline]
174 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
175 self.sku.precision_guarantee
176 }
177
178 pub fn run(
180 &self,
181 stream: &Stream,
182 _workspace: Workspace<'_>,
183 args: PerRowSamplingArgs<'_>,
184 ) -> Result<()> {
185 self.can_implement(&args)?;
186 #[cfg(not(feature = "flashinfer"))]
187 {
188 let _ = (stream, &args);
189 Err(Error::Unsupported(
190 "PerRowSamplingPlan: `flashinfer` cargo feature is not enabled",
191 ))
192 }
193 #[cfg(feature = "flashinfer")]
194 {
195 let stream_ptr = stream.as_raw() as *mut c_void;
196 let probs_ptr = args.probs.data.as_raw().0 as *const c_void;
197 let output_ptr = args.output.data.as_raw().0 as *mut c_void;
198 let valid_ptr = match &args.valid {
199 Some(v) => v.data.as_raw().0 as *mut c_void,
200 None => core::ptr::null_mut::<c_void>(),
201 };
202 let det = if self.desc.deterministic { 1 } else { 0 };
203 let k_ptr = args
204 .top_k_arr
205 .as_ref()
206 .map_or(core::ptr::null::<c_void>(), |t| t.data.as_raw().0 as *const c_void);
207 let p_ptr = args
208 .top_p_arr
209 .as_ref()
210 .map_or(core::ptr::null::<c_void>(), |t| t.data.as_raw().0 as *const c_void);
211 let mp_ptr = args
212 .min_p_arr
213 .as_ref()
214 .map_or(core::ptr::null::<c_void>(), |t| t.data.as_raw().0 as *const c_void);
215
216 let status = match self.desc.sampler {
217 PerRowSampler::TopK => unsafe {
218 baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_sampling_f32_arr_run(
219 self.desc.batch_size, self.desc.vocab_size, k_ptr, det,
220 args.seed_val, args.offset_val, probs_ptr, output_ptr, valid_ptr, stream_ptr,
221 )
222 },
223 PerRowSampler::TopP => unsafe {
224 baracuda_kernels_sys::baracuda_kernels_flashinfer_top_p_sampling_f32_arr_run(
225 self.desc.batch_size, self.desc.vocab_size, p_ptr, det,
226 args.seed_val, args.offset_val, probs_ptr, output_ptr, valid_ptr, stream_ptr,
227 )
228 },
229 PerRowSampler::MinP => unsafe {
230 baracuda_kernels_sys::baracuda_kernels_flashinfer_min_p_sampling_f32_arr_run(
231 self.desc.batch_size, self.desc.vocab_size, mp_ptr, det,
232 args.seed_val, args.offset_val, probs_ptr, output_ptr, valid_ptr, stream_ptr,
233 )
234 },
235 PerRowSampler::TopKTopP => unsafe {
236 baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_arr_run(
237 self.desc.batch_size, self.desc.vocab_size, k_ptr, p_ptr, det,
238 args.seed_val, args.offset_val, probs_ptr, output_ptr, valid_ptr, stream_ptr,
239 )
240 },
241 };
242 map_status(status)
243 }
244 }
245}
246
247#[derive(Copy, Clone, Debug)]
253pub struct SpeculativeSamplingDescriptor {
254 pub batch_size: i32,
256 pub num_speculative_tokens: i32,
258 pub vocab_size: i32,
260 pub deterministic: bool,
262}
263
264pub struct SpeculativeSamplingArgs<'a> {
266 pub draft_probs: TensorRef<'a, f32, 3>,
268 pub draft_token_ids: TensorRef<'a, i32, 2>,
270 pub target_probs: TensorRef<'a, f32, 3>,
272 pub output_token_ids: TensorMut<'a, i32, 2>,
274 pub output_accepted_token_num: TensorMut<'a, i32, 1>,
276 pub output_emitted_draft_token_num: TensorMut<'a, i32, 1>,
278 pub seed_val: u64,
280 pub offset_val: u64,
282}
283
284pub struct SpeculativeSamplingPlan {
286 desc: SpeculativeSamplingDescriptor,
287 sku: KernelSku,
288}
289
290impl SpeculativeSamplingPlan {
291 pub fn select(
293 _stream: &Stream,
294 desc: &SpeculativeSamplingDescriptor,
295 _pref: PlanPreference,
296 ) -> Result<Self> {
297 if desc.batch_size <= 0 || desc.num_speculative_tokens <= 0 || desc.vocab_size <= 0 {
298 return Err(Error::InvalidProblem(
299 "SpeculativeSamplingPlan: extents must be positive",
300 ));
301 }
302 let precision_guarantee = PrecisionGuarantee {
303 math_precision: MathPrecision::F32,
304 accumulator: ElementKind::F32,
305 bit_stable_on_same_hardware: true,
306 deterministic: desc.deterministic,
307 };
308 let sku = KernelSku {
309 category: OpCategory::Random,
310 op: RandomKind::Multinomial as u16,
311 element: ElementKind::F32,
312 aux_element: Some(ElementKind::I32),
313 layout: None,
314 epilogue: None,
315 arch: ArchSku::Sm80,
316 backend: BackendKind::FlashInfer,
317 precision_guarantee,
318 };
319 Ok(Self { desc: *desc, sku })
320 }
321
322 pub fn can_implement(&self, args: &SpeculativeSamplingArgs<'_>) -> Result<()> {
324 let d = &self.desc;
325 if args.draft_probs.shape != [d.batch_size, d.num_speculative_tokens, d.vocab_size] {
326 return Err(Error::InvalidProblem("SpeculativeSamplingPlan: draft_probs shape"));
327 }
328 if args.draft_token_ids.shape != [d.batch_size, d.num_speculative_tokens] {
329 return Err(Error::InvalidProblem("SpeculativeSamplingPlan: draft_token_ids shape"));
330 }
331 if args.target_probs.shape != [d.batch_size, d.num_speculative_tokens + 1, d.vocab_size] {
332 return Err(Error::InvalidProblem("SpeculativeSamplingPlan: target_probs shape"));
333 }
334 if args.output_token_ids.shape != [d.batch_size, d.num_speculative_tokens + 1] {
335 return Err(Error::InvalidProblem("SpeculativeSamplingPlan: output_token_ids shape"));
336 }
337 if args.output_accepted_token_num.shape != [d.batch_size]
338 || args.output_emitted_draft_token_num.shape != [d.batch_size]
339 {
340 return Err(Error::InvalidProblem(
341 "SpeculativeSamplingPlan: output count arrays must be [batch]",
342 ));
343 }
344 if !args.draft_probs.is_contiguous()
345 || !args.draft_token_ids.is_contiguous()
346 || !args.target_probs.is_contiguous()
347 || !args.output_token_ids.is_contiguous()
348 {
349 return Err(Error::Unsupported("SpeculativeSamplingPlan: tensors must be contiguous"));
350 }
351 Ok(())
352 }
353
354 #[inline]
356 pub fn workspace_size(&self) -> usize {
357 0
358 }
359
360 #[inline]
362 pub fn sku(&self) -> KernelSku {
363 self.sku
364 }
365
366 #[inline]
368 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
369 self.sku.precision_guarantee
370 }
371
372 pub fn run(
374 &self,
375 stream: &Stream,
376 _workspace: Workspace<'_>,
377 args: SpeculativeSamplingArgs<'_>,
378 ) -> Result<()> {
379 self.can_implement(&args)?;
380 #[cfg(not(feature = "flashinfer"))]
381 {
382 let _ = (stream, &args);
383 Err(Error::Unsupported(
384 "SpeculativeSamplingPlan: `flashinfer` cargo feature is not enabled",
385 ))
386 }
387 #[cfg(feature = "flashinfer")]
388 {
389 let stream_ptr = stream.as_raw() as *mut c_void;
390 let status = unsafe {
391 baracuda_kernels_sys::baracuda_kernels_flashinfer_chain_speculative_sampling_f32_run(
392 self.desc.batch_size,
393 self.desc.num_speculative_tokens,
394 self.desc.vocab_size,
395 if self.desc.deterministic { 1 } else { 0 },
396 args.seed_val,
397 args.offset_val,
398 args.draft_probs.data.as_raw().0 as *const c_void,
399 args.draft_token_ids.data.as_raw().0 as *const c_void,
400 args.target_probs.data.as_raw().0 as *const c_void,
401 args.output_token_ids.data.as_raw().0 as *mut c_void,
402 args.output_accepted_token_num.data.as_raw().0 as *mut c_void,
403 args.output_emitted_draft_token_num.data.as_raw().0 as *mut c_void,
404 stream_ptr,
405 )
406 };
407 map_status(status)
408 }
409 }
410}