use super::core::*;
use crate::{
cudnn::{result, result::CudnnError, sys},
driver::{DevicePtr, DevicePtrMut},
};
use crate::cudnn::safe::activation::ActivationDescriptor;
use std::{marker::PhantomData, sync::Arc};
#[derive(Debug)]
pub struct FilterDescriptor<T> {
pub(crate) desc: sys::cudnnFilterDescriptor_t,
#[allow(unused)]
pub(crate) handle: Arc<Cudnn>,
pub(crate) marker: PhantomData<T>,
}
impl Cudnn {
pub fn create_4d_filter<T: CudnnDataType>(
self: &Arc<Cudnn>,
format: sys::cudnnTensorFormat_t,
dims: [std::ffi::c_int; 4],
) -> Result<FilterDescriptor<T>, CudnnError> {
let desc = result::create_filter_descriptor()?;
let desc = FilterDescriptor {
desc,
handle: self.clone(),
marker: PhantomData,
};
unsafe { result::set_filter4d_descriptor(desc.desc, T::DATA_TYPE, format, dims) }?;
Ok(desc)
}
pub fn create_nd_filter<T: CudnnDataType>(
self: &Arc<Cudnn>,
format: sys::cudnnTensorFormat_t,
dims: &[std::ffi::c_int],
) -> Result<FilterDescriptor<T>, CudnnError> {
let desc = result::create_filter_descriptor()?;
let desc = FilterDescriptor {
desc,
handle: self.clone(),
marker: PhantomData,
};
unsafe {
result::set_filternd_descriptor(
desc.desc,
T::DATA_TYPE,
format,
dims.len() as std::ffi::c_int,
dims.as_ptr(),
)
}?;
Ok(desc)
}
pub fn create_3d_filter<T: CudnnDataType>(
self: &Arc<Cudnn>,
format: sys::cudnnTensorFormat_t,
dims: [std::ffi::c_int; 3],
) -> Result<FilterDescriptor<T>, CudnnError> {
self.create_nd_filter(format, &dims)
}
pub fn create_5d_filter<T: CudnnDataType>(
self: &Arc<Cudnn>,
format: sys::cudnnTensorFormat_t,
dims: [std::ffi::c_int; 5],
) -> Result<FilterDescriptor<T>, CudnnError> {
self.create_nd_filter(format, &dims)
}
}
impl<T> Drop for FilterDescriptor<T> {
fn drop(&mut self) {
let desc = std::mem::replace(&mut self.desc, std::ptr::null_mut());
if !desc.is_null() {
unsafe { result::destroy_filter_descriptor(desc) }.unwrap()
}
}
}
#[derive(Debug)]
pub struct ConvDescriptor<T> {
pub(crate) desc: sys::cudnnConvolutionDescriptor_t,
pub(crate) handle: Arc<Cudnn>,
pub(crate) marker: PhantomData<T>,
}
#[deprecated(note = "use ConvDescriptor instead. This will be removed in future versions")]
pub type Conv2dDescriptor<T> = ConvDescriptor<T>;
impl Cudnn {
pub fn create_conv2d<T: CudnnDataType>(
self: &Arc<Cudnn>,
pad: [std::ffi::c_int; 2],
stride: [std::ffi::c_int; 2],
dilation: [std::ffi::c_int; 2],
mode: sys::cudnnConvolutionMode_t,
) -> Result<ConvDescriptor<T>, CudnnError> {
let [pad_h, pad_w] = pad;
let [stride_h, stride_w] = stride;
let [dilation_h, dilation_w] = dilation;
let desc = result::create_convolution_descriptor()?;
let desc = ConvDescriptor {
desc,
handle: self.clone(),
marker: PhantomData,
};
unsafe {
result::set_convolution2d_descriptor(
desc.desc,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
mode,
T::DATA_TYPE,
)
}?;
Ok(desc)
}
pub fn create_convnd<T: CudnnDataType>(
self: &Arc<Cudnn>,
pads: &[std::ffi::c_int],
strides: &[std::ffi::c_int],
dilations: &[std::ffi::c_int],
mode: sys::cudnnConvolutionMode_t,
) -> Result<ConvDescriptor<T>, CudnnError> {
let desc = result::create_convolution_descriptor()?;
let desc = ConvDescriptor {
desc,
handle: self.clone(),
marker: PhantomData,
};
unsafe {
result::set_convolutionnd_descriptor(
desc.desc,
pads.len() as std::ffi::c_int,
pads.as_ptr(),
strides.as_ptr(),
dilations.as_ptr(),
mode,
T::DATA_TYPE,
)
}?;
Ok(desc)
}
}
impl<T> ConvDescriptor<T> {
pub fn set_math_type(&mut self, math_type: sys::cudnnMathType_t) -> Result<(), CudnnError> {
unsafe { result::set_convolution_math_type(self.desc, math_type) }
}
pub fn set_group_count(&mut self, group_count: i32) -> Result<(), CudnnError> {
unsafe { result::set_convolution_group_count(self.desc, group_count) }
}
}
impl<T> Drop for ConvDescriptor<T> {
fn drop(&mut self) {
let desc = std::mem::replace(&mut self.desc, std::ptr::null_mut());
if !desc.is_null() {
unsafe { result::destroy_convolution_descriptor(desc) }.unwrap()
}
}
}
#[derive(Debug)]
pub struct ConvForward<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> {
pub conv: &'a ConvDescriptor<C>,
pub x: &'a TensorDescriptor<X>,
pub w: &'a FilterDescriptor<X>,
pub y: &'a TensorDescriptor<Y>,
}
#[deprecated(note = "use ConvForward instead. This will be removed in future versions")]
pub type Conv2dForward<'a, X, C, Y> = ConvForward<'a, X, C, Y>;
impl<X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvForward<'_, X, C, Y> {
pub fn pick_algorithm(&self) -> Result<sys::cudnnConvolutionFwdAlgo_t, CudnnError> {
const NUM_ALGOS: usize = 8;
debug_assert_eq!(
sys::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_COUNT as u32,
NUM_ALGOS as u32
);
let mut returned_count = [0; 1];
let mut perf_results = [unsafe { std::mem::zeroed() }; NUM_ALGOS];
unsafe {
result::get_convolution_forward_algorithm(
self.conv.handle.handle,
self.x.desc,
self.w.desc,
self.conv.desc,
self.y.desc,
NUM_ALGOS as std::ffi::c_int,
returned_count.as_mut_ptr(),
perf_results.as_mut_ptr(),
)
}?;
assert!(returned_count[0] > 0);
perf_results[0].status.result()?;
Ok(perf_results[0].algo)
}
pub fn get_workspace_size(
&self,
algo: sys::cudnnConvolutionFwdAlgo_t,
) -> Result<usize, CudnnError> {
unsafe {
result::get_convolution_forward_workspace_size(
self.conv.handle.handle,
self.x.desc,
self.w.desc,
self.conv.desc,
self.y.desc,
algo,
)
}
}
pub unsafe fn launch<Workspace, Src, Filter, Dst>(
&self,
algo: sys::cudnnConvolutionFwdAlgo_t,
workspace: Option<&mut Workspace>,
(alpha, beta): (Y, Y),
src: &Src,
filter: &Filter,
y: &mut Dst,
) -> Result<(), CudnnError>
where
Workspace: DevicePtrMut<u8>,
Src: DevicePtr<X>,
Filter: DevicePtr<X>,
Dst: DevicePtrMut<Y>,
{
let stream = &self.x.handle.stream;
let alpha = alpha.into_scaling_parameter();
let beta = beta.into_scaling_parameter();
let (src, _record_src) = src.device_ptr(stream);
let (filter, _record_f) = filter.device_ptr(stream);
let (y, _record_y) = y.device_ptr_mut(stream);
let workspace_size_in_bytes = workspace.as_ref().map(|w| w.num_bytes()).unwrap_or(0);
let (w, _record_w) = workspace.map(|w| w.device_ptr_mut(stream)).unzip();
result::convolution_forward(
self.conv.handle.handle,
(&alpha) as *const Y::Scalar as *const std::ffi::c_void,
self.x.desc,
src as *const X as *const std::ffi::c_void,
self.w.desc,
filter as *const X as *const std::ffi::c_void,
self.conv.desc,
algo,
w.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()),
workspace_size_in_bytes,
(&beta) as *const Y::Scalar as *const std::ffi::c_void,
self.y.desc,
y as *mut Y as *mut std::ffi::c_void,
)
}
}
#[derive(Debug)]
pub struct ConvBackwardData<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> {
pub conv: &'a ConvDescriptor<C>,
pub dx: &'a TensorDescriptor<X>,
pub w: &'a FilterDescriptor<X>,
pub dy: &'a TensorDescriptor<Y>,
}
#[deprecated(note = "use ConvBackwardData instead. This will be removed in future versions")]
pub type Conv2dBackwardData<'a, X, C, Y> = ConvBackwardData<'a, X, C, Y>;
impl<X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardData<'_, X, C, Y> {
pub fn pick_algorithm(&self) -> Result<sys::cudnnConvolutionBwdDataAlgo_t, CudnnError> {
const NUM_ALGOS: usize = 6;
debug_assert_eq!(
sys::cudnnConvolutionBwdDataAlgo_t::CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT as u32,
NUM_ALGOS as u32
);
let mut returned_count = [0; 1];
let mut perf_results = [unsafe { std::mem::zeroed() }; NUM_ALGOS];
unsafe {
result::get_convolution_backward_data_algorithm(
self.conv.handle.handle,
self.w.desc,
self.dy.desc,
self.conv.desc,
self.dx.desc,
NUM_ALGOS as std::ffi::c_int,
returned_count.as_mut_ptr(),
perf_results.as_mut_ptr(),
)
}?;
assert!(returned_count[0] > 0);
perf_results[0].status.result()?;
Ok(perf_results[0].algo)
}
pub fn get_workspace_size(
&self,
algo: sys::cudnnConvolutionBwdDataAlgo_t,
) -> Result<usize, CudnnError> {
unsafe {
result::get_convolution_backward_data_workspace_size(
self.conv.handle.handle,
self.w.desc,
self.dy.desc,
self.conv.desc,
self.dx.desc,
algo,
)
}
}
pub unsafe fn launch<Workspace, Src, Filter, Dst>(
&self,
algo: sys::cudnnConvolutionBwdDataAlgo_t,
workspace: Option<&mut Workspace>,
(alpha, beta): (Y, Y),
dx: &mut Src,
filter: &Filter,
dy: &Dst,
) -> Result<(), CudnnError>
where
Workspace: DevicePtrMut<u8>,
Src: DevicePtrMut<X>,
Filter: DevicePtr<X>,
Dst: DevicePtr<Y>,
{
let stream = &self.dx.handle.stream;
let alpha = alpha.into_scaling_parameter();
let beta = beta.into_scaling_parameter();
let (dx, _record_dx) = dx.device_ptr_mut(stream);
let (filter, _record_f) = filter.device_ptr(stream);
let (dy, _record_dy) = dy.device_ptr(stream);
let workspace_size_in_bytes = workspace.as_ref().map(|w| w.num_bytes()).unwrap_or(0);
let (w, _record_w) = workspace.map(|w| w.device_ptr_mut(stream)).unzip();
result::convolution_backward_data(
self.conv.handle.handle,
(&alpha) as *const Y::Scalar as *const std::ffi::c_void,
self.w.desc,
filter as *const X as *const std::ffi::c_void,
self.dy.desc,
dy as *const Y as *const std::ffi::c_void,
self.conv.desc,
algo,
w.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()),
workspace_size_in_bytes,
(&beta) as *const Y::Scalar as *const std::ffi::c_void,
self.dx.desc,
dx as *mut X as *mut std::ffi::c_void,
)
}
}
#[derive(Debug)]
pub struct ConvBackwardFilter<'a, X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> {
pub conv: &'a ConvDescriptor<C>,
pub x: &'a TensorDescriptor<X>,
pub dw: &'a FilterDescriptor<X>,
pub dy: &'a TensorDescriptor<Y>,
}
#[deprecated(note = "use ConvBackwardFilter instead. This will be removed in future versions")]
pub type Conv2dBackwardFilter<'a, X, C, Y> = ConvBackwardFilter<'a, X, C, Y>;
impl<X: CudnnDataType, C: CudnnDataType, Y: CudnnDataType> ConvBackwardFilter<'_, X, C, Y> {
pub fn pick_algorithm(&self) -> Result<sys::cudnnConvolutionBwdFilterAlgo_t, CudnnError> {
const NUM_ALGOS: usize = 7;
debug_assert_eq!(
sys::cudnnConvolutionBwdFilterAlgo_t::CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT as u32,
NUM_ALGOS as u32
);
let mut returned_count = [0; 1];
let mut perf_results = [unsafe { std::mem::zeroed() }; NUM_ALGOS];
unsafe {
result::get_convolution_backward_filter_algorithm(
self.conv.handle.handle,
self.x.desc,
self.dy.desc,
self.conv.desc,
self.dw.desc,
NUM_ALGOS as std::ffi::c_int,
returned_count.as_mut_ptr(),
perf_results.as_mut_ptr(),
)
}?;
assert!(returned_count[0] > 0);
perf_results[0].status.result()?;
Ok(perf_results[0].algo)
}
pub fn get_workspace_size(
&self,
algo: sys::cudnnConvolutionBwdFilterAlgo_t,
) -> Result<usize, CudnnError> {
unsafe {
result::get_convolution_backward_filter_workspace_size(
self.conv.handle.handle,
self.x.desc,
self.dy.desc,
self.conv.desc,
self.dw.desc,
algo,
)
}
}
pub unsafe fn launch<Workspace, Src, Filter, Dst>(
&self,
algo: sys::cudnnConvolutionBwdFilterAlgo_t,
workspace: Option<&mut Workspace>,
(alpha, beta): (Y, Y),
x: &Src,
dfilter: &mut Filter,
dy: &Dst,
) -> Result<(), CudnnError>
where
Workspace: DevicePtrMut<u8>,
Src: DevicePtr<X>,
Filter: DevicePtrMut<X>,
Dst: DevicePtr<Y>,
{
let stream = &self.x.handle.stream;
let alpha = alpha.into_scaling_parameter();
let beta = beta.into_scaling_parameter();
let (x, _record_x) = x.device_ptr(stream);
let (dfilter, _record_f) = dfilter.device_ptr_mut(stream);
let (dy, _record_dy) = dy.device_ptr(stream);
let workspace_size_in_bytes = workspace.as_ref().map(|w| w.num_bytes()).unwrap_or(0);
let (w, _record_w) = workspace.map(|w| w.device_ptr_mut(stream)).unzip();
result::convolution_backward_filter(
self.conv.handle.handle,
(&alpha) as *const Y::Scalar as *const std::ffi::c_void,
self.x.desc,
x as *const _,
self.dy.desc,
dy as *const _,
self.conv.desc,
algo,
w.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()),
workspace_size_in_bytes,
(&beta) as *const Y::Scalar as *const std::ffi::c_void,
self.dw.desc,
dfilter as *mut _,
)
}
}
#[derive(Debug)]
pub struct ConvBiasActivationForward<
'a,
X: CudnnDataType,
C: CudnnDataType,
A: CudnnDataType,
Y: CudnnDataType,
> {
pub conv: &'a ConvDescriptor<C>,
pub act: &'a ActivationDescriptor<A>,
pub x: &'a TensorDescriptor<X>,
pub w: &'a FilterDescriptor<X>,
pub z: &'a TensorDescriptor<X>,
pub bias: &'a TensorDescriptor<X>,
pub y: &'a TensorDescriptor<Y>,
}
impl<X, C, A, Y> ConvBiasActivationForward<'_, X, C, A, Y>
where
X: CudnnDataType,
C: CudnnDataType,
A: CudnnDataType,
Y: CudnnDataType,
{
pub fn pick_algorithm(&self) -> Result<sys::cudnnConvolutionFwdAlgo_t, CudnnError> {
let conv = ConvForward {
conv: self.conv,
x: self.x,
w: self.w,
y: self.y,
};
conv.pick_algorithm()
}
pub fn get_workspace_size(
&self,
algo: sys::cudnnConvolutionFwdAlgo_t,
) -> Result<usize, CudnnError> {
let conv = ConvForward {
conv: self.conv,
x: self.x,
w: self.w,
y: self.y,
};
conv.get_workspace_size(algo)
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch<Workspace, Src, Filter, Dst>(
&self,
algo: sys::cudnnConvolutionFwdAlgo_t,
workspace: Option<&mut Workspace>,
(alpha1, alpha2): (Y, Y),
src: &Src,
filter: &Filter,
z: &Src,
bias: &Src,
y: &mut Dst,
) -> Result<(), CudnnError>
where
Workspace: DevicePtrMut<u8>,
Src: DevicePtr<X>,
Filter: DevicePtr<X>,
Dst: DevicePtrMut<Y>,
{
let alpha1 = alpha1.into_scaling_parameter();
let alpha2 = alpha2.into_scaling_parameter();
let stream = &self.x.handle.stream;
let (src, _record_src) = src.device_ptr(stream);
let (filter, _record_f) = filter.device_ptr(stream);
let (z, _record_z) = z.device_ptr(stream);
let (bias, _record_bias) = bias.device_ptr(stream);
let (y, _record_y) = y.device_ptr_mut(stream);
let workspace_size_in_bytes = workspace.as_ref().map(|w| w.num_bytes()).unwrap_or(0);
let (w, _record_w) = workspace.map(|w| w.device_ptr_mut(stream)).unzip();
result::convolution_bias_activation_forward(
self.conv.handle.handle,
(&alpha1) as *const Y::Scalar as *const std::ffi::c_void,
self.x.desc,
src as *const X as *const std::ffi::c_void,
self.w.desc,
filter as *const X as *const std::ffi::c_void,
self.conv.desc,
algo,
w.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut()),
workspace_size_in_bytes,
(&alpha2) as *const Y::Scalar as *const std::ffi::c_void,
self.z.desc,
z as *const X as *const std::ffi::c_void,
self.bias.desc,
bias as *const X as *const std::ffi::c_void,
self.act.desc,
self.y.desc,
y as *mut Y as *mut std::ffi::c_void,
)
}
}