use crate::{error::*, *};
use cuda::*;
use std::sync::{Arc, Once};
#[derive(Debug, PartialEq, PartialOrd)]
pub struct Device {
device: CUdevice,
}
impl Device {
fn init() {
static DRIVER_API_INIT: Once = Once::new();
DRIVER_API_INIT.call_once(|| unsafe {
ffi_call!(cuda::cuInit, 0).expect("Initialization of CUDA Driver API failed");
});
}
pub fn get_count() -> Result<usize> {
Self::init();
let mut count: i32 = 0;
unsafe {
ffi_call!(cuDeviceGetCount, &mut count as *mut i32)?;
}
Ok(count as usize)
}
pub fn nth(id: usize) -> Result<Self> {
let count = Self::get_count()?;
if id >= count {
return Err(AccelError::DeviceNotFound { id, count });
}
let device = unsafe { ffi_new!(cuDeviceGet, id as i32)? };
Ok(Device { device })
}
pub fn total_memory(&self) -> Result<usize> {
let mut mem = 0;
unsafe {
ffi_call!(cuDeviceTotalMem_v2, &mut mem as *mut _, self.device)?;
}
Ok(mem)
}
pub fn get_name(&self) -> Result<String> {
let mut bytes: Vec<u8> = vec![0_u8; 1024];
unsafe {
ffi_call!(
cuDeviceGetName,
bytes.as_mut_ptr() as *mut i8,
1024,
self.device
)?;
}
Ok(String::from_utf8(bytes).expect("GPU name is not UTF8"))
}
pub fn create_context(&self) -> Arc<Context> {
Arc::new(Context::create(self.device))
}
}
pub struct ContextGuard {
ctx: Arc<Context>,
}
impl ContextGuard {
pub fn guard_context(ctx: Arc<Context>) -> Self {
ctx.push();
Self { ctx }
}
}
impl Drop for ContextGuard {
fn drop(&mut self) {
self.ctx.pop();
}
}
pub trait Contexted {
fn get_context(&self) -> Arc<Context>;
fn guard_context(&self) -> ContextGuard {
let ctx = self.get_context();
ContextGuard::guard_context(ctx)
}
fn sync_context(&self) -> Result<()> {
let ctx = self.get_context();
ctx.sync()?;
Ok(())
}
}
#[derive(Debug, PartialEq)]
pub struct Context {
context_ptr: CUcontext,
}
impl Drop for Context {
fn drop(&mut self) {
if let Err(e) = unsafe { ffi_call!(cuCtxDestroy_v2, self.context_ptr) } {
log::error!("Context remove failed: {:?}", e);
}
}
}
unsafe impl Send for Context {}
unsafe impl Sync for Context {}
impl Context {
fn push(&self) {
unsafe {
ffi_call!(cuCtxPushCurrent_v2, self.context_ptr).expect("Failed to push context");
}
}
fn pop(&self) {
let context_ptr =
unsafe { ffi_new!(cuCtxPopCurrent_v2).expect("Failed to pop current context") };
if context_ptr.is_null() {
panic!("No current context");
}
assert!(
context_ptr == self.context_ptr,
"Pop must return same pointer"
);
}
fn create(device: CUdevice) -> Self {
let context_ptr = unsafe {
ffi_new!(
cuCtxCreate_v2,
CUctx_flags_enum::CU_CTX_SCHED_AUTO as u32,
device
)
}
.expect("Failed to create a new context");
if context_ptr.is_null() {
panic!("Cannot crate a new context");
}
let ctx = Context { context_ptr };
ctx.pop();
ctx
}
pub fn version(&self) -> u32 {
let mut version: u32 = 0;
unsafe { ffi_call!(cuCtxGetApiVersion, self.context_ptr, &mut version as *mut _) }
.expect("Failed to get Driver API version");
version
}
pub fn sync(&self) -> Result<()> {
self.push();
unsafe {
ffi_call!(cuCtxSynchronize)?;
}
self.pop();
Ok(())
}
}
impl Contexted for Arc<Context> {
fn get_context(&self) -> Arc<Context> {
self.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn get_count() -> Result<()> {
Device::get_count()?;
Ok(())
}
#[test]
fn get_zeroth() -> Result<()> {
Device::nth(0)?;
Ok(())
}
#[test]
fn out_of_range() -> Result<()> {
assert!(Device::nth(129).is_err());
Ok(())
}
#[test]
fn create() -> Result<()> {
let device = Device::nth(0)?;
let ctx = device.create_context();
dbg!(&ctx);
Ok(())
}
}