#[allow(unused_imports)]
use crate::error::Status;
use std::{ptr, slice};
use singe_cuda::{data_type::DataTypeLike, memory::DeviceMemory};
use crate::{
context::Context,
error::{Error, Result},
layout::{MatrixMut, MatrixRef},
sys, try_ffi,
types::{IrsRefinement, PrecisionType},
utility::{to_i32, to_u64},
};
#[derive(Debug)]
pub struct IrsParams {
handle: sys::cusolverDnIRSParams_t,
main_precision: Option<PrecisionType>,
lowest_precision: Option<PrecisionType>,
}
#[derive(Debug, Default)]
pub struct IrsInfos {
handle: sys::cusolverDnIRSInfos_t,
residual_history_requested: bool,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ResidualHistoryEntry<T> {
pub total_iterations: T,
pub residual_norm: T,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ResidualHistory<T> {
pub rows: Vec<ResidualHistoryEntry<T>>,
pub leading_dimension: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct IrsSolve {
pub n: usize,
pub right_hand_sides: usize,
}
impl IrsSolve {
pub fn new(n: usize, right_hand_sides: usize) -> Self {
Self {
n,
right_hand_sides,
}
}
pub fn workspace_size<T: DataTypeLike>(
self,
ctx: &Context,
params: &mut IrsParams,
) -> Result<usize> {
xgesv_buffer_size::<T>(ctx, params, self.n, self.right_hand_sides)
}
pub fn execute<T: DataTypeLike>(
self,
ctx: &Context,
params: &mut IrsParams,
infos: &IrsInfos,
bindings: IrsSolveBindings<'_, T>,
) -> Result<i32> {
xgesv(
ctx,
params,
infos,
self.n,
self.right_hand_sides,
bindings.a,
bindings.b,
bindings.x,
bindings.device_workspace,
bindings.dev_info,
)
}
}
#[derive(Debug)]
pub struct IrsSolveBindings<'a, T> {
pub a: MatrixMut<'a, T>,
pub b: MatrixRef<'a, T>,
pub x: MatrixMut<'a, T>,
pub device_workspace: &'a mut DeviceMemory<u8>,
pub dev_info: &'a mut DeviceMemory<i32>,
}
unsafe impl Send for IrsParams {}
unsafe impl Sync for IrsParams {}
unsafe impl Send for IrsInfos {}
unsafe impl Sync for IrsInfos {}
impl IrsParams {
pub fn create() -> Result<Self> {
let mut handle = ptr::null_mut();
unsafe {
try_ffi!(sys::cusolverDnIRSParamsCreate(&raw mut handle))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
let mut params = Self {
handle,
main_precision: None,
lowest_precision: None,
};
params.set_refinement_solver(IrsRefinement::None)?;
Ok(params)
}
pub fn set_refinement_solver(&mut self, refinement: IrsRefinement) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsSetRefinementSolver(
self.as_raw(),
refinement.into(),
))?;
}
Ok(())
}
pub fn set_main_precision(&mut self, precision: PrecisionType) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsSetSolverMainPrecision(
self.as_raw(),
precision.into(),
))?;
}
self.main_precision = Some(precision);
Ok(())
}
pub fn set_lowest_precision(&mut self, precision: PrecisionType) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsSetSolverLowestPrecision(
self.as_raw(),
precision.into(),
))?;
}
self.lowest_precision = Some(precision);
Ok(())
}
pub fn set_solver_precisions(
&mut self,
main_precision: PrecisionType,
lowest_precision: PrecisionType,
) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsSetSolverPrecisions(
self.as_raw(),
main_precision.into(),
lowest_precision.into(),
))?;
}
self.main_precision = Some(main_precision);
self.lowest_precision = Some(lowest_precision);
Ok(())
}
pub fn set_tolerance(&mut self, tolerance: f64) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsSetTol(self.as_raw(), tolerance))?;
}
Ok(())
}
pub fn set_inner_tolerance(&mut self, tolerance: f64) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsSetTolInner(
self.as_raw(),
tolerance,
))?;
}
Ok(())
}
pub fn set_max_iterations(&mut self, max_iterations: i32) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsSetMaxIters(
self.as_raw(),
max_iterations,
))?;
}
Ok(())
}
pub fn set_max_inner_iterations(&mut self, max_iterations: i32) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsSetMaxItersInner(
self.as_raw(),
max_iterations,
))?;
}
Ok(())
}
pub fn max_iterations(&self) -> Result<i32> {
let mut value = 0;
unsafe {
try_ffi!(sys::cusolverDnIRSParamsGetMaxIters(
self.as_raw(),
&raw mut value,
))?;
}
Ok(value)
}
pub fn enable_fallback(&mut self) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsEnableFallback(self.as_raw()))?;
}
Ok(())
}
pub fn disable_fallback(&mut self) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSParamsDisableFallback(self.as_raw()))?;
}
Ok(())
}
fn ensure_type_precision<T: DataTypeLike>(&mut self) -> Result<()> {
let precision = PrecisionType::from_data_type(T::data_type())
.ok_or(Error::InvalidPrecisionConfiguration)?;
match self.main_precision {
Some(existing) if existing != precision => {
return Err(Error::InvalidPrecisionConfiguration);
}
None => self.set_main_precision(precision)?,
_ => {}
}
if self.lowest_precision.is_none() {
self.set_lowest_precision(precision)?;
}
Ok(())
}
pub fn as_raw(&self) -> sys::cusolverDnIRSParams_t {
self.handle
}
}
impl Drop for IrsParams {
fn drop(&mut self) {
unsafe {
if let Err(err) = try_ffi!(sys::cusolverDnIRSParamsDestroy(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cusolver irs params: {err}");
}
}
}
}
impl IrsInfos {
pub fn create() -> Result<Self> {
let mut handle = ptr::null_mut();
unsafe {
try_ffi!(sys::cusolverDnIRSInfosCreate(&raw mut handle))?;
}
if handle.is_null() {
return Err(Error::NullHandle);
}
Ok(Self {
handle,
residual_history_requested: false,
})
}
pub fn niters(&self) -> Result<i32> {
let mut value = 0;
unsafe {
try_ffi!(sys::cusolverDnIRSInfosGetNiters(
self.as_raw(),
&raw mut value,
))?;
}
Ok(value)
}
pub fn outer_niters(&self) -> Result<i32> {
let mut value = 0;
unsafe {
try_ffi!(sys::cusolverDnIRSInfosGetOuterNiters(
self.as_raw(),
&raw mut value,
))?;
}
Ok(value)
}
pub fn max_iterations(&self) -> Result<i32> {
let mut value = 0;
unsafe {
try_ffi!(sys::cusolverDnIRSInfosGetMaxIters(
self.as_raw(),
&raw mut value,
))?;
}
Ok(value)
}
pub fn request_residual_history(&mut self) -> Result<()> {
unsafe {
try_ffi!(sys::cusolverDnIRSInfosRequestResidual(self.as_raw()))?;
}
self.residual_history_requested = true;
Ok(())
}
pub fn residual_history_f32(&self) -> Result<ResidualHistory<f32>> {
if !self.residual_history_requested {
return Err(Error::InvalidPrecisionConfiguration);
}
let (leading_dimension, valid_rows) = self.residual_history_layout()?;
let mut history = ptr::null_mut();
unsafe {
try_ffi!(sys::cusolverDnIRSInfosGetResidualHistory(
self.as_raw(),
&raw mut history,
))?;
Ok(copy_residual_history(
history.cast::<f32>(),
leading_dimension,
valid_rows,
))
}
}
pub fn residual_history_f64(&self) -> Result<ResidualHistory<f64>> {
if !self.residual_history_requested {
return Err(Error::InvalidPrecisionConfiguration);
}
let (leading_dimension, valid_rows) = self.residual_history_layout()?;
let mut history = ptr::null_mut();
unsafe {
try_ffi!(sys::cusolverDnIRSInfosGetResidualHistory(
self.as_raw(),
&raw mut history,
))?;
Ok(copy_residual_history(
history.cast::<f64>(),
leading_dimension,
valid_rows,
))
}
}
pub fn as_raw(&self) -> sys::cusolverDnIRSInfos_t {
self.handle
}
fn residual_history_layout(&self) -> Result<(usize, usize)> {
let leading_dimension = self
.max_iterations()?
.checked_add(1)
.ok_or(Error::InvalidResidualHistory)
.and_then(|value| {
usize::try_from(value).map_err(|_| Error::OutOfRange {
name: "residual history leading dimension".into(),
})
})?;
let valid_rows = self
.outer_niters()?
.checked_add(1)
.ok_or(Error::InvalidResidualHistory)
.and_then(|value| {
usize::try_from(value).map_err(|_| Error::OutOfRange {
name: "residual history rows".into(),
})
})?;
if valid_rows > leading_dimension {
return Err(Error::InvalidResidualHistory);
}
Ok((leading_dimension, valid_rows))
}
}
impl Drop for IrsInfos {
fn drop(&mut self) {
unsafe {
if let Err(err) = try_ffi!(sys::cusolverDnIRSInfosDestroy(self.handle)) {
#[cfg(debug_assertions)]
eprintln!("failed to destroy cusolver irs infos: {err}");
}
}
}
}
pub fn xgesv_buffer_size<T: DataTypeLike>(
ctx: &Context,
params: &mut IrsParams,
n: usize,
nrhs: usize,
) -> Result<usize> {
ctx.bind()?;
if n == 0 || nrhs == 0 {
return Err(Error::InvalidMatrixShape);
}
params.ensure_type_precision::<T>()?;
let mut workspace_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnIRSXgesv_bufferSize(
ctx.as_raw(),
params.as_raw(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
&raw mut workspace_bytes,
))?;
}
Ok(workspace_bytes as usize)
}
pub fn xgesv<T: DataTypeLike>(
ctx: &Context,
params: &mut IrsParams,
infos: &IrsInfos,
n: usize,
nrhs: usize,
a: MatrixMut<'_, T>,
b: MatrixRef<'_, T>,
x: MatrixMut<'_, T>,
device_workspace: &mut DeviceMemory<u8>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<i32> {
ctx.bind()?;
validate_matrix(n, n, a.data.len(), a.leading_dimension)?;
validate_matrix(n, nrhs, b.data.len(), b.leading_dimension)?;
validate_matrix(n, nrhs, x.data.len(), x.leading_dimension)?;
require_info_buffer(dev_info)?;
let workspace_bytes = xgesv_buffer_size::<T>(ctx, params, n, nrhs)?;
require_workspace_bytes(device_workspace.byte_len(), workspace_bytes)?;
let mut niters = 0;
unsafe {
try_ffi!(sys::cusolverDnIRSXgesv(
ctx.as_raw(),
params.as_raw(),
infos.as_raw(),
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.data.as_mut_ptr() as _,
to_i32(a.leading_dimension, "ldda")?,
b.data.as_ptr() as _,
to_i32(b.leading_dimension, "lddb")?,
x.data.as_mut_ptr() as _,
to_i32(x.leading_dimension, "lddx")?,
device_workspace.as_mut_ptr() as _,
to_u64(workspace_bytes, "lwork_bytes")?,
&raw mut niters,
dev_info.as_mut_ptr() as _,
))?;
}
Ok(niters)
}
pub fn xgels_buffer_size<T: DataTypeLike>(
ctx: &Context,
params: &mut IrsParams,
m: usize,
n: usize,
nrhs: usize,
) -> Result<usize> {
ctx.bind()?;
if m == 0 || n == 0 || nrhs == 0 || n > m {
return Err(Error::InvalidMatrixShape);
}
params.ensure_type_precision::<T>()?;
let mut workspace_bytes = 0;
unsafe {
try_ffi!(sys::cusolverDnIRSXgels_bufferSize(
ctx.as_raw(),
params.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
&raw mut workspace_bytes,
))?;
}
Ok(workspace_bytes as usize)
}
pub fn xgels<T: DataTypeLike>(
ctx: &Context,
params: &mut IrsParams,
infos: &IrsInfos,
m: usize,
n: usize,
nrhs: usize,
a: MatrixMut<'_, T>,
b: MatrixRef<'_, T>,
x: MatrixMut<'_, T>,
device_workspace: &mut DeviceMemory<u8>,
dev_info: &mut DeviceMemory<i32>,
) -> Result<i32> {
ctx.bind()?;
if n > m {
return Err(Error::InvalidMatrixShape);
}
validate_matrix(m, n, a.data.len(), a.leading_dimension)?;
validate_matrix(m, nrhs, b.data.len(), b.leading_dimension)?;
validate_matrix(n, nrhs, x.data.len(), x.leading_dimension)?;
require_info_buffer(dev_info)?;
let workspace_bytes = xgels_buffer_size::<T>(ctx, params, m, n, nrhs)?;
require_workspace_bytes(device_workspace.byte_len(), workspace_bytes)?;
let mut niters = 0;
unsafe {
try_ffi!(sys::cusolverDnIRSXgels(
ctx.as_raw(),
params.as_raw(),
infos.as_raw(),
to_i32(m, "m")?,
to_i32(n, "n")?,
to_i32(nrhs, "nrhs")?,
a.data.as_mut_ptr() as _,
to_i32(a.leading_dimension, "ldda")?,
b.data.as_ptr() as _,
to_i32(b.leading_dimension, "lddb")?,
x.data.as_mut_ptr() as _,
to_i32(x.leading_dimension, "lddx")?,
device_workspace.as_mut_ptr() as _,
to_u64(workspace_bytes, "lwork_bytes")?,
&raw mut niters,
dev_info.as_mut_ptr() as _,
))?;
}
Ok(niters)
}
fn require_info_buffer(dev_info: &DeviceMemory<i32>) -> Result<()> {
if dev_info.is_empty() {
return Err(Error::InvalidVectorShape);
}
Ok(())
}
fn require_workspace_bytes(actual: usize, required: usize) -> Result<()> {
if actual < required {
return Err(Error::InsufficientWorkspaceSize { required, actual });
}
Ok(())
}
unsafe fn copy_residual_history<T: Copy>(
history: *const T,
leading_dimension: usize,
valid_rows: usize,
) -> ResidualHistory<T> {
let history = unsafe { slice::from_raw_parts(history, leading_dimension.saturating_mul(2)) };
let mut rows = Vec::with_capacity(valid_rows);
for row in 0..valid_rows {
rows.push(ResidualHistoryEntry {
total_iterations: history[row],
residual_norm: history[row + leading_dimension],
});
}
ResidualHistory {
rows,
leading_dimension,
}
}
fn validate_matrix(rows: usize, cols: usize, len: usize, lda: usize) -> Result<()> {
if rows == 0 || cols == 0 {
return Err(Error::InvalidMatrixShape);
}
if lda < rows {
return Err(Error::InvalidLeadingDimension);
}
let required = lda.checked_mul(cols).ok_or(Error::InvalidMatrixShape)?;
if len < required {
return Err(Error::InvalidMatrixShape);
}
Ok(())
}
#[cfg(all(test, feature = "testing"))]
mod tests {
use singe_cuda::memory::DeviceMemory;
use super::*;
use crate::testing::setup_context_if_available;
#[test]
fn test_xgesv_solves_diagonal_system() -> Result<()> {
let Some(ctx) = setup_context_if_available()? else {
return Ok(());
};
let mut params = IrsParams::create()?;
let infos = IrsInfos::create()?;
let mut a = DeviceMemory::from_slice(&[
2.0_f32, 0.0, 0.0, 4.0,
])?;
let b = DeviceMemory::from_slice(&[
6.0_f32, 8.0,
])?;
let mut x = DeviceMemory::create(2)?;
let workspace_bytes = xgesv_buffer_size::<f32>(&ctx, &mut params, 2, 1)?;
let mut workspace = DeviceMemory::create(workspace_bytes.max(1))?;
let mut dev_info = DeviceMemory::create(1)?;
let _ = xgesv(
&ctx,
&mut params,
&infos,
2,
1,
MatrixMut::new(&mut a, 2),
MatrixRef::new(&b, 2),
MatrixMut::new(&mut x, 2),
&mut workspace,
&mut dev_info,
)?;
assert_eq!(dev_info.copy_to_host_vec()?, vec![0]);
assert_eq!(x.copy_to_host_vec()?, vec![3.0, 2.0]);
Ok(())
}
}