#[allow(unused_imports)]
use lazy_static::lazy_static;
#[cfg(feature = "cuda")]
pub extern crate cust;
pub use ulib_derive::UniversalCopy;
#[cfg(feature = "cuda")]
use cust::memory::{ DeviceCopy, DeviceSlice };
#[cfg(feature = "cuda")]
pub const MAX_NUM_CUDA_DEVICES: usize = 4;
#[cfg(feature = "cuda")]
pub const MAX_DEVICES: usize = MAX_NUM_CUDA_DEVICES + 1;
#[cfg(not(feature = "cuda"))]
pub const MAX_DEVICES: usize = 1;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Device {
CPU,
#[cfg(feature = "cuda")]
CUDA(u8 )
}
pub struct DeviceContext {
#[cfg(feature = "cuda")]
#[allow(dead_code)]
cuda_context: Option<cust::context::Context>,
}
impl Device {
#[inline]
fn to_id(self) -> usize {
use Device::*;
match self {
CPU => 0,
#[cfg(feature = "cuda")]
CUDA(c) => {
assert!((c as usize) < MAX_NUM_CUDA_DEVICES,
"invalid cuda device id");
c as usize + 1
}
}
}
#[inline]
fn from_id(id: usize) -> Device {
use Device::*;
match id {
0 => CPU,
#[cfg(feature = "cuda")]
c @ 1..=MAX_NUM_CUDA_DEVICES => CUDA(c as u8 - 1),
id @ _ => panic!("device id {} is invalid.", id)
}
}
#[inline]
pub fn get_context(self) -> DeviceContext {
use Device::*;
match self {
CPU => DeviceContext {
#[cfg(feature = "cuda")]
cuda_context: None
},
#[cfg(feature = "cuda")]
CUDA(c) => DeviceContext {
cuda_context: Some(cust::context::Context::new(
CUDA_DEVICES[c as usize].0).unwrap())
}
}
}
#[inline]
pub fn synchronize(self) {
use Device::*;
match self {
CPU => {},
#[cfg(feature = "cuda")]
CUDA(c) => {
let _context = cust::context::Context::new(
CUDA_DEVICES[c as usize].0).unwrap();
cust::context::CurrentContext::synchronize().unwrap();
}
}
}
}
#[cfg(feature = "cuda")]
pub trait UniversalCopy: Copy + DeviceCopy { }
#[cfg(feature = "cuda")]
impl<T: Copy + DeviceCopy> UniversalCopy for T { }
#[cfg(not(feature = "cuda"))]
pub trait UniversalCopy: Copy { }
#[cfg(not(feature = "cuda"))]
impl<T: Copy> UniversalCopy for T { }
#[cfg(feature = "cuda")]
lazy_static! {
static ref CUDA_DEVICES: Vec<(cust::device::Device, cust::context::Context)> = {
cust::init(cust::CudaFlags::empty()).unwrap();
let mut ret = cust::device::Device::devices().unwrap()
.map(|d| {
let d = d.unwrap();
(d, cust::context::Context::new(d).unwrap())
})
.collect::<Vec<_>>();
if ret.len() > MAX_NUM_CUDA_DEVICES as usize {
clilog::warn!(ULIB_CUDA_TRUNC,
"the number of available cuda gpus {} \
exceed max supported {}, truncated.",
ret.len(), MAX_NUM_CUDA_DEVICES);
ret.truncate(MAX_NUM_CUDA_DEVICES as usize);
}
ret
};
pub static ref NUM_CUDA_DEVICES: usize = CUDA_DEVICES.len();
}
#[cfg(feature = "cuda")]
pub trait AsCUDASlice<T: UniversalCopy> {
fn as_cuda_slice(&self, cuda_device: Device) -> DeviceSlice<T>;
}
#[cfg(feature = "cuda")]
pub trait AsCUDASliceMut<T: UniversalCopy> {
fn as_cuda_slice_mut(&mut self, cuda_device: Device) ->
DeviceSlice<T>;
}
pub trait AsUPtr<T: UniversalCopy> {
fn as_uptr(&self, device: Device) -> *const T;
}
pub trait AsUPtrMut<T: UniversalCopy> {
fn as_mut_uptr(&mut self, device: Device) -> *mut T;
}
mod uvec;
pub use uvec::UVec;