use core::cell::Cell;
use core::ffi::c_void;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
cudnnCreate, cudnnCreatePoolingDescriptor, cudnnCreateTensorDescriptor, cudnnDestroy,
cudnnDestroyPoolingDescriptor, cudnnDestroyTensorDescriptor, cudnnHandle_t,
cudnnPoolingBackward, cudnnPoolingDescriptor_t, cudnnPoolingForward,
cudnnSetPoolingNdDescriptor, cudnnSetStream, cudnnSetTensorNdDescriptor,
cudnnTensorDescriptor_t, CUDNN_NOT_PROPAGATE_NAN,
};
use baracuda_kernels_types::{Element, ElementKind};
use super::max_pool2d::{cudnn_dtype, cudnn_pool_mode, is_double_compute, PoolMode};
#[inline]
pub(crate) fn out_dim(in_dim: i32, pad: i32, window: i32, stride: i32) -> i32 {
(in_dim + 2 * pad - window) / stride + 1
}
pub(crate) fn ensure_handle(handle: &Cell<cudnnHandle_t>) -> Result<cudnnHandle_t> {
let h = handle.get();
if !h.is_null() {
return Ok(h);
}
let mut new_h: cudnnHandle_t = core::ptr::null_mut();
let status = unsafe { cudnnCreate(&mut new_h as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
handle.set(new_h);
Ok(new_h)
}
pub(crate) fn bind_stream(h: cudnnHandle_t, stream: &Stream) -> Result<()> {
let status = unsafe { cudnnSetStream(h, stream.as_raw() as *mut c_void) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
fn create_tensor_nd<T: Element>(
desc_cell: &Cell<cudnnTensorDescriptor_t>,
dims: &[i32],
) -> Result<()> {
if !desc_cell.get().is_null() {
return Ok(());
}
let mut td: cudnnTensorDescriptor_t = core::ptr::null_mut();
let status = unsafe { cudnnCreateTensorDescriptor(&mut td as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let mut padded: [i32; 5] = [1; 5];
let nb_dims = if dims.len() < 4 { 4 } else { dims.len() };
for (i, &d) in dims.iter().enumerate() {
padded[i] = d;
}
let mut strides: [i32; 5] = [1; 5];
let mut acc: i64 = 1;
let mut i = nb_dims;
while i > 0 {
i -= 1;
strides[i] = acc as i32;
acc = acc.saturating_mul(padded[i] as i64);
}
let dt = cudnn_dtype::<T>();
let status = unsafe {
cudnnSetTensorNdDescriptor(
td,
dt,
nb_dims as i32,
padded.as_ptr(),
strides.as_ptr(),
)
};
if status != 0 {
unsafe {
let _ = cudnnDestroyTensorDescriptor(td);
}
return Err(Error::CutlassInternal(-status));
}
desc_cell.set(td);
Ok(())
}
fn create_pool_nd(
pool_desc: &Cell<cudnnPoolingDescriptor_t>,
mode: PoolMode,
window: &[i32],
padding: &[i32],
stride: &[i32],
) -> Result<()> {
if !pool_desc.get().is_null() {
return Ok(());
}
let mut pd: cudnnPoolingDescriptor_t = core::ptr::null_mut();
let status = unsafe { cudnnCreatePoolingDescriptor(&mut pd as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
let mut win: [i32; 3] = [1; 3];
let mut pad: [i32; 3] = [0; 3];
let mut str_: [i32; 3] = [1; 3];
for (i, (&w, (&p, &s))) in window
.iter()
.zip(padding.iter().zip(stride.iter()))
.enumerate()
{
win[i] = w;
pad[i] = p;
str_[i] = s;
}
let nb_dims = if window.len() < 2 { 2 } else { window.len() };
let status = unsafe {
cudnnSetPoolingNdDescriptor(
pd,
cudnn_pool_mode(mode),
CUDNN_NOT_PROPAGATE_NAN,
nb_dims as i32,
win.as_ptr(),
pad.as_ptr(),
str_.as_ptr(),
)
};
if status != 0 {
unsafe {
let _ = cudnnDestroyPoolingDescriptor(pd);
}
return Err(Error::CutlassInternal(-status));
}
pool_desc.set(pd);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn ensure_descriptors_nd<T: Element>(
x_dims: &[i32],
y_dims: &[i32],
window: &[i32],
padding: &[i32],
stride: &[i32],
mode: PoolMode,
x_desc: &Cell<cudnnTensorDescriptor_t>,
y_desc: &Cell<cudnnTensorDescriptor_t>,
pool_desc: &Cell<cudnnPoolingDescriptor_t>,
) -> Result<()> {
create_tensor_nd::<T>(x_desc, x_dims)?;
create_tensor_nd::<T>(y_desc, y_dims)?;
create_pool_nd(pool_desc, mode, window, padding, stride)
}
pub(crate) fn drop_descriptors_nd(
x_desc: &Cell<cudnnTensorDescriptor_t>,
y_desc: &Cell<cudnnTensorDescriptor_t>,
pool_desc: &Cell<cudnnPoolingDescriptor_t>,
handle: &Cell<cudnnHandle_t>,
) {
let pd = pool_desc.get();
if !pd.is_null() {
unsafe {
let _ = cudnnDestroyPoolingDescriptor(pd);
}
pool_desc.set(core::ptr::null_mut());
}
let yd = y_desc.get();
if !yd.is_null() {
unsafe {
let _ = cudnnDestroyTensorDescriptor(yd);
}
y_desc.set(core::ptr::null_mut());
}
let xd = x_desc.get();
if !xd.is_null() {
unsafe {
let _ = cudnnDestroyTensorDescriptor(xd);
}
x_desc.set(core::ptr::null_mut());
}
let h = handle.get();
if !h.is_null() {
unsafe {
let _ = cudnnDestroy(h);
}
handle.set(core::ptr::null_mut());
}
}
pub(crate) fn run_fw_nd<T: Element>(
h: cudnnHandle_t,
pool_desc: cudnnPoolingDescriptor_t,
x_desc: cudnnTensorDescriptor_t,
y_desc: cudnnTensorDescriptor_t,
x_ptr: u64,
y_ptr: u64,
) -> Result<()> {
let status = if is_double_compute::<T>() {
let alpha: f64 = 1.0;
let beta: f64 = 0.0;
unsafe {
cudnnPoolingForward(
h,
pool_desc,
&alpha as *const f64 as *const c_void,
x_desc,
x_ptr as *const c_void,
&beta as *const f64 as *const c_void,
y_desc,
y_ptr as *mut c_void,
)
}
} else {
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
unsafe {
cudnnPoolingForward(
h,
pool_desc,
&alpha as *const f32 as *const c_void,
x_desc,
x_ptr as *const c_void,
&beta as *const f32 as *const c_void,
y_desc,
y_ptr as *mut c_void,
)
}
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
pub(crate) fn run_bw_nd<T: Element>(
h: cudnnHandle_t,
pool_desc: cudnnPoolingDescriptor_t,
x_desc: cudnnTensorDescriptor_t,
y_desc: cudnnTensorDescriptor_t,
y_ptr: u64,
dy_ptr: u64,
x_ptr: u64,
dx_ptr: u64,
) -> Result<()> {
let status = if is_double_compute::<T>() {
let alpha: f64 = 1.0;
let beta: f64 = 0.0;
unsafe {
cudnnPoolingBackward(
h,
pool_desc,
&alpha as *const f64 as *const c_void,
y_desc,
y_ptr as *const c_void,
y_desc,
dy_ptr as *const c_void,
x_desc,
x_ptr as *const c_void,
&beta as *const f64 as *const c_void,
x_desc,
dx_ptr as *mut c_void,
)
}
} else {
let alpha: f32 = 1.0;
let beta: f32 = 0.0;
unsafe {
cudnnPoolingBackward(
h,
pool_desc,
&alpha as *const f32 as *const c_void,
y_desc,
y_ptr as *const c_void,
y_desc,
dy_ptr as *const c_void,
x_desc,
x_ptr as *const c_void,
&beta as *const f32 as *const c_void,
x_desc,
dx_ptr as *mut c_void,
)
}
};
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
pub(crate) fn validate_dtype<T: Element>() -> Result<()> {
if !matches!(
T::KIND,
ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
) {
return Err(Error::Unsupported(
"baracuda-kernels::PoolNdPlan: cuDNN pooling supports f32 / f64 / f16 / bf16",
));
}
Ok(())
}
#[inline]
#[allow(dead_code)]
pub(crate) fn adaptive_kernel_stride(in_dim: i32, out_dim: i32) -> (i32, i32) {
debug_assert!(in_dim > 0 && out_dim > 0);
let stride = in_dim / out_dim;
let kernel = (in_dim + out_dim - 1) / out_dim; (kernel, stride)
}