1use core::fmt;
2
3#[derive(Debug)]
5#[non_exhaustive]
6pub enum GpuError {
7 #[cfg(feature = "cuda")]
9 Driver(cudarc::driver::DriverError),
10
11 #[cfg(not(feature = "cuda"))]
13 NoCudaFeature,
14
15 InvalidDevice { ordinal: usize, count: usize },
17
18 DeviceMismatch { expected: usize, got: usize },
20
21 OutOfMemory {
24 requested_bytes: usize,
25 free_bytes: usize,
26 },
27
28 BudgetExceeded {
31 requested_bytes: usize,
32 budget_bytes: usize,
33 used_bytes: usize,
34 },
35
36 LengthMismatch { a: usize, b: usize },
38
39 ShapeMismatch {
41 op: &'static str,
42 expected: Vec<usize>,
43 got: Vec<usize>,
44 },
45
46 #[cfg(feature = "cuda")]
48 Blas(cudarc::cublas::result::CublasError),
49
50 #[cfg(feature = "cuda")]
52 Solver(cudarc::cusolver::result::CusolverError),
53
54 PtxCompileFailed { kernel: &'static str },
56
57 InvalidState { message: String },
60}
61
62impl fmt::Display for GpuError {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 match self {
65 #[cfg(feature = "cuda")]
66 GpuError::Driver(e) => write!(f, "CUDA driver error: {e}"),
67
68 #[cfg(not(feature = "cuda"))]
69 GpuError::NoCudaFeature => {
70 write!(f, "GPU operations require the `cuda` feature")
71 }
72
73 GpuError::InvalidDevice { ordinal, count } => {
74 write!(
75 f,
76 "invalid device ordinal {ordinal} (only {count} devices available)"
77 )
78 }
79
80 GpuError::DeviceMismatch { expected, got } => {
81 write!(
82 f,
83 "device mismatch: expected cuda:{expected}, got cuda:{got}"
84 )
85 }
86
87 GpuError::OutOfMemory {
88 requested_bytes,
89 free_bytes,
90 } => {
91 write!(
92 f,
93 "GPU out of memory: requested {requested_bytes} bytes but only \
94 {free_bytes} bytes free"
95 )
96 }
97
98 GpuError::BudgetExceeded {
99 requested_bytes,
100 budget_bytes,
101 used_bytes,
102 } => {
103 write!(
104 f,
105 "memory budget exceeded: requested {requested_bytes} bytes, \
106 budget is {budget_bytes} bytes with {used_bytes} bytes already used"
107 )
108 }
109
110 GpuError::LengthMismatch { a, b } => {
111 write!(f, "buffer length mismatch: {a} vs {b}")
112 }
113
114 GpuError::ShapeMismatch { op, expected, got } => {
115 write!(
116 f,
117 "{op}: shape mismatch, expected {expected:?}, got {got:?}"
118 )
119 }
120
121 #[cfg(feature = "cuda")]
122 GpuError::Blas(e) => write!(f, "cuBLAS error: {e}"),
123
124 #[cfg(feature = "cuda")]
125 GpuError::Solver(e) => write!(f, "cuSOLVER error: {e}"),
126
127 GpuError::PtxCompileFailed { kernel } => {
128 write!(f, "PTX kernel compilation failed: {kernel}")
129 }
130
131 GpuError::InvalidState { message } => {
132 write!(f, "invalid state: {message}")
133 }
134 }
135 }
136}
137
138impl std::error::Error for GpuError {
139 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
140 match self {
141 #[cfg(feature = "cuda")]
142 GpuError::Driver(e) => Some(e),
143 #[cfg(feature = "cuda")]
144 GpuError::Blas(e) => Some(e),
145 #[cfg(feature = "cuda")]
146 GpuError::Solver(e) => Some(e),
147 _ => None,
148 }
149 }
150}
151
152#[cfg(feature = "cuda")]
153impl From<cudarc::driver::DriverError> for GpuError {
154 fn from(e: cudarc::driver::DriverError) -> Self {
155 GpuError::Driver(e)
156 }
157}
158
159#[cfg(feature = "cuda")]
160impl From<cudarc::cublas::result::CublasError> for GpuError {
161 fn from(e: cudarc::cublas::result::CublasError) -> Self {
162 GpuError::Blas(e)
163 }
164}
165
166#[cfg(feature = "cuda")]
167impl From<cudarc::cusolver::result::CusolverError> for GpuError {
168 fn from(e: cudarc::cusolver::result::CusolverError) -> Self {
169 GpuError::Solver(e)
170 }
171}
172
173pub type GpuResult<T> = Result<T, GpuError>;