use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
PrecisionGuarantee, RandomKind, TensorMut, TensorRef, Workspace,
};
#[derive(Copy, Clone, Debug, PartialEq)]
#[non_exhaustive]
pub enum SamplerKind {
TopK {
top_k: i32,
},
TopP {
top_p: f32,
},
MinP {
min_p: f32,
},
TopKTopP {
top_k: i32,
top_p: f32,
},
}
#[derive(Copy, Clone, Debug)]
pub struct TopKTopPSamplingDescriptor {
pub batch_size: i32,
pub vocab_size: i32,
pub sampler: SamplerKind,
pub deterministic: bool,
}
pub struct TopKTopPSamplingArgs<'a> {
pub probs: TensorRef<'a, f32, 2>,
pub output: TensorMut<'a, i32, 1>,
pub valid: Option<TensorMut<'a, u8, 1>>,
pub seed_val: u64,
pub offset_val: u64,
}
pub struct TopKTopPSamplingPlan {
desc: TopKTopPSamplingDescriptor,
sku: KernelSku,
}
impl TopKTopPSamplingPlan {
pub fn select(
_stream: &Stream,
desc: &TopKTopPSamplingDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.batch_size <= 0 || desc.vocab_size <= 0 {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: batch_size / vocab_size must be positive",
));
}
match desc.sampler {
SamplerKind::TopK { top_k } => {
if top_k <= 0 || top_k > desc.vocab_size {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: top_k must be in [1, vocab_size]",
));
}
}
SamplerKind::TopP { top_p } => {
if !(top_p > 0.0 && top_p <= 1.0) {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: top_p must be in (0, 1]",
));
}
}
SamplerKind::MinP { min_p } => {
if !(min_p > 0.0 && min_p <= 1.0) {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: min_p must be in (0, 1]",
));
}
}
SamplerKind::TopKTopP { top_k, top_p } => {
if top_k <= 0 || top_k > desc.vocab_size {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: top_k must be in [1, vocab_size]",
));
}
if !(top_p > 0.0 && top_p <= 1.0) {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: top_p must be in (0, 1]",
));
}
}
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: desc.deterministic,
};
let sku = KernelSku {
category: OpCategory::Random,
op: RandomKind::Multinomial as u16,
element: ElementKind::F32,
aux_element: Some(ElementKind::I32),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::FlashInfer,
precision_guarantee,
};
Ok(Self { desc: *desc, sku })
}
pub fn can_implement(&self, args: &TopKTopPSamplingArgs<'_>) -> Result<()> {
if args.probs.shape != [self.desc.batch_size, self.desc.vocab_size] {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: probs shape must be [batch_size, vocab_size]",
));
}
if args.output.shape != [self.desc.batch_size] {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: output shape must be [batch_size]",
));
}
if let Some(v) = &args.valid {
if v.shape != [self.desc.batch_size] {
return Err(Error::InvalidProblem(
"TopKTopPSamplingPlan: valid shape must be [batch_size]",
));
}
}
if !args.probs.is_contiguous() || !args.output.is_contiguous() {
return Err(Error::Unsupported(
"TopKTopPSamplingPlan: probs / output must be contiguous",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: TopKTopPSamplingArgs<'_>,
) -> Result<()> {
self.can_implement(&args)?;
#[cfg(not(feature = "flashinfer"))]
{
let _ = (stream, &args);
Err(Error::Unsupported(
"TopKTopPSamplingPlan: `flashinfer` cargo feature is not enabled",
))
}
#[cfg(feature = "flashinfer")]
{
let stream_ptr = stream.as_raw() as *mut c_void;
let probs_ptr = args.probs.data.as_raw().0 as *const c_void;
let output_ptr = args.output.data.as_raw().0 as *mut c_void;
let valid_ptr = match &args.valid {
Some(v) => v.data.as_raw().0 as *mut c_void,
None => core::ptr::null_mut::<c_void>(),
};
let det_flag = if self.desc.deterministic { 1 } else { 0 };
let status = match self.desc.sampler {
SamplerKind::TopK { top_k } => unsafe {
baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_sampling_f32_run(
self.desc.batch_size,
self.desc.vocab_size,
top_k,
det_flag,
args.seed_val,
args.offset_val,
probs_ptr,
output_ptr,
valid_ptr,
stream_ptr,
)
},
SamplerKind::TopP { top_p } => unsafe {
baracuda_kernels_sys::baracuda_kernels_flashinfer_top_p_sampling_f32_run(
self.desc.batch_size,
self.desc.vocab_size,
top_p,
det_flag,
args.seed_val,
args.offset_val,
probs_ptr,
output_ptr,
valid_ptr,
stream_ptr,
)
},
SamplerKind::MinP { min_p } => unsafe {
baracuda_kernels_sys::baracuda_kernels_flashinfer_min_p_sampling_f32_run(
self.desc.batch_size,
self.desc.vocab_size,
min_p,
det_flag,
args.seed_val,
args.offset_val,
probs_ptr,
output_ptr,
valid_ptr,
stream_ptr,
)
},
SamplerKind::TopKTopP { top_k, top_p } => unsafe {
baracuda_kernels_sys::baracuda_kernels_flashinfer_top_k_top_p_sampling_f32_run(
self.desc.batch_size,
self.desc.vocab_size,
top_k,
top_p,
det_flag,
args.seed_val,
args.offset_val,
probs_ptr,
output_ptr,
valid_ptr,
stream_ptr,
)
},
};
map_status(status)
}
}
}