use crate::blas::{blas_order, blas_transpose};
use crate::bridge;
use crate::error::{Error, Result};
use core::ffi::c_void;
use core::ptr;
pub type SparseIndex = i64;
pub mod sparse_matrix_property {
pub const UPPER_TRIANGULAR: i32 = 1;
pub const LOWER_TRIANGULAR: i32 = 2;
pub const UPPER_SYMMETRIC: i32 = 4;
pub const LOWER_SYMMETRIC: i32 = 8;
}
pub mod sparse_status {
pub const SUCCESS: i32 = 0;
pub const ILLEGAL_PARAMETER: i32 = -1000;
pub const CANNOT_SET_PROPERTY: i32 = -1001;
pub const SYSTEM_ERROR: i32 = -1002;
}
fn u64_len(value: usize) -> Result<u64> {
u64::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension overflowed"))
}
fn i64_index(value: usize) -> Result<i64> {
i64::try_from(value).map_err(|_| Error::OperationFailed("sparse index overflowed"))
}
fn usize_dimension(value: u64) -> Result<usize> {
usize::try_from(value).map_err(|_| Error::OperationFailed("sparse dimension exceeds usize"))
}
fn usize_count(value: i64) -> Result<usize> {
if value < 0 {
return Err(Error::SparseStatus(
i32::try_from(value).unwrap_or(sparse_status::SYSTEM_ERROR),
));
}
usize::try_from(value).map_err(|_| Error::OperationFailed("sparse count exceeds usize"))
}
fn sparse_result(status: i32) -> Result<()> {
if status == sparse_status::SUCCESS {
Ok(())
} else {
Err(Error::SparseStatus(status))
}
}
fn validate_sparse_entries(values: &[f32], indices: &[SparseIndex]) -> Result<()> {
if values.len() != indices.len() {
return Err(Error::InvalidLength {
expected: values.len(),
actual: indices.len(),
});
}
for window in indices.windows(2) {
if window[0] >= window[1] {
return Err(Error::InvalidValue(
"sparse indices must be strictly increasing and unique",
));
}
}
Ok(())
}
fn validate_dense_coverage(indices: &[SparseIndex], dense_len: usize) -> Result<()> {
if let Some(&max_index) = indices.last() {
let max_index = usize::try_from(max_index)
.map_err(|_| Error::InvalidValue("sparse indices must be non-negative"))?;
if max_index >= dense_len {
return Err(Error::InvalidLength {
expected: max_index + 1,
actual: dense_len,
});
}
}
Ok(())
}
pub struct SparseMatrixF32 {
ptr: *mut c_void,
}
unsafe impl Send for SparseMatrixF32 {}
unsafe impl Sync for SparseMatrixF32 {}
impl Drop for SparseMatrixF32 {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe { bridge::acc_release_handle(self.ptr) };
self.ptr = ptr::null_mut();
}
}
}
impl SparseMatrixF32 {
#[must_use]
pub fn new(rows: usize, columns: usize) -> Option<Self> {
if rows == 0 || columns == 0 {
return None;
}
let rows = u64::try_from(rows).ok()?;
let columns = u64::try_from(columns).ok()?;
let ptr = unsafe { bridge::acc_sparse_matrix_f32_create(rows, columns) };
if ptr.is_null() {
None
} else {
Some(Self { ptr })
}
}
pub fn set_property(&mut self, property: i32) -> Result<()> {
sparse_result(unsafe { bridge::acc_sparse_matrix_f32_set_property(self.ptr, property) })
}
pub fn insert_entry(&mut self, row: usize, column: usize, value: f32) -> Result<()> {
let rows = self.rows()?;
let columns = self.columns()?;
if row >= rows || column >= columns {
return Err(Error::InvalidValue(
"sparse entry coordinates must be within matrix bounds",
));
}
let row = i64_index(row)?;
let column = i64_index(column)?;
sparse_result(unsafe { bridge::acc_sparse_matrix_f32_insert_entry(self.ptr, value, row, column) })
}
pub fn commit(&mut self) -> Result<()> {
sparse_result(unsafe { bridge::acc_sparse_matrix_f32_commit(self.ptr) })
}
pub fn rows(&self) -> Result<usize> {
usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_rows(self.ptr) })
}
pub fn columns(&self) -> Result<usize> {
usize_dimension(unsafe { bridge::acc_sparse_matrix_f32_columns(self.ptr) })
}
pub fn nonzero_count(&self) -> Result<usize> {
usize_count(unsafe { bridge::acc_sparse_matrix_f32_nonzero_count(self.ptr) })
}
pub fn triangular_solve_vector(
&self,
transpose: i32,
alpha: f32,
values: &mut [f32],
) -> Result<()> {
let rows = self.rows()?;
let columns = self.columns()?;
if rows != columns {
return Err(Error::InvalidValue(
"sparse triangular solve requires a square matrix",
));
}
if values.len() != rows {
return Err(Error::InvalidLength {
expected: rows,
actual: values.len(),
});
}
let len = u64_len(values.len())?;
sparse_result(unsafe {
bridge::acc_sparse_matrix_f32_triangular_solve_vector(
self.ptr,
transpose,
alpha,
values.as_mut_ptr(),
len,
)
})
}
pub fn triangular_solve_matrix_row_major(
&self,
transpose: i32,
rhs_columns: usize,
alpha: f32,
values: &mut [f32],
) -> Result<()> {
let rows = self.rows()?;
let columns = self.columns()?;
if rows != columns {
return Err(Error::InvalidValue(
"sparse triangular solve requires a square matrix",
));
}
let expected = rows
.checked_mul(rhs_columns)
.ok_or(Error::OperationFailed("sparse rhs dimensions overflowed"))?;
if values.len() != expected {
return Err(Error::InvalidLength {
expected,
actual: values.len(),
});
}
if rhs_columns == 0 {
return Ok(());
}
let rhs_count = u64_len(rhs_columns)?;
let ldb = u64_len(rhs_columns)?;
sparse_result(unsafe {
bridge::acc_sparse_matrix_f32_triangular_solve_matrix(
self.ptr,
blas_order::ROW_MAJOR,
transpose,
rhs_count,
alpha,
values.as_mut_ptr(),
ldb,
)
})
}
}
pub fn dot_dense_f32(values: &[f32], indices: &[SparseIndex], dense: &[f32]) -> Result<f32> {
validate_sparse_entries(values, indices)?;
validate_dense_coverage(indices, dense.len())?;
if values.is_empty() {
return Ok(0.0);
}
let nz = u64_len(values.len())?;
Ok(unsafe { bridge::acc_sparse_dot_dense_f32(nz, values.as_ptr(), indices.as_ptr(), dense.as_ptr()) })
}
pub fn dot_sparse_f32(
lhs_values: &[f32],
lhs_indices: &[SparseIndex],
rhs_values: &[f32],
rhs_indices: &[SparseIndex],
) -> Result<f32> {
validate_sparse_entries(lhs_values, lhs_indices)?;
validate_sparse_entries(rhs_values, rhs_indices)?;
if lhs_values.is_empty() || rhs_values.is_empty() {
return Ok(0.0);
}
let lhs_count = u64_len(lhs_values.len())?;
let rhs_count = u64_len(rhs_values.len())?;
Ok(unsafe {
bridge::acc_sparse_dot_sparse_f32(
lhs_count,
lhs_values.as_ptr(),
lhs_indices.as_ptr(),
rhs_count,
rhs_values.as_ptr(),
rhs_indices.as_ptr(),
)
})
}
pub fn add_to_dense_f32(
values: &[f32],
indices: &[SparseIndex],
alpha: f32,
dense: &mut [f32],
) -> Result<()> {
validate_sparse_entries(values, indices)?;
validate_dense_coverage(indices, dense.len())?;
if values.is_empty() {
return Ok(());
}
let nz = u64_len(values.len())?;
let ok = unsafe {
bridge::acc_sparse_add_to_dense_f32(
nz,
alpha,
values.as_ptr(),
indices.as_ptr(),
dense.as_mut_ptr(),
)
};
if ok {
Ok(())
} else {
Err(Error::SparseStatus(-1))
}
}
#[allow(dead_code)]
const _: i32 = blas_transpose::NO_TRANS;