1use std::fmt;
4
5#[derive(Debug, Clone)]
7pub enum CudaError {
8 DeviceNotFound,
10 OutOfMemory,
12 InvalidDevice(i32),
14 LaunchFailed(String),
16 InvalidValue(String),
18 DriverError(i32),
20 CublasError(i32),
22 CudnnError(i32),
24 NotInitialized,
26 AlreadyInitialized,
28 SyncError,
30 MemcpyError,
32}
33
34impl fmt::Display for CudaError {
35 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36 match self {
37 CudaError::DeviceNotFound => write!(f, "No CUDA device found"),
38 CudaError::OutOfMemory => write!(f, "CUDA out of memory"),
39 CudaError::InvalidDevice(id) => write!(f, "Invalid CUDA device: {}", id),
40 CudaError::LaunchFailed(msg) => write!(f, "Kernel launch failed: {}", msg),
41 CudaError::InvalidValue(msg) => write!(f, "Invalid value: {}", msg),
42 CudaError::DriverError(code) => write!(f, "CUDA driver error: {}", code),
43 CudaError::CublasError(code) => write!(f, "cuBLAS error: {}", code),
44 CudaError::CudnnError(code) => write!(f, "cuDNN error: {}", code),
45 CudaError::NotInitialized => write!(f, "CUDA not initialized"),
46 CudaError::AlreadyInitialized => write!(f, "CUDA already initialized"),
47 CudaError::SyncError => write!(f, "CUDA synchronization error"),
48 CudaError::MemcpyError => write!(f, "CUDA memory copy error"),
49 }
50 }
51}
52
53impl std::error::Error for CudaError {}
54
55pub type CudaResult<T> = Result<T, CudaError>;
57
58pub fn check_cuda_error(code: i32) -> CudaResult<()> {
60 if code == 0 {
61 Ok(())
62 } else {
63 Err(CudaError::DriverError(code))
64 }
65}