#![allow(clippy::too_many_arguments)]
use super::{result, result::CublasError, sys};
use crate::driver::CudaStream;
use std::sync::Arc;
mod asum;
mod gemm;
mod gemv;
mod grouped_gemm;
pub use asum::*;
pub use gemm::*;
pub use gemv::*;
pub use grouped_gemm::*;
#[derive(Debug)]
pub struct CudaBlas {
pub(crate) handle: sys::cublasHandle_t,
pub(crate) stream: Arc<CudaStream>,
}
unsafe impl Send for CudaBlas {}
unsafe impl Sync for CudaBlas {}
impl CudaBlas {
pub fn new(stream: Arc<CudaStream>) -> Result<Self, CublasError> {
let ctx = stream.context();
ctx.record_err(ctx.bind_to_thread());
let handle = result::create_handle()?;
unsafe { result::set_stream(handle, stream.cu_stream() as _) }?;
let blas = Self { handle, stream };
Ok(blas)
}
pub fn handle(&self) -> &sys::cublasHandle_t {
&self.handle
}
pub unsafe fn set_stream(&mut self, stream: Arc<CudaStream>) -> Result<(), CublasError> {
self.stream = stream;
unsafe { result::set_stream(self.handle, self.stream.cu_stream() as _) }
}
pub fn set_pointer_mode(
&self,
pointer_mode: sys::cublasPointerMode_t,
) -> Result<(), CublasError> {
unsafe {
sys::cublasSetPointerMode_v2(self.handle, pointer_mode).result()?;
}
Ok(())
}
pub fn get_pointer_mode(&self) -> Result<sys::cublasPointerMode_t, CublasError> {
unsafe {
let mut mode = ::core::mem::MaybeUninit::uninit();
sys::cublasGetPointerMode_v2(self.handle, mode.as_mut_ptr()).result()?;
Ok(mode.assume_init())
}
}
}
impl Drop for CudaBlas {
fn drop(&mut self) {
let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
if !handle.is_null() {
unsafe { result::destroy_handle(handle) }.unwrap();
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::needless_range_loop)]
use crate::driver::CudaContext;
use super::*;
#[test]
fn cublas_pointer_mode() {
let ctx = CudaContext::new(0).unwrap();
let stream = ctx.default_stream();
let blas = CudaBlas::new(stream.clone()).unwrap();
assert_eq!(
blas.get_pointer_mode().unwrap(),
sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_HOST,
"The default pointer mode uses host pointers"
);
blas.set_pointer_mode(sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE)
.unwrap();
assert_eq!(
blas.get_pointer_mode().unwrap(),
sys::cublasPointerMode_t::CUBLAS_POINTER_MODE_DEVICE,
"We have set the mode to use device pointers"
);
}
}