use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
cudnnCTCLoss, cudnnCTCLossDescriptor_t, cudnnCreate, cudnnCreateCTCLossDescriptor,
cudnnCreateTensorDescriptor, cudnnDestroy, cudnnDestroyCTCLossDescriptor,
cudnnDestroyTensorDescriptor, cudnnGetCTCLossWorkspaceSize, cudnnHandle_t,
cudnnSetCTCLossDescriptorEx, cudnnSetStream, cudnnSetTensorNdDescriptor,
cudnnTensorDescriptor_t, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC, CUDNN_DATA_DOUBLE, CUDNN_DATA_FLOAT,
CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_NOT_PROPAGATE_NAN,
};
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, LossKind, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
#[derive(Copy, Clone, Debug)]
pub struct CtcLossCudnnDescriptor {
pub batch: i32,
pub max_input_length: i32,
pub num_classes: i32,
pub blank_index: i32,
pub element: ElementKind,
pub deterministic: bool,
}
pub struct CtcLossCudnnArgs<'a, T: Element> {
pub log_probs: TensorRef<'a, T, 3>,
pub labels: &'a [i32],
pub label_lengths: &'a [i32],
pub input_lengths: &'a [i32],
pub costs: TensorMut<'a, T, 1>,
pub grads: TensorMut<'a, T, 3>,
}
pub struct CtcLossCudnnPlan<T: Element> {
desc: CtcLossCudnnDescriptor,
sku: KernelSku,
handle: Cell<cudnnHandle_t>,
probs_desc: Cell<cudnnTensorDescriptor_t>,
grads_desc: Cell<cudnnTensorDescriptor_t>,
ctc_desc: Cell<cudnnCTCLossDescriptor_t>,
workspace_bytes: Cell<usize>,
workspace_queried: Cell<bool>,
_marker: PhantomData<T>,
}
impl<T: Element> CtcLossCudnnPlan<T> {
pub fn select(
_stream: &Stream,
desc: &CtcLossCudnnDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::CtcLossCudnnPlan: descriptor.element != T::KIND",
));
}
if !matches!(T::KIND, ElementKind::F32 | ElementKind::F64) {
return Err(Error::Unsupported(
"baracuda-kernels::CtcLossCudnnPlan: cuDNN CTC supports f32 / f64 only \
(f16 / bf16 are bespoke-plan-only)",
));
}
if desc.batch < 0 || desc.max_input_length < 0 || desc.num_classes < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: dimensions must be non-negative",
));
}
if desc.num_classes > 0
&& (desc.blank_index < 0 || desc.blank_index >= desc.num_classes)
{
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: blank_index must be in [0, num_classes)",
));
}
let math_precision = match T::KIND {
ElementKind::F64 => MathPrecision::F64,
_ => MathPrecision::F32,
};
let accumulator = match T::KIND {
ElementKind::F64 => ElementKind::F64,
_ => ElementKind::F32,
};
let precision_guarantee = PrecisionGuarantee {
math_precision,
accumulator,
bit_stable_on_same_hardware: false,
deterministic: desc.deterministic,
};
let sku = KernelSku {
category: OpCategory::Loss,
op: LossKind::Ctc as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Cudnn,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
handle: Cell::new(core::ptr::null_mut()),
probs_desc: Cell::new(core::ptr::null_mut()),
grads_desc: Cell::new(core::ptr::null_mut()),
ctc_desc: Cell::new(core::ptr::null_mut()),
workspace_bytes: Cell::new(0),
workspace_queried: Cell::new(false),
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
self.workspace_bytes.get()
}
#[inline]
pub fn workspace_size_queried(&self) -> bool {
self.workspace_queried.get()
}
pub fn query_workspace_size(
&self,
stream: &Stream,
labels: &[i32],
label_lengths: &[i32],
input_lengths: &[i32],
) -> Result<usize> {
self.check_host_arrays(labels, label_lengths, input_lengths)?;
let h = self.ensure_handle()?;
self.bind_stream(h, stream)?;
self.ensure_descriptors()?;
let algo = if self.desc.deterministic {
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
} else {
CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC
};
let mut bytes: usize = 0;
let status = unsafe {
cudnnGetCTCLossWorkspaceSize(
h,
self.probs_desc.get(),
self.grads_desc.get(),
labels.as_ptr(),
label_lengths.as_ptr(),
input_lengths.as_ptr(),
algo,
self.ctc_desc.get(),
&mut bytes as *mut usize,
)
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
self.workspace_bytes.set(bytes);
self.workspace_queried.set(true);
Ok(bytes)
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: CtcLossCudnnArgs<'_, T>,
) -> Result<()> {
self.check_args(&args)?;
let h = self.ensure_handle()?;
self.bind_stream(h, stream)?;
self.ensure_descriptors()?;
let needed = if self.workspace_queried.get() {
self.workspace_bytes.get()
} else {
self.query_workspace_size(
stream,
args.labels,
args.label_lengths,
args.input_lengths,
)?
};
let (ws_ptr, _ws_bytes) = unpack_workspace(workspace, needed)?;
let algo = if self.desc.deterministic {
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC
} else {
CUDNN_CTC_LOSS_ALGO_NON_DETERMINISTIC
};
let status = unsafe {
cudnnCTCLoss(
h,
self.probs_desc.get(),
args.log_probs.data.as_raw().0 as *const c_void,
args.labels.as_ptr(),
args.label_lengths.as_ptr(),
args.input_lengths.as_ptr(),
args.costs.data.as_raw().0 as *mut c_void,
self.grads_desc.get(),
args.grads.data.as_raw().0 as *mut c_void,
algo,
self.ctc_desc.get(),
ws_ptr,
needed,
)
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
fn ensure_handle(&self) -> Result<cudnnHandle_t> {
let h = self.handle.get();
if !h.is_null() {
return Ok(h);
}
let mut last_status = 0;
for attempt in 0..5 {
let mut handle: cudnnHandle_t = core::ptr::null_mut();
let status = unsafe { cudnnCreate(&mut handle as *mut _) };
if status == 0 {
self.handle.set(handle);
return Ok(handle);
}
last_status = status;
std::thread::sleep(std::time::Duration::from_millis(
50 * (attempt as u64 + 1),
));
}
Err(Error::CutlassInternal(-last_status))
}
fn bind_stream(&self, h: cudnnHandle_t, stream: &Stream) -> Result<()> {
let status = unsafe { cudnnSetStream(h, stream.as_raw() as *mut c_void) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
fn ensure_descriptors(&self) -> Result<()> {
if !self.ctc_desc.get().is_null() {
return Ok(());
}
let dt = cudnn_dtype::<T>();
let comp_dt = if matches!(T::KIND, ElementKind::F64) {
CUDNN_DATA_DOUBLE
} else {
CUDNN_DATA_FLOAT
};
let dims: [i32; 3] = [
self.desc.max_input_length,
self.desc.batch,
self.desc.num_classes,
];
let strides: [i32; 3] = [
self.desc.batch * self.desc.num_classes,
self.desc.num_classes,
1,
];
let mut pd: cudnnTensorDescriptor_t = core::ptr::null_mut();
let status = unsafe { cudnnCreateTensorDescriptor(&mut pd as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let status = unsafe {
cudnnSetTensorNdDescriptor(pd, dt, 3, dims.as_ptr(), strides.as_ptr())
};
if status != 0 {
unsafe {
let _ = cudnnDestroyTensorDescriptor(pd);
}
return Err(Error::CutlassInternal(-status));
}
self.probs_desc.set(pd);
let mut gd: cudnnTensorDescriptor_t = core::ptr::null_mut();
let status = unsafe { cudnnCreateTensorDescriptor(&mut gd as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let status = unsafe {
cudnnSetTensorNdDescriptor(gd, dt, 3, dims.as_ptr(), strides.as_ptr())
};
if status != 0 {
unsafe {
let _ = cudnnDestroyTensorDescriptor(gd);
}
return Err(Error::CutlassInternal(-status));
}
self.grads_desc.set(gd);
let mut cd: cudnnCTCLossDescriptor_t = core::ptr::null_mut();
let status = unsafe { cudnnCreateCTCLossDescriptor(&mut cd as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let status = unsafe {
cudnnSetCTCLossDescriptorEx(
cd,
comp_dt,
CUDNN_LOSS_NORMALIZATION_SOFTMAX,
CUDNN_NOT_PROPAGATE_NAN,
)
};
if status != 0 {
unsafe {
let _ = cudnnDestroyCTCLossDescriptor(cd);
}
return Err(Error::CutlassInternal(-status));
}
self.ctc_desc.set(cd);
Ok(())
}
fn check_args(&self, args: &CtcLossCudnnArgs<'_, T>) -> Result<()> {
let probs_shape = [
self.desc.max_input_length,
self.desc.batch,
self.desc.num_classes,
];
if args.log_probs.shape != probs_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: log_probs shape != \
[max_input_length, batch, num_classes]",
));
}
if args.grads.shape != probs_shape {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: grads shape != log_probs shape",
));
}
if args.costs.shape != [self.desc.batch] {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: costs shape != [batch]",
));
}
self.check_host_arrays(args.labels, args.label_lengths, args.input_lengths)
}
fn check_host_arrays(
&self,
labels: &[i32],
label_lengths: &[i32],
input_lengths: &[i32],
) -> Result<()> {
let b = self.desc.batch as usize;
if label_lengths.len() != b {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: label_lengths.len() != batch",
));
}
if input_lengths.len() != b {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: input_lengths.len() != batch",
));
}
let total: i64 = label_lengths.iter().map(|&v| v as i64).sum();
if (labels.len() as i64) != total {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: labels.len() != Σ label_lengths",
));
}
for &v in input_lengths {
if v < 0 || v > self.desc.max_input_length {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: input_lengths[i] out of \
[0, max_input_length]",
));
}
}
for &v in label_lengths {
if v < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: label_lengths[i] < 0",
));
}
}
for &v in labels {
if v < 0 || v >= self.desc.num_classes {
return Err(Error::InvalidProblem(
"baracuda-kernels::CtcLossCudnnPlan: labels[i] out of [0, num_classes)",
));
}
}
Ok(())
}
}
impl<T: Element> Drop for CtcLossCudnnPlan<T> {
fn drop(&mut self) {
let cd = self.ctc_desc.get();
if !cd.is_null() {
unsafe {
let _ = cudnnDestroyCTCLossDescriptor(cd);
}
self.ctc_desc.set(core::ptr::null_mut());
}
let gd = self.grads_desc.get();
if !gd.is_null() {
unsafe {
let _ = cudnnDestroyTensorDescriptor(gd);
}
self.grads_desc.set(core::ptr::null_mut());
}
let pd = self.probs_desc.get();
if !pd.is_null() {
unsafe {
let _ = cudnnDestroyTensorDescriptor(pd);
}
self.probs_desc.set(core::ptr::null_mut());
}
let h = self.handle.get();
if !h.is_null() {
unsafe {
let _ = cudnnDestroy(h);
}
self.handle.set(core::ptr::null_mut());
}
}
}
#[inline]
fn cudnn_dtype<T: Element>() -> i32 {
match T::KIND {
ElementKind::F32 => CUDNN_DATA_FLOAT,
ElementKind::F64 => CUDNN_DATA_DOUBLE,
_ => unreachable!("CtcLossCudnnPlan::select gates on F32 / F64"),
}
}
fn unpack_workspace(workspace: Workspace<'_>, needed: usize) -> Result<(*mut c_void, usize)> {
match workspace {
Workspace::None => {
if needed == 0 {
Ok((core::ptr::null_mut(), 0))
} else {
Err(Error::WorkspaceTooSmall { needed, got: 0 })
}
}
Workspace::Borrowed(slice) => {
let got = slice.len();
if got < needed {
return Err(Error::WorkspaceTooSmall { needed, got });
}
Ok((slice.as_raw().0 as *mut c_void, got))
}
}
}