use crate::Scalar;
use crate::error::{CoreError, Result};
use crate::tensor::Tensor;
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct CooMatrix<T: Scalar> {
rows: Vec<usize>,
cols: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
}
impl<T: Scalar> CooMatrix<T> {
pub fn new(nrows: usize, ncols: usize) -> Self {
Self {
rows: Vec::new(),
cols: Vec::new(),
values: Vec::new(),
nrows,
ncols,
}
}
pub fn from_triplets(
nrows: usize,
ncols: usize,
rows: Vec<usize>,
cols: Vec<usize>,
values: Vec<T>,
) -> Result<Self> {
if rows.len() != cols.len() || rows.len() != values.len() {
return Err(CoreError::InvalidArgument {
reason: "rows, cols, and values must have the same length",
});
}
for (&r, &c) in rows.iter().zip(cols.iter()) {
if r >= nrows || c >= ncols {
return Err(CoreError::InvalidArgument {
reason: "index out of bounds for matrix dimensions",
});
}
}
Ok(Self {
rows,
cols,
values,
nrows,
ncols,
})
}
pub fn push(&mut self, row: usize, col: usize, value: T) -> Result<()> {
if row >= self.nrows || col >= self.ncols {
return Err(CoreError::InvalidArgument {
reason: "index out of bounds for matrix dimensions",
});
}
self.rows.push(row);
self.cols.push(col);
self.values.push(value);
Ok(())
}
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
pub fn to_dense(&self) -> Tensor<T> {
let mut data = vec![T::zero(); self.nrows * self.ncols];
for ((&r, &c), &v) in self
.rows
.iter()
.zip(self.cols.iter())
.zip(self.values.iter())
{
data[r * self.ncols + c] += v;
}
Tensor::from_vec(data, vec![self.nrows, self.ncols])
.expect("dense data length equals nrows*ncols by construction")
}
pub fn to_csr(&self) -> CsrMatrix<T> {
let mut row_counts = vec![0usize; self.nrows + 1];
for &r in &self.rows {
row_counts[r + 1] += 1;
}
for i in 1..=self.nrows {
row_counts[i] += row_counts[i - 1];
}
let nnz = self.values.len();
let mut col_idx = vec![0usize; nnz];
let mut values = vec![T::zero(); nnz];
let mut offset = row_counts.clone();
for ((&r, &c), &v) in self
.rows
.iter()
.zip(self.cols.iter())
.zip(self.values.iter())
{
let pos = offset[r];
col_idx[pos] = c;
values[pos] = v;
offset[r] += 1;
}
let mut result = CsrMatrix {
row_ptr: row_counts,
col_idx,
values,
nrows: self.nrows,
ncols: self.ncols,
};
result.sort_and_sum_duplicates();
result
}
pub fn to_csc(&self) -> CscMatrix<T> {
let mut col_counts = vec![0usize; self.ncols + 1];
for &c in &self.cols {
col_counts[c + 1] += 1;
}
for i in 1..=self.ncols {
col_counts[i] += col_counts[i - 1];
}
let nnz = self.values.len();
let mut row_idx = vec![0usize; nnz];
let mut values = vec![T::zero(); nnz];
let mut offset = col_counts.clone();
for ((&r, &c), &v) in self
.rows
.iter()
.zip(self.cols.iter())
.zip(self.values.iter())
{
let pos = offset[c];
row_idx[pos] = r;
values[pos] = v;
offset[c] += 1;
}
let mut result = CscMatrix {
col_ptr: col_counts,
row_idx,
values,
nrows: self.nrows,
ncols: self.ncols,
};
result.sort_and_sum_duplicates();
result
}
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct CsrMatrix<T: Scalar> {
row_ptr: Vec<usize>,
col_idx: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
}
impl<T: Scalar> CsrMatrix<T> {
pub fn new(nrows: usize, ncols: usize) -> Self {
Self {
row_ptr: vec![0; nrows + 1],
col_idx: Vec::new(),
values: Vec::new(),
nrows,
ncols,
}
}
pub fn from_triplets(
nrows: usize,
ncols: usize,
rows: Vec<usize>,
cols: Vec<usize>,
values: Vec<T>,
) -> Result<Self> {
let coo = CooMatrix::from_triplets(nrows, ncols, rows, cols, values)?;
Ok(coo.to_csr())
}
pub fn from_dense(tensor: &Tensor<T>) -> Result<Self> {
if tensor.ndim() != 2 {
return Err(CoreError::InvalidArgument {
reason: "from_dense requires a 2-D tensor",
});
}
let nrows = tensor.shape()[0];
let ncols = tensor.shape()[1];
let data = tensor.as_slice();
let mut row_ptr = vec![0usize; nrows + 1];
let mut col_idx = Vec::new();
let mut values = Vec::new();
for r in 0..nrows {
for c in 0..ncols {
let v = data[r * ncols + c];
if v != T::zero() {
col_idx.push(c);
values.push(v);
}
}
row_ptr[r + 1] = values.len();
}
Ok(Self {
row_ptr,
col_idx,
values,
nrows,
ncols,
})
}
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
pub fn get(&self, row: usize, col: usize) -> Option<&T> {
if row >= self.nrows || col >= self.ncols {
return None;
}
let start = self.row_ptr[row];
let end = self.row_ptr[row + 1];
self.col_idx[start..end]
.binary_search(&col)
.ok()
.map(|pos| &self.values[start + pos])
}
pub fn to_dense(&self) -> Tensor<T> {
let mut data = vec![T::zero(); self.nrows * self.ncols];
for r in 0..self.nrows {
let start = self.row_ptr[r];
let end = self.row_ptr[r + 1];
for idx in start..end {
let c = self.col_idx[idx];
data[r * self.ncols + c] = self.values[idx];
}
}
Tensor::from_vec(data, vec![self.nrows, self.ncols])
.expect("dense data length equals nrows*ncols by construction")
}
pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
if x.ndim() != 1 || x.numel() != self.ncols {
return Err(CoreError::DimensionMismatch {
expected: vec![self.ncols],
got: x.shape().to_vec(),
});
}
let xdata = x.as_slice();
let mut result = vec![T::zero(); self.nrows];
for (r, dest) in result.iter_mut().enumerate() {
let start = self.row_ptr[r];
let end = self.row_ptr[r + 1];
let mut acc = T::zero();
for idx in start..end {
acc += self.values[idx] * xdata[self.col_idx[idx]];
}
*dest = acc;
}
Tensor::from_vec(result, vec![self.nrows])
}
pub fn transpose(&self) -> CscMatrix<T> {
CscMatrix {
col_ptr: self.row_ptr.clone(),
row_idx: self.col_idx.clone(),
values: self.values.clone(),
nrows: self.ncols,
ncols: self.nrows,
}
}
pub fn to_coo(&self) -> CooMatrix<T> {
let mut rows = Vec::with_capacity(self.nnz());
let mut cols = Vec::with_capacity(self.nnz());
let mut values = Vec::with_capacity(self.nnz());
for r in 0..self.nrows {
let start = self.row_ptr[r];
let end = self.row_ptr[r + 1];
for idx in start..end {
rows.push(r);
cols.push(self.col_idx[idx]);
values.push(self.values[idx]);
}
}
CooMatrix {
rows,
cols,
values,
nrows: self.nrows,
ncols: self.ncols,
}
}
pub fn to_csc(&self) -> CscMatrix<T> {
self.to_coo().to_csc()
}
fn sort_and_sum_duplicates(&mut self) {
for r in 0..self.nrows {
let start = self.row_ptr[r];
let end = self.row_ptr[r + 1];
if start == end {
continue;
}
let len = end - start;
let mut perm: Vec<usize> = (0..len).collect();
perm.sort_unstable_by_key(|&i| self.col_idx[start + i]);
let old_cols: Vec<usize> = self.col_idx[start..end].to_vec();
let old_vals: Vec<T> = self.values[start..end].to_vec();
for (j, &p) in perm.iter().enumerate() {
self.col_idx[start + j] = old_cols[p];
self.values[start + j] = old_vals[p];
}
let mut write = start;
for read in (start + 1)..end {
if self.col_idx[read] == self.col_idx[write] {
let v = self.values[read];
self.values[write] += v;
} else {
write += 1;
self.col_idx[write] = self.col_idx[read];
self.values[write] = self.values[read];
}
}
let new_end = write + 1;
if new_end < end {
let removed = end - new_end;
let total_nnz = self.col_idx.len();
self.col_idx.copy_within(end..total_nnz, new_end);
self.col_idx.truncate(total_nnz - removed);
let total_vals = self.values.len();
self.values.copy_within(end..total_vals, new_end);
self.values.truncate(total_vals - removed);
for i in (r + 1)..=self.nrows {
self.row_ptr[i] -= removed;
}
}
}
}
}
#[cfg_attr(
feature = "serde-support",
derive(serde::Serialize, serde::Deserialize)
)]
#[derive(Debug, Clone)]
pub struct CscMatrix<T: Scalar> {
col_ptr: Vec<usize>,
row_idx: Vec<usize>,
values: Vec<T>,
nrows: usize,
ncols: usize,
}
impl<T: Scalar> CscMatrix<T> {
pub fn new(nrows: usize, ncols: usize) -> Self {
Self {
col_ptr: vec![0; ncols + 1],
row_idx: Vec::new(),
values: Vec::new(),
nrows,
ncols,
}
}
pub fn from_triplets(
nrows: usize,
ncols: usize,
rows: Vec<usize>,
cols: Vec<usize>,
values: Vec<T>,
) -> Result<Self> {
let coo = CooMatrix::from_triplets(nrows, ncols, rows, cols, values)?;
Ok(coo.to_csc())
}
#[inline]
pub fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
#[inline]
pub fn nnz(&self) -> usize {
self.values.len()
}
#[inline]
pub fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
pub fn to_dense(&self) -> Tensor<T> {
let mut data = vec![T::zero(); self.nrows * self.ncols];
for c in 0..self.ncols {
let start = self.col_ptr[c];
let end = self.col_ptr[c + 1];
for idx in start..end {
let r = self.row_idx[idx];
data[r * self.ncols + c] = self.values[idx];
}
}
Tensor::from_vec(data, vec![self.nrows, self.ncols])
.expect("dense data length equals nrows*ncols by construction")
}
pub fn matvec(&self, x: &Tensor<T>) -> Result<Tensor<T>> {
if x.ndim() != 1 || x.numel() != self.ncols {
return Err(CoreError::DimensionMismatch {
expected: vec![self.ncols],
got: x.shape().to_vec(),
});
}
let xdata = x.as_slice();
let mut result = vec![T::zero(); self.nrows];
for (c, &xc) in xdata.iter().enumerate().take(self.ncols) {
let start = self.col_ptr[c];
let end = self.col_ptr[c + 1];
for idx in start..end {
result[self.row_idx[idx]] += self.values[idx] * xc;
}
}
Tensor::from_vec(result, vec![self.nrows])
}
pub fn transpose(&self) -> CsrMatrix<T> {
CsrMatrix {
row_ptr: self.col_ptr.clone(),
col_idx: self.row_idx.clone(),
values: self.values.clone(),
nrows: self.ncols,
ncols: self.nrows,
}
}
pub fn to_coo(&self) -> CooMatrix<T> {
let mut rows = Vec::with_capacity(self.nnz());
let mut cols = Vec::with_capacity(self.nnz());
let mut values = Vec::with_capacity(self.nnz());
for c in 0..self.ncols {
let start = self.col_ptr[c];
let end = self.col_ptr[c + 1];
for idx in start..end {
rows.push(self.row_idx[idx]);
cols.push(c);
values.push(self.values[idx]);
}
}
CooMatrix {
rows,
cols,
values,
nrows: self.nrows,
ncols: self.ncols,
}
}
pub fn to_csr(&self) -> CsrMatrix<T> {
self.to_coo().to_csr()
}
fn sort_and_sum_duplicates(&mut self) {
for c in 0..self.ncols {
let start = self.col_ptr[c];
let end = self.col_ptr[c + 1];
if start == end {
continue;
}
let len = end - start;
let mut perm: Vec<usize> = (0..len).collect();
perm.sort_unstable_by_key(|&i| self.row_idx[start + i]);
let old_rows: Vec<usize> = self.row_idx[start..end].to_vec();
let old_vals: Vec<T> = self.values[start..end].to_vec();
for (j, &p) in perm.iter().enumerate() {
self.row_idx[start + j] = old_rows[p];
self.values[start + j] = old_vals[p];
}
let mut write = start;
for read in (start + 1)..end {
if self.row_idx[read] == self.row_idx[write] {
let v = self.values[read];
self.values[write] += v;
} else {
write += 1;
self.row_idx[write] = self.row_idx[read];
self.values[write] = self.values[read];
}
}
let new_end = write + 1;
if new_end < end {
let removed = end - new_end;
let total_idx = self.row_idx.len();
self.row_idx.copy_within(end..total_idx, new_end);
self.row_idx.truncate(total_idx - removed);
let total_vals = self.values.len();
self.values.copy_within(end..total_vals, new_end);
self.values.truncate(total_vals - removed);
for i in (c + 1)..=self.ncols {
self.col_ptr[i] -= removed;
}
}
}
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
fn sample_coo() -> CooMatrix<f64> {
CooMatrix::from_triplets(
3,
3,
vec![0, 0, 1, 2, 2],
vec![0, 2, 1, 0, 2],
vec![1.0, 2.0, 3.0, 4.0, 5.0],
)
.unwrap()
}
#[test]
fn test_coo_to_dense() {
let coo = sample_coo();
let dense = coo.to_dense();
assert_eq!(dense.shape(), &[3, 3]);
assert_eq!(
dense.as_slice(),
&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0]
);
}
#[test]
fn test_csr_from_dense_roundtrip() {
let dense = Tensor::from_vec(
vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0],
vec![3, 3],
)
.unwrap();
let csr = CsrMatrix::from_dense(&dense).unwrap();
assert_eq!(csr.nnz(), 5);
let back = csr.to_dense();
assert_eq!(dense, back);
}
#[test]
fn test_csr_matvec() {
let csr = sample_coo().to_csr();
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let y = csr.matvec(&x).unwrap();
assert_eq!(y.as_slice(), &[7.0, 6.0, 19.0]);
}
#[test]
fn test_csc_matvec() {
let csc = sample_coo().to_csc();
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]).unwrap();
let y = csc.matvec(&x).unwrap();
assert_eq!(y.as_slice(), &[7.0, 6.0, 19.0]);
}
#[test]
fn test_coo_csr_csc_dense_roundtrip() {
let coo = sample_coo();
let expected = coo.to_dense();
let csr = coo.to_csr();
assert_eq!(csr.to_dense(), expected);
let csc = csr.to_csc();
assert_eq!(csc.to_dense(), expected);
let coo2 = csc.to_coo();
assert_eq!(coo2.to_dense(), expected);
}
#[test]
fn test_identity_matrix() {
let csr = CsrMatrix::from_dense(&Tensor::<f64>::eye(4)).unwrap();
assert_eq!(csr.nnz(), 4);
let x = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let y = csr.matvec(&x).unwrap();
assert_eq!(y, x);
}
#[test]
fn test_empty_matrix() {
let csr = CsrMatrix::<f64>::new(3, 3);
assert_eq!(csr.nnz(), 0);
let dense = csr.to_dense();
assert_eq!(dense, Tensor::<f64>::zeros(vec![3, 3]));
}
#[test]
fn test_dimension_mismatch() {
let csr = sample_coo().to_csr();
let x = Tensor::from_vec(vec![1.0, 2.0], vec![2]).unwrap();
assert!(csr.matvec(&x).is_err());
}
#[test]
fn test_duplicate_coo_entries_summed() {
let coo = CooMatrix::from_triplets(2, 2, vec![0, 0, 1], vec![0, 0, 1], vec![1.0, 2.0, 5.0])
.unwrap();
let csr = coo.to_csr();
assert_eq!(*csr.get(0, 0).unwrap(), 3.0);
assert_eq!(*csr.get(1, 1).unwrap(), 5.0);
assert_eq!(csr.nnz(), 2);
}
#[test]
fn test_csr_transpose() {
let csr = sample_coo().to_csr();
let csc = csr.transpose();
assert_eq!(csc.nrows(), 3);
assert_eq!(csc.ncols(), 3);
let orig = csr.to_dense();
let trans = csc.to_dense();
for i in 0..3 {
for j in 0..3 {
assert_eq!(*trans.get(&[i, j]).unwrap(), *orig.get(&[j, i]).unwrap());
}
}
}
#[test]
fn test_csr_get() {
let csr = sample_coo().to_csr();
assert_eq!(*csr.get(0, 0).unwrap(), 1.0);
assert_eq!(*csr.get(0, 2).unwrap(), 2.0);
assert!(csr.get(0, 1).is_none()); assert!(csr.get(5, 0).is_none()); }
#[test]
fn test_coo_push() {
let mut coo = CooMatrix::<f64>::new(2, 2);
coo.push(0, 0, 1.0).unwrap();
coo.push(1, 1, 2.0).unwrap();
assert_eq!(coo.nnz(), 2);
assert!(coo.push(2, 0, 1.0).is_err()); }
}