#![warn(missing_debug_implementations)]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![allow(clippy::not_unsafe_ptr_arg_deref)]
#![allow(clippy::type_complexity)]
use std::ffi::CString;
use std::marker::PhantomData;
pub use aocl_error::{Error, Result};
use aocl_sparse_sys as sys;
use aocl_types::sealed::Sealed;
pub use aocl_types::{Complex32, Complex64, Trans};
pub mod complex;
fn trans_raw(t: Trans) -> sys::aoclsparse_operation {
match t {
Trans::No => sys::aoclsparse_operation__aoclsparse_operation_none,
Trans::T => sys::aoclsparse_operation__aoclsparse_operation_transpose,
Trans::C => sys::aoclsparse_operation__aoclsparse_operation_conjugate_transpose,
}
}
fn check_status(component: &'static str, status: sys::aoclsparse_status) -> Result<()> {
if status == sys::aoclsparse_status__aoclsparse_status_success {
return Ok(());
}
let message = match status {
s if s == sys::aoclsparse_status__aoclsparse_status_not_implemented => "not implemented",
s if s == sys::aoclsparse_status__aoclsparse_status_invalid_pointer => "invalid pointer",
s if s == sys::aoclsparse_status__aoclsparse_status_invalid_size => "invalid size",
s if s == sys::aoclsparse_status__aoclsparse_status_internal_error => "internal error",
s if s == sys::aoclsparse_status__aoclsparse_status_invalid_value => "invalid value",
s if s == sys::aoclsparse_status__aoclsparse_status_invalid_index_value => {
"invalid index value"
}
s if s == sys::aoclsparse_status__aoclsparse_status_maxit => "max iterations reached",
s if s == sys::aoclsparse_status__aoclsparse_status_user_stop => "user stop",
s if s == sys::aoclsparse_status__aoclsparse_status_wrong_type => "wrong type",
s if s == sys::aoclsparse_status__aoclsparse_status_memory_error => "memory error",
_ => "unknown sparse status",
}
.to_string();
Err(Error::Status {
component,
code: status as i64,
message,
})
}
pub struct MatDescr {
raw: sys::aoclsparse_mat_descr,
}
impl MatDescr {
pub fn new() -> Result<Self> {
let mut raw: sys::aoclsparse_mat_descr = std::ptr::null_mut();
let status = unsafe { sys::aoclsparse_create_mat_descr(&mut raw) };
check_status("sparse", status)?;
if raw.is_null() {
return Err(Error::AllocationFailed("sparse"));
}
Ok(MatDescr { raw })
}
pub fn as_raw(&self) -> sys::aoclsparse_mat_descr {
self.raw
}
pub fn set_type(&mut self, ty: MatType) -> Result<()> {
let status = unsafe { sys::aoclsparse_set_mat_type(self.raw, ty.raw()) };
check_status("sparse", status)
}
pub fn set_index_base(&mut self, base: IndexBase) -> Result<()> {
let status = unsafe { sys::aoclsparse_set_mat_index_base(self.raw, base.raw()) };
check_status("sparse", status)
}
pub fn set_fill_mode(&mut self, fill: FillMode) -> Result<()> {
let status = unsafe { sys::aoclsparse_set_mat_fill_mode(self.raw, fill.raw()) };
check_status("sparse", status)
}
pub fn set_diag_type(&mut self, diag: DiagType) -> Result<()> {
let status = unsafe { sys::aoclsparse_set_mat_diag_type(self.raw, diag.raw()) };
check_status("sparse", status)
}
pub fn ty(&self) -> MatType {
let raw = unsafe { sys::aoclsparse_get_mat_type(self.raw) };
MatType::from_raw(raw).unwrap_or(MatType::General)
}
pub fn index_base(&self) -> IndexBase {
let raw = unsafe { sys::aoclsparse_get_mat_index_base(self.raw) };
if raw == sys::aoclsparse_index_base__aoclsparse_index_base_one {
IndexBase::One
} else {
IndexBase::Zero
}
}
pub fn fill_mode(&self) -> FillMode {
let raw = unsafe { sys::aoclsparse_get_mat_fill_mode(self.raw) };
FillMode::from_raw(raw).unwrap_or(FillMode::Lower)
}
pub fn diag_type(&self) -> DiagType {
let raw = unsafe { sys::aoclsparse_get_mat_diag_type(self.raw) };
DiagType::from_raw(raw).unwrap_or(DiagType::NonUnit)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MatType {
General,
Symmetric,
Hermitian,
Triangular,
}
impl MatType {
fn raw(self) -> sys::aoclsparse_matrix_type {
match self {
MatType::General => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_general,
MatType::Symmetric => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric,
MatType::Hermitian => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_hermitian,
MatType::Triangular => sys::aoclsparse_matrix_type__aoclsparse_matrix_type_triangular,
}
}
fn from_raw(raw: sys::aoclsparse_matrix_type) -> Option<Self> {
Some(match raw {
r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_general => {
MatType::General
}
r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric => {
MatType::Symmetric
}
r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_hermitian => {
MatType::Hermitian
}
r if r == sys::aoclsparse_matrix_type__aoclsparse_matrix_type_triangular => {
MatType::Triangular
}
_ => return None,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FillMode {
Lower,
Upper,
}
impl FillMode {
fn raw(self) -> sys::aoclsparse_fill_mode {
match self {
FillMode::Lower => sys::aoclsparse_fill_mode__aoclsparse_fill_mode_lower,
FillMode::Upper => sys::aoclsparse_fill_mode__aoclsparse_fill_mode_upper,
}
}
fn from_raw(raw: sys::aoclsparse_fill_mode) -> Option<Self> {
Some(match raw {
r if r == sys::aoclsparse_fill_mode__aoclsparse_fill_mode_lower => FillMode::Lower,
r if r == sys::aoclsparse_fill_mode__aoclsparse_fill_mode_upper => FillMode::Upper,
_ => return None,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DiagType {
Unit,
NonUnit,
}
impl DiagType {
fn raw(self) -> sys::aoclsparse_diag_type {
match self {
DiagType::Unit => sys::aoclsparse_diag_type__aoclsparse_diag_type_unit,
DiagType::NonUnit => sys::aoclsparse_diag_type__aoclsparse_diag_type_non_unit,
}
}
fn from_raw(raw: sys::aoclsparse_diag_type) -> Option<Self> {
Some(match raw {
r if r == sys::aoclsparse_diag_type__aoclsparse_diag_type_unit => DiagType::Unit,
r if r == sys::aoclsparse_diag_type__aoclsparse_diag_type_non_unit => DiagType::NonUnit,
_ => return None,
})
}
}
pub fn copy_mat_descr(src: &MatDescr) -> Result<MatDescr> {
let dest = MatDescr::new()?;
let status = unsafe { sys::aoclsparse_copy_mat_descr(dest.raw, src.raw) };
check_status("sparse", status)?;
Ok(dest)
}
pub fn optimize<T: Scalar>(mat: &mut SparseMatrix<T>) -> Result<()> {
let status = unsafe { sys::aoclsparse_optimize(mat.as_raw()) };
check_status("sparse", status)
}
pub fn version() -> &'static str {
unsafe {
let p = sys::aoclsparse_get_version();
if p.is_null() {
return "";
}
std::ffi::CStr::from_ptr(p).to_str().unwrap_or("")
}
}
impl Drop for MatDescr {
fn drop(&mut self) {
if !self.raw.is_null() {
unsafe {
let _ = sys::aoclsparse_destroy_mat_descr(self.raw);
}
self.raw = std::ptr::null_mut();
}
}
}
impl std::fmt::Debug for MatDescr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MatDescr").finish_non_exhaustive()
}
}
pub trait Scalar: Copy + Sized + Sealed {
#[allow(clippy::too_many_arguments)]
fn csrmv(
op: Trans,
alpha: Self,
m: usize,
n: usize,
csr_val: &[Self],
csr_col_ind: &[sys::aoclsparse_int],
csr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[Self],
beta: Self,
y: &mut [Self],
) -> Result<()>;
fn axpyi(alpha: Self, x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()>;
fn gthr(y: &[Self], indx: &[sys::aoclsparse_int], x: &mut [Self]) -> Result<()>;
fn sctr(x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn csrsv(
op: Trans,
alpha: Self,
m: usize,
csr_val: &[Self],
csr_col_ind: &[sys::aoclsparse_int],
csr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[Self],
y: &mut [Self],
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn csr_to_dense(
m: usize,
n: usize,
descr: &MatDescr,
csr_val: &[Self],
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
a: &mut [Self],
ld: usize,
order: Order,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn csr_to_csc(
m: usize,
n: usize,
descr: &MatDescr,
base_csc: IndexBase,
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
csr_val: &[Self],
csc_row_ind: &mut [sys::aoclsparse_int],
csc_col_ptr: &mut [sys::aoclsparse_int],
csc_val: &mut [Self],
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn ellmv(
op: Trans,
alpha: Self,
m: usize,
n: usize,
ell_val: &[Self],
ell_col_ind: &[sys::aoclsparse_int],
ell_width: usize,
descr: &MatDescr,
x: &[Self],
beta: Self,
y: &mut [Self],
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn bsrmv(
op: Trans,
alpha: Self,
mb: usize,
nb: usize,
bsr_dim: usize,
bsr_val: &[Self],
bsr_col_ind: &[sys::aoclsparse_int],
bsr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[Self],
beta: Self,
y: &mut [Self],
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn create_csr(
base: IndexBase,
m: usize,
n: usize,
nnz: usize,
row_ptr: *mut sys::aoclsparse_int,
col_idx: *mut sys::aoclsparse_int,
val: *mut Self,
) -> Result<sys::aoclsparse_matrix>;
fn export_csr(
mat: sys::aoclsparse_matrix,
) -> Result<(
IndexBase,
usize,
usize,
usize,
*mut sys::aoclsparse_int,
*mut sys::aoclsparse_int,
*mut Self,
)>;
fn ilu_smoother(
op: Trans,
a: sys::aoclsparse_matrix,
descr: &MatDescr,
x: &mut [Self],
b: &[Self],
) -> Result<()>;
fn itsol_init(handle: &mut sys::aoclsparse_itsol_handle) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn itsol_solve(
handle: sys::aoclsparse_itsol_handle,
n: usize,
mat: sys::aoclsparse_matrix,
descr: &MatDescr,
b: &[Self],
x: &mut [Self],
rinfo: &mut [Self; 100],
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
unsafe fn csr2m_ffi(
op_a: sys::aoclsparse_operation,
descr_a: sys::aoclsparse_mat_descr,
a: sys::aoclsparse_matrix,
op_b: sys::aoclsparse_operation,
descr_b: sys::aoclsparse_mat_descr,
b: sys::aoclsparse_matrix,
request: sys::aoclsparse_request,
out: *mut sys::aoclsparse_matrix,
) -> sys::aoclsparse_status;
#[allow(clippy::too_many_arguments)]
fn csrmm(
op: Trans,
alpha: Self,
a: sys::aoclsparse_matrix,
descr: &MatDescr,
order: Order,
b: &[Self],
n: usize,
ldb: usize,
beta: Self,
c: &mut [Self],
ldc: usize,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn spmmd(
op: Trans,
a: sys::aoclsparse_matrix,
b: sys::aoclsparse_matrix,
layout: Order,
c: &mut [Self],
ldc: usize,
) -> Result<()>;
#[allow(clippy::too_many_arguments)]
fn sp2md(
op_a: Trans,
descr_a: &MatDescr,
a: sys::aoclsparse_matrix,
op_b: Trans,
descr_b: &MatDescr,
b: sys::aoclsparse_matrix,
alpha: Self,
beta: Self,
c: &mut [Self],
layout: Order,
ldc: usize,
) -> Result<()>;
unsafe fn add_ffi(
op: sys::aoclsparse_operation,
a: sys::aoclsparse_matrix,
alpha: Self,
b: sys::aoclsparse_matrix,
out: *mut sys::aoclsparse_matrix,
) -> sys::aoclsparse_status;
#[allow(clippy::too_many_arguments)]
fn sorv(
sor_type: SorType,
descr: &MatDescr,
a: sys::aoclsparse_matrix,
omega: Self,
alpha: Self,
x: &mut [Self],
b: &[Self],
) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SorType {
Forward,
Backward,
Symmetric,
}
impl SorType {
pub(crate) fn raw(self) -> sys::aoclsparse_sor_type {
match self {
SorType::Forward => sys::aoclsparse_sor_type__aoclsparse_sor_forward,
SorType::Backward => sys::aoclsparse_sor_type__aoclsparse_sor_backward,
SorType::Symmetric => sys::aoclsparse_sor_type__aoclsparse_sor_symmetric,
}
}
}
macro_rules! impl_scalar {
(
$t:ty,
csrmv = $csrmv:ident,
axpyi = $axpyi:ident,
gthr = $gthr:ident,
sctr = $sctr:ident,
csrsv = $csrsv:ident,
csr2dense = $csr2dense:ident,
csr2csc = $csr2csc:ident,
ellmv = $ellmv:ident,
bsrmv = $bsrmv:ident,
create_csr = $create_csr:ident,
export_csr = $export_csr:ident,
ilu_smoother = $ilu_smoother:ident,
itsol_init = $itsol_init:ident,
itsol_solve = $itsol_solve:ident,
csr2m = $csr2m:ident,
csrmm = $csrmm:ident,
spmmd = $spmmd:ident,
sp2md = $sp2md:ident,
add = $add:ident,
sorv = $sorv:ident
) => {
impl Scalar for $t {
fn csrmv(
op: Trans,
alpha: Self,
m: usize,
n: usize,
csr_val: &[Self],
csr_col_ind: &[sys::aoclsparse_int],
csr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[Self],
beta: Self,
y: &mut [Self],
) -> Result<()> {
if csr_row_ptr.len() != m + 1 {
return Err(Error::InvalidArgument(format!(
"csrmv: csr_row_ptr length {} != m+1 = {}",
csr_row_ptr.len(),
m + 1
)));
}
let nnz = csr_val.len();
if csr_col_ind.len() != nnz {
return Err(Error::InvalidArgument(format!(
"csrmv: csr_col_ind length {} != csr_val length {}",
csr_col_ind.len(),
nnz
)));
}
let (x_len, y_len) = match op {
Trans::No => (n, m),
Trans::T | Trans::C => (m, n),
};
if x.len() < x_len {
return Err(Error::InvalidArgument(format!(
"csrmv: x length {} < expected {x_len}",
x.len()
)));
}
if y.len() < y_len {
return Err(Error::InvalidArgument(format!(
"csrmv: y length {} < expected {y_len}",
y.len()
)));
}
let status = unsafe {
sys::$csrmv(
trans_raw(op),
&alpha,
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
csr_val.as_ptr(),
csr_col_ind.as_ptr(),
csr_row_ptr.as_ptr(),
descr.as_raw(),
x.as_ptr(),
&beta,
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
fn axpyi(
alpha: Self,
x: &[Self],
indx: &[sys::aoclsparse_int],
y: &mut [Self],
) -> Result<()> {
let status = unsafe {
sys::$axpyi(
x.len() as sys::aoclsparse_int,
alpha,
x.as_ptr(),
indx.as_ptr(),
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
fn gthr(y: &[Self], indx: &[sys::aoclsparse_int], x: &mut [Self]) -> Result<()> {
let status = unsafe {
sys::$gthr(
x.len() as sys::aoclsparse_int,
y.as_ptr(),
x.as_mut_ptr(),
indx.as_ptr(),
)
};
check_status("sparse", status)
}
fn sctr(x: &[Self], indx: &[sys::aoclsparse_int], y: &mut [Self]) -> Result<()> {
let status = unsafe {
sys::$sctr(
x.len() as sys::aoclsparse_int,
x.as_ptr(),
indx.as_ptr(),
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
fn csrsv(
op: Trans,
alpha: Self,
m: usize,
csr_val: &[Self],
csr_col_ind: &[sys::aoclsparse_int],
csr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[Self],
y: &mut [Self],
) -> Result<()> {
if csr_row_ptr.len() != m + 1 {
return Err(Error::InvalidArgument(format!(
"csrsv: csr_row_ptr length {} != m+1 = {}",
csr_row_ptr.len(),
m + 1
)));
}
if x.len() < m || y.len() < m {
return Err(Error::InvalidArgument(format!(
"csrsv: x.len()={}, y.len()={}, m={m}",
x.len(),
y.len()
)));
}
let status = unsafe {
sys::$csrsv(
trans_raw(op),
&alpha,
m as sys::aoclsparse_int,
csr_val.as_ptr(),
csr_col_ind.as_ptr(),
csr_row_ptr.as_ptr(),
descr.as_raw(),
x.as_ptr(),
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
fn csr_to_dense(
m: usize,
n: usize,
descr: &MatDescr,
csr_val: &[Self],
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
a: &mut [Self],
ld: usize,
order: Order,
) -> Result<()> {
let status = unsafe {
sys::$csr2dense(
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
descr.as_raw(),
csr_val.as_ptr(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
a.as_mut_ptr(),
ld as sys::aoclsparse_int,
order.raw(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
fn csr_to_csc(
m: usize,
n: usize,
descr: &MatDescr,
base_csc: IndexBase,
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
csr_val: &[Self],
csc_row_ind: &mut [sys::aoclsparse_int],
csc_col_ptr: &mut [sys::aoclsparse_int],
csc_val: &mut [Self],
) -> Result<()> {
let nnz = csr_val.len();
let status = unsafe {
sys::$csr2csc(
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
descr.as_raw(),
base_csc.raw(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
csr_val.as_ptr(),
csc_row_ind.as_mut_ptr(),
csc_col_ptr.as_mut_ptr(),
csc_val.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
fn ellmv(
op: Trans,
alpha: Self,
m: usize,
n: usize,
ell_val: &[Self],
ell_col_ind: &[sys::aoclsparse_int],
ell_width: usize,
descr: &MatDescr,
x: &[Self],
beta: Self,
y: &mut [Self],
) -> Result<()> {
let nnz = ell_val.len();
if ell_col_ind.len() != nnz {
return Err(Error::InvalidArgument(format!(
"ellmv: ell_col_ind length {} != ell_val length {nnz}",
ell_col_ind.len()
)));
}
let needed = m.checked_mul(ell_width).ok_or_else(|| {
Error::InvalidArgument("ellmv: m * ell_width overflows".into())
})?;
if nnz < needed {
return Err(Error::InvalidArgument(format!(
"ellmv: ell_val length {nnz} < m*ell_width = {needed}"
)));
}
let (x_len, y_len) = match op {
Trans::No => (n, m),
Trans::T | Trans::C => (m, n),
};
if x.len() < x_len || y.len() < y_len {
return Err(Error::InvalidArgument(format!(
"ellmv: x.len()={}, y.len()={}, expected ({x_len}, {y_len})",
x.len(),
y.len()
)));
}
let status = unsafe {
sys::$ellmv(
trans_raw(op),
&alpha,
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
ell_val.as_ptr(),
ell_col_ind.as_ptr(),
ell_width as sys::aoclsparse_int,
descr.as_raw(),
x.as_ptr(),
&beta,
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
fn bsrmv(
op: Trans,
alpha: Self,
mb: usize,
nb: usize,
bsr_dim: usize,
bsr_val: &[Self],
bsr_col_ind: &[sys::aoclsparse_int],
bsr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[Self],
beta: Self,
y: &mut [Self],
) -> Result<()> {
if bsr_row_ptr.len() != mb + 1 {
return Err(Error::InvalidArgument(format!(
"bsrmv: bsr_row_ptr length {} != mb+1 = {}",
bsr_row_ptr.len(),
mb + 1
)));
}
let block_area = bsr_dim.checked_mul(bsr_dim).ok_or_else(|| {
Error::InvalidArgument("bsrmv: bsr_dim*bsr_dim overflows".into())
})?;
let nnzb = bsr_col_ind.len();
if bsr_val.len() < nnzb * block_area {
return Err(Error::InvalidArgument(format!(
"bsrmv: bsr_val length {} < nnzb*bsr_dim^2 = {}",
bsr_val.len(),
nnzb * block_area
)));
}
let (x_len, y_len) = match op {
Trans::No => (nb * bsr_dim, mb * bsr_dim),
Trans::T | Trans::C => (mb * bsr_dim, nb * bsr_dim),
};
if x.len() < x_len || y.len() < y_len {
return Err(Error::InvalidArgument(format!(
"bsrmv: x.len()={}, y.len()={}, expected ({x_len}, {y_len})",
x.len(),
y.len()
)));
}
let status = unsafe {
sys::$bsrmv(
trans_raw(op),
&alpha,
mb as sys::aoclsparse_int,
nb as sys::aoclsparse_int,
bsr_dim as sys::aoclsparse_int,
bsr_val.as_ptr(),
bsr_col_ind.as_ptr(),
bsr_row_ptr.as_ptr(),
descr.as_raw(),
x.as_ptr(),
&beta,
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
fn create_csr(
base: IndexBase,
m: usize,
n: usize,
nnz: usize,
row_ptr: *mut sys::aoclsparse_int,
col_idx: *mut sys::aoclsparse_int,
val: *mut Self,
) -> Result<sys::aoclsparse_matrix> {
let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
let status = unsafe {
sys::$create_csr(
&mut raw,
base.raw(),
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
row_ptr,
col_idx,
val,
)
};
check_status("sparse", status)?;
if raw.is_null() {
return Err(Error::AllocationFailed("sparse"));
}
Ok(raw)
}
fn export_csr(
mat: sys::aoclsparse_matrix,
) -> Result<(
IndexBase,
usize,
usize,
usize,
*mut sys::aoclsparse_int,
*mut sys::aoclsparse_int,
*mut Self,
)> {
let mut base: sys::aoclsparse_index_base = 0;
let mut m: sys::aoclsparse_int = 0;
let mut n: sys::aoclsparse_int = 0;
let mut nnz: sys::aoclsparse_int = 0;
let mut row_ptr: *mut sys::aoclsparse_int = std::ptr::null_mut();
let mut col_ind: *mut sys::aoclsparse_int = std::ptr::null_mut();
let mut val: *mut Self = std::ptr::null_mut();
let status = unsafe {
sys::$export_csr(
mat,
&mut base,
&mut m,
&mut n,
&mut nnz,
&mut row_ptr,
&mut col_ind,
&mut val,
)
};
check_status("sparse", status)?;
let base_e = if base == sys::aoclsparse_index_base__aoclsparse_index_base_one {
IndexBase::One
} else {
IndexBase::Zero
};
Ok((
base_e,
m as usize,
n as usize,
nnz as usize,
row_ptr,
col_ind,
val,
))
}
fn ilu_smoother(
op: Trans,
a: sys::aoclsparse_matrix,
descr: &MatDescr,
x: &mut [Self],
b: &[Self],
) -> Result<()> {
let mut precond_csr_val: *mut Self = std::ptr::null_mut();
let status = unsafe {
sys::$ilu_smoother(
trans_raw(op),
a,
descr.as_raw(),
&mut precond_csr_val,
std::ptr::null(),
x.as_mut_ptr(),
b.as_ptr(),
)
};
check_status("sparse", status)
}
fn itsol_init(handle: &mut sys::aoclsparse_itsol_handle) -> Result<()> {
let status = unsafe { sys::$itsol_init(handle) };
check_status("sparse", status)
}
fn itsol_solve(
handle: sys::aoclsparse_itsol_handle,
n: usize,
mat: sys::aoclsparse_matrix,
descr: &MatDescr,
b: &[Self],
x: &mut [Self],
rinfo: &mut [Self; 100],
) -> Result<()> {
if b.len() < n || x.len() < n {
return Err(Error::InvalidArgument(format!(
"itsol_solve: b.len()={}, x.len()={}, n={n}",
b.len(),
x.len()
)));
}
let status = unsafe {
sys::$itsol_solve(
handle,
n as sys::aoclsparse_int,
mat,
descr.as_raw(),
b.as_ptr(),
x.as_mut_ptr(),
rinfo.as_mut_ptr(),
None,
None,
std::ptr::null_mut(),
)
};
check_status("sparse", status)
}
unsafe fn csr2m_ffi(
op_a: sys::aoclsparse_operation,
descr_a: sys::aoclsparse_mat_descr,
a: sys::aoclsparse_matrix,
op_b: sys::aoclsparse_operation,
descr_b: sys::aoclsparse_mat_descr,
b: sys::aoclsparse_matrix,
request: sys::aoclsparse_request,
out: *mut sys::aoclsparse_matrix,
) -> sys::aoclsparse_status {
sys::$csr2m(op_a, descr_a, a, op_b, descr_b, b, request, out)
}
#[allow(clippy::too_many_arguments)]
fn csrmm(
op: Trans,
alpha: Self,
a: sys::aoclsparse_matrix,
descr: &MatDescr,
order: Order,
b: &[Self],
n: usize,
ldb: usize,
beta: Self,
c: &mut [Self],
ldc: usize,
) -> Result<()> {
let status = unsafe {
sys::$csrmm(
trans_raw(op),
alpha,
a,
descr.as_raw(),
order.raw(),
b.as_ptr(),
n as sys::aoclsparse_int,
ldb as sys::aoclsparse_int,
beta,
c.as_mut_ptr(),
ldc as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
fn spmmd(
op: Trans,
a: sys::aoclsparse_matrix,
b: sys::aoclsparse_matrix,
layout: Order,
c: &mut [Self],
ldc: usize,
) -> Result<()> {
let status = unsafe {
sys::$spmmd(
trans_raw(op),
a,
b,
layout.raw(),
c.as_mut_ptr(),
ldc as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
fn sp2md(
op_a: Trans,
descr_a: &MatDescr,
a: sys::aoclsparse_matrix,
op_b: Trans,
descr_b: &MatDescr,
b: sys::aoclsparse_matrix,
alpha: Self,
beta: Self,
c: &mut [Self],
layout: Order,
ldc: usize,
) -> Result<()> {
let status = unsafe {
sys::$sp2md(
trans_raw(op_a),
descr_a.as_raw(),
a,
trans_raw(op_b),
descr_b.as_raw(),
b,
alpha,
beta,
c.as_mut_ptr(),
layout.raw(),
ldc as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
unsafe fn add_ffi(
op: sys::aoclsparse_operation,
a: sys::aoclsparse_matrix,
alpha: Self,
b: sys::aoclsparse_matrix,
out: *mut sys::aoclsparse_matrix,
) -> sys::aoclsparse_status {
sys::$add(op, a, alpha, b, out)
}
#[allow(clippy::too_many_arguments)]
fn sorv(
sor_type: SorType,
descr: &MatDescr,
a: sys::aoclsparse_matrix,
omega: Self,
alpha: Self,
x: &mut [Self],
b: &[Self],
) -> Result<()> {
let status = unsafe {
sys::$sorv(
sor_type.raw(),
descr.as_raw(),
a,
omega,
alpha,
x.as_mut_ptr(),
b.as_ptr(),
)
};
check_status("sparse", status)
}
}
};
}
impl_scalar!(
f32,
csrmv = aoclsparse_scsrmv,
axpyi = aoclsparse_saxpyi,
gthr = aoclsparse_sgthr,
sctr = aoclsparse_ssctr,
csrsv = aoclsparse_scsrsv,
csr2dense = aoclsparse_scsr2dense,
csr2csc = aoclsparse_scsr2csc,
ellmv = aoclsparse_sellmv,
bsrmv = aoclsparse_sbsrmv,
create_csr = aoclsparse_create_scsr,
export_csr = aoclsparse_export_scsr,
ilu_smoother = aoclsparse_silu_smoother,
itsol_init = aoclsparse_itsol_s_init,
itsol_solve = aoclsparse_itsol_s_solve,
csr2m = aoclsparse_scsr2m,
csrmm = aoclsparse_scsrmm,
spmmd = aoclsparse_sspmmd,
sp2md = aoclsparse_ssp2md,
add = aoclsparse_sadd,
sorv = aoclsparse_ssorv
);
impl_scalar!(
f64,
csrmv = aoclsparse_dcsrmv,
axpyi = aoclsparse_daxpyi,
gthr = aoclsparse_dgthr,
sctr = aoclsparse_dsctr,
csrsv = aoclsparse_dcsrsv,
csr2dense = aoclsparse_dcsr2dense,
csr2csc = aoclsparse_dcsr2csc,
ellmv = aoclsparse_dellmv,
bsrmv = aoclsparse_dbsrmv,
create_csr = aoclsparse_create_dcsr,
export_csr = aoclsparse_export_dcsr,
ilu_smoother = aoclsparse_dilu_smoother,
itsol_init = aoclsparse_itsol_d_init,
itsol_solve = aoclsparse_itsol_d_solve,
csr2m = aoclsparse_dcsr2m,
csrmm = aoclsparse_dcsrmm,
spmmd = aoclsparse_dspmmd,
sp2md = aoclsparse_dsp2md,
add = aoclsparse_dadd,
sorv = aoclsparse_dsorv
);
#[allow(clippy::too_many_arguments)]
pub fn csrmv<T: Scalar>(
alpha: T,
m: usize,
n: usize,
csr_val: &[T],
csr_col_ind: &[sys::aoclsparse_int],
csr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[T],
beta: T,
y: &mut [T],
) -> Result<()> {
T::csrmv(
Trans::No,
alpha,
m,
n,
csr_val,
csr_col_ind,
csr_row_ptr,
descr,
x,
beta,
y,
)
}
pub fn axpyi<T: Scalar>(
alpha: T,
x: &[T],
indx: &[sys::aoclsparse_int],
y: &mut [T],
) -> Result<()> {
if x.len() != indx.len() {
return Err(Error::InvalidArgument(format!(
"axpyi: x.len()={}, indx.len()={}",
x.len(),
indx.len()
)));
}
T::axpyi(alpha, x, indx, y)
}
pub fn gthr<T: Scalar>(y: &[T], indx: &[sys::aoclsparse_int], x: &mut [T]) -> Result<()> {
if x.len() != indx.len() {
return Err(Error::InvalidArgument(format!(
"gthr: x.len()={}, indx.len()={}",
x.len(),
indx.len()
)));
}
T::gthr(y, indx, x)
}
pub fn sctr<T: Scalar>(x: &[T], indx: &[sys::aoclsparse_int], y: &mut [T]) -> Result<()> {
if x.len() != indx.len() {
return Err(Error::InvalidArgument(format!(
"sctr: x.len()={}, indx.len()={}",
x.len(),
indx.len()
)));
}
T::sctr(x, indx, y)
}
#[allow(clippy::too_many_arguments)]
pub fn csrsv<T: Scalar>(
op: Trans,
alpha: T,
m: usize,
csr_val: &[T],
csr_col_ind: &[sys::aoclsparse_int],
csr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[T],
y: &mut [T],
) -> Result<()> {
T::csrsv(op, alpha, m, csr_val, csr_col_ind, csr_row_ptr, descr, x, y)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Order {
RowMajor,
ColMajor,
}
impl Order {
pub(crate) fn raw(self) -> sys::aoclsparse_order {
match self {
Order::RowMajor => sys::aoclsparse_order__aoclsparse_order_row,
Order::ColMajor => sys::aoclsparse_order__aoclsparse_order_column,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn csr_to_dense<T: Scalar>(
m: usize,
n: usize,
descr: &MatDescr,
csr_val: &[T],
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
a: &mut [T],
ld: usize,
order: Order,
) -> Result<()> {
if csr_row_ptr.len() != m + 1 {
return Err(Error::InvalidArgument(format!(
"csr_to_dense: csr_row_ptr length {} != m+1 = {}",
csr_row_ptr.len(),
m + 1
)));
}
let needed = match order {
Order::RowMajor => m.saturating_sub(1) * ld + n,
Order::ColMajor => n.saturating_sub(1) * ld + m,
};
if a.len() < needed {
return Err(Error::InvalidArgument(format!(
"csr_to_dense: A length {} < needed {needed}",
a.len()
)));
}
T::csr_to_dense(m, n, descr, csr_val, csr_row_ptr, csr_col_ind, a, ld, order)
}
#[allow(clippy::too_many_arguments)]
pub fn csr_to_csc<T: Scalar>(
m: usize,
n: usize,
descr: &MatDescr,
base_csc: IndexBase,
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
csr_val: &[T],
csc_row_ind: &mut [sys::aoclsparse_int],
csc_col_ptr: &mut [sys::aoclsparse_int],
csc_val: &mut [T],
) -> Result<()> {
let nnz = csr_val.len();
if csr_col_ind.len() != nnz || csc_row_ind.len() < nnz || csc_val.len() < nnz {
return Err(Error::InvalidArgument(format!(
"csr_to_csc: nnz mismatch (csr_val={}, csr_col_ind={}, csc_row_ind={}, csc_val={})",
nnz,
csr_col_ind.len(),
csc_row_ind.len(),
csc_val.len()
)));
}
if csr_row_ptr.len() != m + 1 || csc_col_ptr.len() != n + 1 {
return Err(Error::InvalidArgument(format!(
"csr_to_csc: row_ptr length {} != m+1 = {} or col_ptr length {} != n+1 = {}",
csr_row_ptr.len(),
m + 1,
csc_col_ptr.len(),
n + 1
)));
}
T::csr_to_csc(
m,
n,
descr,
base_csc,
csr_row_ptr,
csr_col_ind,
csr_val,
csc_row_ind,
csc_col_ptr,
csc_val,
)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IndexBase {
Zero,
One,
}
impl IndexBase {
fn raw(self) -> sys::aoclsparse_index_base {
match self {
IndexBase::Zero => sys::aoclsparse_index_base__aoclsparse_index_base_zero,
IndexBase::One => sys::aoclsparse_index_base__aoclsparse_index_base_one,
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn ellmv<T: Scalar>(
op: Trans,
alpha: T,
m: usize,
n: usize,
ell_val: &[T],
ell_col_ind: &[sys::aoclsparse_int],
ell_width: usize,
descr: &MatDescr,
x: &[T],
beta: T,
y: &mut [T],
) -> Result<()> {
T::ellmv(
op,
alpha,
m,
n,
ell_val,
ell_col_ind,
ell_width,
descr,
x,
beta,
y,
)
}
#[allow(clippy::too_many_arguments)]
pub fn bsrmv<T: Scalar>(
op: Trans,
alpha: T,
mb: usize,
nb: usize,
bsr_dim: usize,
bsr_val: &[T],
bsr_col_ind: &[sys::aoclsparse_int],
bsr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[T],
beta: T,
y: &mut [T],
) -> Result<()> {
T::bsrmv(
op,
alpha,
mb,
nb,
bsr_dim,
bsr_val,
bsr_col_ind,
bsr_row_ptr,
descr,
x,
beta,
y,
)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Stage {
NnzCount,
Finalize,
FullComputation,
}
impl Stage {
fn raw(self) -> sys::aoclsparse_request {
match self {
Stage::NnzCount => sys::aoclsparse_request__aoclsparse_stage_nnz_count,
Stage::Finalize => sys::aoclsparse_request__aoclsparse_stage_finalize,
Stage::FullComputation => sys::aoclsparse_request__aoclsparse_stage_full_computation,
}
}
}
enum CsrStorage<T: Scalar> {
Owned {
_row_ptr: Vec<sys::aoclsparse_int>,
_col_ind: Vec<sys::aoclsparse_int>,
_val: Vec<T>,
},
LibraryOwned,
}
pub struct SparseMatrix<T: Scalar> {
raw: sys::aoclsparse_matrix,
#[allow(dead_code)] storage: CsrStorage<T>,
base: IndexBase,
m: usize,
n: usize,
nnz: usize,
}
impl<T: Scalar> SparseMatrix<T> {
pub fn from_csr(
base: IndexBase,
m: usize,
n: usize,
row_ptr: &[sys::aoclsparse_int],
col_ind: &[sys::aoclsparse_int],
val: &[T],
) -> Result<Self> {
if row_ptr.len() != m + 1 {
return Err(Error::InvalidArgument(format!(
"from_csr: row_ptr length {} != m+1 = {}",
row_ptr.len(),
m + 1
)));
}
let nnz = val.len();
if col_ind.len() != nnz {
return Err(Error::InvalidArgument(format!(
"from_csr: col_ind length {} != val length {nnz}",
col_ind.len()
)));
}
let mut row_ptr = row_ptr.to_vec();
let mut col_ind = col_ind.to_vec();
let mut val = val.to_vec();
let raw = T::create_csr(
base,
m,
n,
nnz,
row_ptr.as_mut_ptr(),
col_ind.as_mut_ptr(),
val.as_mut_ptr(),
)?;
Ok(Self {
raw,
storage: CsrStorage::Owned {
_row_ptr: row_ptr,
_col_ind: col_ind,
_val: val,
},
base,
m,
n,
nnz,
})
}
pub unsafe fn from_library_owned(raw: sys::aoclsparse_matrix) -> Result<Self> {
if raw.is_null() {
return Err(Error::AllocationFailed("sparse"));
}
let (base, m, n, nnz, _, _, _) = T::export_csr(raw)?;
Ok(Self {
raw,
storage: CsrStorage::LibraryOwned,
base,
m,
n,
nnz,
})
}
pub fn dims(&self) -> (usize, usize) {
(self.m, self.n)
}
pub fn nnz(&self) -> usize {
self.nnz
}
pub fn base(&self) -> IndexBase {
self.base
}
pub fn as_raw(&self) -> sys::aoclsparse_matrix {
self.raw
}
pub fn export_csr(
&self,
) -> Result<(
IndexBase,
Vec<sys::aoclsparse_int>,
Vec<sys::aoclsparse_int>,
Vec<T>,
)> {
let (base, m, _, nnz, row_ptr, col_ind, val) = T::export_csr(self.raw)?;
let row_ptr = unsafe { std::slice::from_raw_parts(row_ptr, m + 1).to_vec() };
let col_ind = unsafe { std::slice::from_raw_parts(col_ind, nnz).to_vec() };
let val = unsafe { std::slice::from_raw_parts(val, nnz).to_vec() };
Ok((base, row_ptr, col_ind, val))
}
pub fn set_mv_hint(
&mut self,
op: Trans,
descr: &MatDescr,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_mv_hint(
self.raw,
trans_raw(op),
descr.as_raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn set_sv_hint(
&mut self,
op: Trans,
descr: &MatDescr,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_sv_hint(
self.raw,
trans_raw(op),
descr.as_raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn set_mm_hint(
&mut self,
op: Trans,
descr: &MatDescr,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_mm_hint(
self.raw,
trans_raw(op),
descr.as_raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn set_2m_hint(
&mut self,
op: Trans,
descr: &MatDescr,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_2m_hint(
self.raw,
trans_raw(op),
descr.as_raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn set_sm_hint(
&mut self,
op: Trans,
descr: &MatDescr,
order: Order,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_sm_hint(
self.raw,
trans_raw(op),
descr.as_raw(),
order.raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn set_lu_smoother_hint(
&mut self,
op: Trans,
descr: &MatDescr,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_lu_smoother_hint(
self.raw,
trans_raw(op),
descr.as_raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn set_symgs_hint(
&mut self,
op: Trans,
descr: &MatDescr,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_symgs_hint(
self.raw,
trans_raw(op),
descr.as_raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn set_dotmv_hint(
&mut self,
op: Trans,
descr: &MatDescr,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_dotmv_hint(
self.raw,
trans_raw(op),
descr.as_raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn set_sorv_hint(
&mut self,
sor_type: SorType,
descr: &MatDescr,
expected_calls: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_set_sorv_hint(
self.raw,
descr.as_raw(),
sor_type.raw(),
expected_calls as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
}
impl<T: Scalar> Drop for SparseMatrix<T> {
fn drop(&mut self) {
if !self.raw.is_null() {
unsafe {
let _ = sys::aoclsparse_destroy(&mut self.raw);
}
self.raw = std::ptr::null_mut();
}
}
}
impl<T: Scalar> std::fmt::Debug for SparseMatrix<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SparseMatrix")
.field("m", &self.m)
.field("n", &self.n)
.field("nnz", &self.nnz)
.field("base", &self.base)
.finish()
}
}
#[allow(clippy::too_many_arguments)]
pub fn csr2m<T: Scalar>(
op_a: Trans,
descr_a: &MatDescr,
a: &SparseMatrix<T>,
op_b: Trans,
descr_b: &MatDescr,
b: &SparseMatrix<T>,
stage: Stage,
) -> Result<SparseMatrix<T>> {
let mut c_raw: sys::aoclsparse_matrix = std::ptr::null_mut();
let status = unsafe {
T::csr2m_ffi(
trans_raw(op_a),
descr_a.as_raw(),
a.raw,
trans_raw(op_b),
descr_b.as_raw(),
b.raw,
stage.raw(),
&mut c_raw,
)
};
check_status("sparse", status)?;
unsafe { SparseMatrix::from_library_owned(c_raw) }
}
#[allow(clippy::too_many_arguments)]
pub fn csrmm<T: Scalar>(
op: Trans,
alpha: T,
a: &SparseMatrix<T>,
descr: &MatDescr,
order: Order,
b: &[T],
n: usize,
ldb: usize,
beta: T,
c: &mut [T],
ldc: usize,
) -> Result<()> {
T::csrmm(op, alpha, a.as_raw(), descr, order, b, n, ldb, beta, c, ldc)
}
pub fn spmmd<T: Scalar>(
op: Trans,
a: &SparseMatrix<T>,
b: &SparseMatrix<T>,
layout: Order,
c: &mut [T],
ldc: usize,
) -> Result<()> {
T::spmmd(op, a.as_raw(), b.as_raw(), layout, c, ldc)
}
#[allow(clippy::too_many_arguments)]
pub fn sp2md<T: Scalar>(
op_a: Trans,
descr_a: &MatDescr,
a: &SparseMatrix<T>,
op_b: Trans,
descr_b: &MatDescr,
b: &SparseMatrix<T>,
alpha: T,
beta: T,
c: &mut [T],
layout: Order,
ldc: usize,
) -> Result<()> {
T::sp2md(
op_a,
descr_a,
a.as_raw(),
op_b,
descr_b,
b.as_raw(),
alpha,
beta,
c,
layout,
ldc,
)
}
pub fn add<T: Scalar>(
op: Trans,
a: &SparseMatrix<T>,
alpha: T,
b: &SparseMatrix<T>,
) -> Result<SparseMatrix<T>> {
let mut c_raw: sys::aoclsparse_matrix = std::ptr::null_mut();
let status = unsafe { T::add_ffi(trans_raw(op), a.as_raw(), alpha, b.as_raw(), &mut c_raw) };
check_status("sparse", status)?;
unsafe { SparseMatrix::from_library_owned(c_raw) }
}
pub fn sorv<T: Scalar>(
sor_type: SorType,
descr: &MatDescr,
a: &SparseMatrix<T>,
omega: T,
alpha: T,
x: &mut [T],
b: &[T],
) -> Result<()> {
if x.len() < a.dims().1 || b.len() < a.dims().0 {
return Err(Error::InvalidArgument(format!(
"sorv: x.len()={}, b.len()={}, dims=({}, {})",
x.len(),
b.len(),
a.dims().0,
a.dims().1
)));
}
T::sorv(sor_type, descr, a.as_raw(), omega, alpha, x, b)
}
#[allow(clippy::too_many_arguments)]
pub fn mv_f64(
op: Trans,
alpha: f64,
a: &SparseMatrix<f64>,
descr: &MatDescr,
x: &[f64],
beta: f64,
y: &mut [f64],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dmv(
trans_raw(op),
&alpha,
a.as_raw(),
descr.as_raw(),
x.as_ptr(),
&beta,
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn mv_f32(
op: Trans,
alpha: f32,
a: &SparseMatrix<f32>,
descr: &MatDescr,
x: &[f32],
beta: f32,
y: &mut [f32],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_smv(
trans_raw(op),
&alpha,
a.as_raw(),
descr.as_raw(),
x.as_ptr(),
&beta,
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
pub fn trsv_f64(
op: Trans,
alpha: f64,
a: &SparseMatrix<f64>,
descr: &MatDescr,
b: &[f64],
x: &mut [f64],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dtrsv(
trans_raw(op),
alpha,
a.as_raw(),
descr.as_raw(),
b.as_ptr(),
x.as_mut_ptr(),
)
};
check_status("sparse", status)
}
pub fn trsv_f32(
op: Trans,
alpha: f32,
a: &SparseMatrix<f32>,
descr: &MatDescr,
b: &[f32],
x: &mut [f32],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_strsv(
trans_raw(op),
alpha,
a.as_raw(),
descr.as_raw(),
b.as_ptr(),
x.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn trsm_f64(
op: Trans,
alpha: f64,
a: &SparseMatrix<f64>,
descr: &MatDescr,
order: Order,
b: &[f64],
n_rhs: usize,
ldb: usize,
x: &mut [f64],
ldx: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dtrsm(
trans_raw(op),
alpha,
a.as_raw(),
descr.as_raw(),
order.raw(),
b.as_ptr(),
n_rhs as sys::aoclsparse_int,
ldb as sys::aoclsparse_int,
x.as_mut_ptr(),
ldx as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn trsm_f32(
op: Trans,
alpha: f32,
a: &SparseMatrix<f32>,
descr: &MatDescr,
order: Order,
b: &[f32],
n_rhs: usize,
ldb: usize,
x: &mut [f32],
ldx: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_strsm(
trans_raw(op),
alpha,
a.as_raw(),
descr.as_raw(),
order.raw(),
b.as_ptr(),
n_rhs as sys::aoclsparse_int,
ldb as sys::aoclsparse_int,
x.as_mut_ptr(),
ldx as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn doti_f64(x: &[f64], indx: &[sys::aoclsparse_int], y: &[f64]) -> Result<f64> {
if x.len() != indx.len() {
return Err(Error::InvalidArgument(format!(
"doti: x.len()={} != indx.len()={}",
x.len(),
indx.len()
)));
}
let r = unsafe {
sys::aoclsparse_ddoti(
x.len() as sys::aoclsparse_int,
x.as_ptr(),
indx.as_ptr(),
y.as_ptr(),
)
};
Ok(r)
}
pub fn doti_f32(x: &[f32], indx: &[sys::aoclsparse_int], y: &[f32]) -> Result<f32> {
if x.len() != indx.len() {
return Err(Error::InvalidArgument(format!(
"doti: x.len()={} != indx.len()={}",
x.len(),
indx.len()
)));
}
let r = unsafe {
sys::aoclsparse_sdoti(
x.len() as sys::aoclsparse_int,
x.as_ptr(),
indx.as_ptr(),
y.as_ptr(),
)
};
Ok(r)
}
#[allow(clippy::too_many_arguments)]
pub fn csr2ell_f64(
m: usize,
descr: &MatDescr,
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
csr_val: &[f64],
ell_col_ind: &mut [sys::aoclsparse_int],
ell_val: &mut [f64],
ell_width: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dcsr2ell(
m as sys::aoclsparse_int,
descr.as_raw(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
csr_val.as_ptr(),
ell_col_ind.as_mut_ptr(),
ell_val.as_mut_ptr(),
ell_width as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn csr2ell_f32(
m: usize,
descr: &MatDescr,
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
csr_val: &[f32],
ell_col_ind: &mut [sys::aoclsparse_int],
ell_val: &mut [f32],
ell_width: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_scsr2ell(
m as sys::aoclsparse_int,
descr.as_raw(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
csr_val.as_ptr(),
ell_col_ind.as_mut_ptr(),
ell_val.as_mut_ptr(),
ell_width as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn csr2dia_f64(
m: usize,
n: usize,
descr: &MatDescr,
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
csr_val: &[f64],
dia_num_diag: usize,
dia_offset: &mut [sys::aoclsparse_int],
dia_val: &mut [f64],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dcsr2dia(
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
descr.as_raw(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
csr_val.as_ptr(),
dia_num_diag as sys::aoclsparse_int,
dia_offset.as_mut_ptr(),
dia_val.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn csr2dia_f32(
m: usize,
n: usize,
descr: &MatDescr,
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
csr_val: &[f32],
dia_num_diag: usize,
dia_offset: &mut [sys::aoclsparse_int],
dia_val: &mut [f32],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_scsr2dia(
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
descr.as_raw(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
csr_val.as_ptr(),
dia_num_diag as sys::aoclsparse_int,
dia_offset.as_mut_ptr(),
dia_val.as_mut_ptr(),
)
};
check_status("sparse", status)
}
pub fn csr2bsr_nnz(
m: usize,
n: usize,
descr: &MatDescr,
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
block_dim: usize,
bsr_row_ptr: &mut [sys::aoclsparse_int],
) -> Result<usize> {
let mut bsr_nnz: sys::aoclsparse_int = 0;
let status = unsafe {
sys::aoclsparse_csr2bsr_nnz(
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
descr.as_raw(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
block_dim as sys::aoclsparse_int,
bsr_row_ptr.as_mut_ptr(),
&mut bsr_nnz,
)
};
check_status("sparse", status)?;
Ok(bsr_nnz as usize)
}
#[allow(clippy::too_many_arguments)]
pub fn csr2bsr_f64(
m: usize,
n: usize,
descr: &MatDescr,
csr_val: &[f64],
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
block_dim: usize,
bsr_val: &mut [f64],
bsr_row_ptr: &mut [sys::aoclsparse_int],
bsr_col_ind: &mut [sys::aoclsparse_int],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dcsr2bsr(
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
descr.as_raw(),
csr_val.as_ptr(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
block_dim as sys::aoclsparse_int,
bsr_val.as_mut_ptr(),
bsr_row_ptr.as_mut_ptr(),
bsr_col_ind.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn csr2bsr_f32(
m: usize,
n: usize,
descr: &MatDescr,
csr_val: &[f32],
csr_row_ptr: &[sys::aoclsparse_int],
csr_col_ind: &[sys::aoclsparse_int],
block_dim: usize,
bsr_val: &mut [f32],
bsr_row_ptr: &mut [sys::aoclsparse_int],
bsr_col_ind: &mut [sys::aoclsparse_int],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_scsr2bsr(
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
descr.as_raw(),
csr_val.as_ptr(),
csr_row_ptr.as_ptr(),
csr_col_ind.as_ptr(),
block_dim as sys::aoclsparse_int,
bsr_val.as_mut_ptr(),
bsr_row_ptr.as_mut_ptr(),
bsr_col_ind.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn blkcsrmv_f64(
op: Trans,
alpha: f64,
m: usize,
n: usize,
masks: &[u8],
csr_val: &[f64],
csr_col_ind: &[sys::aoclsparse_int],
csr_row_ptr: &[sys::aoclsparse_int],
descr: &MatDescr,
x: &[f64],
beta: f64,
y: &mut [f64],
n_rows_blk: usize,
) -> Result<()> {
let nnz = csr_val.len();
let status = unsafe {
sys::aoclsparse_dblkcsrmv(
trans_raw(op),
&alpha,
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
masks.as_ptr(),
csr_val.as_ptr(),
csr_col_ind.as_ptr(),
csr_row_ptr.as_ptr(),
descr.as_raw(),
x.as_ptr(),
&beta,
y.as_mut_ptr(),
n_rows_blk as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn symgs_f64(
op: Trans,
a: &SparseMatrix<f64>,
descr: &MatDescr,
alpha: f64,
b: &[f64],
x: &mut [f64],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dsymgs(
trans_raw(op),
a.as_raw(),
descr.as_raw(),
alpha,
b.as_ptr(),
x.as_mut_ptr(),
)
};
check_status("sparse", status)
}
pub fn symgs_f32(
op: Trans,
a: &SparseMatrix<f32>,
descr: &MatDescr,
alpha: f32,
b: &[f32],
x: &mut [f32],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_ssymgs(
trans_raw(op),
a.as_raw(),
descr.as_raw(),
alpha,
b.as_ptr(),
x.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn symgs_mv_f64(
op: Trans,
a: &SparseMatrix<f64>,
descr: &MatDescr,
alpha: f64,
b: &[f64],
x: &mut [f64],
y: &mut [f64],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dsymgs_mv(
trans_raw(op),
a.as_raw(),
descr.as_raw(),
alpha,
b.as_ptr(),
x.as_mut_ptr(),
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn symgs_mv_f32(
op: Trans,
a: &SparseMatrix<f32>,
descr: &MatDescr,
alpha: f32,
b: &[f32],
x: &mut [f32],
y: &mut [f32],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_ssymgs_mv(
trans_raw(op),
a.as_raw(),
descr.as_raw(),
alpha,
b.as_ptr(),
x.as_mut_ptr(),
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
pub fn set_value_f64(
a: &mut SparseMatrix<f64>,
row_idx: i32,
col_idx: i32,
val: f64,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dset_value(
a.as_raw(),
row_idx as sys::aoclsparse_int,
col_idx as sys::aoclsparse_int,
val,
)
};
check_status("sparse", status)
}
pub fn set_value_f32(
a: &mut SparseMatrix<f32>,
row_idx: i32,
col_idx: i32,
val: f32,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_sset_value(
a.as_raw(),
row_idx as sys::aoclsparse_int,
col_idx as sys::aoclsparse_int,
val,
)
};
check_status("sparse", status)
}
pub fn update_values_f64(a: &mut SparseMatrix<f64>, val: &mut [f64]) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dupdate_values(
a.as_raw(),
val.len() as sys::aoclsparse_int,
val.as_mut_ptr(),
)
};
check_status("sparse", status)
}
pub fn update_values_f32(a: &mut SparseMatrix<f32>, val: &mut [f32]) -> Result<()> {
let status = unsafe {
sys::aoclsparse_supdate_values(
a.as_raw(),
val.len() as sys::aoclsparse_int,
val.as_mut_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn dotmv_f64(
op: Trans,
alpha: f64,
a: &SparseMatrix<f64>,
descr: &MatDescr,
x: &[f64],
beta: f64,
y: &mut [f64],
d: &mut f64,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_ddotmv(
trans_raw(op),
alpha,
a.as_raw(),
descr.as_raw(),
x.as_ptr(),
beta,
y.as_mut_ptr(),
d,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn dotmv_f32(
op: Trans,
alpha: f32,
a: &SparseMatrix<f32>,
descr: &MatDescr,
x: &[f32],
beta: f32,
y: &mut [f32],
d: &mut f32,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_sdotmv(
trans_raw(op),
alpha,
a.as_raw(),
descr.as_raw(),
x.as_ptr(),
beta,
y.as_mut_ptr(),
d,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn syrkd_f64(
op_a: Trans,
a: &SparseMatrix<f64>,
alpha: f64,
beta: f64,
c: &mut [f64],
order_c: Order,
ldc: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dsyrkd(
trans_raw(op_a),
a.as_raw(),
alpha,
beta,
c.as_mut_ptr(),
order_c.raw(),
ldc as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn syrkd_f32(
op_a: Trans,
a: &SparseMatrix<f32>,
alpha: f32,
beta: f32,
c: &mut [f32],
order_c: Order,
ldc: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_ssyrkd(
trans_raw(op_a),
a.as_raw(),
alpha,
beta,
c.as_mut_ptr(),
order_c.raw(),
ldc as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn syprd_f64(
op_a: Trans,
a: &SparseMatrix<f64>,
b: &[f64],
order_b: Order,
ldb: usize,
alpha: f64,
beta: f64,
c: &mut [f64],
order_c: Order,
ldc: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dsyprd(
trans_raw(op_a),
a.as_raw(),
b.as_ptr(),
order_b.raw(),
ldb as sys::aoclsparse_int,
alpha,
beta,
c.as_mut_ptr(),
order_c.raw(),
ldc as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn syprd_f32(
op_a: Trans,
a: &SparseMatrix<f32>,
b: &[f32],
order_b: Order,
ldb: usize,
alpha: f32,
beta: f32,
c: &mut [f32],
order_c: Order,
ldc: usize,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_ssyprd(
trans_raw(op_a),
a.as_raw(),
b.as_ptr(),
order_b.raw(),
ldb as sys::aoclsparse_int,
alpha,
beta,
c.as_mut_ptr(),
order_c.raw(),
ldc as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn roti_f64(
x: &mut [f64],
indx: &[sys::aoclsparse_int],
y: &mut [f64],
c: f64,
s: f64,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_droti(
x.len() as sys::aoclsparse_int,
x.as_mut_ptr(),
indx.as_ptr(),
y.as_mut_ptr(),
c,
s,
)
};
check_status("sparse", status)
}
pub fn roti_f32(
x: &mut [f32],
indx: &[sys::aoclsparse_int],
y: &mut [f32],
c: f32,
s: f32,
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_sroti(
x.len() as sys::aoclsparse_int,
x.as_mut_ptr(),
indx.as_ptr(),
y.as_mut_ptr(),
c,
s,
)
};
check_status("sparse", status)
}
pub fn gthrs_f64(y: &[f64], x: &mut [f64], stride: i32) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dgthrs(
x.len() as sys::aoclsparse_int,
y.as_ptr(),
x.as_mut_ptr(),
stride as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn gthrs_f32(y: &[f32], x: &mut [f32], stride: i32) -> Result<()> {
let status = unsafe {
sys::aoclsparse_sgthrs(
x.len() as sys::aoclsparse_int,
y.as_ptr(),
x.as_mut_ptr(),
stride as sys::aoclsparse_int,
)
};
check_status("sparse", status)
}
pub fn sctrs_f64(x: &[f64], y: &mut [f64], stride: i32) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dsctrs(
x.len() as sys::aoclsparse_int,
x.as_ptr(),
stride as sys::aoclsparse_int,
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
pub fn sctrs_f32(x: &[f32], y: &mut [f32], stride: i32) -> Result<()> {
let status = unsafe {
sys::aoclsparse_ssctrs(
x.len() as sys::aoclsparse_int,
x.as_ptr(),
stride as sys::aoclsparse_int,
y.as_mut_ptr(),
)
};
check_status("sparse", status)
}
pub fn gthrz_f64(y: &mut [f64], indx: &[sys::aoclsparse_int], x: &mut [f64]) -> Result<()> {
let status = unsafe {
sys::aoclsparse_dgthrz(
x.len() as sys::aoclsparse_int,
y.as_mut_ptr(),
x.as_mut_ptr(),
indx.as_ptr(),
)
};
check_status("sparse", status)
}
pub fn gthrz_f32(y: &mut [f32], indx: &[sys::aoclsparse_int], x: &mut [f32]) -> Result<()> {
let status = unsafe {
sys::aoclsparse_sgthrz(
x.len() as sys::aoclsparse_int,
y.as_mut_ptr(),
x.as_mut_ptr(),
indx.as_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn sorv_f32(
sor_type: SorType,
descr: &MatDescr,
a: &SparseMatrix<f32>,
omega: f32,
alpha: f32,
x: &mut [f32],
b: &[f32],
) -> Result<()> {
let status = unsafe {
sys::aoclsparse_ssorv(
sor_type.raw(),
descr.as_raw(),
a.as_raw(),
omega,
alpha,
x.as_mut_ptr(),
b.as_ptr(),
)
};
check_status("sparse", status)
}
#[allow(clippy::too_many_arguments)]
pub fn create_tcsr_f64(
base: IndexBase,
m: usize,
n: usize,
nnz: usize,
row_ptr_l: &mut [sys::aoclsparse_int],
row_ptr_u: &mut [sys::aoclsparse_int],
col_idx_l: &mut [sys::aoclsparse_int],
col_idx_u: &mut [sys::aoclsparse_int],
val_l: &mut [f64],
val_u: &mut [f64],
) -> Result<sys::aoclsparse_matrix> {
let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
let status = unsafe {
sys::aoclsparse_create_dtcsr(
&mut raw,
base.raw(),
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
row_ptr_l.as_mut_ptr(),
row_ptr_u.as_mut_ptr(),
col_idx_l.as_mut_ptr(),
col_idx_u.as_mut_ptr(),
val_l.as_mut_ptr(),
val_u.as_mut_ptr(),
)
};
check_status("sparse", status)?;
if raw.is_null() {
return Err(Error::AllocationFailed("sparse"));
}
Ok(raw)
}
#[allow(clippy::too_many_arguments)]
pub fn create_tcsr_f32(
base: IndexBase,
m: usize,
n: usize,
nnz: usize,
row_ptr_l: &mut [sys::aoclsparse_int],
row_ptr_u: &mut [sys::aoclsparse_int],
col_idx_l: &mut [sys::aoclsparse_int],
col_idx_u: &mut [sys::aoclsparse_int],
val_l: &mut [f32],
val_u: &mut [f32],
) -> Result<sys::aoclsparse_matrix> {
let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
let status = unsafe {
sys::aoclsparse_create_stcsr(
&mut raw,
base.raw(),
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
row_ptr_l.as_mut_ptr(),
row_ptr_u.as_mut_ptr(),
col_idx_l.as_mut_ptr(),
col_idx_u.as_mut_ptr(),
val_l.as_mut_ptr(),
val_u.as_mut_ptr(),
)
};
check_status("sparse", status)?;
if raw.is_null() {
return Err(Error::AllocationFailed("sparse"));
}
Ok(raw)
}
#[allow(clippy::too_many_arguments)]
pub fn create_csc_f64(
base: IndexBase,
m: usize,
n: usize,
nnz: usize,
col_ptr: &mut [sys::aoclsparse_int],
row_idx: &mut [sys::aoclsparse_int],
val: &mut [f64],
) -> Result<sys::aoclsparse_matrix> {
let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
let status = unsafe {
sys::aoclsparse_create_dcsc(
&mut raw,
base.raw(),
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
col_ptr.as_mut_ptr(),
row_idx.as_mut_ptr(),
val.as_mut_ptr(),
)
};
check_status("sparse", status)?;
if raw.is_null() {
return Err(Error::AllocationFailed("sparse"));
}
Ok(raw)
}
#[allow(clippy::too_many_arguments)]
pub fn create_csc_f32(
base: IndexBase,
m: usize,
n: usize,
nnz: usize,
col_ptr: &mut [sys::aoclsparse_int],
row_idx: &mut [sys::aoclsparse_int],
val: &mut [f32],
) -> Result<sys::aoclsparse_matrix> {
let mut raw: sys::aoclsparse_matrix = std::ptr::null_mut();
let status = unsafe {
sys::aoclsparse_create_scsc(
&mut raw,
base.raw(),
m as sys::aoclsparse_int,
n as sys::aoclsparse_int,
nnz as sys::aoclsparse_int,
col_ptr.as_mut_ptr(),
row_idx.as_mut_ptr(),
val.as_mut_ptr(),
)
};
check_status("sparse", status)?;
if raw.is_null() {
return Err(Error::AllocationFailed("sparse"));
}
Ok(raw)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemoryUsage {
Minimal,
Unrestricted,
}
impl MemoryUsage {
fn raw(self) -> sys::aoclsparse_memory_usage {
match self {
MemoryUsage::Minimal => sys::aoclsparse_memory_usage__aoclsparse_memory_usage_minimal,
MemoryUsage::Unrestricted => {
sys::aoclsparse_memory_usage__aoclsparse_memory_usage_unrestricted
}
}
}
}
pub fn set_memory_hint<T: Scalar>(mat: &mut SparseMatrix<T>, policy: MemoryUsage) -> Result<()> {
let status = unsafe { sys::aoclsparse_set_memory_hint(mat.as_raw(), policy.raw()) };
check_status("sparse", status)
}
pub fn ilu_smoother<T: Scalar>(
op: Trans,
a: &SparseMatrix<T>,
descr: &MatDescr,
x: &mut [T],
b: &[T],
) -> Result<()> {
if x.len() < a.n || b.len() < a.m {
return Err(Error::InvalidArgument(format!(
"ilu_smoother: x.len()={}, b.len()={}, dims=({}, {})",
x.len(),
b.len(),
a.m,
a.n
)));
}
T::ilu_smoother(op, a.raw, descr, x, b)
}
pub struct IterSolver<T: Scalar> {
handle: sys::aoclsparse_itsol_handle,
_marker: PhantomData<T>,
}
impl<T: Scalar> IterSolver<T> {
pub fn new() -> Result<Self> {
let mut handle: sys::aoclsparse_itsol_handle = std::ptr::null_mut();
T::itsol_init(&mut handle)?;
if handle.is_null() {
return Err(Error::AllocationFailed("sparse"));
}
Ok(Self {
handle,
_marker: PhantomData,
})
}
pub fn set_option(&mut self, name: &str, value: &str) -> Result<()> {
let c_name = CString::new(name)
.map_err(|_| Error::InvalidArgument("set_option: name has interior NUL".into()))?;
let c_value = CString::new(value)
.map_err(|_| Error::InvalidArgument("set_option: value has interior NUL".into()))?;
let status = unsafe {
sys::aoclsparse_itsol_option_set(self.handle, c_name.as_ptr(), c_value.as_ptr())
};
check_status("sparse", status)
}
pub fn solve(
&mut self,
mat: &SparseMatrix<T>,
descr: &MatDescr,
b: &[T],
x: &mut [T],
) -> Result<Box<[T; 100]>>
where
T: Default,
{
let n = mat.n;
if mat.m != mat.n {
return Err(Error::InvalidArgument(format!(
"iterative solve requires square matrix; got ({}, {})",
mat.m, mat.n
)));
}
let mut rinfo: Box<[T; 100]> = Box::new([T::default(); 100]);
T::itsol_solve(self.handle, n, mat.raw, descr, b, x, &mut rinfo)?;
Ok(rinfo)
}
}
impl<T: Scalar> Drop for IterSolver<T> {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe {
sys::aoclsparse_itsol_destroy(&mut self.handle);
}
self.handle = std::ptr::null_mut();
}
}
}
impl<T: Scalar> std::fmt::Debug for IterSolver<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IterSolver").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn csrmv_2x2_identity_f64() {
let val = [1.0_f64, 1.0];
let col: [sys::aoclsparse_int; 2] = [0, 1];
let rowptr: [sys::aoclsparse_int; 3] = [0, 1, 2];
let x = [3.0_f64, 4.0];
let mut y = [0.0_f64; 2];
let descr = MatDescr::new().unwrap();
csrmv(1.0_f64, 2, 2, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap();
assert!((y[0] - 3.0).abs() < 1e-12);
assert!((y[1] - 4.0).abs() < 1e-12);
}
#[test]
fn csrmv_simple_2x3() {
let val = [1.0_f64, 2.0, 3.0];
let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
let rowptr: [sys::aoclsparse_int; 3] = [0, 2, 3];
let x = [1.0_f64; 3];
let mut y = [0.0_f64; 2];
let descr = MatDescr::new().unwrap();
csrmv(1.0_f64, 2, 3, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap();
assert!((y[0] - 3.0).abs() < 1e-12, "got {}", y[0]);
assert!((y[1] - 3.0).abs() < 1e-12, "got {}", y[1]);
}
#[test]
fn dim_mismatch_is_error() {
let val = [1.0_f64];
let col: [sys::aoclsparse_int; 1] = [0];
let rowptr: [sys::aoclsparse_int; 2] = [0, 1];
let x = [1.0_f64];
let mut y = [0.0_f64; 2];
let descr = MatDescr::new().unwrap();
let err = csrmv(1.0_f64, 2, 1, &val, &col, &rowptr, &descr, &x, 0.0, &mut y).unwrap_err();
matches!(err, Error::InvalidArgument(_));
}
#[test]
fn axpyi_scatter() {
let mut y = [10.0_f64, 20.0, 30.0, 40.0];
let x = [1.0_f64, 2.0];
let indx: [sys::aoclsparse_int; 2] = [0, 2];
axpyi(3.0_f64, &x, &indx, &mut y).unwrap();
assert_eq!(y, [13.0, 20.0, 36.0, 40.0]);
}
#[test]
fn gthr_scatter_round_trip() {
let y = [10.0_f64, 20.0, 30.0, 40.0];
let indx: [sys::aoclsparse_int; 2] = [1, 3];
let mut x = [0.0_f64; 2];
gthr(&y, &indx, &mut x).unwrap();
assert_eq!(x, [20.0, 40.0]);
let mut y2 = [0.0_f64; 4];
sctr(&x, &indx, &mut y2).unwrap();
assert_eq!(y2, [0.0, 20.0, 0.0, 40.0]);
}
#[test]
fn add_identity_plus_identity_is_2_diag() {
let val = [1.0_f64, 1.0];
let col: [sys::aoclsparse_int; 2] = [0, 1];
let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
let c = add(Trans::No, &a, 1.0, &b).unwrap();
let (_, _, _, val_c) = c.export_csr().unwrap();
assert_eq!(val_c.len(), 2);
for v in &val_c {
assert!((v - 2.0).abs() < 1e-12, "got {v}, want 2.0");
}
}
#[test]
fn csrmm_2x2_identity_against_2x3_dense() {
let val = [1.0_f64, 1.0];
let col: [sys::aoclsparse_int; 2] = [0, 1];
let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 2, &rp, &col, &val).unwrap();
let descr = MatDescr::new().unwrap();
let b: [f64; 6] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut c = [0.0_f64; 6];
csrmm(
Trans::No,
1.0,
&a,
&descr,
Order::RowMajor,
&b,
3,
3,
0.0,
&mut c,
3,
)
.unwrap();
for (got, want) in c.iter().zip(b.iter()) {
assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
}
}
#[test]
fn spmmd_identity_squared_yields_identity_dense() {
let val = [1.0_f64; 3];
let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
let mut c = [0.0_f64; 9];
spmmd(Trans::No, &a, &b, Order::RowMajor, &mut c, 3).unwrap();
let expected = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
for (got, want) in c.iter().zip(expected.iter()) {
assert!((got - want).abs() < 1e-12, "got {got}, want {want}");
}
}
#[test]
fn ellmv_2x3_f64() {
let val: [f64; 4] = [1.0, 2.0, 3.0, 4.0];
let col: [sys::aoclsparse_int; 4] = [0, 2, 0, 1];
let descr = MatDescr::new().unwrap();
let x = [10.0_f64, 20.0, 30.0];
let mut y = [0.0_f64; 2];
ellmv(
Trans::No,
1.0_f64,
2,
3,
&val,
&col,
2,
&descr,
&x,
0.0,
&mut y,
)
.unwrap();
assert!((y[0] - 70.0).abs() < 1e-12, "got {}", y[0]);
assert!((y[1] - 110.0).abs() < 1e-12, "got {}", y[1]);
}
#[test]
fn bsrmv_2x2_blocks_f64() {
let val: [f64; 8] = [1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0];
let col: [sys::aoclsparse_int; 2] = [0, 1];
let rp: [sys::aoclsparse_int; 3] = [0, 1, 2];
let descr = MatDescr::new().unwrap();
let x = [1.0_f64, 2.0, 3.0, 4.0];
let mut y = [0.0_f64; 4];
bsrmv(
Trans::No,
1.0_f64,
2,
2,
2,
&val,
&col,
&rp,
&descr,
&x,
0.0,
&mut y,
)
.unwrap();
assert!((y[0] - 1.0).abs() < 1e-12);
assert!((y[1] - 2.0).abs() < 1e-12);
assert!((y[2] - 6.0).abs() < 1e-12);
assert!((y[3] - 8.0).abs() < 1e-12);
}
#[test]
fn sparse_matrix_round_trip() {
let val = [1.0_f64, 2.0, 3.0];
let col: [sys::aoclsparse_int; 3] = [0, 2, 1];
let rp: [sys::aoclsparse_int; 3] = [0, 2, 3];
let mat = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 2, 3, &rp, &col, &val).unwrap();
assert_eq!(mat.dims(), (2, 3));
assert_eq!(mat.nnz(), 3);
assert_eq!(mat.base(), IndexBase::Zero);
let (base, rp2, col2, val2) = mat.export_csr().unwrap();
assert_eq!(base, IndexBase::Zero);
assert_eq!(rp2, [0, 2, 3]);
assert_eq!(col2, [0, 2, 1]);
assert_eq!(val2, [1.0, 2.0, 3.0]);
}
#[test]
fn csr2m_identity_squared_is_identity() {
let val = [1.0_f64; 3];
let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
let a = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
let b = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
let descr = MatDescr::new().unwrap();
let c = csr2m(
Trans::No,
&descr,
&a,
Trans::No,
&descr,
&b,
Stage::FullComputation,
)
.unwrap();
assert_eq!(c.dims(), (3, 3));
let (_, rp_c, col_c, val_c) = c.export_csr().unwrap();
assert_eq!(rp_c, [0, 1, 2, 3]);
assert_eq!(col_c, [0, 1, 2]);
for v in &val_c {
assert!((v - 1.0).abs() < 1e-12);
}
}
#[test]
fn iter_solver_cg_diagonal_3x3() {
let val = [2.0_f64, 2.0, 2.0];
let col: [sys::aoclsparse_int; 3] = [0, 1, 2];
let rp: [sys::aoclsparse_int; 4] = [0, 1, 2, 3];
let mat = SparseMatrix::<f64>::from_csr(IndexBase::Zero, 3, 3, &rp, &col, &val).unwrap();
let descr = MatDescr::new().unwrap();
unsafe {
sys::aoclsparse_set_mat_type(
descr.as_raw(),
sys::aoclsparse_matrix_type__aoclsparse_matrix_type_symmetric,
);
}
let b = [4.0_f64, 6.0, 10.0];
let mut x = [0.0_f64; 3];
let mut solver = IterSolver::<f64>::new().unwrap();
solver.set_option("iterative method", "cg").unwrap();
solver.set_option("cg rel tolerance", "1e-10").unwrap();
solver.set_option("cg iteration limit", "200").unwrap();
solver.solve(&mat, &descr, &b, &mut x).unwrap();
assert!((x[0] - 2.0).abs() < 1e-6, "x[0] = {}", x[0]);
assert!((x[1] - 3.0).abs() < 1e-6, "x[1] = {}", x[1]);
assert!((x[2] - 5.0).abs() < 1e-6, "x[2] = {}", x[2]);
}
#[test]
fn csr_to_dense_round_trip() {
let val = [1.0_f64, 2.0, 3.0];
let col: [sys::aoclsparse_int; 3] = [0, 2, 1];
let rp: [sys::aoclsparse_int; 3] = [0, 2, 3];
let descr = MatDescr::new().unwrap();
let mut dense = [0.0_f64; 6];
csr_to_dense::<f64>(
2,
3,
&descr,
&val,
&rp,
&col,
&mut dense,
3,
Order::RowMajor,
)
.unwrap();
assert_eq!(dense, [1.0, 0.0, 2.0, 0.0, 3.0, 0.0]);
}
}