use libloading::Library;
use ndarray::{Array2, ArrayBase, Data, Ix2};
use std::borrow::Cow;
use std::path::Path;
use std::sync::OnceLock;
use super::error::GpuError;
pub type CuResult = i32;
type CuInit = unsafe extern "C" fn(u32) -> CuResult;
type CuDeviceGet = unsafe extern "C" fn(*mut i32, i32) -> CuResult;
type CuCtxCreate = unsafe extern "C" fn(*mut usize, u32, i32) -> CuResult;
type CuCtxSetCurrent = unsafe extern "C" fn(usize) -> CuResult;
type CuCtxDestroy = unsafe extern "C" fn(usize) -> CuResult;
type CuMemAlloc = unsafe extern "C" fn(*mut u64, usize) -> CuResult;
type CuMemFree = unsafe extern "C" fn(u64) -> CuResult;
type CuMemcpyHtoD = unsafe extern "C" fn(u64, *const std::ffi::c_void, usize) -> CuResult;
type CuMemcpyDtoH = unsafe extern "C" fn(*mut std::ffi::c_void, u64, usize) -> CuResult;
pub struct DriverApi {
pub cu_init: CuInit,
pub cu_device_get: CuDeviceGet,
pub cu_ctx_create: CuCtxCreate,
pub cu_ctx_set_current: CuCtxSetCurrent,
pub cu_ctx_destroy: CuCtxDestroy,
pub cu_mem_alloc: CuMemAlloc,
pub cu_mem_free: CuMemFree,
pub cu_memcpy_htod: CuMemcpyHtoD,
pub cu_memcpy_dtoh: CuMemcpyDtoH,
}
impl DriverApi {
pub fn load(library: &'static Library) -> Result<Self, GpuError> {
let sym = |e: libloading::Error| GpuError::DriverSymbolMissing {
reason: e.to_string(),
};
unsafe {
Ok(Self {
cu_init: *library.get(b"cuInit\0").map_err(sym)?,
cu_device_get: *library.get(b"cuDeviceGet\0").map_err(sym)?,
cu_ctx_create: *library.get(b"cuCtxCreate_v2\0").map_err(sym)?,
cu_ctx_set_current: *library.get(b"cuCtxSetCurrent\0").map_err(sym)?,
cu_ctx_destroy: *library.get(b"cuCtxDestroy_v2\0").map_err(sym)?,
cu_mem_alloc: *library.get(b"cuMemAlloc_v2\0").map_err(sym)?,
cu_mem_free: *library.get(b"cuMemFree_v2\0").map_err(sym)?,
cu_memcpy_htod: *library.get(b"cuMemcpyHtoD_v2\0").map_err(sym)?,
cu_memcpy_dtoh: *library.get(b"cuMemcpyDtoH_v2\0").map_err(sym)?,
})
}
}
}
pub struct CudaWorkingState {
pub api: DriverApi,
pub context: usize,
}
impl CudaWorkingState {
pub fn init(device_ordinal: usize) -> Option<Self> {
let ordinal = to_i32(device_ordinal)?;
let library = load_static_cuda_driver_library().ok()?;
let api = DriverApi::load(library).ok()?;
unsafe {
check_cuda((api.cu_init)(0), "cuInit").ok()?;
let mut device = 0_i32;
check_cuda((api.cu_device_get)(&mut device, ordinal), "cuDeviceGet").ok()?;
let mut context = 0_usize;
check_cuda((api.cu_ctx_create)(&mut context, 0, device), "cuCtxCreate").ok()?;
Some(Self { api, context })
}
}
#[inline]
pub fn set_current(&self) -> Result<(), GpuError> {
check_cuda(
unsafe { (self.api.cu_ctx_set_current)(self.context) },
"cuCtxSetCurrent",
)
}
}
impl Drop for CudaWorkingState {
fn drop(&mut self) {
unsafe {
(self.api.cu_ctx_destroy)(self.context);
}
}
}
pub struct DeviceAllocation<'a> {
state: &'a CudaWorkingState,
pub ptr: u64,
}
impl<'a> DeviceAllocation<'a> {
pub unsafe fn new(state: &'a CudaWorkingState, bytes: usize) -> Option<Self> {
let mut ptr = 0_u64;
check_cuda(
unsafe { (state.api.cu_mem_alloc)(&mut ptr, bytes) },
"cuMemAlloc",
)
.ok()?;
Some(Self { state, ptr })
}
}
impl Drop for DeviceAllocation<'_> {
fn drop(&mut self) {
unsafe {
(self.state.api.cu_ctx_set_current)(self.state.context);
}
unsafe {
(self.state.api.cu_mem_free)(self.ptr);
}
}
}
#[inline]
pub fn check_cuda(result: CuResult, name: &str) -> Result<(), GpuError> {
if result == 0 {
Ok(())
} else {
Err(GpuError::DriverCallFailed {
reason: format!("{name} failed with CUDA driver error {result}"),
})
}
}
#[must_use]
pub fn cuda_driver_library_present() -> bool {
load_library_names(&cuda_library_candidate_names()).is_ok()
}
fn load_library_names(candidates: &[String]) -> Result<Library, GpuError> {
for candidate in candidates {
if let Ok(library) = unsafe { Library::new(candidate) } {
return Ok(library);
}
}
Err(GpuError::DriverLibraryUnavailable {
reason: format!("could not load any of: {}", candidates.join(", ")),
})
}
fn load_static_cuda_driver_library() -> Result<&'static Library, GpuError> {
let candidates = cuda_library_candidate_names();
let raw = Box::into_raw(Box::new(load_library_names(&candidates)?));
Ok(unsafe { &*raw })
}
pub fn preload_cuda_driver() -> Result<(), String> {
static PRELOAD: OnceLock<Result<(), String>> = OnceLock::new();
PRELOAD
.get_or_init(|| {
load_static_cuda_driver_library()
.map(|_| ())
.map_err(|err| err.to_string())
})
.clone()
}
fn cuda_library_candidate_names() -> Vec<String> {
let mut out: Vec<String> = cuda_library_candidates()
.iter()
.map(|candidate| (*candidate).to_string())
.collect();
if cfg!(target_os = "linux") {
for dir in [
"/usr/local/nvidia/lib64",
"/usr/local/nvidia/lib",
"/usr/local/cuda/compat",
"/usr/lib/x86_64-linux-gnu",
"/usr/lib64",
"/usr/lib/wsl/lib",
] {
append_versioned_linux_libcuda_candidates(&mut out, Path::new(dir));
}
}
out
}
fn append_versioned_linux_libcuda_candidates(out: &mut Vec<String>, dir: &Path) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
let mut versioned = Vec::new();
for entry in entries.flatten() {
let path = entry.path();
let Some(name) = path.file_name().and_then(|name| name.to_str()) else {
continue;
};
if name.starts_with("libcuda.so.") && name != "libcuda.so.1" {
versioned.push(path);
}
}
versioned.sort();
for path in versioned {
let candidate = path.to_string_lossy().into_owned();
if !out.iter().any(|existing| existing == &candidate) {
out.push(candidate);
}
}
}
pub fn cuda_library_candidates() -> &'static [&'static str] {
if cfg!(target_os = "windows") {
&["nvcuda.dll"]
} else if cfg!(target_os = "macos") {
&["/usr/local/cuda/lib/libcuda.dylib", "libcuda.dylib"]
} else {
&[
"/usr/local/nvidia/lib64/libcuda.so.1",
"/usr/local/nvidia/lib64/libcuda.so",
"/usr/local/nvidia/lib/libcuda.so.1",
"/usr/local/nvidia/lib/libcuda.so",
"/usr/local/cuda/compat/libcuda.so.1",
"/usr/local/cuda/compat/libcuda.so",
"/usr/lib/x86_64-linux-gnu/libcuda.so.1",
"/usr/lib/x86_64-linux-gnu/libcuda.so",
"/usr/lib64/libcuda.so.1",
"/usr/lib64/libcuda.so",
"/usr/lib/wsl/lib/libcuda.so.1",
"/usr/lib/wsl/lib/libcuda.so",
"libcuda.so.1",
"libcuda.so",
]
}
}
#[inline]
pub fn bytes_len<T>(len: usize) -> Option<usize> {
len.checked_mul(std::mem::size_of::<T>())
}
#[inline]
pub fn to_i32(value: usize) -> Option<i32> {
i32::try_from(value).ok()
}
#[inline]
pub fn to_i64(value: usize) -> Option<i64> {
i64::try_from(value).ok()
}
pub fn to_col_major<'a, S: Data<Elem = f64>>(a: &'a ArrayBase<S, Ix2>) -> Cow<'a, [f64]> {
let (rows, cols) = a.dim();
let strides = a.strides();
if rows > 0
&& cols > 0
&& strides[0] == 1
&& strides[1] == rows as isize
&& let Some(slice) = a.as_slice_memory_order()
{
return Cow::Borrowed(slice);
}
let mut out: Vec<f64> = Vec::with_capacity(rows.saturating_mul(cols));
for col in 0..cols {
out.extend(a.column(col).iter().copied());
}
Cow::Owned(out)
}
pub fn from_col_major_inplace(values: &[f64], out: &mut Array2<f64>) -> Option<()> {
let (rows, cols) = out.dim();
if values.len() != rows.checked_mul(cols)? {
return None;
}
for col in 0..cols {
let src = ndarray::ArrayView1::from(&values[col * rows..(col + 1) * rows]);
out.column_mut(col).assign(&src);
}
Some(())
}
pub fn from_col_major(values: &[f64], rows: usize, cols: usize) -> Option<Array2<f64>> {
let mut out = Array2::<f64>::zeros((rows, cols));
from_col_major_inplace(values, &mut out)?;
Some(out)
}