use std::ffi::CStr;
use std::mem::MaybeUninit;
use std::{
error,
fmt::{self, Display, Formatter},
};
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct DriverError(pub cuda_bindings::CUresult);
impl DriverError {
fn _fmt(&self, formatter: &mut Formatter) -> fmt::Result {
if self.0 == cuda_bindings::cudaError_enum_CUDA_ERROR_SHARED_OBJECT_INIT_FAILED {
if let Some(load_error) = cuda_bindings::cuda_driver_load_error() {
return formatter
.debug_tuple("DriverError")
.field(&self.0)
.field(&format!("CUDA driver library unavailable: {load_error}"))
.finish();
}
}
match self.error_string() {
Ok(err_str) => formatter
.debug_tuple("DriverError")
.field(&self.0)
.field(&err_str)
.finish(),
Err(_) => formatter
.debug_tuple("DriverError")
.field(&self.0)
.field(&"<Failure when calling cuGetErrorString()>")
.finish(),
}
}
}
impl Display for DriverError {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
self._fmt(formatter)
}
}
impl std::fmt::Debug for DriverError {
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> fmt::Result {
self._fmt(formatter)
}
}
impl error::Error for DriverError {}
pub trait IntoResult<T> {
fn result(self) -> Result<T, DriverError>
where
Self: Sized;
}
impl IntoResult<()> for cuda_bindings::CUresult {
fn result(self) -> Result<(), DriverError> {
match self {
cuda_bindings::cudaError_enum_CUDA_SUCCESS => Ok(()),
_ => Err(DriverError(self)),
}
}
}
impl<T> IntoResult<T> for (cuda_bindings::CUresult, T) {
fn result(self) -> Result<T, DriverError> {
match self.0 {
cuda_bindings::cudaError_enum_CUDA_SUCCESS => Ok(self.1),
_ => Err(DriverError(self.0)),
}
}
}
impl<T> IntoResult<T> for (cuda_bindings::CUresult, MaybeUninit<T>) {
fn result(self) -> Result<T, DriverError> {
match self.0 {
cuda_bindings::cudaError_enum_CUDA_SUCCESS => Ok(unsafe { self.1.assume_init() }),
_ => Err(DriverError(self.0)),
}
}
}
impl DriverError {
pub fn error_name(&self) -> Result<&CStr, DriverError> {
let mut err_str = MaybeUninit::uninit();
unsafe {
cuda_bindings::cuGetErrorName(self.0, err_str.as_mut_ptr()).result()?;
Ok(CStr::from_ptr(err_str.assume_init()))
}
}
pub fn error_string(&self) -> Result<&CStr, DriverError> {
let mut err_str = MaybeUninit::uninit();
unsafe {
cuda_bindings::cuGetErrorString(self.0, err_str.as_mut_ptr()).result()?;
Ok(CStr::from_ptr(err_str.assume_init()))
}
}
}
#[cfg(test)]
mod tests {
use super::DriverError;
#[test]
fn shared_object_init_failed_includes_loader_error_when_driver_is_missing() {
if cuda_bindings::cuda_driver_load_error().is_none() {
return;
}
let formatted =
DriverError(cuda_bindings::cudaError_enum_CUDA_ERROR_SHARED_OBJECT_INIT_FAILED)
.to_string();
assert!(
formatted.contains("CUDA driver library unavailable"),
"expected a human-readable loader hint, got: {formatted}"
);
assert!(
formatted.contains("failed to load any of"),
"expected the cached loader failure context to be surfaced, got: {formatted}"
);
assert!(
formatted.contains("libcuda"),
"expected the driver library name to be surfaced, got: {formatted}"
);
}
}