use super::Context;
use super::PointerMode;
use crate::ffi::*;
use crate::{Error, API};
use lazy_static::lazy_static;
use log::debug;
use std::collections::HashSet;
use std::convert::AsRef;
use std::convert::TryFrom;
use std::ptr;
use std::ptr::NonNull;
use std::sync::{Arc, Mutex};
#[derive(Hash, Eq, PartialEq)]
struct Cookie(NonNull<cublasContext>);
unsafe impl std::marker::Send for Cookie {}
impl Cookie {
fn as_ptr(&self) -> *mut cublasContext {
self.0.as_ptr()
}
}
impl TryFrom<cublasHandle_t> for Cookie {
type Error = Error;
fn try_from(handle: *mut cublasContext) -> std::result::Result<Self, Self::Error> {
if let Some(nn) = NonNull::new(handle) {
Ok(Cookie(nn))
} else {
Err(Error::Unknown("cublasHandle is a nullptr", 0))
}
}
}
lazy_static! {
static ref TRACKER: Arc<Mutex<HashSet<Cookie>>> =
Arc::new(Mutex::new(HashSet::with_capacity(3)));
}
fn track(handle: cublasHandle_t) {
let mut guard = TRACKER.as_ref().lock().unwrap();
let _ = guard.insert(Cookie::try_from(handle as *mut cublasContext).unwrap());
debug!("Added handle {:?}, total of {}", handle, guard.len());
}
fn untrack(handle: cublasHandle_t) {
let mut guard = TRACKER.as_ref().lock().unwrap();
debug!("Removed handle {:?}, total of {}", handle, guard.len());
let k = Cookie::try_from(handle as *mut cublasContext).unwrap();
let _ = guard.remove(&k);
}
fn cleanup() {
let guard = TRACKER.lock().unwrap();
for handle in guard.iter() {
unsafe {
API::ffi_destroy(handle.as_ptr()).unwrap();
}
}
}
impl API {
pub fn create() -> Result<Context, Error> {
let handle = unsafe { API::ffi_create() }?;
track(handle);
Ok(Context::from_c(handle))
}
pub unsafe fn destroy(context: &mut Context) -> Result<(), Error> {
let handle = *context.id_c();
untrack(handle);
Ok(API::ffi_destroy(handle)?)
}
pub fn get_version(context: &Context) -> Result<i32, Error> {
unsafe { API::ffi_get_version(*context.id_c()) }
}
unsafe fn ffi_get_version(handle: cublasHandle_t) -> Result<i32, Error> {
let mut version: i32 = 0;
let version_ptr: *mut i32 = &mut version;
match cublasGetVersion_v2(handle, version_ptr) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(version),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
status => Err(Error::Unknown(
"Other Unknown Error with CUBLAS Get Version",
status as i32 as u64,
)),
}
}
pub fn get_pointer_mode(context: &Context) -> Result<PointerMode, Error> {
Ok(PointerMode::from_c(unsafe {
API::ffi_get_pointer_mode(*context.id_c())
}?))
}
pub fn set_pointer_mode(context: &mut Context, pointer_mode: PointerMode) -> Result<(), Error> {
Ok(unsafe { API::ffi_set_pointer_mode(*context.id_c(), pointer_mode.as_c()) }?)
}
unsafe fn ffi_create() -> Result<cublasHandle_t, Error> {
let mut handle: cublasHandle_t = ptr::null_mut();
match cublasCreate_v2(&mut handle) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(handle),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
cublasStatus_t::CUBLAS_STATUS_ARCH_MISMATCH => Err(Error::ArchMismatch),
cublasStatus_t::CUBLAS_STATUS_ALLOC_FAILED => Err(Error::AllocFailed),
status => Err(Error::Unknown(
"Unable to create the cuBLAS context/resources.",
status as i32 as u64,
)),
}
}
unsafe fn ffi_destroy(handle: cublasHandle_t) -> Result<(), Error> {
match cublasDestroy_v2(handle) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
status => Err(Error::Unknown(
"Unable to destroy the CUDA cuDNN context/resources.",
status as i32 as u64,
)),
}
}
unsafe fn ffi_get_pointer_mode(handle: cublasHandle_t) -> Result<cublasPointerMode_t, Error> {
let pointer_mode = &mut [cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST];
match cublasGetPointerMode_v2(handle, pointer_mode.as_mut_ptr()) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(pointer_mode[0]),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
status => Err(Error::Unknown(
"Unable to get cuBLAS pointer mode.",
status as i32 as u64,
)),
}
}
unsafe fn ffi_set_pointer_mode(
handle: cublasHandle_t,
pointer_mode: cublasPointerMode_t,
) -> Result<(), Error> {
match cublasSetPointerMode_v2(handle, pointer_mode) {
cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
cublasStatus_t::CUBLAS_STATUS_NOT_INITIALIZED => Err(Error::NotInitialized),
status => Err(Error::Unknown(
"Unable to get cuBLAS pointer mode.",
status as i32 as u64,
)),
}
}
}
#[cfg(test)]
mod test {
use crate::ffi::cublasPointerMode_t;
use crate::Context;
use crate::API;
#[test]
fn manual_context_creation() {
crate::chore::test_setup();
unsafe {
let handle = API::ffi_create().unwrap();
API::ffi_destroy(handle).unwrap();
}
}
#[test]
fn default_pointer_mode_is_host() {
crate::chore::test_setup();
unsafe {
dbg!("Pointer Mode Test - Initialises New CUBLAS");
let context = Context::new().unwrap();
let mode = API::ffi_get_pointer_mode(*context.id_c()).unwrap();
assert_eq!(cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST, mode);
}
crate::chore::test_teardown();
}
#[test]
fn can_set_pointer_mode() {
crate::chore::test_setup();
unsafe {
let context = Context::new().unwrap();
API::ffi_set_pointer_mode(
*context.id_c(),
cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
)
.unwrap();
let mode = API::ffi_get_pointer_mode(*context.id_c()).unwrap();
assert_eq!(cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE, mode);
API::ffi_set_pointer_mode(
*context.id_c(),
cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST,
)
.unwrap();
let mode2 = API::ffi_get_pointer_mode(*context.id_c()).unwrap();
assert_eq!(cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST, mode2);
}
crate::chore::test_teardown();
}
}