use core::ffi::{c_int, c_void};
use std::ffi::CString;
use baracuda_cuda_sys::runtime::runtime;
use crate::error::{check, Error, Result};
#[derive(Copy, Clone, Debug)]
pub struct DriverEntryPoint {
pub fn_ptr: *mut c_void,
pub status: i32,
}
impl DriverEntryPoint {
#[inline]
pub fn is_success(&self) -> bool {
self.status == 0 && !self.fn_ptr.is_null()
}
}
pub fn driver_entry_point(symbol: &str, flags: u64) -> Result<DriverEntryPoint> {
let c_sym = CString::new(symbol).map_err(|_| {
Error::Loader(baracuda_core::LoaderError::SymbolNotFound {
library: "cuda-runtime",
symbol: "cudaGetDriverEntryPoint(symbol contained a NUL byte)",
})
})?;
let r = runtime()?;
let cu = r.cuda_get_driver_entry_point()?;
let mut fn_ptr: *mut c_void = core::ptr::null_mut();
let mut driver_status: c_int = 0;
check(unsafe { cu(c_sym.as_ptr(), &mut fn_ptr, flags, &mut driver_status) })?;
Ok(DriverEntryPoint {
fn_ptr,
status: driver_status,
})
}