use std::{cell::RefCell, mem::MaybeUninit};
use crate::{
error::{DiffsolError, LinearSolverError},
linear_solver_error, CudaContext, CudaMat, CudaVec, LinearSolver, Matrix, NonLinearOpJacobian,
ScalarCuda,
};
use cudarc::{
cusolver::sys::{
cublasOperation_t, cusolverDnCreate, cusolverDnDestroy, cusolverDnDgetrf,
cusolverDnDgetrf_bufferSize, cusolverDnDgetrs, cusolverDnHandle_t, cusolverStatus_t,
},
driver::{CudaSlice, DevicePtr, DevicePtrMut},
};
pub struct CudaLU<T>
where
T: ScalarCuda,
{
work: Option<CudaSlice<T>>,
pivots: Option<CudaSlice<i32>>,
nfo: Option<RefCell<CudaSlice<i32>>>,
matrix: Option<CudaMat<T>>,
handle: cusolverDnHandle_t,
linearisation_set: bool,
}
impl<T> Default for CudaLU<T>
where
T: ScalarCuda,
{
fn default() -> Self {
let handle = {
let mut handle = MaybeUninit::uninit();
unsafe {
let stat = cusolverDnCreate(handle.as_mut_ptr());
assert_eq!(stat, cusolverStatus_t::CUSOLVER_STATUS_SUCCESS);
handle.assume_init()
}
};
Self {
matrix: None,
work: None,
pivots: None,
nfo: None,
handle,
linearisation_set: false,
}
}
}
impl<T: ScalarCuda> Drop for CudaLU<T> {
fn drop(&mut self) {
unsafe {
cusolverDnDestroy(self.handle);
}
}
}
impl<T: ScalarCuda> LinearSolver<CudaMat<T>> for CudaLU<T> {
fn set_linearisation<
C: NonLinearOpJacobian<T = T, V = CudaVec<T>, M = CudaMat<T>, C = CudaContext>,
>(
&mut self,
op: &C,
x: &CudaVec<T>,
t: T,
) {
let matrix = self.matrix.as_mut().expect("Matrix not set");
let work = self.work.as_mut().expect("Work space not set");
let pivots = self.pivots.as_mut().expect("Pivots not set");
let mut nfo = self.nfo.as_mut().expect("NFO not set").borrow_mut();
op.jacobian_inplace(x, t, matrix);
{
let m = i32::try_from(matrix.nrows()).unwrap();
let n = i32::try_from(matrix.ncols()).unwrap();
let lda = i32::try_from(matrix.nrows()).unwrap();
let stream = &op.context().stream;
let (a, _syn) = matrix.data.device_ptr_mut(stream);
let (workspace, _ws_syn) = work.device_ptr_mut(stream);
let (pivots, _pivots_syn) = pivots.device_ptr_mut(stream);
let (nfo, _nfo_syn) = nfo.device_ptr_mut(stream);
unsafe {
cusolverDnDgetrf(
self.handle,
m,
n,
a as *mut f64,
lda,
workspace as *mut f64,
pivots as *mut i32,
nfo as *mut i32,
)
};
}
self.linearisation_set = true;
}
fn solve_in_place(&self, x: &mut CudaVec<T>) -> Result<(), DiffsolError> {
let matrix = if let Some(ref matrix) = self.matrix {
if matrix.nrows() != matrix.ncols() {
return Err(linear_solver_error!(LinearSolverMatrixNotSquare))?;
}
matrix
} else {
return Err(linear_solver_error!(LinearSolverNotSetup))?;
};
if !self.linearisation_set {
return Err(linear_solver_error!(LinearSolverNotSetup))?;
}
if x.data.len() != matrix.nrows() {
return Err(linear_solver_error!(LinearSolverMatrixVectorNotCompatible))?;
}
let mut nfo = self.nfo.as_ref().expect("NFO not set").borrow_mut();
{
let stream = matrix.data.stream();
let n = i32::try_from(x.data.len()).unwrap();
let lda = i32::try_from(self.matrix.as_ref().unwrap().nrows()).unwrap();
let (a, _syn) = matrix.data.device_ptr(stream);
let (x_data, _x_syn) = x.data.device_ptr_mut(stream);
let (pivots, _pivots_syn) = self.pivots.as_ref().unwrap().device_ptr(stream);
let (nfo, _nfo_syn) = nfo.device_ptr_mut(stream);
unsafe {
cusolverDnDgetrs(
self.handle,
cublasOperation_t::CUBLAS_OP_N,
n,
1,
a as *mut f64,
lda,
pivots as *mut i32,
x_data as *mut f64,
n,
nfo as *mut i32,
)
};
}
Ok(())
}
fn set_problem<
C: NonLinearOpJacobian<T = T, V = CudaVec<T>, M = CudaMat<T>, C = CudaContext>,
>(
&mut self,
op: &C,
) {
let ncols = op.nstates();
let nrows = op.nout();
let matrix =
C::M::new_from_sparsity(nrows, ncols, op.jacobian_sparsity(), op.context().clone());
self.matrix = Some(matrix);
let stream = &op.context().stream;
let lwork = {
let mut lwork = 0;
let (a, _syn) = self.matrix.as_mut().unwrap().data.device_ptr_mut(stream);
let m = i32::try_from(nrows).unwrap();
let n = i32::try_from(ncols).unwrap();
let lda = i32::try_from(nrows).unwrap();
unsafe {
cusolverDnDgetrf_bufferSize(self.handle, m, n, a as *mut f64, lda, &mut lwork);
}
lwork
};
unsafe {
self.work = Some(
stream
.alloc(lwork as usize)
.expect("Failed to allocate work space"),
);
self.pivots = Some(stream.alloc(nrows).expect("Failed to allocate pivots"));
self.nfo = Some(RefCell::new(
stream.alloc(1).expect("Failed to allocate NFO"),
));
}
self.linearisation_set = false;
}
}