#![allow(unused_imports)]
use super::types::{ActivationMode, ConvolutionMode, NanPropagation, PoolingMode};
use crate::cuda::error::{CudaError, CudaResult};
use torsh_core::DType;
#[cfg(feature = "cudnn")]
use cudnn_sys::*;
#[cfg(feature = "cudnn")]
use super::compat::{
cudnnActivationDescriptor_t, cudnnCreateActivationDescriptor, cudnnDestroyActivationDescriptor,
cudnnMathType_t, cudnnSetActivationDescriptor, cudnnSetConvolutionMathType,
};
#[cfg(feature = "cudnn")]
fn to_sys_pooling_mode(mode: super::compat::cudnnPoolingMode_t) -> cudnnPoolingMode_t {
match mode {
super::compat::cudnnPoolingMode_t::CUDNN_POOLING_MAX => {
cudnnPoolingMode_t::CUDNN_POOLING_MAX
}
super::compat::cudnnPoolingMode_t::CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING => {
cudnnPoolingMode_t::CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
}
super::compat::cudnnPoolingMode_t::CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING => {
cudnnPoolingMode_t::CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
}
super::compat::cudnnPoolingMode_t::CUDNN_POOLING_MAX_DETERMINISTIC => {
cudnnPoolingMode_t::CUDNN_POOLING_MAX
}
}
}
pub struct TensorDescriptor {
#[cfg(feature = "cudnn")]
desc: cudnnTensorDescriptor_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
impl TensorDescriptor {
pub fn new() -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut desc: cudnnTensorDescriptor_t = std::ptr::null_mut();
let status = unsafe { cudnnCreateTensorDescriptor(&mut desc) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create tensor descriptor: {:?}",
status
)));
}
Ok(Self { desc })
}
#[cfg(not(feature = "cudnn"))]
{
Ok(Self {
_phantom: std::marker::PhantomData,
})
}
}
pub fn set_4d(&mut self, dtype: DType, n: i32, c: i32, h: i32, w: i32) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let cudnn_type = match dtype {
DType::F32 => cudnnDataType_t::CUDNN_DATA_FLOAT,
DType::F64 => cudnnDataType_t::CUDNN_DATA_DOUBLE,
DType::F16 => cudnnDataType_t::CUDNN_DATA_HALF,
_ => {
return Err(CudaError::CudnnError(format!(
"Unsupported dtype for cuDNN: {:?}",
dtype
)))
}
};
let status = unsafe {
cudnnSetTensor4dDescriptor(
self.desc,
cudnnTensorFormat_t::CUDNN_TENSOR_NCHW,
cudnn_type,
n,
c,
h,
w,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set tensor descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (dtype, n, c, h, w);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn set_4d_nhwc(&mut self, dtype: DType, n: i32, h: i32, w: i32, c: i32) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let cudnn_type = match dtype {
DType::F32 => cudnnDataType_t::CUDNN_DATA_FLOAT,
DType::F64 => cudnnDataType_t::CUDNN_DATA_DOUBLE,
DType::F16 => cudnnDataType_t::CUDNN_DATA_HALF,
_ => {
return Err(CudaError::CudnnError(format!(
"Unsupported dtype for cuDNN: {:?}",
dtype
)))
}
};
let status = unsafe {
cudnnSetTensor4dDescriptor(
self.desc,
cudnnTensorFormat_t::CUDNN_TENSOR_NHWC,
cudnn_type,
n,
c,
h,
w,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set tensor descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (dtype, n, h, w, c);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn set_nd(&mut self, dtype: DType, dims: &[i32], strides: &[i32]) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
if dims.len() != strides.len() {
return Err(CudaError::CudnnError(
"Dimensions and strides must have the same length".to_string(),
));
}
let cudnn_type = match dtype {
DType::F32 => cudnnDataType_t::CUDNN_DATA_FLOAT,
DType::F64 => cudnnDataType_t::CUDNN_DATA_DOUBLE,
DType::F16 => cudnnDataType_t::CUDNN_DATA_HALF,
_ => {
return Err(CudaError::CudnnError(format!(
"Unsupported dtype for cuDNN: {:?}",
dtype
)))
}
};
let status = unsafe {
cudnnSetTensorNdDescriptor(
self.desc,
cudnn_type,
dims.len() as i32,
dims.as_ptr(),
strides.as_ptr(),
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set tensor descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (dtype, dims, strides);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnTensorDescriptor_t {
self.desc
}
}
impl Drop for TensorDescriptor {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.desc.is_null() {
unsafe {
let _status = cudnnDestroyTensorDescriptor(self.desc);
}
}
}
}
}
pub struct FilterDescriptor {
#[cfg(feature = "cudnn")]
desc: cudnnFilterDescriptor_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
impl FilterDescriptor {
pub fn new() -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut desc: cudnnFilterDescriptor_t = std::ptr::null_mut();
let status = unsafe { cudnnCreateFilterDescriptor(&mut desc) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create filter descriptor: {:?}",
status
)));
}
Ok(Self { desc })
}
#[cfg(not(feature = "cudnn"))]
{
Ok(Self {
_phantom: std::marker::PhantomData,
})
}
}
pub fn set_4d(&mut self, dtype: DType, k: i32, c: i32, h: i32, w: i32) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let cudnn_type = match dtype {
DType::F32 => cudnnDataType_t::CUDNN_DATA_FLOAT,
DType::F64 => cudnnDataType_t::CUDNN_DATA_DOUBLE,
DType::F16 => cudnnDataType_t::CUDNN_DATA_HALF,
_ => {
return Err(CudaError::CudnnError(format!(
"Unsupported dtype for cuDNN: {:?}",
dtype
)))
}
};
let status = unsafe {
cudnnSetFilter4dDescriptor(self.desc, cudnn_type, k, c, h, w)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set filter descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (dtype, k, c, h, w);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn set_nd(&mut self, dtype: DType, dims: &[i32]) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let cudnn_type = match dtype {
DType::F32 => cudnnDataType_t::CUDNN_DATA_FLOAT,
DType::F64 => cudnnDataType_t::CUDNN_DATA_DOUBLE,
DType::F16 => cudnnDataType_t::CUDNN_DATA_HALF,
_ => {
return Err(CudaError::CudnnError(format!(
"Unsupported dtype for cuDNN: {:?}",
dtype
)))
}
};
let status = unsafe {
cudnnSetFilterNdDescriptor(self.desc, cudnn_type, dims.len() as i32, dims.as_ptr())
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set filter descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (dtype, dims);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnFilterDescriptor_t {
self.desc
}
}
impl Drop for FilterDescriptor {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.desc.is_null() {
unsafe {
let _status = cudnnDestroyFilterDescriptor(self.desc);
}
}
}
}
}
pub struct ConvolutionDescriptor {
#[cfg(feature = "cudnn")]
desc: cudnnConvolutionDescriptor_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
impl ConvolutionDescriptor {
pub fn new() -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut desc: cudnnConvolutionDescriptor_t = std::ptr::null_mut();
let status = unsafe { cudnnCreateConvolutionDescriptor(&mut desc) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create convolution descriptor: {:?}",
status
)));
}
Ok(Self { desc })
}
#[cfg(not(feature = "cudnn"))]
{
Ok(Self {
_phantom: std::marker::PhantomData,
})
}
}
pub fn set_2d(
&mut self,
pad_h: i32,
pad_w: i32,
stride_h: i32,
stride_w: i32,
dilation_h: i32,
dilation_w: i32,
mode: ConvolutionMode,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let cudnn_mode = mode.to_cudnn();
let mode_value = match cudnn_mode {
crate::cuda::cudnn::compat::cudnnConvolutionMode_t::CUDNN_CONVOLUTION => {
cudnn_sys::cudnnConvolutionMode_t::CUDNN_CONVOLUTION
}
crate::cuda::cudnn::compat::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION => {
cudnn_sys::cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION
}
};
let status = unsafe {
cudnnSetConvolution2dDescriptor(
self.desc, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, mode_value,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set convolution descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, mode,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn set_math_type(&mut self, math_type: cudnnMathType_t) -> CudaResult<()> {
let status = unsafe { cudnnSetConvolutionMathType(self.desc, math_type) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set convolution math type: {:?}",
status
)));
}
Ok(())
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnConvolutionDescriptor_t {
self.desc
}
}
impl Drop for ConvolutionDescriptor {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.desc.is_null() {
unsafe {
let _status = cudnnDestroyConvolutionDescriptor(self.desc);
}
}
}
}
}
pub struct ActivationDescriptor {
#[cfg(feature = "cudnn")]
desc: cudnnActivationDescriptor_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
impl ActivationDescriptor {
pub fn new() -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut desc: cudnnActivationDescriptor_t = std::ptr::null_mut();
let status = unsafe { cudnnCreateActivationDescriptor(&mut desc) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create activation descriptor: {:?}",
status
)));
}
Ok(Self { desc })
}
#[cfg(not(feature = "cudnn"))]
{
Ok(Self {
_phantom: std::marker::PhantomData,
})
}
}
pub fn set(
&mut self,
mode: ActivationMode,
nan_opt: NanPropagation,
coef: f64,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let cudnn_mode = mode.to_cudnn();
let cudnn_nan = nan_opt.to_cudnn();
let status =
unsafe { cudnnSetActivationDescriptor(self.desc, cudnn_mode, cudnn_nan, coef) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set activation descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (mode, nan_opt, coef);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnActivationDescriptor_t {
self.desc
}
}
impl Drop for ActivationDescriptor {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.desc.is_null() {
unsafe {
let _status = cudnnDestroyActivationDescriptor(self.desc);
}
}
}
}
}
pub struct PoolingDescriptor {
#[cfg(feature = "cudnn")]
desc: cudnnPoolingDescriptor_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
impl PoolingDescriptor {
pub fn new() -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut desc: cudnnPoolingDescriptor_t = std::ptr::null_mut();
let status = unsafe { cudnnCreatePoolingDescriptor(&mut desc) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create pooling descriptor: {:?}",
status
)));
}
Ok(Self { desc })
}
#[cfg(not(feature = "cudnn"))]
{
Ok(Self {
_phantom: std::marker::PhantomData,
})
}
}
pub fn set_2d(
&mut self,
mode: PoolingMode,
nan_opt: NanPropagation,
window_h: i32,
window_w: i32,
pad_h: i32,
pad_w: i32,
stride_h: i32,
stride_w: i32,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let compat_mode = mode.to_cudnn();
let cudnn_mode = to_sys_pooling_mode(compat_mode);
let _ = nan_opt;
let status = unsafe {
cudnnSetPooling2dDescriptor(
self.desc, cudnn_mode, window_h, window_w, pad_h, pad_w, stride_h, stride_w,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set pooling descriptor: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
mode, nan_opt, window_h, window_w, pad_h, pad_w, stride_h, stride_w,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnPoolingDescriptor_t {
self.desc
}
}
impl Drop for PoolingDescriptor {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.desc.is_null() {
unsafe {
let _status = cudnnDestroyPoolingDescriptor(self.desc);
}
}
}
}
}
impl Default for TensorDescriptor {
fn default() -> Self {
Self::new().expect("Failed to create default tensor descriptor")
}
}
impl Default for FilterDescriptor {
fn default() -> Self {
Self::new().expect("Failed to create default filter descriptor")
}
}
impl Default for ConvolutionDescriptor {
fn default() -> Self {
Self::new().expect("Failed to create default convolution descriptor")
}
}
impl Default for ActivationDescriptor {
fn default() -> Self {
Self::new().expect("Failed to create default activation descriptor")
}
}
impl Default for PoolingDescriptor {
fn default() -> Self {
Self::new().expect("Failed to create default pooling descriptor")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_descriptor_creation() {
match TensorDescriptor::new() {
Ok(_desc) => {
}
Err(_) => {
}
}
}
#[test]
fn test_filter_descriptor_creation() {
match FilterDescriptor::new() {
Ok(_desc) => {
}
Err(_) => {
}
}
}
#[test]
fn test_convolution_descriptor_creation() {
match ConvolutionDescriptor::new() {
Ok(_desc) => {
}
Err(_) => {
}
}
}
#[test]
fn test_activation_descriptor_creation() {
match ActivationDescriptor::new() {
Ok(_desc) => {
}
Err(_) => {
}
}
}
#[test]
fn test_pooling_descriptor_creation() {
match PoolingDescriptor::new() {
Ok(_desc) => {
}
Err(_) => {
}
}
}
#[test]
fn test_tensor_descriptor_4d() {
if let Ok(mut desc) = TensorDescriptor::new() {
let result = desc.set_4d(DType::F32, 1, 3, 224, 224);
match result {
Ok(_) => {
}
Err(_) => {
}
}
}
}
#[test]
fn test_default_implementations() {
#[cfg(feature = "cudnn")]
{
}
}
}