use std::sync::Arc;
use crate::driver::CudaStream;
use super::{result, sys};
pub use super::result::CusolverError;
#[derive(Debug)]
pub struct DnHandle {
handle: sys::cusolverDnHandle_t,
stream: Arc<CudaStream>,
}
unsafe impl Send for DnHandle {}
unsafe impl Sync for DnHandle {}
impl Drop for DnHandle {
fn drop(&mut self) {
let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
if !handle.is_null() {
unsafe { result::dn_destroy(handle) }.unwrap();
}
}
}
impl DnHandle {
pub fn new(stream: Arc<CudaStream>) -> Result<Self, CusolverError> {
let handle = result::dn_create()?;
unsafe { result::dn_set_stream(handle, stream.cu_stream() as _) }?;
Ok(Self { handle, stream })
}
pub fn cu(&self) -> sys::cusolverDnHandle_t {
self.handle
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
pub unsafe fn set_stream(&mut self, stream: Arc<CudaStream>) -> Result<(), CusolverError> {
self.stream = stream;
result::dn_set_stream(self.handle, self.stream.cu_stream() as _)
}
#[cfg(any(
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
))]
pub fn set_deterministic_mode(
&self,
mode: sys::cusolverDeterministicMode_t,
) -> Result<(), CusolverError> {
unsafe { result::dn_set_deterministic_mode(self.handle, mode) }
}
#[cfg(any(
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
))]
pub fn get_deterministic_mode(&self) -> sys::cusolverDeterministicMode_t {
unsafe { result::dn_get_deterministic_mode(self.handle) }.unwrap()
}
}
#[derive(Debug)]
pub struct DnParams {
params: sys::cusolverDnParams_t,
}
impl Drop for DnParams {
fn drop(&mut self) {
let params = std::mem::replace(&mut self.params, std::ptr::null_mut());
if !params.is_null() {
unsafe { result::dn_destroy_params(params) }.unwrap();
}
}
}
impl DnParams {
pub fn new(
function: sys::cusolverDnFunction_t,
algo: sys::cusolverAlgMode_t,
) -> Result<Self, CusolverError> {
let params = result::dn_create_params()?;
unsafe { result::dn_set_adv_options(params, function, algo) }?;
Ok(Self { params })
}
pub fn cu(&self) -> sys::cusolverDnParams_t {
self.params
}
}
#[derive(Debug)]
pub struct SpHandle {
handle: sys::cusolverSpHandle_t,
stream: Arc<CudaStream>,
}
unsafe impl Send for SpHandle {}
unsafe impl Sync for SpHandle {}
impl Drop for SpHandle {
fn drop(&mut self) {
let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
if !handle.is_null() {
unsafe { result::sp_destroy(handle) }.unwrap();
}
}
}
impl SpHandle {
pub fn new(stream: Arc<CudaStream>) -> Result<Self, CusolverError> {
let handle = result::sp_create()?;
unsafe { result::sp_set_stream(handle, stream.cu_stream() as _) }?;
Ok(Self { handle, stream })
}
pub fn cu(&self) -> sys::cusolverSpHandle_t {
self.handle
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
pub unsafe fn set_stream(&mut self, stream: Arc<CudaStream>) -> Result<(), CusolverError> {
self.stream = stream;
result::sp_set_stream(self.handle, self.stream.cu_stream() as _)
}
}
#[derive(Debug)]
pub struct RfHandle {
handle: sys::cusolverRfHandle_t,
}
unsafe impl Send for RfHandle {}
unsafe impl Sync for RfHandle {}
impl Drop for RfHandle {
fn drop(&mut self) {
let handle = std::mem::replace(&mut self.handle, std::ptr::null_mut());
if !handle.is_null() {
unsafe { result::rf_destroy(handle) }.unwrap();
}
}
}
impl RfHandle {
pub fn new() -> Result<Self, CusolverError> {
let handle = result::rf_create()?;
Ok(Self { handle })
}
pub fn cu(&self) -> sys::cusolverRfHandle_t {
self.handle
}
pub fn set_matrix_format(
&self,
format: sys::cusolverRfMatrixFormat_t,
diag: sys::cusolverRfUnitDiagonal_t,
) {
unsafe { result::rf_set_matrix_format(self.handle, format, diag) }.unwrap()
}
pub fn set_numeric_properties(&self, zero: f64, boost: f64) {
unsafe { result::rf_set_numeric_properties(self.handle, zero, boost) }.unwrap()
}
pub fn set_reset_values_fast_mode(&self, fast_mode: sys::cusolverRfResetValuesFastMode_t) {
unsafe { result::rf_set_reset_values_fast_mode(self.handle, fast_mode) }.unwrap()
}
pub fn set_algs(
&self,
fact_alg: sys::cusolverRfFactorization_t,
alg: sys::cusolverRfTriangularSolve_t,
) {
unsafe { result::rf_set_algs(self.handle, fact_alg, alg) }.unwrap();
}
}