1use core::fmt;
2
3#[derive(Debug)]
5#[non_exhaustive]
6pub enum GpuError {
7 #[cfg(feature = "cuda")]
9 Driver(cudarc::driver::DriverError),
10
11 #[cfg(not(feature = "cuda"))]
13 NoCudaFeature,
14
15 InvalidDevice { ordinal: usize, count: usize },
17
18 DeviceMismatch { expected: usize, got: usize },
20
21 OutOfMemory {
24 requested_bytes: usize,
25 free_bytes: usize,
26 },
27
28 BudgetExceeded {
31 requested_bytes: usize,
32 budget_bytes: usize,
33 used_bytes: usize,
34 },
35
36 LengthMismatch { a: usize, b: usize },
38
39 ShapeMismatch {
41 op: &'static str,
42 expected: Vec<usize>,
43 got: Vec<usize>,
44 },
45
46 #[cfg(feature = "cuda")]
48 Blas(cudarc::cublas::result::CublasError),
49
50 PtxCompileFailed { kernel: &'static str },
52
53 InvalidState { message: String },
56}
57
58impl fmt::Display for GpuError {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 match self {
61 #[cfg(feature = "cuda")]
62 GpuError::Driver(e) => write!(f, "CUDA driver error: {e}"),
63
64 #[cfg(not(feature = "cuda"))]
65 GpuError::NoCudaFeature => {
66 write!(f, "GPU operations require the `cuda` feature")
67 }
68
69 GpuError::InvalidDevice { ordinal, count } => {
70 write!(
71 f,
72 "invalid device ordinal {ordinal} (only {count} devices available)"
73 )
74 }
75
76 GpuError::DeviceMismatch { expected, got } => {
77 write!(
78 f,
79 "device mismatch: expected cuda:{expected}, got cuda:{got}"
80 )
81 }
82
83 GpuError::OutOfMemory {
84 requested_bytes,
85 free_bytes,
86 } => {
87 write!(
88 f,
89 "GPU out of memory: requested {requested_bytes} bytes but only \
90 {free_bytes} bytes free"
91 )
92 }
93
94 GpuError::BudgetExceeded {
95 requested_bytes,
96 budget_bytes,
97 used_bytes,
98 } => {
99 write!(
100 f,
101 "memory budget exceeded: requested {requested_bytes} bytes, \
102 budget is {budget_bytes} bytes with {used_bytes} bytes already used"
103 )
104 }
105
106 GpuError::LengthMismatch { a, b } => {
107 write!(f, "buffer length mismatch: {a} vs {b}")
108 }
109
110 GpuError::ShapeMismatch { op, expected, got } => {
111 write!(
112 f,
113 "{op}: shape mismatch, expected {expected:?}, got {got:?}"
114 )
115 }
116
117 #[cfg(feature = "cuda")]
118 GpuError::Blas(e) => write!(f, "cuBLAS error: {e}"),
119
120 GpuError::PtxCompileFailed { kernel } => {
121 write!(f, "PTX kernel compilation failed: {kernel}")
122 }
123
124 GpuError::InvalidState { message } => {
125 write!(f, "invalid state: {message}")
126 }
127 }
128 }
129}
130
131impl std::error::Error for GpuError {
132 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
133 match self {
134 #[cfg(feature = "cuda")]
135 GpuError::Driver(e) => Some(e),
136 #[cfg(feature = "cuda")]
137 GpuError::Blas(e) => Some(e),
138 _ => None,
139 }
140 }
141}
142
143#[cfg(feature = "cuda")]
144impl From<cudarc::driver::DriverError> for GpuError {
145 fn from(e: cudarc::driver::DriverError) -> Self {
146 GpuError::Driver(e)
147 }
148}
149
150#[cfg(feature = "cuda")]
151impl From<cudarc::cublas::result::CublasError> for GpuError {
152 fn from(e: cudarc::cublas::result::CublasError) -> Self {
153 GpuError::Blas(e)
154 }
155}
156
157pub type GpuResult<T> = Result<T, GpuError>;