use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{One, Zero};
use std::collections::HashMap;
use std::fmt::Debug;
use std::ops::{Add, Div, Mul, Sub};
#[derive(Clone, Debug)]
pub struct SparseArray<T> {
pub data: HashMap<Vec<usize>, T>,
pub shape: Vec<usize>,
}
impl<T> SparseArray<T>
where
T: Clone + PartialEq + Zero,
{
pub fn new(shape: &[usize]) -> Self {
SparseArray {
data: HashMap::new(),
shape: shape.to_vec(),
}
}
pub fn from_array(array: &Array<T>) -> Self {
let shape = array.shape();
let dense_data = array.to_vec();
let mut data = HashMap::new();
let mut idx = vec![0; shape.len()];
let mut size = 1;
for i in (0..shape.len()).rev() {
size *= shape[i];
}
for (i, value) in dense_data.iter().enumerate().take(size) {
let mut temp = i;
for j in (0..shape.len()).rev() {
idx[j] = temp % shape[j];
temp /= shape[j];
}
let value = value.clone();
if value != T::zero() {
data.insert(idx.clone(), value);
}
}
SparseArray { data, shape }
}
pub fn to_array(&self) -> Array<T> {
let size: usize = self.shape.iter().product();
let mut dense_data = vec![T::zero(); size];
for (indices, value) in &self.data {
let mut idx = 0;
let mut stride = 1;
for i in (0..indices.len()).rev() {
idx += indices[i] * stride;
if i > 0 {
stride *= self.shape[i];
}
}
dense_data[idx] = value.clone();
}
Array::from_vec(dense_data).reshape(&self.shape)
}
pub fn nnz(&self) -> usize {
self.data.len()
}
pub fn density(&self) -> f64 {
let total_size: usize = self.shape.iter().product();
if total_size == 0 {
return 0.0;
}
self.nnz() as f64 / total_size as f64
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
fn check_index(&self, indices: &[usize]) -> Result<()> {
if indices.len() != self.shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Index has wrong number of dimensions: expected {}, got {}",
self.shape.len(),
indices.len()
)));
}
for (i, &idx) in indices.iter().enumerate() {
if idx >= self.shape[i] {
return Err(NumRs2Error::InvalidOperation(format!(
"Index {} is out of bounds for dimension {} with size {}",
idx, i, self.shape[i]
)));
}
}
Ok(())
}
pub fn get(&self, indices: &[usize]) -> Result<T> {
self.check_index(indices)?;
Ok(self.data.get(indices).cloned().unwrap_or_else(T::zero))
}
pub fn set(&mut self, indices: &[usize], value: T) -> Result<()> {
self.check_index(indices)?;
if value == T::zero() {
self.data.remove(indices);
} else {
self.data.insert(indices.to_vec(), value);
}
Ok(())
}
}
impl<T> SparseArray<T>
where
T: Clone + PartialEq + Zero + Add<Output = T>,
{
pub fn add(&self, other: &SparseArray<T>) -> Result<SparseArray<T>> {
if self.shape != other.shape {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape.clone(),
actual: other.shape.clone(),
});
}
let mut result = SparseArray::new(&self.shape);
for (indices, value) in &self.data {
result.data.insert(indices.clone(), value.clone());
}
for (indices, value) in &other.data {
if let Some(existing) = result.data.get_mut(indices) {
*existing = existing.clone() + value.clone();
if *existing == T::zero() {
result.data.remove(indices);
}
} else {
result.data.insert(indices.clone(), value.clone());
}
}
Ok(result)
}
}
impl<T> SparseArray<T>
where
T: Clone + PartialEq + Zero + Sub<Output = T>,
{
pub fn subtract(&self, other: &SparseArray<T>) -> Result<SparseArray<T>> {
if self.shape != other.shape {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape.clone(),
actual: other.shape.clone(),
});
}
let mut result = SparseArray::new(&self.shape);
for (indices, value) in &self.data {
result.data.insert(indices.clone(), value.clone());
}
for (indices, value) in &other.data {
if let Some(existing) = result.data.get_mut(indices) {
*existing = existing.clone() - value.clone();
if *existing == T::zero() {
result.data.remove(indices);
}
} else {
let neg_value = T::zero() - value.clone();
if neg_value != T::zero() {
result.data.insert(indices.clone(), neg_value);
}
}
}
Ok(result)
}
}
impl<T> SparseArray<T>
where
T: Clone + PartialEq + Zero + Mul<Output = T>,
{
pub fn multiply(&self, other: &SparseArray<T>) -> Result<SparseArray<T>> {
if self.shape != other.shape {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape.clone(),
actual: other.shape.clone(),
});
}
let mut result = SparseArray::new(&self.shape);
for (indices, value) in &self.data {
if let Some(other_value) = other.data.get(indices) {
let product = value.clone() * other_value.clone();
if product != T::zero() {
result.data.insert(indices.clone(), product);
}
}
}
Ok(result)
}
pub fn multiply_scalar(&self, scalar: T) -> SparseArray<T> {
if scalar == T::zero() {
return SparseArray::new(&self.shape);
}
let mut result = SparseArray::new(&self.shape);
for (indices, value) in &self.data {
let product = value.clone() * scalar.clone();
if product != T::zero() {
result.data.insert(indices.clone(), product);
}
}
result
}
}
impl<T> SparseArray<T>
where
T: Clone + PartialEq + Zero + Div<Output = T>,
{
pub fn divide(&self, other: &SparseArray<T>) -> Result<SparseArray<T>> {
if self.shape != other.shape {
return Err(NumRs2Error::ShapeMismatch {
expected: self.shape.clone(),
actual: other.shape.clone(),
});
}
let mut result = SparseArray::new(&self.shape);
for (indices, value) in &self.data {
let other_value = other.data.get(indices).cloned().unwrap_or_else(T::zero);
if other_value == T::zero() {
return Err(NumRs2Error::InvalidOperation(
"Division by zero in sparse array".to_string(),
));
}
let quotient = value.clone() / other_value;
if quotient != T::zero() {
result.data.insert(indices.clone(), quotient);
}
}
Ok(result)
}
pub fn divide_scalar(&self, scalar: T) -> Result<SparseArray<T>> {
if scalar == T::zero() {
return Err(NumRs2Error::InvalidOperation(
"Division by zero scalar".to_string(),
));
}
let mut result = SparseArray::new(&self.shape);
for (indices, value) in &self.data {
let quotient = value.clone() / scalar.clone();
if quotient != T::zero() {
result.data.insert(indices.clone(), quotient);
}
}
Ok(result)
}
}
#[derive(Clone, Debug)]
pub enum SparseMatrixFormat {
COO,
CSR,
CSC,
DIA,
}
#[derive(Clone, Debug)]
pub struct SparseMatrix<T> {
pub array: SparseArray<T>,
pub format: SparseMatrixFormat,
pub indices: Option<Vec<usize>>,
pub indptr: Option<Vec<usize>>,
pub diag_offsets: Option<Vec<isize>>,
}
impl<T> SparseMatrix<T>
where
T: Clone + PartialEq + Zero + Debug,
{
pub fn new(shape: &[usize]) -> Result<Self> {
if shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"SparseMatrix requires a 2D shape".to_string(),
));
}
Ok(SparseMatrix {
array: SparseArray::new(shape),
format: SparseMatrixFormat::COO,
indices: None,
indptr: None,
diag_offsets: None,
})
}
pub fn from_array(array: &Array<T>) -> Result<Self> {
let shape = array.shape();
if shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"SparseMatrix requires a 2D array".to_string(),
));
}
Ok(SparseMatrix {
array: SparseArray::from_array(array),
format: SparseMatrixFormat::COO,
indices: None,
indptr: None,
diag_offsets: None,
})
}
pub fn eye(n: usize) -> Result<Self>
where
T: One,
{
let mut matrix = SparseMatrix::new(&[n, n])?;
for i in 0..n {
matrix.array.set(&[i, i], T::one())?;
}
Ok(matrix)
}
pub fn diag(diagonal: &[T]) -> Result<Self> {
let n = diagonal.len();
let mut matrix = SparseMatrix::new(&[n, n])?;
for (i, value) in diagonal.iter().enumerate().take(n) {
if *value != T::zero() {
matrix.array.set(&[i, i], value.clone())?;
}
}
Ok(matrix)
}
pub fn to_csr(&mut self) -> Result<()> {
if let SparseMatrixFormat::CSR = self.format {
return Ok(());
}
let n_rows = self.array.shape[0];
let n_cols = self.array.shape[1];
let nnz = self.array.nnz();
let mut data = Vec::with_capacity(nnz);
let mut indices = Vec::with_capacity(nnz);
let mut indptr = Vec::with_capacity(n_rows + 1);
indptr.push(0);
let mut entries: Vec<((usize, usize), T)> = self
.array
.data
.iter()
.map(|(idx, val)| ((idx[0], idx[1]), val.clone()))
.collect();
entries.sort_by_key(|((row, col), _)| (*row, *col));
let mut current_row = 0;
for ((row, col), val) in entries {
while current_row < row {
current_row += 1;
indptr.push(data.len());
}
data.push(val);
indices.push(col);
}
while indptr.len() <= n_rows {
indptr.push(data.len());
}
let mut new_array = SparseArray::new(&[n_rows, n_cols]);
for i in 0..n_rows {
let row_start = indptr[i];
let row_end = indptr[i + 1];
for j in row_start..row_end {
let col = indices[j];
let val = data[j].clone();
new_array.set(&[i, col], val)?;
}
}
self.array = new_array;
self.format = SparseMatrixFormat::CSR;
self.indices = Some(indices);
self.indptr = Some(indptr);
self.diag_offsets = None;
Ok(())
}
pub fn to_csc(&mut self) -> Result<()> {
if let SparseMatrixFormat::CSC = self.format {
return Ok(());
}
let n_rows = self.array.shape[0];
let n_cols = self.array.shape[1];
let nnz = self.array.nnz();
let mut data = Vec::with_capacity(nnz);
let mut indices = Vec::with_capacity(nnz);
let mut indptr = Vec::with_capacity(n_cols + 1);
indptr.push(0);
let mut entries: Vec<((usize, usize), T)> = self
.array
.data
.iter()
.map(|(idx, val)| ((idx[0], idx[1]), val.clone()))
.collect();
entries.sort_by_key(|((row, col), _)| (*col, *row));
let mut current_col = 0;
for ((row, col), val) in entries {
while current_col < col {
current_col += 1;
indptr.push(data.len());
}
data.push(val);
indices.push(row);
}
while indptr.len() <= n_cols {
indptr.push(data.len());
}
let mut new_array = SparseArray::new(&[n_rows, n_cols]);
for j in 0..n_cols {
let col_start = indptr[j];
let col_end = indptr[j + 1];
for i in col_start..col_end {
let row = indices[i];
let val = data[i].clone();
new_array.set(&[row, j], val)?;
}
}
self.array = new_array;
self.format = SparseMatrixFormat::CSC;
self.indices = Some(indices);
self.indptr = Some(indptr);
self.diag_offsets = None;
Ok(())
}
pub fn to_dia(&mut self) -> Result<()> {
if let SparseMatrixFormat::DIA = self.format {
return Ok(());
}
let n_rows = self.array.shape[0];
let n_cols = self.array.shape[1];
let mut diag_indices = std::collections::HashSet::new();
for indices in self.array.data.keys() {
let row = indices[0];
let col = indices[1];
let diag_idx = col as isize - row as isize;
diag_indices.insert(diag_idx);
}
let mut diag_offsets: Vec<isize> = diag_indices.into_iter().collect();
diag_offsets.sort();
let mut new_array = SparseArray::new(&[n_rows, n_cols]);
for indices in self.array.data.keys() {
let row = indices[0];
let col = indices[1];
let value = self
.array
.data
.get(indices)
.cloned()
.unwrap_or_else(T::zero);
new_array.set(&[row, col], value)?;
}
self.array = new_array;
self.format = SparseMatrixFormat::DIA;
self.indices = None;
self.indptr = None;
self.diag_offsets = Some(diag_offsets);
Ok(())
}
pub fn to_array(&self) -> Array<T> {
self.array.to_array()
}
pub fn nnz(&self) -> usize {
self.array.nnz()
}
pub fn density(&self) -> f64 {
self.array.density()
}
pub fn shape(&self) -> &[usize] {
&self.array.shape
}
pub fn get(&self, row: usize, col: usize) -> Result<T> {
self.array.get(&[row, col])
}
pub fn set(&mut self, row: usize, col: usize, value: T) -> Result<()> {
self.array.set(&[row, col], value)
}
}
impl<T> SparseMatrix<T>
where
T: Clone + PartialEq + Zero + Add<Output = T> + Mul<Output = T> + Debug,
{
pub fn matmul(&self, other: &SparseMatrix<T>) -> Result<SparseMatrix<T>> {
let self_shape = self.array.shape();
let other_shape = other.array.shape();
if self_shape[1] != other_shape[0] {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![self_shape[0], other_shape[1]],
actual: vec![self_shape[0], self_shape[1]],
});
}
let mut self_matrix = self.clone();
let mut other_matrix = other.clone();
self_matrix.to_csr()?;
other_matrix.to_csc()?;
let n_rows = self_shape[0];
let n_cols = other_shape[1];
let _k = self_shape[1];
let mut result = SparseMatrix::new(&[n_rows, n_cols])?;
let self_indptr = self_matrix.indptr.as_ref().ok_or_else(|| {
NumRs2Error::ComputationError("CSR format should have indptr".to_string())
})?;
let self_indices = self_matrix.indices.as_ref().ok_or_else(|| {
NumRs2Error::ComputationError("CSR format should have indices".to_string())
})?;
let other_indptr = other_matrix.indptr.as_ref().ok_or_else(|| {
NumRs2Error::ComputationError("CSC format should have indptr".to_string())
})?;
let other_indices = other_matrix.indices.as_ref().ok_or_else(|| {
NumRs2Error::ComputationError("CSC format should have indices".to_string())
})?;
for i in 0..n_rows {
let row_start = self_indptr[i];
let row_end = self_indptr[i + 1];
for j in 0..n_cols {
let col_start = other_indptr[j];
let col_end = other_indptr[j + 1];
let mut sum = T::zero();
let mut added = false;
let mut row_idx = row_start;
let mut col_idx = col_start;
while row_idx < row_end && col_idx < col_end {
let row_k = self_indices[row_idx];
let col_k = other_indices[col_idx];
match row_k.cmp(&col_k) {
std::cmp::Ordering::Equal => {
let a_val = self_matrix.array.get(&[i, row_k])?;
let b_val = other_matrix.array.get(&[col_k, j])?;
sum = sum + a_val * b_val;
added = true;
row_idx += 1;
col_idx += 1;
}
std::cmp::Ordering::Less => {
row_idx += 1;
}
std::cmp::Ordering::Greater => {
col_idx += 1;
}
}
}
if added && sum != T::zero() {
result.array.set(&[i, j], sum)?;
}
}
}
Ok(result)
}
pub fn transpose(&self) -> Result<SparseMatrix<T>> {
let n_rows = self.array.shape[0];
let n_cols = self.array.shape[1];
let mut result = SparseMatrix::new(&[n_cols, n_rows])?;
for (indices, value) in &self.array.data {
result.array.set(&[indices[1], indices[0]], value.clone())?;
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_sparse_array_creation() -> Result<()> {
let mut sparse = SparseArray::new(&[3, 3]);
sparse.set(&[0, 0], 1.0)?;
sparse.set(&[1, 1], 2.0)?;
sparse.set(&[2, 2], 3.0)?;
assert_eq!(sparse.nnz(), 3);
assert_relative_eq!(sparse.density(), 3.0 / 9.0);
assert_relative_eq!(sparse.get(&[0, 0])?, 1.0);
assert_relative_eq!(sparse.get(&[1, 1])?, 2.0);
assert_relative_eq!(sparse.get(&[2, 2])?, 3.0);
assert_relative_eq!(sparse.get(&[0, 1])?, 0.0);
Ok(())
}
#[test]
fn test_sparse_array_from_dense() -> Result<()> {
let dense =
Array::from_vec(vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]).reshape(&[3, 3]);
let sparse = SparseArray::from_array(&dense);
assert_eq!(sparse.nnz(), 3);
assert_relative_eq!(sparse.get(&[0, 0])?, 1.0);
assert_relative_eq!(sparse.get(&[1, 1])?, 2.0);
assert_relative_eq!(sparse.get(&[2, 2])?, 3.0);
Ok(())
}
#[test]
fn test_sparse_array_to_dense() -> Result<()> {
let mut sparse = SparseArray::new(&[3, 3]);
sparse.set(&[0, 0], 1.0)?;
sparse.set(&[1, 1], 2.0)?;
sparse.set(&[2, 2], 3.0)?;
let dense = sparse.to_array();
let dense_data = dense.to_vec();
assert_relative_eq!(dense_data[0], 1.0);
assert_relative_eq!(dense_data[4], 2.0);
assert_relative_eq!(dense_data[8], 3.0);
for i in [1, 2, 3, 5, 6, 7] {
assert_relative_eq!(dense_data[i], 0.0);
}
Ok(())
}
#[test]
fn test_sparse_array_arithmetic() -> Result<()> {
let mut a = SparseArray::new(&[3, 3]);
let mut b = SparseArray::new(&[3, 3]);
a.set(&[0, 0], 1.0)?;
a.set(&[1, 1], 2.0)?;
a.set(&[2, 2], 3.0)?;
b.set(&[0, 0], 2.0)?;
b.set(&[1, 1], 1.0)?;
b.set(&[0, 2], 4.0)?;
let sum = a.add(&b)?;
assert_relative_eq!(sum.get(&[0, 0])?, 3.0);
assert_relative_eq!(sum.get(&[1, 1])?, 3.0);
assert_relative_eq!(sum.get(&[2, 2])?, 3.0);
assert_relative_eq!(sum.get(&[0, 2])?, 4.0);
let diff = a.subtract(&b)?;
assert_relative_eq!(diff.get(&[0, 0])?, -1.0);
assert_relative_eq!(diff.get(&[1, 1])?, 1.0);
assert_relative_eq!(diff.get(&[2, 2])?, 3.0);
assert_relative_eq!(diff.get(&[0, 2])?, -4.0);
let prod = a.multiply(&b)?;
assert_relative_eq!(prod.get(&[0, 0])?, 2.0);
assert_relative_eq!(prod.get(&[1, 1])?, 2.0);
assert_relative_eq!(prod.get(&[2, 2])?, 0.0);
assert_relative_eq!(prod.get(&[0, 2])?, 0.0);
let scaled = a.multiply_scalar(2.0);
assert_relative_eq!(scaled.get(&[0, 0])?, 2.0);
assert_relative_eq!(scaled.get(&[1, 1])?, 4.0);
assert_relative_eq!(scaled.get(&[2, 2])?, 6.0);
Ok(())
}
#[test]
fn test_sparse_matrix_creation() -> Result<()> {
let mut sparse = SparseMatrix::new(&[3, 3])?;
sparse.set(0, 0, 1.0)?;
sparse.set(1, 1, 2.0)?;
sparse.set(2, 2, 3.0)?;
assert_eq!(sparse.nnz(), 3);
assert_relative_eq!(sparse.get(0, 0)?, 1.0);
assert_relative_eq!(sparse.get(1, 1)?, 2.0);
assert_relative_eq!(sparse.get(2, 2)?, 3.0);
assert_relative_eq!(sparse.get(0, 1)?, 0.0);
Ok(())
}
#[test]
fn test_sparse_matrix_special_constructors() -> Result<()> {
let eye: SparseMatrix<f64> = SparseMatrix::eye(3)?;
assert_relative_eq!(eye.get(0, 0)?, 1.0);
assert_relative_eq!(eye.get(1, 1)?, 1.0);
assert_relative_eq!(eye.get(2, 2)?, 1.0);
assert_relative_eq!(eye.get(0, 1)?, 0.0);
assert_relative_eq!(eye.get(1, 2)?, 0.0);
let diag = SparseMatrix::diag(&[1.0, 2.0, 3.0])?;
assert_relative_eq!(diag.get(0, 0)?, 1.0);
assert_relative_eq!(diag.get(1, 1)?, 2.0);
assert_relative_eq!(diag.get(2, 2)?, 3.0);
assert_relative_eq!(diag.get(0, 1)?, 0.0);
assert_relative_eq!(diag.get(1, 2)?, 0.0);
Ok(())
}
#[test]
fn test_sparse_matrix_format_conversion() -> Result<()> {
let mut sparse = SparseMatrix::new(&[3, 3])?;
sparse.set(0, 0, 1.0)?;
sparse.set(0, 2, 2.0)?;
sparse.set(1, 1, 3.0)?;
sparse.set(2, 0, 4.0)?;
sparse.to_csr()?;
if let SparseMatrixFormat::CSR = sparse.format {
} else {
panic!("Format should be CSR");
}
assert_relative_eq!(sparse.get(0, 0)?, 1.0);
assert_relative_eq!(sparse.get(0, 2)?, 2.0);
assert_relative_eq!(sparse.get(1, 1)?, 3.0);
assert_relative_eq!(sparse.get(2, 0)?, 4.0);
sparse.to_csc()?;
if let SparseMatrixFormat::CSC = sparse.format {
} else {
panic!("Format should be CSC");
}
assert_relative_eq!(sparse.get(0, 0)?, 1.0);
assert_relative_eq!(sparse.get(0, 2)?, 2.0);
assert_relative_eq!(sparse.get(1, 1)?, 3.0);
assert_relative_eq!(sparse.get(2, 0)?, 4.0);
Ok(())
}
#[test]
fn test_sparse_matrix_operations() -> Result<()> {
let mut a = SparseMatrix::new(&[3, 3])?;
let mut b = SparseMatrix::new(&[3, 2])?;
a.set(0, 0, 1.0)?;
a.set(0, 1, 2.0)?;
a.set(1, 0, 3.0)?;
a.set(1, 1, 4.0)?;
a.set(2, 0, 5.0)?;
a.set(2, 1, 6.0)?;
b.set(0, 0, 7.0)?;
b.set(0, 1, 8.0)?;
b.set(1, 0, 9.0)?;
b.set(1, 1, 10.0)?;
let c = a.matmul(&b)?;
assert_relative_eq!(c.get(0, 0)?, 1.0 * 7.0 + 2.0 * 9.0);
assert_relative_eq!(c.get(0, 1)?, 1.0 * 8.0 + 2.0 * 10.0);
assert_relative_eq!(c.get(1, 0)?, 3.0 * 7.0 + 4.0 * 9.0);
assert_relative_eq!(c.get(1, 1)?, 3.0 * 8.0 + 4.0 * 10.0);
assert_relative_eq!(c.get(2, 0)?, 5.0 * 7.0 + 6.0 * 9.0);
assert_relative_eq!(c.get(2, 1)?, 5.0 * 8.0 + 6.0 * 10.0);
let at = a.transpose()?;
assert_relative_eq!(at.get(0, 0)?, 1.0);
assert_relative_eq!(at.get(0, 1)?, 3.0);
assert_relative_eq!(at.get(0, 2)?, 5.0);
assert_relative_eq!(at.get(1, 0)?, 2.0);
assert_relative_eq!(at.get(1, 1)?, 4.0);
assert_relative_eq!(at.get(1, 2)?, 6.0);
Ok(())
}
}