use super::core::*;
use crate::{
cudnn::{result, result::CudnnError, sys},
driver::{DevicePtr, DevicePtrMut},
};
use std::{marker::PhantomData, sync::Arc};
#[derive(Debug, Default, Copy, Clone)]
pub struct FlatIndices;
#[derive(Debug, Default, Copy, Clone)]
pub struct NoIndices;
#[derive(Debug)]
pub struct ReductionDescriptor<T, Idx> {
pub(crate) desc: sys::cudnnReduceTensorDescriptor_t,
#[allow(unused)]
pub(crate) indices: Idx,
#[allow(unused)]
pub(crate) handle: Arc<Cudnn>,
pub(crate) marker: PhantomData<T>,
}
impl Cudnn {
pub fn create_reduction_flat_indices<T: CudnnDataType>(
self: &Arc<Cudnn>,
op: sys::cudnnReduceTensorOp_t,
nan_opt: sys::cudnnNanPropagation_t,
) -> Result<ReductionDescriptor<T, FlatIndices>, CudnnError> {
let desc = result::create_reduce_tensor_descriptor()?;
let desc = ReductionDescriptor {
desc,
indices: FlatIndices,
handle: self.clone(),
marker: PhantomData,
};
unsafe {
result::set_reduce_tensor_descriptor(
desc.desc,
op,
T::DATA_TYPE,
nan_opt,
sys::cudnnReduceTensorIndices_t::CUDNN_REDUCE_TENSOR_FLATTENED_INDICES,
sys::cudnnIndicesType_t::CUDNN_32BIT_INDICES,
)
}?;
Ok(desc)
}
pub fn create_reduction_no_indices<T: CudnnDataType>(
self: &Arc<Cudnn>,
op: sys::cudnnReduceTensorOp_t,
nan_opt: sys::cudnnNanPropagation_t,
) -> Result<ReductionDescriptor<T, NoIndices>, CudnnError> {
let desc = result::create_reduce_tensor_descriptor()?;
let desc = ReductionDescriptor {
desc,
indices: NoIndices,
handle: self.clone(),
marker: PhantomData,
};
unsafe {
result::set_reduce_tensor_descriptor(
desc.desc,
op,
T::DATA_TYPE,
nan_opt,
sys::cudnnReduceTensorIndices_t::CUDNN_REDUCE_TENSOR_NO_INDICES,
sys::cudnnIndicesType_t::CUDNN_32BIT_INDICES,
)
}?;
Ok(desc)
}
}
impl<T, Idx> Drop for ReductionDescriptor<T, Idx> {
fn drop(&mut self) {
let desc = std::mem::replace(&mut self.desc, std::ptr::null_mut());
if !desc.is_null() {
unsafe { result::destroy_reduce_tensor_descriptor(desc) }.unwrap()
}
}
}
pub struct ReduceTensor<'a, T: CudnnDataType, Idx> {
pub reduce: &'a ReductionDescriptor<T, Idx>,
pub a: &'a TensorDescriptor<T>,
pub c: &'a TensorDescriptor<T>,
}
impl<T: CudnnDataType> ReduceTensor<'_, T, FlatIndices> {
pub fn get_indices_size(&self) -> Result<usize, CudnnError> {
unsafe {
result::get_reduction_indices_size(
self.reduce.handle.handle,
self.reduce.desc,
self.a.desc,
self.c.desc,
)
}
}
}
impl<T: CudnnDataType, Idx> ReduceTensor<'_, T, Idx> {
pub fn get_workspace_size(&self) -> Result<usize, CudnnError> {
unsafe {
result::get_reduction_workspace_size(
self.reduce.handle.handle,
self.reduce.desc,
self.a.desc,
self.c.desc,
)
}
}
}
impl<T: CudnnDataType> ReduceTensor<'_, T, FlatIndices> {
pub unsafe fn launch<Indices, Workspace, A, C>(
&self,
indices: &mut Indices,
workspace: &mut Workspace,
(alpha, beta): (T, T),
a: &A,
c: &mut C,
) -> Result<(), CudnnError>
where
Indices: DevicePtrMut<u32>,
Workspace: DevicePtrMut<u8>,
A: DevicePtr<T>,
C: DevicePtrMut<T>,
{
let stream = &self.a.handle.stream;
let workspace_size_in_bytes = workspace.num_bytes();
let indices_size_in_bytes = indices.num_bytes();
let (indices, _record_i) = indices.device_ptr_mut(stream);
let (workspace, _record_w) = workspace.device_ptr_mut(stream);
let (a, _record_a) = a.device_ptr(stream);
let (c, _record_c) = c.device_ptr_mut(stream);
result::reduce_tensor(
self.reduce.handle.handle,
self.reduce.desc,
indices as *mut std::ffi::c_void,
indices_size_in_bytes,
workspace as *mut std::ffi::c_void,
workspace_size_in_bytes,
(&alpha) as *const T as *const std::ffi::c_void,
self.a.desc,
a as *const _,
(&beta) as *const T as *const std::ffi::c_void,
self.c.desc,
c as *mut _,
)
}
}
impl<T: CudnnDataType> ReduceTensor<'_, T, NoIndices> {
pub unsafe fn launch<Workspace, A, C>(
&self,
workspace: &mut Workspace,
(alpha, beta): (T, T),
a: &A,
c: &mut C,
) -> Result<(), CudnnError>
where
Workspace: DevicePtrMut<u8>,
A: DevicePtr<T>,
C: DevicePtrMut<T>,
{
let stream = &self.a.handle.stream;
let workspace_size_in_bytes = workspace.num_bytes();
let (workspace, _record_w) = workspace.device_ptr_mut(stream);
let (a, _record_a) = a.device_ptr(stream);
let (c, _record_c) = c.device_ptr_mut(stream);
result::reduce_tensor(
self.reduce.handle.handle,
self.reduce.desc,
std::ptr::null_mut(),
0,
workspace as *mut std::ffi::c_void,
workspace_size_in_bytes,
(&alpha) as *const T as *const std::ffi::c_void,
self.a.desc,
a as *const _,
(&beta) as *const T as *const std::ffi::c_void,
self.c.desc,
c as *mut _,
)
}
}