use std::mem::MaybeUninit;
use super::sys;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct CusolverError(pub sys::cusolverStatus_t);
impl sys::cusolverStatus_t {
pub fn result(self) -> Result<(), CusolverError> {
match self {
sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS => Ok(()),
_ => Err(CusolverError(self)),
}
}
}
#[cfg(feature = "std")]
impl std::fmt::Display for CusolverError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CusolverError {}
pub fn dn_create() -> Result<sys::cusolverDnHandle_t, CusolverError> {
let mut handle = MaybeUninit::uninit();
unsafe { sys::cusolverDnCreate(handle.as_mut_ptr()) }.result()?;
Ok(unsafe { handle.assume_init() })
}
pub unsafe fn dn_destroy(handle: sys::cusolverDnHandle_t) -> Result<(), CusolverError> {
sys::cusolverDnDestroy(handle).result()
}
pub unsafe fn dn_set_stream(
handle: sys::cusolverDnHandle_t,
stream: sys::cudaStream_t,
) -> Result<(), CusolverError> {
sys::cusolverDnSetStream(handle, stream).result()
}
#[cfg(any(
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
))]
pub unsafe fn dn_get_deterministic_mode(
handle: sys::cusolverDnHandle_t,
) -> Result<sys::cusolverDeterministicMode_t, CusolverError> {
let mut mode = MaybeUninit::uninit();
sys::cusolverDnGetDeterministicMode(handle, mode.as_mut_ptr()).result()?;
Ok(unsafe { mode.assume_init() })
}
#[cfg(any(
feature = "cuda-12020",
feature = "cuda-12030",
feature = "cuda-12040",
feature = "cuda-12050",
feature = "cuda-12060",
feature = "cuda-12080",
feature = "cuda-12090",
feature = "cuda-13000",
))]
pub unsafe fn dn_set_deterministic_mode(
handle: sys::cusolverDnHandle_t,
mode: sys::cusolverDeterministicMode_t,
) -> Result<(), CusolverError> {
sys::cusolverDnSetDeterministicMode(handle, mode).result()
}
pub fn dn_create_params() -> Result<sys::cusolverDnParams_t, CusolverError> {
let mut params = MaybeUninit::uninit();
unsafe { sys::cusolverDnCreateParams(params.as_mut_ptr()) }.result()?;
Ok(unsafe { params.assume_init() })
}
pub unsafe fn dn_set_adv_options(
params: sys::cusolverDnParams_t,
function: sys::cusolverDnFunction_t,
algo: sys::cusolverAlgMode_t,
) -> Result<(), CusolverError> {
sys::cusolverDnSetAdvOptions(params, function, algo).result()
}
pub unsafe fn dn_destroy_params(params: sys::cusolverDnParams_t) -> Result<(), CusolverError> {
sys::cusolverDnDestroyParams(params).result()
}
pub fn sp_create() -> Result<sys::cusolverSpHandle_t, CusolverError> {
let mut handle = MaybeUninit::uninit();
unsafe { sys::cusolverSpCreate(handle.as_mut_ptr()) }.result()?;
Ok(unsafe { handle.assume_init() })
}
pub unsafe fn sp_destroy(handle: sys::cusolverSpHandle_t) -> Result<(), CusolverError> {
sys::cusolverSpDestroy(handle).result()
}
pub unsafe fn sp_set_stream(
handle: sys::cusolverSpHandle_t,
stream: sys::cudaStream_t,
) -> Result<(), CusolverError> {
sys::cusolverSpSetStream(handle, stream).result()
}
pub fn rf_create() -> Result<sys::cusolverRfHandle_t, CusolverError> {
let mut handle = MaybeUninit::uninit();
unsafe { sys::cusolverRfCreate(handle.as_mut_ptr()) }.result()?;
Ok(unsafe { handle.assume_init() })
}
pub unsafe fn rf_destroy(handle: sys::cusolverRfHandle_t) -> Result<(), CusolverError> {
sys::cusolverRfDestroy(handle).result()
}
pub unsafe fn rf_set_matrix_format(
handle: sys::cusolverRfHandle_t,
format: sys::cusolverRfMatrixFormat_t,
diag: sys::cusolverRfUnitDiagonal_t,
) -> Result<(), CusolverError> {
sys::cusolverRfSetMatrixFormat(handle, format, diag).result()
}
pub unsafe fn rf_set_numeric_properties(
handle: sys::cusolverRfHandle_t,
zero: f64,
boost: f64,
) -> Result<(), CusolverError> {
sys::cusolverRfSetNumericProperties(handle, zero, boost).result()
}
pub unsafe fn rf_set_reset_values_fast_mode(
handle: sys::cusolverRfHandle_t,
fast_mode: sys::cusolverRfResetValuesFastMode_t,
) -> Result<(), CusolverError> {
sys::cusolverRfSetResetValuesFastMode(handle, fast_mode).result()
}
pub unsafe fn rf_set_algs(
handle: sys::cusolverRfHandle_t,
fact_alg: sys::cusolverRfFactorization_t,
alg: sys::cusolverRfTriangularSolve_t,
) -> Result<(), CusolverError> {
sys::cusolverRfSetAlgs(handle, fact_alg, alg).result()
}