use core::ffi::c_void;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, ElementKind, KernelSku, MathPrecision, OpCategory, PlanPreference,
PrecisionGuarantee, RandomKind, TensorMut, TensorRef, Workspace,
};
use crate::attention::map_status;
#[derive(Copy, Clone, Debug)]
pub struct TokenPenaltyDescriptor {
pub batch_size: i32,
pub vocab_size: i32,
pub rep_penalty: f32,
pub freq_penalty: f32,
pub pres_penalty: f32,
}
pub struct TokenPenaltyArgs<'a> {
pub logits: TensorMut<'a, f32, 2>,
pub counts: TensorRef<'a, i32, 2>,
}
pub struct TokenPenaltyPlan {
desc: TokenPenaltyDescriptor,
sku: KernelSku,
}
impl TokenPenaltyPlan {
pub fn select(
_stream: &Stream,
desc: &TokenPenaltyDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.batch_size <= 0 || desc.vocab_size <= 0 {
return Err(Error::InvalidProblem(
"TokenPenaltyPlan: batch_size / vocab_size must be positive",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Random,
op: RandomKind::Multinomial as u16,
element: ElementKind::F32,
aux_element: Some(ElementKind::I32),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self { desc: *desc, sku })
}
pub fn can_implement(&self, args: &TokenPenaltyArgs<'_>) -> Result<()> {
let shape = [self.desc.batch_size, self.desc.vocab_size];
if args.logits.shape != shape || args.counts.shape != shape {
return Err(Error::InvalidProblem(
"TokenPenaltyPlan: logits / counts shape must be [batch, vocab]",
));
}
if !args.logits.is_contiguous() || !args.counts.is_contiguous() {
return Err(Error::Unsupported(
"TokenPenaltyPlan: logits / counts must be contiguous",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: TokenPenaltyArgs<'_>,
) -> Result<()> {
self.can_implement(&args)?;
let stream_ptr = stream.as_raw() as *mut c_void;
let logits_ptr = args.logits.data.as_raw().0 as *mut c_void;
let counts_ptr = args.counts.data.as_raw().0 as *const c_void;
let status = unsafe {
baracuda_kernels_sys::baracuda_kernels_apply_token_penalty_f32_run(
self.desc.batch_size,
self.desc.vocab_size,
self.desc.rep_penalty,
self.desc.freq_penalty,
self.desc.pres_penalty,
logits_ptr,
counts_ptr,
stream_ptr,
)
};
map_status(status)
}
}