kn_cuda_sys/wrapper/
status.rs1use std::ffi::CStr;
2use std::ptr::null;
3
4use crate::bindings::{
5 cuGetErrorName, cublasStatus_t, cudaError, cudaGetErrorString, nvrtcGetErrorString, nvrtcResult, CUresult,
6};
7use crate::bindings::{cudnnGetErrorString, cudnnStatus_t};
8
9pub trait Status: Copy + PartialEq {
10 const SUCCESS: Self;
11
12 fn as_string(&self) -> &'static str;
13
14 fn is_success(&self) -> bool {
15 *self == Self::SUCCESS
16 }
17
18 #[track_caller]
19 fn unwrap(&self) {
20 if !self.is_success() {
21 panic!("Operation returned error {:?}", self.as_string());
22 }
23 }
24
25 #[track_caller]
28 fn unwrap_in_drop(&self) {
29 if !std::thread::panicking() {
30 self.unwrap();
31 }
32 }
33}
34
35impl Status for cudaError {
36 const SUCCESS: Self = cudaError::cudaSuccess;
37
38 fn as_string(&self) -> &'static str {
39 unsafe { CStr::from_ptr(cudaGetErrorString(*self)) }.to_str().unwrap()
40 }
41}
42
43impl Status for cudnnStatus_t {
44 const SUCCESS: Self = cudnnStatus_t::CUDNN_STATUS_SUCCESS;
45
46 fn as_string(&self) -> &'static str {
47 unsafe { CStr::from_ptr(cudnnGetErrorString(*self)) }.to_str().unwrap()
48 }
49}
50
51impl Status for cublasStatus_t {
52 const SUCCESS: Self = cublasStatus_t::CUBLAS_STATUS_SUCCESS;
53
54 fn as_string(&self) -> &'static str {
55 match self {
56 cublasStatus_t::CUBLAS_STATUS_SUCCESS => "CUBLAS_STATUS_SUCCESS",
57 cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => "CUBLAS_STATUS_NOT_INITIALIZED",
58 cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED => "CUBLAS_STATUS_ALLOC_FAILED",
59 cublasStatus_t::CUBLAS_STATUS_INVALID_VALUE => "CUBLAS_STATUS_INVALID_VALUE",
60 cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => "CUBLAS_STATUS_ARCH_MISMATCH",
61 cublasStatus_t::CUBLAS_STATUS_MAPPING_ERROR => "CUBLAS_STATUS_MAPPING_ERROR",
62 cublasStatus_t::CUBLAS_STATUS_EXECUTION_FAILED => "CUBLAS_STATUS_EXECUTION_FAILED",
63 cublasStatus_t::CUBLAS_STATUS_INTERNAL_ERROR => "CUBLAS_STATUS_INTERNAL_ERROR",
64 cublasStatus_t::CUBLAS_STATUS_NOT_SUPPORTED => "CUBLAS_STATUS_NOT_SUPPORTED",
65 cublasStatus_t::CUBLAS_STATUS_LICENSE_ERROR => "CUBLAS_STATUS_LICENSE_ERROR",
66 }
67 }
68}
69
70impl Status for nvrtcResult {
71 const SUCCESS: Self = nvrtcResult::NVRTC_SUCCESS;
72
73 fn as_string(&self) -> &'static str {
74 unsafe { CStr::from_ptr(nvrtcGetErrorString(*self)) }.to_str().unwrap()
75 }
76}
77
78impl Status for CUresult {
79 const SUCCESS: Self = CUresult::CUDA_SUCCESS;
80
81 fn as_string(&self) -> &'static str {
82 unsafe {
83 let mut ptr = null();
84 let result = cuGetErrorName(*self, &mut ptr as *mut _);
85 if result != CUresult::CUDA_SUCCESS {
86 panic!("Error '{:?}' while getting name of error '{:?}'", result, self);
87 }
88 CStr::from_ptr(ptr).to_str().unwrap()
89 }
90 }
91}