#[allow(unused_imports)]
use crate::{dense::*, svd::*};
use std::{path::Path, ptr, sync::Arc};
use singe_core::path_to_cstring;
use singe_cuda::{context::Context as CudaContext, stream::Stream, types::EmulationStrategy};
use singe_cuda_sys::runtime;
use crate::{
error::{Error, Result},
sys, try_ffi,
types::{DeterministicMode, MathMode},
};
#[derive(Debug)]
pub struct Context {
handle: Handle,
}
#[derive(Debug)]
struct Handle {
raw: sys::cusolverDnHandle_t,
cuda_ctx: Arc<CudaContext>,
}
unsafe impl Send for Handle {}
#[derive(Debug, Clone)]
pub enum StreamBinding {
Default,
Borrowed(BorrowedStream),
}
#[derive(Debug, Clone)]
pub struct BorrowedStream {
handle: runtime::cudaStream_t,
cuda_ctx: Arc<CudaContext>,
}
impl BorrowedStream {
pub const fn as_raw(&self) -> runtime::cudaStream_t {
self.handle
}
pub fn context(&self) -> &CudaContext {
self.cuda_ctx.as_ref()
}
}
impl Context {
pub fn create(cuda_ctx: &Arc<CudaContext>) -> Result<Self> {
cuda_ctx.bind()?;
let mut handle = ptr::null_mut();
unsafe {
try_ffi!(sys::cusolverDnCreate(&raw mut handle))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Self {
handle: Handle {
raw: handle,
cuda_ctx: Arc::clone(cuda_ctx),
},
})
}
pub fn cuda_context(&self) -> &Arc<CudaContext> {
&self.handle.cuda_ctx
}
pub fn bind(&self) -> Result<()> {
Ok(self.cuda_context().bind()?)
}
pub fn ensure_stream(&self, stream: &Stream) -> Result<()> {
if self.cuda_context().as_ref() != stream.context() {
return Err(Error::StreamContextMismatch);
}
self.bind()
}
pub fn stream(&self) -> Result<StreamBinding> {
self.bind()?;
let mut stream = ptr::null_mut();
unsafe {
try_ffi!(sys::cusolverDnGetStream(self.as_raw(), &raw mut stream))?;
}
Ok(if stream.is_null() {
StreamBinding::Default
} else {
StreamBinding::Borrowed(BorrowedStream {
handle: stream,
cuda_ctx: Arc::clone(self.cuda_context()),
})
})
}
pub fn set_stream(&self, stream: Option<&Stream>) -> Result<()> {
if let Some(stream) = stream {
self.ensure_stream(stream)?;
} else {
self.bind()?;
}
unsafe {
try_ffi!(sys::cusolverDnSetStream(
self.as_raw(),
match stream {
Some(stream) => stream.as_raw(),
None => ptr::null_mut(),
},
))?;
}
Ok(())
}
pub fn deterministic_mode(&self) -> Result<DeterministicMode> {
self.bind()?;
let mut mode = sys::cusolverDeterministicMode_t::CUSOLVER_DETERMINISTIC_RESULTS;
unsafe {
try_ffi!(sys::cusolverDnGetDeterministicMode(
self.as_raw(),
&raw mut mode,
))?;
}
Ok(mode.into())
}
pub fn set_deterministic_mode(&self, mode: DeterministicMode) -> Result<()> {
self.bind()?;
unsafe {
try_ffi!(sys::cusolverDnSetDeterministicMode(
self.as_raw(),
mode.into(),
))?;
}
Ok(())
}
pub fn math_mode(&self) -> Result<MathMode> {
self.bind()?;
let mut mode = sys::cusolverMathMode_t::CUSOLVER_DEFAULT_MATH;
unsafe {
try_ffi!(sys::cusolverDnGetMathMode(self.as_raw(), &raw mut mode))?;
}
Ok(mode.into())
}
pub fn set_math_mode(&self, mode: MathMode) -> Result<()> {
self.bind()?;
unsafe {
try_ffi!(sys::cusolverDnSetMathMode(self.as_raw(), mode.into()))?;
}
Ok(())
}
pub fn emulation_strategy(&self) -> Result<EmulationStrategy> {
self.bind()?;
let mut strategy = EmulationStrategy::Default.into();
unsafe {
try_ffi!(sys::cusolverDnGetEmulationStrategy(
self.as_raw(),
&raw mut strategy,
))?;
}
Ok(strategy.into())
}
pub fn set_emulation_strategy(&self, strategy: EmulationStrategy) -> Result<()> {
self.bind()?;
unsafe {
try_ffi!(sys::cusolverDnSetEmulationStrategy(
self.as_raw(),
strategy.into(),
))?;
}
Ok(())
}
pub unsafe fn set_logger_callback(callback: sys::cusolverDnLoggerCallback_t) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnLoggerSetCallback(callback))?;
}
Ok(())
}
pub fn set_logger_level(level: i32) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnLoggerSetLevel(level))?;
}
Ok(())
}
pub fn set_logger_mask(mask: i32) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnLoggerSetMask(mask))?;
}
Ok(())
}
pub unsafe fn set_logger_file(file: *mut sys::FILE) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnLoggerSetFile(file))?;
}
Ok(())
}
pub fn set_logger_path(path: impl AsRef<Path>) -> Result<()> {
let path = path_to_cstring(path.as_ref())?;
unsafe {
try_ffi!(sys::cusolverDnLoggerOpenFile(path.as_ptr()))?;
}
Ok(())
}
pub fn disable_logger() -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnLoggerForceDisable())?;
}
Ok(())
}
pub fn as_raw(&self) -> sys::cusolverDnHandle_t {
self.handle.raw
}
}
impl Drop for Handle {
fn drop(&mut self) {
if let Err(err) = self.cuda_ctx.bind() {
#[cfg(debug_assertions)]
eprintln!("failed to bind cuda context before destroying cusolver handle: {err}");
}
unsafe {
if let Err(err) = try_ffi!(sys::cusolverDnDestroy(self.raw)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cusolver context: {err}");
}
}
}
}