#![allow(unused_imports)]
use cust::prelude::DevicePointer;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::cuda::error::{CudaError, CudaResult};
use torsh_core::DType;
use super::descriptors::{
ActivationDescriptor, ConvolutionDescriptor, FilterDescriptor, PoolingDescriptor,
TensorDescriptor,
};
use super::handle::CudnnHandle;
use super::rnn::{RNNDataDescriptor, RNNDescriptor, RNNForwardMode};
use super::types::{
ActivationMode, ConvolutionForwardAlgorithm, ConvolutionForwardAlgorithmPerformance,
ConvolutionMode, NanPropagation, PoolingMode,
};
#[cfg(feature = "cudnn")]
use cudnn_sys::*;
#[cfg(feature = "cudnn")]
use super::compat::{
cudnnBatchNormMode_t, cudnnBatchNormalizationForwardInference,
cudnnBatchNormalizationForwardTraining, cudnnForwardMode_t, cudnnNormAlgo_t, cudnnNormMode_t,
cudnnNormOps_t, cudnnNormalizationForwardInference, cudnnRNNForward,
cudnnSetConvolutionGroupCount,
};
#[cfg(feature = "cudnn")]
fn to_sys_conv_fwd_algo(
algo: super::compat::cudnnConvolutionFwdAlgo_t,
) -> cudnnConvolutionFwdAlgo_t {
match algo {
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM => {
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
}
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM => {
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
}
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_GEMM => {
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_GEMM
}
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT => {
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_DIRECT
}
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT => {
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT
}
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING => {
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
}
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD |
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED |
super::compat::cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_COUNT => {
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_GEMM
}
}
}
pub struct CudnnOps {
handle: Arc<Mutex<CudnnHandle>>,
cache: Mutex<HashMap<String, Box<dyn std::any::Any + Send + Sync>>>,
}
impl CudnnOps {
pub fn new() -> CudaResult<Self> {
let handle = CudnnHandle::new()?;
Ok(Self {
handle: Arc::new(Mutex::new(handle)),
cache: Mutex::new(HashMap::new()),
})
}
pub fn conv2d_forward(
&self,
input: DevicePointer<f32>,
weight: DevicePointer<f32>,
bias: Option<DevicePointer<f32>>,
output: DevicePointer<f32>,
input_shape: (i32, i32, i32, i32), weight_shape: (i32, i32, i32, i32), output_shape: (i32, i32, i32, i32), padding: (i32, i32),
stride: (i32, i32),
dilation: (i32, i32),
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let mut input_desc = TensorDescriptor::new()?;
input_desc.set_4d(
DType::F32,
input_shape.0,
input_shape.1,
input_shape.2,
input_shape.3,
)?;
let mut weight_desc = FilterDescriptor::new()?;
weight_desc.set_4d(
DType::F32,
weight_shape.0,
weight_shape.1,
weight_shape.2,
weight_shape.3,
)?;
let mut output_desc = TensorDescriptor::new()?;
output_desc.set_4d(
DType::F32,
output_shape.0,
output_shape.1,
output_shape.2,
output_shape.3,
)?;
let mut conv_desc = ConvolutionDescriptor::new()?;
conv_desc.set_2d(
padding.0,
padding.1,
stride.0,
stride.1,
dilation.0,
dilation.1,
ConvolutionMode::CrossCorrelation,
)?;
let alpha = 1.0f32;
let beta = 0.0f32;
let handle = self.handle.lock().expect("lock should not be poisoned");
let status = unsafe {
cudnnConvolutionForward(
handle.raw(),
&alpha as *const f32 as *const std::ffi::c_void,
input_desc.raw(),
input.as_raw() as *const std::ffi::c_void,
weight_desc.raw(),
weight.as_raw() as *const std::ffi::c_void,
conv_desc.raw(),
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
std::ptr::null_mut(),
0,
&beta as *const f32 as *const std::ffi::c_void,
output_desc.raw(),
output.as_raw() as *mut std::ffi::c_void,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN convolution forward failed: {:?}",
status
)));
}
if let Some(bias_ptr) = bias {
let mut bias_desc = TensorDescriptor::new()?;
bias_desc.set_4d(DType::F32, 1, output_shape.1, 1, 1)?;
let status = unsafe {
cudnnAddTensor(
handle.raw(),
cudnnAddMode_t::CUDNN_ADD_SAME_C,
&alpha as *const f32 as *const std::ffi::c_void,
bias_desc.raw(),
bias_ptr.as_raw() as *const std::ffi::c_void,
&alpha as *const f32 as *const std::ffi::c_void,
output_desc.raw(),
output.as_raw() as *mut std::ffi::c_void,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN bias addition failed: {:?}",
status
)));
}
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
input,
weight,
bias,
output,
input_shape,
weight_shape,
output_shape,
padding,
stride,
dilation,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn batchnorm_forward(
&self,
input: DevicePointer<f32>,
output: DevicePointer<f32>,
scale: DevicePointer<f32>,
bias: DevicePointer<f32>,
running_mean: DevicePointer<f32>,
running_var: DevicePointer<f32>,
epsilon: f64,
exponential_average_factor: f64,
shape: (i32, i32, i32, i32),
training: bool,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let mut input_desc = TensorDescriptor::new()?;
input_desc.set_4d(DType::F32, shape.0, shape.1, shape.2, shape.3)?;
let mut output_desc = TensorDescriptor::new()?;
output_desc.set_4d(DType::F32, shape.0, shape.1, shape.2, shape.3)?;
let mut scale_bias_desc = TensorDescriptor::new()?;
scale_bias_desc.set_4d(DType::F32, 1, shape.1, 1, 1)?;
let alpha = 1.0f32;
let beta = 0.0f32;
let handle = self.handle.lock().expect("lock should not be poisoned");
let status = if training {
unsafe {
cudnnBatchNormalizationForwardTraining(
handle.raw(),
cudnnBatchNormMode_t::CUDNN_BATCHNORM_SPATIAL,
&alpha as *const f32 as *const std::ffi::c_void,
&beta as *const f32 as *const std::ffi::c_void,
input_desc.raw(),
input.as_raw() as *const std::ffi::c_void,
output_desc.raw(),
output.as_raw() as *mut std::ffi::c_void,
scale_bias_desc.raw(),
scale.as_raw() as *const std::ffi::c_void,
bias.as_raw() as *const std::ffi::c_void,
exponential_average_factor,
running_mean.as_raw() as *mut std::ffi::c_void,
running_var.as_raw() as *mut std::ffi::c_void,
epsilon,
std::ptr::null_mut(),
std::ptr::null_mut(),
)
}
} else {
unsafe {
cudnnBatchNormalizationForwardInference(
handle.raw(),
cudnnBatchNormMode_t::CUDNN_BATCHNORM_SPATIAL,
&alpha as *const f32 as *const std::ffi::c_void,
&beta as *const f32 as *const std::ffi::c_void,
input_desc.raw(),
input.as_raw() as *const std::ffi::c_void,
output_desc.raw(),
output.as_raw() as *mut std::ffi::c_void,
scale_bias_desc.raw(),
scale.as_raw() as *const std::ffi::c_void,
bias.as_raw() as *const std::ffi::c_void,
running_mean.as_raw() as *const std::ffi::c_void,
running_var.as_raw() as *const std::ffi::c_void,
epsilon,
)
}
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN batch normalization forward failed: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
input,
output,
scale,
bias,
running_mean,
running_var,
epsilon,
exponential_average_factor,
shape,
training,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn activation_forward(
&self,
mode: ActivationMode,
input: DevicePointer<f32>,
output: DevicePointer<f32>,
shape: (i32, i32, i32, i32),
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let mut input_desc = TensorDescriptor::new()?;
input_desc.set_4d(DType::F32, shape.0, shape.1, shape.2, shape.3)?;
let mut output_desc = TensorDescriptor::new()?;
output_desc.set_4d(DType::F32, shape.0, shape.1, shape.2, shape.3)?;
let cudnn_mode = match mode {
ActivationMode::Sigmoid => cudnnActivationMode_t::CUDNN_ACTIVATION_SIGMOID,
ActivationMode::Relu => cudnnActivationMode_t::CUDNN_ACTIVATION_RELU,
ActivationMode::Tanh => cudnnActivationMode_t::CUDNN_ACTIVATION_TANH,
_ => cudnnActivationMode_t::CUDNN_ACTIVATION_RELU, };
let alpha = 1.0f32;
let beta = 0.0f32;
let handle = self.handle.lock().expect("lock should not be poisoned");
let status = unsafe {
cudnnActivationForward(
handle.raw(),
cudnn_mode,
&alpha as *const f32 as *const std::ffi::c_void,
input_desc.raw(),
input.as_raw() as *const std::ffi::c_void,
&beta as *const f32 as *const std::ffi::c_void,
output_desc.raw(),
output.as_raw() as *mut std::ffi::c_void,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN activation forward failed: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (mode, input, output, shape);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn pooling2d_forward(
&self,
mode: PoolingMode,
input: DevicePointer<f32>,
output: DevicePointer<f32>,
input_shape: (i32, i32, i32, i32), output_shape: (i32, i32, i32, i32), window_size: (i32, i32),
padding: (i32, i32),
stride: (i32, i32),
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let mut input_desc = TensorDescriptor::new()?;
input_desc.set_4d(
DType::F32,
input_shape.0,
input_shape.1,
input_shape.2,
input_shape.3,
)?;
let mut output_desc = TensorDescriptor::new()?;
output_desc.set_4d(
DType::F32,
output_shape.0,
output_shape.1,
output_shape.2,
output_shape.3,
)?;
let mut pool_desc = PoolingDescriptor::new()?;
pool_desc.set_2d(
mode,
NanPropagation::NotPropagate,
window_size.0,
window_size.1,
padding.0,
padding.1,
stride.0,
stride.1,
)?;
let alpha = 1.0f32;
let beta = 0.0f32;
let handle = self.handle.lock().expect("lock should not be poisoned");
let status = unsafe {
cudnnPoolingForward(
handle.raw(),
pool_desc.raw(),
&alpha as *const f32 as *const std::ffi::c_void,
input_desc.raw(),
input.as_raw() as *const std::ffi::c_void,
&beta as *const f32 as *const std::ffi::c_void,
output_desc.raw(),
output.as_raw() as *mut std::ffi::c_void,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN pooling forward failed: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
mode,
input,
output,
input_shape,
output_shape,
window_size,
padding,
stride,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn layer_norm_forward(
&self,
x: DevicePointer<f32>,
scale: DevicePointer<f32>,
bias: DevicePointer<f32>,
epsilon: f64,
x_desc: &TensorDescriptor,
scale_bias_desc: &TensorDescriptor,
y: DevicePointer<f32>,
mean: DevicePointer<f32>,
inv_variance: DevicePointer<f32>,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let handle = self.handle.lock().expect("lock should not be poisoned");
let alpha = 1.0f32;
let beta = 0.0f32;
let status = unsafe {
cudnnNormalizationForwardInference(
handle.raw(),
cudnnNormMode_t::CUDNN_LAYER_NORM,
cudnnNormOps_t::CUDNN_NORM_OPS_NORM_ACTIVATION,
cudnnNormAlgo_t::CUDNN_NORM_ALGO_STANDARD,
&alpha as *const f32 as *const std::ffi::c_void,
&beta as *const f32 as *const std::ffi::c_void,
x_desc.raw(),
x.as_raw() as *const std::ffi::c_void,
scale_bias_desc.raw(),
scale.as_raw() as *const std::ffi::c_void,
bias.as_raw() as *const std::ffi::c_void,
epsilon,
x_desc.raw(),
y.as_raw() as *mut std::ffi::c_void,
scale_bias_desc.raw(),
mean.as_raw() as *mut std::ffi::c_void,
inv_variance.as_raw() as *mut std::ffi::c_void,
std::ptr::null_mut(), std::ptr::null_mut(), 0, std::ptr::null_mut(), 0, )
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN layer normalization forward failed: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
x,
scale,
bias,
epsilon,
x_desc,
scale_bias_desc,
y,
mean,
inv_variance,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn grouped_conv2d_forward(
&self,
input: DevicePointer<f32>,
filter: DevicePointer<f32>,
bias: Option<DevicePointer<f32>>,
output: DevicePointer<f32>,
input_desc: &TensorDescriptor,
filter_desc: &FilterDescriptor,
bias_desc: Option<&TensorDescriptor>,
output_desc: &TensorDescriptor,
conv_desc: &ConvolutionDescriptor,
groups: i32,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let status = unsafe { cudnnSetConvolutionGroupCount(conv_desc.raw(), groups) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set convolution group count: {:?}",
status
)));
}
let alpha = 1.0f32;
let beta = 0.0f32;
let handle = self.handle.lock().expect("lock should not be poisoned");
let status = unsafe {
cudnnConvolutionForward(
handle.raw(),
&alpha as *const f32 as *const std::ffi::c_void,
input_desc.raw(),
input.as_raw() as *const std::ffi::c_void,
filter_desc.raw(),
filter.as_raw() as *const std::ffi::c_void,
conv_desc.raw(),
cudnnConvolutionFwdAlgo_t::CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
std::ptr::null_mut(),
0,
&beta as *const f32 as *const std::ffi::c_void,
output_desc.raw(),
output.as_raw() as *mut std::ffi::c_void,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN grouped convolution forward failed: {:?}",
status
)));
}
if let (Some(bias_ptr), Some(bias_desc)) = (bias, bias_desc) {
let status = unsafe {
cudnnAddTensor(
handle.raw(),
cudnnAddMode_t::CUDNN_ADD_SAME_C,
&alpha as *const f32 as *const std::ffi::c_void,
bias_desc.raw(),
bias_ptr.as_raw() as *const std::ffi::c_void,
&alpha as *const f32 as *const std::ffi::c_void,
output_desc.raw(),
output.as_raw() as *mut std::ffi::c_void,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN grouped convolution bias addition failed: {:?}",
status
)));
}
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
input,
filter,
bias,
output,
input_desc,
filter_desc,
bias_desc,
output_desc,
conv_desc,
groups,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn get_convolution_forward_workspace_size(
&self,
input_desc: &TensorDescriptor,
filter_desc: &FilterDescriptor,
conv_desc: &ConvolutionDescriptor,
output_desc: &TensorDescriptor,
algorithm: ConvolutionForwardAlgorithm,
) -> CudaResult<usize> {
#[cfg(feature = "cudnn")]
{
let handle = self.handle.lock().expect("lock should not be poisoned");
let mut size: usize = 0;
let status = unsafe {
cudnnGetConvolutionForwardWorkspaceSize(
handle.raw(),
input_desc.raw(),
filter_desc.raw(),
conv_desc.raw(),
output_desc.raw(),
to_sys_conv_fwd_algo(algorithm.to_cudnn()),
&mut size,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to get convolution workspace size: {:?}",
status
)));
}
Ok(size)
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (input_desc, filter_desc, conv_desc, output_desc, algorithm);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn find_convolution_forward_algorithm(
&self,
input_desc: &TensorDescriptor,
filter_desc: &FilterDescriptor,
conv_desc: &ConvolutionDescriptor,
output_desc: &TensorDescriptor,
request_algo_count: i32,
) -> CudaResult<Vec<ConvolutionForwardAlgorithmPerformance>> {
#[cfg(feature = "cudnn")]
{
let handle = self.handle.lock().expect("lock should not be poisoned");
let mut returned_algo_count: i32 = 0;
let mut perf_results = vec![
ConvolutionForwardAlgorithmPerformance::default();
request_algo_count as usize
];
let status = unsafe {
cudnnFindConvolutionForwardAlgorithm(
handle.raw(),
input_desc.raw(),
filter_desc.raw(),
conv_desc.raw(),
output_desc.raw(),
request_algo_count,
&mut returned_algo_count,
perf_results.as_mut_ptr() as *mut cudnnConvolutionFwdAlgoPerf_t,
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to find convolution algorithm: {:?}",
status
)));
}
perf_results.truncate(returned_algo_count as usize);
Ok(perf_results)
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
input_desc,
filter_desc,
conv_desc,
output_desc,
request_algo_count,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn lstm_forward(
&self,
rnn_desc: &RNNDescriptor,
forward_mode: RNNForwardMode,
dev_seq_lengths: Option<DevicePointer<i32>>,
x_desc: &RNNDataDescriptor,
x: DevicePointer<f32>,
y_desc: &RNNDataDescriptor,
y: DevicePointer<f32>,
h_desc: &TensorDescriptor,
hx: Option<DevicePointer<f32>>,
hy: Option<DevicePointer<f32>>,
c_desc: &TensorDescriptor,
cx: Option<DevicePointer<f32>>,
cy: Option<DevicePointer<f32>>,
weight_space_size: usize,
weight_space: DevicePointer<u8>,
work_space_size: usize,
work_space: DevicePointer<u8>,
reserve_space_size: usize,
reserve_space: Option<DevicePointer<u8>>,
) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let handle = self.handle.lock().expect("lock should not be poisoned");
let status = unsafe {
cudnnRNNForward(
handle.raw(),
rnn_desc.raw(),
forward_mode.to_cudnn(),
dev_seq_lengths
.map(|p| p.as_raw() as *const i32)
.unwrap_or(std::ptr::null()),
x_desc.raw(),
x.as_raw() as *const std::ffi::c_void,
y_desc.raw(),
y.as_raw() as *mut std::ffi::c_void,
h_desc.raw(),
hx.map(|p| p.as_raw() as *const std::ffi::c_void)
.unwrap_or(std::ptr::null()),
hy.map(|p| p.as_raw() as *mut std::ffi::c_void)
.unwrap_or(std::ptr::null_mut()),
c_desc.raw(),
cx.map(|p| p.as_raw() as *const std::ffi::c_void)
.unwrap_or(std::ptr::null()),
cy.map(|p| p.as_raw() as *mut std::ffi::c_void)
.unwrap_or(std::ptr::null_mut()),
weight_space_size,
weight_space.as_raw() as *const std::ffi::c_void,
work_space_size,
work_space.as_raw() as *mut std::ffi::c_void,
reserve_space_size,
reserve_space
.map(|p| p.as_raw() as *mut std::ffi::c_void)
.unwrap_or(std::ptr::null_mut()),
)
};
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"cuDNN LSTM forward failed: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = (
rnn_desc,
forward_mode,
dev_seq_lengths,
x_desc,
x,
y_desc,
y,
h_desc,
hx,
hy,
c_desc,
cx,
cy,
weight_space_size,
weight_space,
work_space_size,
work_space,
reserve_space_size,
reserve_space,
);
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
}
unsafe impl Send for CudnnOps {}
unsafe impl Sync for CudnnOps {}
pub fn init() -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cudnn_ops_creation() {
#[cfg(feature = "cudnn")]
{
match CudnnOps::new() {
Ok(ops) => {
assert!(std::ptr::addr_of!(ops) as usize != 0);
}
Err(_) => {
}
}
}
#[cfg(not(feature = "cudnn"))]
{
let result = CudnnOps::new();
assert!(result.is_err());
}
}
#[test]
fn test_init() {
#[cfg(feature = "cudnn")]
{
assert!(init().is_ok());
}
#[cfg(not(feature = "cudnn"))]
{
let result = init();
assert!(result.is_err());
}
}
#[test]
fn test_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<CudnnOps>();
assert_sync::<CudnnOps>();
}
}