use cudarc::driver::CudaContext;
use std::any::Any;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::sync::Arc;
pub(crate) fn new_context(device: usize) -> Result<Arc<CudaContext>, String> {
context_init_result(|| CudaContext::new(device))
}
pub(crate) fn context_init_result<T, E>(init: impl FnOnce() -> Result<T, E>) -> Result<T, String>
where
E: std::fmt::Display,
{
match catch_unwind(AssertUnwindSafe(init)) {
Ok(Ok(ctx)) => Ok(ctx),
Ok(Err(err)) => Err(format!("CUDA context init failed: {err}")),
Err(payload) => Err(format!(
"CUDA context init panicked: {}",
panic_payload_message(payload.as_ref())
)),
}
}
fn panic_payload_message(payload: &(dyn Any + Send)) -> &str {
if let Some(message) = payload.downcast_ref::<&'static str>() {
message
} else if let Some(message) = payload.downcast_ref::<String>() {
message.as_str()
} else {
"unknown panic payload"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn context_init_result_converts_errors() {
let err = context_init_result::<(), _>(|| Err("driver unavailable")).unwrap_err();
assert_eq!(err, "CUDA context init failed: driver unavailable");
}
#[test]
fn context_init_result_converts_panics() {
let err =
context_init_result::<(), &'static str>(|| panic!("loader exploded")).unwrap_err();
assert_eq!(err, "CUDA context init panicked: loader exploded");
}
}