#[allow(unused_imports)]
use crate::error::Status;
use std::{ptr, sync::Arc};
use singe_core::LibraryVersion;
use singe_cublas_sys as sys;
use singe_cuda::{context::Context as CudaContext, stream::Stream};
use singe_cuda_sys::runtime;
use crate::{
error::{Error, Result},
lt::version,
try_ffi,
};
#[derive(Debug)]
pub struct Context {
handle: Handle,
}
#[derive(Debug)]
struct Handle {
raw: sys::cublasLtHandle_t,
cuda_ctx: Arc<CudaContext>,
}
unsafe impl Send for Handle {}
impl Context {
pub fn create(cuda_ctx: &Arc<CudaContext>) -> Result<Self> {
cuda_ctx.bind()?;
let mut handle = ptr::null_mut();
unsafe {
try_ffi!(sys::cublasLtCreate(&raw mut handle))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Self {
handle: Handle {
raw: handle,
cuda_ctx: Arc::clone(cuda_ctx),
},
})
}
pub fn cuda_context(&self) -> &Arc<CudaContext> {
&self.handle.cuda_ctx
}
pub fn bind(&self) -> Result<()> {
Ok(self.cuda_context().bind()?)
}
pub fn ensure_stream(&self, stream: &Stream) -> Result<()> {
if self.cuda_context().as_ref() != stream.context() {
return Err(Error::StreamContextMismatch);
}
self.bind()
}
pub fn version(&self) -> Result<LibraryVersion> {
self.bind()?;
version()
}
pub(crate) unsafe fn stream_raw(
&self,
stream: Option<&Stream>,
) -> Result<runtime::cudaStream_t> {
if let Some(stream) = stream {
self.ensure_stream(stream)?;
Ok(stream.as_raw())
} else {
self.bind()?;
Ok(ptr::null_mut())
}
}
pub fn as_raw(&self) -> sys::cublasLtHandle_t {
self.handle.raw
}
}
impl Drop for Handle {
fn drop(&mut self) {
if let Err(err) = self.cuda_ctx.bind() {
#[cfg(debug_assertions)]
eprintln!("failed to bind cuda context before destroying cublasLt handle: {err}");
}
unsafe {
if let Err(err) = try_ffi!(sys::cublasLtDestroy(self.raw)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cublasLt context: {err}");
}
}
}
}