use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::{Float, SparseElement, Zero};
use std::fmt::{self, Debug};
use std::ops::{Add, Div, Mul, Sub};
use crate::csr_array::CsrArray;
use crate::error::{SparseError, SparseResult};
use crate::sparray::{SparseArray, SparseSum};
#[derive(Clone)]
pub struct CooArray<T>
where
T: SparseElement + Div<Output = T> + PartialOrd + 'static,
{
row: Array1<usize>,
col: Array1<usize>,
data: Array1<T>,
shape: (usize, usize),
has_canonical_format: bool,
}
impl<T> CooArray<T>
where
T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
{
pub fn new(
data: Array1<T>,
row: Array1<usize>,
col: Array1<usize>,
shape: (usize, usize),
has_canonical_format: bool,
) -> SparseResult<Self> {
if data.len() != row.len() || data.len() != col.len() {
return Err(SparseError::InconsistentData {
reason: "data, row, and col must have the same length".to_string(),
});
}
if let Some(&max_row) = row.iter().max() {
if max_row >= shape.0 {
return Err(SparseError::IndexOutOfBounds {
index: (max_row, 0),
shape,
});
}
}
if let Some(&max_col) = col.iter().max() {
if max_col >= shape.1 {
return Err(SparseError::IndexOutOfBounds {
index: (0, max_col),
shape,
});
}
}
Ok(Self {
data,
row,
col,
shape,
has_canonical_format,
})
}
pub fn from_triplets(
row: &[usize],
col: &[usize],
data: &[T],
shape: (usize, usize),
sorted: bool,
) -> SparseResult<Self> {
let row_array = Array1::from_vec(row.to_vec());
let col_array = Array1::from_vec(col.to_vec());
let data_array = Array1::from_vec(data.to_vec());
Self::new(data_array, row_array, col_array, shape, sorted)
}
pub fn get_rows(&self) -> &Array1<usize> {
&self.row
}
pub fn get_cols(&self) -> &Array1<usize> {
&self.col
}
pub fn get_data(&self) -> &Array1<T> {
&self.data
}
pub fn canonical_format(&mut self) {
if self.has_canonical_format {
return;
}
let n = self.data.len();
let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(n);
for i in 0..n {
triplets.push((self.row[i], self.col[i], self.data[i]));
}
triplets.sort_by_key(|&(r1, c1_, _)| (r1, c1_));
for (i, &(r, c, v)) in triplets.iter().enumerate() {
self.row[i] = r;
self.col[i] = c;
self.data[i] = v;
}
self.has_canonical_format = true;
}
pub fn sum_duplicates(&mut self) {
self.canonical_format();
let n = self.data.len();
if n == 0 {
return;
}
let mut new_data = Vec::new();
let mut new_row = Vec::new();
let mut new_col = Vec::new();
let mut curr_row = self.row[0];
let mut curr_col = self.col[0];
let mut curr_sum = self.data[0];
for i in 1..n {
if self.row[i] == curr_row && self.col[i] == curr_col {
curr_sum = curr_sum + self.data[i];
} else {
if curr_sum != T::sparse_zero() {
new_data.push(curr_sum);
new_row.push(curr_row);
new_col.push(curr_col);
}
curr_row = self.row[i];
curr_col = self.col[i];
curr_sum = self.data[i];
}
}
if curr_sum != T::sparse_zero() {
new_data.push(curr_sum);
new_row.push(curr_row);
new_col.push(curr_col);
}
self.data = Array1::from_vec(new_data);
self.row = Array1::from_vec(new_row);
self.col = Array1::from_vec(new_col);
}
}
impl<T> SparseArray<T> for CooArray<T>
where
T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
{
fn shape(&self) -> (usize, usize) {
self.shape
}
fn nnz(&self) -> usize {
self.data.len()
}
fn dtype(&self) -> &str {
"float" }
fn to_array(&self) -> Array2<T> {
let (rows, cols) = self.shape;
let mut result = Array2::zeros((rows, cols));
for i in 0..self.data.len() {
let r = self.row[i];
let c = self.col[i];
result[[r, c]] = result[[r, c]] + self.data[i]; }
result
}
fn toarray(&self) -> Array2<T> {
self.to_array()
}
fn to_coo(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
let mut new_coo = self.clone();
new_coo.sum_duplicates();
Ok(Box::new(new_coo))
}
fn to_csr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
let mut data_vec = self.data.to_vec();
let mut row_vec = self.row.to_vec();
let mut col_vec = self.col.to_vec();
let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
for i in 0..data_vec.len() {
triplets.push((row_vec[i], col_vec[i], data_vec[i]));
}
triplets.sort_by_key(|&(r1, c1_, _)| (r1, c1_));
for (i, &(r, c, v)) in triplets.iter().enumerate() {
row_vec[i] = r;
col_vec[i] = c;
data_vec[i] = v;
}
CsrArray::from_triplets(&row_vec, &col_vec, &data_vec, self.shape, true)
.map(|csr| Box::new(csr) as Box<dyn SparseArray<T>>)
}
fn to_csc(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
let csr = self.to_csr()?;
csr.transpose()
}
fn to_dok(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_lil(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_dia(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn to_bsr(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
Ok(Box::new(self.clone()))
}
fn add(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.add(other)
}
fn sub(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.sub(other)
}
fn mul(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.mul(other)
}
fn div(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.div(other)
}
fn dot(&self, other: &dyn SparseArray<T>) -> SparseResult<Box<dyn SparseArray<T>>> {
let self_csr = self.to_csr()?;
self_csr.dot(other)
}
fn dot_vector(&self, other: &ArrayView1<T>) -> SparseResult<Array1<T>> {
let (m, n) = self.shape();
if n != other.len() {
return Err(SparseError::DimensionMismatch {
expected: n,
found: other.len(),
});
}
let mut result = Array1::zeros(m);
for i in 0..self.data.len() {
let row = self.row[i];
let col = self.col[i];
result[row] = result[row] + self.data[i] * other[col];
}
Ok(result)
}
fn transpose(&self) -> SparseResult<Box<dyn SparseArray<T>>> {
CooArray::new(
self.data.clone(),
self.col.clone(), self.row.clone(), (self.shape.1, self.shape.0), self.has_canonical_format,
)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
}
fn copy(&self) -> Box<dyn SparseArray<T>> {
Box::new(self.clone())
}
fn get(&self, i: usize, j: usize) -> T {
if i >= self.shape.0 || j >= self.shape.1 {
return T::sparse_zero();
}
let mut sum = T::sparse_zero();
for idx in 0..self.data.len() {
if self.row[idx] == i && self.col[idx] == j {
sum = sum + self.data[idx];
}
}
sum
}
fn set(&mut self, i: usize, j: usize, value: T) -> SparseResult<()> {
if i >= self.shape.0 || j >= self.shape.1 {
return Err(SparseError::IndexOutOfBounds {
index: (i, j),
shape: self.shape,
});
}
if value == T::sparse_zero() {
let mut new_data = Vec::new();
let mut new_row = Vec::new();
let mut new_col = Vec::new();
for idx in 0..self.data.len() {
if !(self.row[idx] == i && self.col[idx] == j) {
new_data.push(self.data[idx]);
new_row.push(self.row[idx]);
new_col.push(self.col[idx]);
}
}
self.data = Array1::from_vec(new_data);
self.row = Array1::from_vec(new_row);
self.col = Array1::from_vec(new_col);
} else {
self.set(i, j, T::sparse_zero())?;
let mut new_data = self.data.to_vec();
let mut new_row = self.row.to_vec();
let mut new_col = self.col.to_vec();
new_data.push(value);
new_row.push(i);
new_col.push(j);
self.data = Array1::from_vec(new_data);
self.row = Array1::from_vec(new_row);
self.col = Array1::from_vec(new_col);
self.has_canonical_format = false;
}
Ok(())
}
fn eliminate_zeros(&mut self) {
let mut new_data = Vec::new();
let mut new_row = Vec::new();
let mut new_col = Vec::new();
for i in 0..self.data.len() {
if !SparseElement::is_zero(&self.data[i]) {
new_data.push(self.data[i]);
new_row.push(self.row[i]);
new_col.push(self.col[i]);
}
}
self.data = Array1::from_vec(new_data);
self.row = Array1::from_vec(new_row);
self.col = Array1::from_vec(new_col);
}
fn sort_indices(&mut self) {
self.canonical_format();
}
fn sorted_indices(&self) -> Box<dyn SparseArray<T>> {
if self.has_canonical_format {
return Box::new(self.clone());
}
let mut sorted = self.clone();
sorted.canonical_format();
Box::new(sorted)
}
fn has_sorted_indices(&self) -> bool {
self.has_canonical_format
}
fn sum(&self, axis: Option<usize>) -> SparseResult<SparseSum<T>> {
let self_csr = self.to_csr()?;
self_csr.sum(axis)
}
fn max(&self) -> T {
if self.data.is_empty() {
return T::sparse_zero();
}
let mut max_val = self.data[0];
for &val in self.data.iter().skip(1) {
if val > max_val {
max_val = val;
}
}
let zero = T::sparse_zero();
if max_val < zero && self.nnz() < self.shape.0 * self.shape.1 {
max_val = zero;
}
max_val
}
fn min(&self) -> T {
if self.data.is_empty() {
return T::sparse_zero();
}
let mut min_val = self.data[0];
for &val in self.data.iter().skip(1) {
if val < min_val {
min_val = val;
}
}
if min_val > T::sparse_zero() && self.nnz() < self.shape.0 * self.shape.1 {
min_val = T::sparse_zero();
}
min_val
}
fn find(&self) -> (Array1<usize>, Array1<usize>, Array1<T>) {
let data_vec = self.data.to_vec();
let row_vec = self.row.to_vec();
let col_vec = self.col.to_vec();
if self.has_canonical_format {
(self.row.clone(), self.col.clone(), self.data.clone())
} else {
let mut triplets: Vec<(usize, usize, T)> = Vec::with_capacity(data_vec.len());
for i in 0..data_vec.len() {
triplets.push((row_vec[i], col_vec[i], data_vec[i]));
}
triplets.sort_by_key(|&(r1, c1_, _)| (r1, c1_));
let mut result_row = Vec::new();
let mut result_col = Vec::new();
let mut result_data = Vec::new();
if !triplets.is_empty() {
let mut curr_row = triplets[0].0;
let mut curr_col = triplets[0].1;
let mut curr_sum = triplets[0].2;
for &(r, c, v) in triplets.iter().skip(1) {
if r == curr_row && c == curr_col {
curr_sum = curr_sum + v;
} else {
if curr_sum != T::sparse_zero() {
result_row.push(curr_row);
result_col.push(curr_col);
result_data.push(curr_sum);
}
curr_row = r;
curr_col = c;
curr_sum = v;
}
}
if curr_sum != T::sparse_zero() {
result_row.push(curr_row);
result_col.push(curr_col);
result_data.push(curr_sum);
}
}
(
Array1::from_vec(result_row),
Array1::from_vec(result_col),
Array1::from_vec(result_data),
)
}
}
fn slice(
&self,
row_range: (usize, usize),
col_range: (usize, usize),
) -> SparseResult<Box<dyn SparseArray<T>>> {
let (start_row, end_row) = row_range;
let (start_col, end_col) = col_range;
if start_row >= self.shape.0
|| end_row > self.shape.0
|| start_col >= self.shape.1
|| end_col > self.shape.1
{
return Err(SparseError::InvalidSliceRange);
}
if start_row >= end_row || start_col >= end_col {
return Err(SparseError::InvalidSliceRange);
}
let mut new_data = Vec::new();
let mut new_row = Vec::new();
let mut new_col = Vec::new();
for i in 0..self.data.len() {
let r = self.row[i];
let c = self.col[i];
if r >= start_row && r < end_row && c >= start_col && c < end_col {
new_data.push(self.data[i]);
new_row.push(r - start_row); new_col.push(c - start_col); }
}
CooArray::new(
Array1::from_vec(new_data),
Array1::from_vec(new_row),
Array1::from_vec(new_col),
(end_row - start_row, end_col - start_col),
false,
)
.map(|array| Box::new(array) as Box<dyn SparseArray<T>>)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T> fmt::Debug for CooArray<T>
where
T: SparseElement + Div<Output = T> + PartialOrd + Zero + 'static,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CooArray<{}x{}, nnz={}>",
self.shape.0,
self.shape.1,
self.nnz()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coo_array_construction() {
let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let shape = (3, 3);
let coo = CooArray::new(data, row, col, shape, false).expect("Operation failed");
assert_eq!(coo.shape(), (3, 3));
assert_eq!(coo.nnz(), 5);
assert_eq!(coo.get(0, 0), 1.0);
assert_eq!(coo.get(0, 2), 2.0);
assert_eq!(coo.get(1, 1), 3.0);
assert_eq!(coo.get(2, 0), 4.0);
assert_eq!(coo.get(2, 2), 5.0);
assert_eq!(coo.get(0, 1), 0.0);
}
#[test]
fn test_coo_from_triplets() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 2];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let coo =
CooArray::from_triplets(&rows, &cols, &data, shape, false).expect("Operation failed");
assert_eq!(coo.shape(), (3, 3));
assert_eq!(coo.nnz(), 5);
assert_eq!(coo.get(0, 0), 1.0);
assert_eq!(coo.get(0, 2), 2.0);
assert_eq!(coo.get(1, 1), 3.0);
assert_eq!(coo.get(2, 0), 4.0);
assert_eq!(coo.get(2, 2), 5.0);
assert_eq!(coo.get(0, 1), 0.0);
}
#[test]
fn test_coo_array_to_array() {
let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let shape = (3, 3);
let coo = CooArray::new(data, row, col, shape, false).expect("Operation failed");
let dense = coo.to_array();
assert_eq!(dense.shape(), &[3, 3]);
assert_eq!(dense[[0, 0]], 1.0);
assert_eq!(dense[[0, 1]], 0.0);
assert_eq!(dense[[0, 2]], 2.0);
assert_eq!(dense[[1, 0]], 0.0);
assert_eq!(dense[[1, 1]], 3.0);
assert_eq!(dense[[1, 2]], 0.0);
assert_eq!(dense[[2, 0]], 4.0);
assert_eq!(dense[[2, 1]], 0.0);
assert_eq!(dense[[2, 2]], 5.0);
}
#[test]
fn test_coo_array_duplicate_entries() {
let row = Array1::from_vec(vec![0, 0, 0, 1, 1]);
let col = Array1::from_vec(vec![0, 0, 1, 0, 0]);
let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let shape = (2, 2);
let mut coo = CooArray::new(data, row, col, shape, false).expect("Operation failed");
coo.sum_duplicates();
assert_eq!(coo.nnz(), 3);
assert_eq!(coo.get(0, 0), 3.0); assert_eq!(coo.get(0, 1), 3.0);
assert_eq!(coo.get(1, 0), 9.0); }
#[test]
fn test_coo_set_get() {
let row = Array1::from_vec(vec![0, 1]);
let col = Array1::from_vec(vec![0, 1]);
let data = Array1::from_vec(vec![1.0, 2.0]);
let shape = (2, 2);
let mut coo = CooArray::new(data, row, col, shape, false).expect("Operation failed");
coo.set(0, 1, 3.0).expect("Operation failed");
assert_eq!(coo.get(0, 1), 3.0);
coo.set(0, 0, 4.0).expect("Operation failed");
assert_eq!(coo.get(0, 0), 4.0);
coo.set(0, 0, 0.0).expect("Operation failed");
assert_eq!(coo.get(0, 0), 0.0);
assert_eq!(coo.nnz(), 2);
}
#[test]
fn test_coo_canonical_format() {
let row = Array1::from_vec(vec![1, 0, 2, 0]);
let col = Array1::from_vec(vec![1, 0, 2, 2]);
let data = Array1::from_vec(vec![3.0, 1.0, 5.0, 2.0]);
let shape = (3, 3);
let mut coo = CooArray::new(data, row, col, shape, false).expect("Operation failed");
assert!(!coo.has_canonical_format);
coo.canonical_format();
assert!(coo.has_canonical_format);
assert_eq!(coo.row[0], 0);
assert_eq!(coo.col[0], 0);
assert_eq!(coo.data[0], 1.0);
assert_eq!(coo.row[1], 0);
assert_eq!(coo.col[1], 2);
assert_eq!(coo.data[1], 2.0);
assert_eq!(coo.row[2], 1);
assert_eq!(coo.col[2], 1);
assert_eq!(coo.data[2], 3.0);
assert_eq!(coo.row[3], 2);
assert_eq!(coo.col[3], 2);
assert_eq!(coo.data[3], 5.0);
}
#[test]
fn test_coo_to_csr() {
let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let shape = (3, 3);
let coo = CooArray::new(data, row, col, shape, false).expect("Operation failed");
let csr = coo.to_csr().expect("Operation failed");
let dense = csr.to_array();
assert_eq!(dense[[0, 0]], 1.0);
assert_eq!(dense[[0, 2]], 2.0);
assert_eq!(dense[[1, 1]], 3.0);
assert_eq!(dense[[2, 0]], 4.0);
assert_eq!(dense[[2, 2]], 5.0);
}
#[test]
fn test_coo_transpose() {
let row = Array1::from_vec(vec![0, 0, 1, 2, 2]);
let col = Array1::from_vec(vec![0, 2, 1, 0, 2]);
let data = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let shape = (3, 3);
let coo = CooArray::new(data, row, col, shape, false).expect("Operation failed");
let transposed = coo.transpose().expect("Operation failed");
assert_eq!(transposed.shape(), (3, 3));
let dense = transposed.to_array();
assert_eq!(dense[[0, 0]], 1.0);
assert_eq!(dense[[2, 0]], 2.0);
assert_eq!(dense[[1, 1]], 3.0);
assert_eq!(dense[[0, 2]], 4.0);
assert_eq!(dense[[2, 2]], 5.0);
}
}