use crate::error::{SparseError, SparseResult};
use scirs2_core::numeric::{SparseElement, Zero};
use std::cmp::PartialEq;
pub struct CooMatrix<T> {
rows: usize,
cols: usize,
row_indices: Vec<usize>,
col_indices: Vec<usize>,
data: Vec<T>,
}
impl<T> CooMatrix<T>
where
T: Clone + Copy + Zero + PartialEq + SparseElement,
{
pub fn get_triplets(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
(
self.row_indices.clone(),
self.col_indices.clone(),
self.data.clone(),
)
}
pub fn new(
data: Vec<T>,
row_indices: Vec<usize>,
col_indices: Vec<usize>,
shape: (usize, usize),
) -> SparseResult<Self> {
if data.len() != row_indices.len() || data.len() != col_indices.len() {
return Err(SparseError::DimensionMismatch {
expected: data.len(),
found: std::cmp::min(row_indices.len(), col_indices.len()),
});
}
let (rows, cols) = shape;
if row_indices.iter().any(|&i| i >= rows) {
return Err(SparseError::ValueError(
"Row index out of bounds".to_string(),
));
}
if col_indices.iter().any(|&i| i >= cols) {
return Err(SparseError::ValueError(
"Column index out of bounds".to_string(),
));
}
Ok(CooMatrix {
rows,
cols,
row_indices,
col_indices,
data,
})
}
pub fn empty(shape: (usize, usize)) -> Self {
let (rows, cols) = shape;
CooMatrix {
rows,
cols,
row_indices: Vec::new(),
col_indices: Vec::new(),
data: Vec::new(),
}
}
pub fn add_element(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
if row >= self.rows || col >= self.cols {
return Err(SparseError::ValueError(
"Row or column index out of bounds".to_string(),
));
}
self.row_indices.push(row);
self.col_indices.push(col);
self.data.push(value);
Ok(())
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
pub fn nnz(&self) -> usize {
self.data.len()
}
pub fn row_indices(&self) -> &[usize] {
&self.row_indices
}
pub fn col_indices(&self) -> &[usize] {
&self.col_indices
}
pub fn data(&self) -> &[T] {
&self.data
}
pub fn to_dense(&self) -> Vec<Vec<T>>
where
T: Zero + Copy + SparseElement,
{
let mut result = vec![vec![T::sparse_zero(); self.cols]; self.rows];
for i in 0..self.data.len() {
let row = self.row_indices[i];
let col = self.col_indices[i];
result[row][col] = self.data[i];
}
result
}
pub fn to_csr(&self) -> crate::csr::CsrMatrix<T> {
crate::csr::CsrMatrix::new(
self.data.clone(),
self.row_indices.clone(),
self.col_indices.clone(),
(self.rows, self.cols),
)
.expect("Operation failed")
}
pub fn to_csc(&self) -> crate::csc::CscMatrix<T> {
crate::csc::CscMatrix::new(
self.data.clone(),
self.row_indices.clone(),
self.col_indices.clone(),
(self.rows, self.cols),
)
.expect("Operation failed")
}
pub fn transpose(&self) -> Self {
let mut transposed_data = Vec::with_capacity(self.data.len());
let mut transposed_row_indices = Vec::with_capacity(self.row_indices.len());
let mut transposed_col_indices = Vec::with_capacity(self.col_indices.len());
for i in 0..self.data.len() {
transposed_data.push(self.data[i]);
transposed_row_indices.push(self.col_indices[i]);
transposed_col_indices.push(self.row_indices[i]);
}
CooMatrix {
rows: self.cols,
cols: self.rows,
row_indices: transposed_row_indices,
col_indices: transposed_col_indices,
data: transposed_data,
}
}
pub fn sort_by_row_col(&mut self) {
let mut indices: Vec<usize> = (0..self.data.len()).collect();
indices.sort_by_key(|&i| (self.row_indices[i], self.col_indices[i]));
let row_indices = self.row_indices.clone();
let col_indices = self.col_indices.clone();
let data = self.data.clone();
for (i, &idx) in indices.iter().enumerate() {
self.row_indices[i] = row_indices[idx];
self.col_indices[i] = col_indices[idx];
self.data[i] = data[idx];
}
}
pub fn sort_by_col_row(&mut self) {
let mut indices: Vec<usize> = (0..self.data.len()).collect();
indices.sort_by_key(|&i| (self.col_indices[i], self.row_indices[i]));
let row_indices = self.row_indices.clone();
let col_indices = self.col_indices.clone();
let data = self.data.clone();
for (i, &idx) in indices.iter().enumerate() {
self.row_indices[i] = row_indices[idx];
self.col_indices[i] = col_indices[idx];
self.data[i] = data[idx];
}
}
pub fn get(&self, row: usize, col: usize) -> T
where
T: Zero + SparseElement,
{
for i in 0..self.data.len() {
if self.row_indices[i] == row && self.col_indices[i] == col {
return self.data[i];
}
}
T::sparse_zero()
}
pub fn sum_duplicates(&mut self)
where
T: std::ops::Add<Output = T>,
{
if self.data.is_empty() {
return;
}
self.sort_by_row_col();
let mut unique_row_indices = Vec::new();
let mut unique_col_indices = Vec::new();
let mut unique_data = Vec::new();
let mut current_row = self.row_indices[0];
let mut current_col = self.col_indices[0];
let mut current_val = self.data[0];
for i in 1..self.data.len() {
if self.row_indices[i] == current_row && self.col_indices[i] == current_col {
current_val = current_val + self.data[i];
} else {
unique_row_indices.push(current_row);
unique_col_indices.push(current_col);
unique_data.push(current_val);
current_row = self.row_indices[i];
current_col = self.col_indices[i];
current_val = self.data[i];
}
}
unique_row_indices.push(current_row);
unique_col_indices.push(current_col);
unique_data.push(current_val);
self.row_indices = unique_row_indices;
self.col_indices = unique_col_indices;
self.data = unique_data;
}
}
impl CooMatrix<f64> {
pub fn dot(&self, vec: &[f64]) -> SparseResult<Vec<f64>> {
if vec.len() != self.cols {
return Err(SparseError::DimensionMismatch {
expected: self.cols,
found: vec.len(),
});
}
let mut result = vec![0.0; self.rows];
for i in 0..self.data.len() {
let row = self.row_indices[i];
let col = self.col_indices[i];
result[row] += self.data[i] * vec[col];
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_coo_create() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 2, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let matrix = CooMatrix::new(data, rows, cols, shape).expect("Operation failed");
assert_eq!(matrix.shape(), (3, 3));
assert_eq!(matrix.nnz(), 5);
}
#[test]
fn test_coo_add_element() {
let mut matrix = CooMatrix::<f64>::empty((3, 3));
matrix.add_element(0, 0, 1.0).expect("Operation failed");
matrix.add_element(0, 2, 2.0).expect("Operation failed");
matrix.add_element(1, 2, 3.0).expect("Operation failed");
matrix.add_element(2, 0, 4.0).expect("Operation failed");
matrix.add_element(2, 1, 5.0).expect("Operation failed");
assert_eq!(matrix.nnz(), 5);
assert!(matrix.add_element(3, 0, 6.0).is_err());
assert!(matrix.add_element(0, 3, 6.0).is_err());
}
#[test]
fn test_coo_to_dense() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 2, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let matrix = CooMatrix::new(data, rows, cols, shape).expect("Operation failed");
let dense = matrix.to_dense();
let expected = vec![
vec![1.0, 0.0, 2.0],
vec![0.0, 0.0, 3.0],
vec![4.0, 5.0, 0.0],
];
assert_eq!(dense, expected);
}
#[test]
fn test_coo_dot() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 2, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let matrix = CooMatrix::new(data, rows, cols, shape).expect("Operation failed");
let vec = vec![1.0, 2.0, 3.0];
let result = matrix.dot(&vec).expect("Operation failed");
let expected = [7.0, 9.0, 14.0];
assert_eq!(result.len(), expected.len());
for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_coo_transpose() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 2, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let matrix = CooMatrix::new(data, rows, cols, shape).expect("Operation failed");
let transposed = matrix.transpose();
assert_eq!(transposed.shape(), (3, 3));
assert_eq!(transposed.nnz(), 5);
let dense = transposed.to_dense();
let expected = vec![
vec![1.0, 0.0, 4.0],
vec![0.0, 0.0, 5.0],
vec![2.0, 3.0, 0.0],
];
assert_eq!(dense, expected);
}
#[test]
fn test_coo_sort_and_sum_duplicates() {
let rows = vec![0, 0, 0, 1, 1, 2];
let cols = vec![0, 0, 1, 0, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = (3, 2);
let mut matrix = CooMatrix::new(data, rows, cols, shape).expect("Operation failed");
matrix.sum_duplicates();
assert_eq!(matrix.nnz(), 4);
let dense = matrix.to_dense();
let expected = vec![vec![3.0, 3.0], vec![9.0, 0.0], vec![0.0, 6.0]];
assert_eq!(dense, expected);
}
#[test]
fn test_coo_to_csr_to_csc() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 2, 0, 1];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let shape = (3, 3);
let coo_matrix = CooMatrix::new(data, rows, cols, shape).expect("Operation failed");
let csr_matrix = coo_matrix.to_csr();
let csc_matrix = coo_matrix.to_csc();
let dense_from_coo = coo_matrix.to_dense();
let dense_from_csr = csr_matrix.to_dense();
let dense_from_csc = csc_matrix.to_dense();
assert_eq!(dense_from_coo, dense_from_csr);
assert_eq!(dense_from_coo, dense_from_csc);
}
}