rustsim 0.0.1

High-performance agent-based modelling engine - top-level orchestration crate
Documentation
//! Shared CUDA context initialization helpers.

use cudarc::driver::CudaContext;
use std::any::Any;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::sync::Arc;

/// Create a CUDA context and convert driver-loader panics into errors.
pub(crate) fn new_context(device: usize) -> Result<Arc<CudaContext>, String> {
    context_init_result(|| CudaContext::new(device))
}

pub(crate) fn context_init_result<T, E>(init: impl FnOnce() -> Result<T, E>) -> Result<T, String>
where
    E: std::fmt::Display,
{
    match catch_unwind(AssertUnwindSafe(init)) {
        Ok(Ok(ctx)) => Ok(ctx),
        Ok(Err(err)) => Err(format!("CUDA context init failed: {err}")),
        Err(payload) => Err(format!(
            "CUDA context init panicked: {}",
            panic_payload_message(payload.as_ref())
        )),
    }
}

fn panic_payload_message(payload: &(dyn Any + Send)) -> &str {
    if let Some(message) = payload.downcast_ref::<&'static str>() {
        message
    } else if let Some(message) = payload.downcast_ref::<String>() {
        message.as_str()
    } else {
        "unknown panic payload"
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn context_init_result_converts_errors() {
        let err = context_init_result::<(), _>(|| Err("driver unavailable")).unwrap_err();
        assert_eq!(err, "CUDA context init failed: driver unavailable");
    }

    #[test]
    fn context_init_result_converts_panics() {
        let err =
            context_init_result::<(), &'static str>(|| panic!("loader exploded")).unwrap_err();
        assert_eq!(err, "CUDA context init panicked: loader exploded");
    }
}