use core::cell::Cell;
use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_sys::{
cudnnCreate, cudnnHandle_t, cudnnPoolingDescriptor_t, cudnnSetStream,
cudnnTensorDescriptor_t,
};
use baracuda_kernels_types::{
Element, KernelSku, PlanPreference, PoolKind, PrecisionGuarantee, Workspace,
};
use super::max_pool2d::{
build_sku, check_bw_args, check_fw_args, drop_pool_descriptors, ensure_pool_descriptors,
run_bw_inner, run_fw_inner, validate_descriptor, Pool2dBwArgs, Pool2dDescriptor, Pool2dFwArgs,
PoolMode,
};
pub struct AvgPool2dPlan<T: Element> {
desc: Pool2dDescriptor,
sku: KernelSku,
handle: Cell<cudnnHandle_t>,
x_desc: Cell<cudnnTensorDescriptor_t>,
y_desc: Cell<cudnnTensorDescriptor_t>,
pool_desc: Cell<cudnnPoolingDescriptor_t>,
_marker: PhantomData<T>,
}
impl<T: Element> AvgPool2dPlan<T> {
pub fn select(
_stream: &Stream,
desc: &Pool2dDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
validate_descriptor::<T>(desc)?;
let op = match desc.mode {
PoolMode::AvgIncludePad => PoolKind::AvgPool2dIncludePad,
PoolMode::AvgExcludePad => PoolKind::AvgPool2dExcludePad,
PoolMode::Max => {
return Err(Error::Unsupported(
"baracuda-kernels::AvgPool2dPlan: descriptor.mode must be one of \
PoolMode::AvgIncludePad | AvgExcludePad — use MaxPool2dPlan for max",
));
}
};
let sku = build_sku::<T>(op);
Ok(Self {
desc: *desc,
sku,
handle: Cell::new(core::ptr::null_mut()),
x_desc: Cell::new(core::ptr::null_mut()),
y_desc: Cell::new(core::ptr::null_mut()),
pool_desc: Cell::new(core::ptr::null_mut()),
_marker: PhantomData,
})
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn output_dims(&self) -> (i32, i32) {
super::max_pool2d::compute_output_dims(&self.desc)
}
pub fn run_fw(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: Pool2dFwArgs<'_, T>,
) -> Result<()> {
check_fw_args(&self.desc, &args)?;
let h = self.ensure_handle()?;
self.bind_stream(h, stream)?;
self.ensure_descriptors()?;
run_fw_inner::<T>(
h,
self.pool_desc.get(),
self.x_desc.get(),
self.y_desc.get(),
args.x.data.as_raw().0,
args.y.data.as_raw().0,
)
}
pub fn run_bw(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: Pool2dBwArgs<'_, T>,
) -> Result<()> {
check_bw_args(&self.desc, &args)?;
let h = self.ensure_handle()?;
self.bind_stream(h, stream)?;
self.ensure_descriptors()?;
run_bw_inner::<T>(
h,
self.pool_desc.get(),
self.x_desc.get(),
self.y_desc.get(),
args.y.data.as_raw().0,
args.dy.data.as_raw().0,
args.x.data.as_raw().0,
args.dx.data.as_raw().0,
)
}
fn ensure_handle(&self) -> Result<cudnnHandle_t> {
let h = self.handle.get();
if !h.is_null() {
return Ok(h);
}
let mut handle: cudnnHandle_t = core::ptr::null_mut();
let status = unsafe { cudnnCreate(&mut handle as *mut _) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
self.handle.set(handle);
Ok(handle)
}
fn bind_stream(&self, h: cudnnHandle_t, stream: &Stream) -> Result<()> {
let status = unsafe { cudnnSetStream(h, stream.as_raw() as *mut c_void) };
if status != 0 {
return Err(Error::CutlassInternal(-status));
}
Ok(())
}
fn ensure_descriptors(&self) -> Result<()> {
ensure_pool_descriptors::<T>(
&self.desc,
&self.x_desc,
&self.y_desc,
&self.pool_desc,
)
}
}
impl<T: Element> Drop for AvgPool2dPlan<T> {
fn drop(&mut self) {
drop_pool_descriptors(&self.x_desc, &self.y_desc, &self.pool_desc, &self.handle);
}
}