use crate::{
cudnn::{result::CudnnError, sys},
driver::{DevicePtr, DevicePtrMut},
};
use crate::cudnn::{result, Cudnn, CudnnDataType, TensorDescriptor};
use std::{marker::PhantomData, sync::Arc};
pub struct PoolingDescriptor<T> {
desc: sys::cudnnPoolingDescriptor_t,
#[allow(unused)]
handle: Arc<Cudnn>,
marker: PhantomData<T>,
}
impl Cudnn {
pub fn create_poolingnd<T: CudnnDataType>(
self: &Arc<Cudnn>,
window: &[std::ffi::c_int],
pads: &[std::ffi::c_int],
strides: &[std::ffi::c_int],
mode: sys::cudnnPoolingMode_t,
nan_propagation: sys::cudnnNanPropagation_t,
) -> Result<PoolingDescriptor<T>, CudnnError> {
let desc = result::create_pooling_descriptor()?;
let desc = PoolingDescriptor {
desc,
handle: self.clone(),
marker: PhantomData,
};
unsafe {
result::set_pooling_descriptor(
desc.desc,
mode,
nan_propagation,
window.len() as std::ffi::c_int,
window,
pads,
strides,
)
}?;
Ok(desc)
}
}
pub struct PoolingForward<'a, P, X, Y> {
pub pooling: &'a PoolingDescriptor<P>,
pub x: &'a TensorDescriptor<X>,
pub y: &'a TensorDescriptor<Y>,
}
impl<P, X, Y> PoolingForward<'_, P, X, Y>
where
P: CudnnDataType,
X: CudnnDataType,
Y: CudnnDataType,
{
pub unsafe fn launch<Src, Dst>(
&self,
(alpha, beta): (Y, Y),
src: &Src,
y: &mut Dst,
) -> Result<(), CudnnError>
where
Src: DevicePtr<X>,
Dst: DevicePtrMut<Y>,
{
let stream = &self.x.handle.stream;
let alpha = alpha.into_scaling_parameter();
let beta = beta.into_scaling_parameter();
let (src, _record_src) = src.device_ptr(stream);
let (y, _record_y) = y.device_ptr_mut(stream);
result::pooling_forward(
self.pooling.handle.handle,
self.pooling.desc,
(&alpha) as *const Y::Scalar as *const std::ffi::c_void,
self.x.desc,
src as *const X as *const std::ffi::c_void,
(&beta) as *const Y::Scalar as *const std::ffi::c_void,
self.y.desc,
y as *mut Y as *mut std::ffi::c_void,
)
}
}