use crate::cuda::error::{CudaError, CudaResult};
use crate::cuda::stream::CudaStream;
#[cfg(feature = "cudnn")]
use cudnn_sys::*;
#[cfg(feature = "cudnn")]
use super::compat::{cudnnGetMathType, cudnnMathType_t, cudnnSetMathType};
pub struct CudnnHandle {
#[cfg(feature = "cudnn")]
handle: cudnnHandle_t,
#[cfg(not(feature = "cudnn"))]
_phantom: std::marker::PhantomData<()>,
}
unsafe impl Send for CudnnHandle {}
unsafe impl Sync for CudnnHandle {}
impl CudnnHandle {
pub fn new() -> CudaResult<Self> {
#[cfg(feature = "cudnn")]
{
let mut handle: cudnnHandle_t = std::ptr::null_mut();
let status = unsafe { cudnnCreate(&mut handle) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to create cuDNN handle: {:?}",
status
)));
}
Ok(Self { handle })
}
#[cfg(not(feature = "cudnn"))]
{
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
pub fn set_stream(&mut self, stream: &CudaStream) -> CudaResult<()> {
#[cfg(feature = "cudnn")]
{
let status = unsafe { cudnnSetStream(self.handle, stream.stream() as cudaStream_t) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set cuDNN stream: {:?}",
status
)));
}
Ok(())
}
#[cfg(not(feature = "cudnn"))]
{
let _ = stream;
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn raw(&self) -> cudnnHandle_t {
self.handle
}
pub fn is_valid(&self) -> bool {
#[cfg(feature = "cudnn")]
{
!self.handle.is_null()
}
#[cfg(not(feature = "cudnn"))]
{
false
}
}
pub fn get_version() -> CudaResult<(i32, i32, i32)> {
#[cfg(feature = "cudnn")]
{
let version = unsafe { cudnnGetVersion() };
let major = (version / 1000) as i32;
let minor = ((version % 1000) / 100) as i32;
let patch = (version % 100) as i32;
Ok((major, minor, patch))
}
#[cfg(not(feature = "cudnn"))]
{
Err(CudaError::CudnnError(
"cuDNN not available - feature not enabled".to_string(),
))
}
}
#[cfg(feature = "cudnn")]
pub fn get_error_string(status: cudnnStatus_t) -> String {
unsafe {
let ptr = cudnnGetErrorString(status);
if ptr.is_null() {
"Unknown cuDNN error".to_string()
} else {
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
}
}
}
#[cfg(feature = "cudnn")]
pub fn set_math_type(&mut self, math_type: cudnnMathType_t) -> CudaResult<()> {
let status = unsafe { cudnnSetMathType(self.handle, math_type) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to set math type: {}",
Self::get_error_string(status)
)));
}
Ok(())
}
#[cfg(feature = "cudnn")]
pub fn get_math_type(&self) -> CudaResult<cudnnMathType_t> {
let mut math_type: cudnnMathType_t = cudnnMathType_t::CUDNN_DEFAULT_MATH;
let status = unsafe { cudnnGetMathType(self.handle, &mut math_type) };
if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS {
return Err(CudaError::CudnnError(format!(
"Failed to get math type: {}",
Self::get_error_string(status)
)));
}
Ok(math_type)
}
}
impl Drop for CudnnHandle {
fn drop(&mut self) {
#[cfg(feature = "cudnn")]
{
if !self.handle.is_null() {
unsafe {
let _status = cudnnDestroy(self.handle);
}
}
}
}
}
impl Default for CudnnHandle {
fn default() -> Self {
Self::new().expect("Failed to create default cuDNN handle")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handle_creation() {
#[cfg(feature = "cudnn")]
{
match CudnnHandle::new() {
Ok(handle) => {
assert!(handle.is_valid());
}
Err(_) => {
}
}
}
#[cfg(not(feature = "cudnn"))]
{
let result = CudnnHandle::new();
assert!(result.is_err());
}
}
#[test]
fn test_handle_validity() {
#[cfg(feature = "cudnn")]
{
if let Ok(handle) = CudnnHandle::new() {
assert!(handle.is_valid());
}
}
#[cfg(not(feature = "cudnn"))]
{
if let Ok(handle) = CudnnHandle::new() {
assert!(!handle.is_valid());
}
}
}
#[test]
fn test_version_info() {
#[cfg(feature = "cudnn")]
{
match CudnnHandle::get_version() {
Ok((major, minor, patch)) => {
assert!(major >= 7); assert!(minor >= 0);
assert!(patch >= 0);
}
Err(_) => {
}
}
}
#[cfg(not(feature = "cudnn"))]
{
let result = CudnnHandle::get_version();
assert!(result.is_err());
}
}
#[test]
fn test_error_string() {
#[cfg(feature = "cudnn")]
{
let error_str = CudnnHandle::get_error_string(cudnnStatus_t::CUDNN_STATUS_SUCCESS);
assert!(!error_str.is_empty());
let error_str = CudnnHandle::get_error_string(cudnnStatus_t::CUDNN_STATUS_BAD_PARAM);
assert!(!error_str.is_empty());
assert!(
error_str.to_lowercase().contains("param")
|| error_str.to_lowercase().contains("parameter")
);
}
}
#[test]
fn test_math_type_operations() {
#[cfg(feature = "cudnn")]
{
if let Ok(mut handle) = CudnnHandle::new() {
let math_type = cudnnMathType_t::CUDNN_DEFAULT_MATH;
if handle.set_math_type(math_type).is_ok() {
if let Ok(retrieved_type) = handle.get_math_type() {
assert_eq!(retrieved_type, math_type);
}
}
}
}
}
#[test]
fn test_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<CudnnHandle>();
assert_sync::<CudnnHandle>();
}
}