use libloading::Library;
use ndarray::{Array2, ArrayBase, Data, Ix2};
pub type CuResult = i32;
type CuInit = unsafe extern "C" fn(u32) -> CuResult;
type CuDeviceGet = unsafe extern "C" fn(*mut i32, i32) -> CuResult;
type CuCtxCreate = unsafe extern "C" fn(*mut usize, u32, i32) -> CuResult;
type CuCtxSetCurrent = unsafe extern "C" fn(usize) -> CuResult;
type CuCtxDestroy = unsafe extern "C" fn(usize) -> CuResult;
type CuMemAlloc = unsafe extern "C" fn(*mut u64, usize) -> CuResult;
type CuMemFree = unsafe extern "C" fn(u64) -> CuResult;
type CuMemcpyHtoD = unsafe extern "C" fn(u64, *const std::ffi::c_void, usize) -> CuResult;
type CuMemcpyDtoH = unsafe extern "C" fn(*mut std::ffi::c_void, u64, usize) -> CuResult;
pub struct DriverApi {
pub cu_init: CuInit,
pub cu_device_get: CuDeviceGet,
pub cu_ctx_create: CuCtxCreate,
pub cu_ctx_set_current: CuCtxSetCurrent,
pub cu_ctx_destroy: CuCtxDestroy,
pub cu_mem_alloc: CuMemAlloc,
pub cu_mem_free: CuMemFree,
pub cu_memcpy_htod: CuMemcpyHtoD,
pub cu_memcpy_dtoh: CuMemcpyDtoH,
}
impl DriverApi {
pub fn load(library: &Library) -> Result<Self, String> {
unsafe {
Ok(Self {
cu_init: *library.get(b"cuInit\0").map_err(|e| e.to_string())?,
cu_device_get: *library.get(b"cuDeviceGet\0").map_err(|e| e.to_string())?,
cu_ctx_create: *library
.get(b"cuCtxCreate_v2\0")
.map_err(|e| e.to_string())?,
cu_ctx_set_current: *library
.get(b"cuCtxSetCurrent\0")
.map_err(|e| e.to_string())?,
cu_ctx_destroy: *library
.get(b"cuCtxDestroy_v2\0")
.map_err(|e| e.to_string())?,
cu_mem_alloc: *library.get(b"cuMemAlloc_v2\0").map_err(|e| e.to_string())?,
cu_mem_free: *library.get(b"cuMemFree_v2\0").map_err(|e| e.to_string())?,
cu_memcpy_htod: *library
.get(b"cuMemcpyHtoD_v2\0")
.map_err(|e| e.to_string())?,
cu_memcpy_dtoh: *library
.get(b"cuMemcpyDtoH_v2\0")
.map_err(|e| e.to_string())?,
})
}
}
}
pub struct CudaWorkingState {
pub api: DriverApi,
pub context: usize,
}
impl CudaWorkingState {
pub fn init(device_ordinal: usize) -> Option<Self> {
let ordinal = to_i32(device_ordinal)?;
let library = load_static_library(cuda_library_candidates()).ok()?;
let api = DriverApi::load(library).ok()?;
unsafe {
check_cuda((api.cu_init)(0), "cuInit").ok()?;
let mut device = 0_i32;
check_cuda((api.cu_device_get)(&mut device, ordinal), "cuDeviceGet").ok()?;
let mut context = 0_usize;
check_cuda((api.cu_ctx_create)(&mut context, 0, device), "cuCtxCreate").ok()?;
Some(Self { api, context })
}
}
#[inline]
pub fn set_current(&self) -> Result<(), String> {
check_cuda(
unsafe { (self.api.cu_ctx_set_current)(self.context) },
"cuCtxSetCurrent",
)
}
}
impl Drop for CudaWorkingState {
fn drop(&mut self) {
unsafe {
let _ = (self.api.cu_ctx_destroy)(self.context);
}
}
}
pub struct DeviceAllocation<'a> {
driver: &'a DriverApi,
pub ptr: u64,
}
impl<'a> DeviceAllocation<'a> {
pub unsafe fn new(driver: &'a DriverApi, bytes: usize) -> Option<Self> {
let mut ptr = 0_u64;
check_cuda(
unsafe { (driver.cu_mem_alloc)(&mut ptr, bytes) },
"cuMemAlloc",
)
.ok()?;
Some(Self { driver, ptr })
}
}
impl Drop for DeviceAllocation<'_> {
fn drop(&mut self) {
unsafe {
let _ = (self.driver.cu_mem_free)(self.ptr);
}
}
}
#[inline]
pub fn check_cuda(result: CuResult, name: &str) -> Result<(), String> {
if result == 0 {
Ok(())
} else {
Err(format!("{name} failed with CUDA driver error {result}"))
}
}
fn load_library(candidates: &[&str]) -> Result<Library, String> {
for candidate in candidates {
if let Ok(library) = unsafe { Library::new(*candidate) } {
return Ok(library);
}
}
Err(format!("could not load any of: {}", candidates.join(", ")))
}
pub fn load_static_library(candidates: &[&str]) -> Result<&'static Library, String> {
Ok(Box::leak(Box::new(load_library(candidates)?)))
}
pub fn cuda_library_candidates() -> &'static [&'static str] {
if cfg!(target_os = "windows") {
&["nvcuda.dll"]
} else if cfg!(target_os = "macos") {
&["/usr/local/cuda/lib/libcuda.dylib", "libcuda.dylib"]
} else {
&["libcuda.so.1", "libcuda.so"]
}
}
#[inline]
pub fn bytes_len<T>(len: usize) -> Option<usize> {
len.checked_mul(std::mem::size_of::<T>())
}
#[inline]
pub fn to_i32(value: usize) -> Option<i32> {
i32::try_from(value).ok()
}
#[inline]
pub fn to_i64(value: usize) -> Option<i64> {
i64::try_from(value).ok()
}
pub fn to_col_major<S: Data<Elem = f64>>(a: &ArrayBase<S, Ix2>) -> Vec<f64> {
let (rows, cols) = a.dim();
let mut out: Vec<f64> = Vec::with_capacity(rows.saturating_mul(cols));
for col in 0..cols {
out.extend(a.column(col).iter().copied());
}
out
}
pub fn from_col_major_inplace(values: &[f64], out: &mut Array2<f64>) {
let (rows, cols) = out.dim();
debug_assert_eq!(values.len(), rows.saturating_mul(cols));
for col in 0..cols {
let src = ndarray::ArrayView1::from(&values[col * rows..(col + 1) * rows]);
out.column_mut(col).assign(&src);
}
}
pub fn from_col_major(values: &[f64], rows: usize, cols: usize) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((rows, cols));
from_col_major_inplace(values, &mut out);
out
}