kn_cuda_sys/wrapper/
status.rs

1use std::ffi::CStr;
2use std::ptr::null;
3
4use crate::bindings::{
5    cuGetErrorName, cublasStatus_t, cudaError, cudaGetErrorString, nvrtcGetErrorString, nvrtcResult, CUresult,
6};
7use crate::bindings::{cudnnGetErrorString, cudnnStatus_t};
8
9pub trait Status: Copy + PartialEq {
10    const SUCCESS: Self;
11
12    fn as_string(&self) -> &'static str;
13
14    fn is_success(&self) -> bool {
15        *self == Self::SUCCESS
16    }
17
18    #[track_caller]
19    fn unwrap(&self) {
20        if !self.is_success() {
21            panic!("Operation returned error {:?}", self.as_string());
22        }
23    }
24
25    /// Alternative to `unwrap` that only panics if `!std::thread::panicking()`.
26    /// This is useful to avoid double panics in [Drop] implementations.
27    #[track_caller]
28    fn unwrap_in_drop(&self) {
29        if !std::thread::panicking() {
30            self.unwrap();
31        }
32    }
33}
34
35impl Status for cudaError {
36    const SUCCESS: Self = cudaError::cudaSuccess;
37
38    fn as_string(&self) -> &'static str {
39        unsafe { CStr::from_ptr(cudaGetErrorString(*self)) }.to_str().unwrap()
40    }
41}
42
43impl Status for cudnnStatus_t {
44    const SUCCESS: Self = cudnnStatus_t::CUDNN_STATUS_SUCCESS;
45
46    fn as_string(&self) -> &'static str {
47        unsafe { CStr::from_ptr(cudnnGetErrorString(*self)) }.to_str().unwrap()
48    }
49}
50
51impl Status for cublasStatus_t {
52    const SUCCESS: Self = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
53
54    fn as_string(&self) -> &'static str {
55        match self {
56            cublasStatus_t::CUBLAS_STATUS_SUCCESS => "CUBLAS_STATUS_SUCCESS",
57            cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => "CUBLAS_STATUS_NOT_INITIALIZED",
58            cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED => "CUBLAS_STATUS_ALLOC_FAILED",
59            cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE => "CUBLAS_STATUS_INVALID_VALUE",
60            cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => "CUBLAS_STATUS_ARCH_MISMATCH",
61            cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR => "CUBLAS_STATUS_MAPPING_ERROR",
62            cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => "CUBLAS_STATUS_EXECUTION_FAILED",
63            cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR => "CUBLAS_STATUS_INTERNAL_ERROR",
64            cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED => "CUBLAS_STATUS_NOT_SUPPORTED",
65            cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR => "CUBLAS_STATUS_LICENSE_ERROR",
66        }
67    }
68}
69
70impl Status for nvrtcResult {
71    const SUCCESS: Self = nvrtcResult::NVRTC_SUCCESS;
72
73    fn as_string(&self) -> &'static str {
74        unsafe { CStr::from_ptr(nvrtcGetErrorString(*self)) }.to_str().unwrap()
75    }
76}
77
78impl Status for CUresult {
79    const SUCCESS: Self = CUresult::CUDA_SUCCESS;
80
81    fn as_string(&self) -> &'static str {
82        unsafe {
83            let mut ptr = null();
84            let result = cuGetErrorName(*self, &mut ptr as *mut _);
85            if result != CUresult::CUDA_SUCCESS {
86                panic!("Error '{:?}' while getting name of error '{:?}'", result, self);
87            }
88            CStr::from_ptr(ptr).to_str().unwrap()
89        }
90    }
91}