use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::{Float, SparseElement, Zero};
use std::fmt::Debug;
use std::ops::{Add, Div, Mul, Sub};
use crate::error::{SparseError, SparseResult};
pub trait SparseArray<T>: std::any::Any
where
T: SparseElement + Div<Output = T> + 'static,
{
fn shape(&self) -> (usize, usize);
fn nnz(&self) -> usize;
fn dtype(&self) -> &str;
fn to_array(&self) -> Array2<T>;
fn toarray(&self) -> Array2<T>;
fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>>;
fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>>;
fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>>;
fn copy(&self) -> Box<dyn SparseArray<T>>;
fn get(&self, i: usize, j: usize) -> T;
fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()>;
fn eliminate_zeros(&mut self);
fn sort_indices(&mut self);
fn sorted_indices(&self) -> Box<dyn SparseArray<T>>;
fn has_sorted_indices(&self) -> bool;
fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>>;
fn max(&self) -> T;
fn min(&self) -> T;
fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>);
fn slice(
&self,
row_range: (usize, usize),
col_range: (usize, usize),
) -> SparseResult<Box<dyn SparseArray<T>>>;
fn as_any(&self) -> &dyn std::any::Any;
fn get_indptr(&self) -> Option<&Array1<usize>> {
None
}
fn indptr(&self) -> Option<&Array1<usize>> {
None
}
}
pub enum SparseSum<T>
where
T: SparseElement + Div<Output = T> + 'static,
{
SparseArray(Box<dyn SparseArray<T>>),
Scalar(T),
}
impl<T> Debug for SparseSum<T>
where
T: SparseElement + Div<Output = T> + 'static,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SparseSum::SparseArray(_) => write!(f, "SparseSum::SparseArray(...)"),
SparseSum::Scalar(value) => write!(f, "SparseSum::Scalar({value:?})"),
}
}
}
impl<T> Clone for SparseSum<T>
where
T: SparseElement + Div<Output = T> + 'static,
{
fn clone(&self) -> Self {
match self {
SparseSum::SparseArray(array) => SparseSum::SparseArray(array.copy()),
SparseSum::Scalar(value) => SparseSum::Scalar(*value),
}
}
}
#[allow(dead_code)]
pub fn is_sparse<T>(obj: &dyn SparseArray<T>) -> bool
where
T: SparseElement + Div<Output = T> + 'static,
{
true }
pub struct SparseArrayBase<T>
where
T: SparseElement + Div<Output = T> + 'static,
{
data: Array2<T>,
}
impl<T> SparseArrayBase<T>
where
T: SparseElement + Div<Output = T> + Zero + 'static,
{
pub fn new(data: Array2<T>) -> Self {
Self { data }
}
}
impl<T> SparseArray<T> for SparseArrayBase<T>
where
T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
{
fn shape(&self) -> (usize, usize) {
let shape = self.data.shape();
(shape[0], shape[1])
}
fn nnz(&self) -> usize {
self.data.iter().filter(|&&x| x != T::sparse_zero()).count()
}
fn dtype(&self) -> &str {
"float" }
fn to_array(&self) -> Array2<T> {
self.data.clone()
}
fn toarray(&self) -> Array2<T> {
self.data.clone()
}
fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let other_array = other.to_array();
let result = &self.data + &other_array;
Ok(Box::new(SparseArrayBase::new(result)))
}
fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let other_array = other.to_array();
let result = &self.data - &other_array;
Ok(Box::new(SparseArrayBase::new(result)))
}
fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let other_array = other.to_array();
let result = &self.data * &other_array;
Ok(Box::new(SparseArrayBase::new(result)))
}
fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let other_array = other.to_array();
let result = &self.data / &other_array;
Ok(Box::new(SparseArrayBase::new(result)))
}
fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let other_array = other.to_array();
let (m, n) = self.shape();
let (p, q) = other.shape();
if n != p {
return Err(SparseError::DimensionMismatch {
expected: n,
found: p,
});
}
let mut result = Array2::zeros((m, q));
for i in 0..m {
for j in 0..q {
let mut sum = T::sparse_zero();
for k in 0..n {
sum = sum + self.data[[i, k]] * other_array[[k, j]];
}
result[[i, j]] = sum;
}
}
Ok(Box::new(SparseArrayBase::new(result)))
}
fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
let (m, n) = self.shape();
if n != other.len() {
return Err(SparseError::DimensionMismatch {
expected: n,
found: other.len(),
});
}
let mut result = Array1::zeros(m);
for i in 0..m {
let mut sum = T::sparse_zero();
for j in 0..n {
sum = sum + self.data[[i, j]] * other[j];
}
result[i] = sum;
}
Ok(result)
}
fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(SparseArrayBase::new(self.data.t().to_owned())))
}
fn copy(&self) -> Box<dyn SparseArray<T>> {
Box::new(self.clone())
}
fn get(&self, i: usize, j: usize) -> T {
self.data[[i, j]]
}
fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
let (m, n) = self.shape();
if i >= m || j >= n {
return Err(SparseError::IndexOutOfBounds {
index: (i, j),
shape: (m, n),
});
}
self.data[[i, j]] = value;
Ok(())
}
fn eliminate_zeros(&mut self) {
}
fn sort_indices(&mut self) {
}
fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
self.copy()
}
fn has_sorted_indices(&self) -> bool {
true }
fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
match axis {
None => {
let mut sum = T::sparse_zero();
for &val in self.data.iter() {
sum = sum + val;
}
Ok(SparseSum::Scalar(sum))
}
Some(0) => {
let (_, n) = self.shape();
let mut result = Array2::zeros((1, n));
for j in 0..n {
let mut sum = T::sparse_zero();
for i in 0..self.data.shape()[0] {
sum = sum + self.data[[i, j]];
}
result[[0, j]] = sum;
}
Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
result,
))))
}
Some(1) => {
let (m_, _) = self.shape();
let mut result = Array2::zeros((m_, 1));
for i in 0..m_ {
let mut sum = T::sparse_zero();
for j in 0..self.data.shape()[1] {
sum = sum + self.data[[i, j]];
}
result[[i, 0]] = sum;
}
Ok(SparseSum::SparseArray(Box::new(SparseArrayBase::new(
result,
))))
}
_ => Err(SparseError::InvalidAxis),
}
}
fn max(&self) -> T {
if self.data.is_empty() {
return T::sparse_zero();
}
let mut max_val = self.data[[0, 0]];
for &val in self.data.iter() {
if val > max_val {
max_val = val;
}
}
max_val
}
fn min(&self) -> T {
if self.data.is_empty() {
return T::sparse_zero();
}
let mut min_val = self.data[[0, 0]];
for &val in self.data.iter() {
if val < min_val {
min_val = val;
}
}
min_val
}
fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
let (m, n) = self.shape();
let nnz = self.nnz();
let mut rows = Vec::with_capacity(nnz);
let mut cols = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
for i in 0..m {
for j in 0..n {
let value = self.data[[i, j]];
if value != T::sparse_zero() {
rows.push(i);
cols.push(j);
values.push(value);
}
}
}
(
Array1::from_vec(rows),
Array1::from_vec(cols),
Array1::from_vec(values),
)
}
fn slice(
&self,
row_range: (usize, usize),
col_range: (usize, usize),
) -> SparseResult<Box<dyn SparseArray<T>>> {
let (start_row, end_row) = row_range;
let (start_col, end_col) = col_range;
let (m, n) = self.shape();
if start_row >= m
|| end_row > m
|| start_col >= n
|| end_col > n
|| start_row >= end_row
|| start_col >= end_col
{
return Err(SparseError::InvalidSliceRange);
}
let view = self.data.slice(scirs2_core::ndarray::s![
start_row..end_row,
start_col..end_col
]);
Ok(Box::new(SparseArrayBase::new(view.to_owned())))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T> Clone for SparseArrayBase<T>
where
T: SparseElement + Div<Output = T> + 'static,
{
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
#[test]
fn test_sparse_array_base() {
let data = Array::from_shape_vec((3, 3), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0])
.expect("Operation failed");
let sparse = SparseArrayBase::new(data);
assert_eq!(sparse.shape(), (3, 3));
assert_eq!(sparse.nnz(), 5);
assert_eq!(sparse.get(0, 0), 1.0);
assert_eq!(sparse.get(1, 1), 3.0);
assert_eq!(sparse.get(2, 2), 5.0);
assert_eq!(sparse.get(0, 1), 0.0);
}
#[test]
fn test_sparse_array_operations() {
let data1 =
Array::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed");
let data2 =
Array::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).expect("Operation failed");
let sparse1 = SparseArrayBase::new(data1);
let sparse2 = SparseArrayBase::new(data2);
let result = sparse1.add(&sparse2).expect("Operation failed");
let result_array = result.to_array();
assert_eq!(result_array[[0, 0]], 6.0);
assert_eq!(result_array[[0, 1]], 8.0);
assert_eq!(result_array[[1, 0]], 10.0);
assert_eq!(result_array[[1, 1]], 12.0);
let result = sparse1.dot(&sparse2).expect("Operation failed");
let result_array = result.to_array();
assert_eq!(result_array[[0, 0]], 19.0);
assert_eq!(result_array[[0, 1]], 22.0);
assert_eq!(result_array[[1, 0]], 43.0);
assert_eq!(result_array[[1, 1]], 50.0);
}
}