#![allow(non_snake_case)]
use crate::algebra::utils::sortperm_by;
use crate::algebra::{permute, MatrixTriangle, TriangularMatrixChecks};
use crate::algebra::{Adjoint, MatrixShape, ShapedMatrix, SparseFormatError, Symmetric};
use num_traits::Num;
use std::iter::{repeat, zip};
#[cfg(feature = "serde")]
use serde::{de::DeserializeOwned, Deserialize, Serialize};
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(bound = "T: Serialize + DeserializeOwned"))]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CscMatrix<T = f64> {
pub m: usize,
pub n: usize,
pub colptr: Vec<usize>,
pub rowval: Vec<usize>,
pub nzval: Vec<T>,
}
impl<'a, I, J, T> From<I> for CscMatrix<T>
where
I: IntoIterator<Item = J>,
J: IntoIterator<Item = &'a T>,
T: Num + Copy + 'a,
{
#[allow(clippy::needless_range_loop)]
fn from(rows: I) -> CscMatrix<T> {
let rows: Vec<Vec<T>> = rows
.into_iter()
.map(|r| r.into_iter().copied().collect())
.collect();
let m = rows.len();
let n = rows.iter().map(|r| r.len()).next().unwrap_or(0);
assert!(rows.iter().all(|r| r.len() == n));
let nnz = rows.iter().flatten().filter(|&v| *v != T::zero()).count();
let mut colptr = Vec::with_capacity(n + 1);
let mut rowval = Vec::with_capacity(nnz);
let mut nzval = Vec::<T>::with_capacity(nnz);
colptr.push(0);
for c in 0..n {
for r in 0..m {
let value = rows[r][c];
if value != T::zero() {
rowval.push(r);
nzval.push(value);
}
}
colptr.push(nzval.len());
}
CscMatrix::<T> {
m,
n,
colptr,
rowval,
nzval,
}
}
}
impl<T> CscMatrix<T>
where
T: Num + Copy,
{
pub fn new(m: usize, n: usize, colptr: Vec<usize>, rowval: Vec<usize>, nzval: Vec<T>) -> Self {
assert_eq!(rowval.len(), nzval.len());
assert_eq!(colptr.len(), n + 1);
assert_eq!(colptr[n], rowval.len());
CscMatrix {
m,
n,
colptr,
rowval,
nzval,
}
}
pub fn new_from_triplets(m: usize, n: usize, I: Vec<usize>, J: Vec<usize>, V: Vec<T>) -> Self {
assert_eq!(I.len(), J.len());
assert_eq!(I.len(), V.len());
let mut M = CscMatrix::spalloc((m, n), V.len());
let mut p = vec![0; V.len()];
M.rowval.iter_mut().enumerate().for_each(|(i, p)| *p = i);
sortperm_by(&mut p, &M.rowval, |&a, &b| {
J[a].cmp(&J[b]).then(I[a].cmp(&I[b]))
});
permute(&mut M.rowval, &I, &p);
permute(&mut M.nzval, &V, &p);
for &c in J.iter() {
M.colptr[c] += 1;
}
let mut readidx = 0;
let mut writeidx = 0;
for col in 0..n {
let nentries = M.colptr[col]; for j in 0..nentries {
if j == 0 || M.rowval[readidx] != M.rowval[readidx - 1] {
if writeidx != readidx {
M.rowval[writeidx] = M.rowval[readidx];
M.nzval[writeidx] = M.nzval[readidx];
}
writeidx += 1;
readidx += 1;
}
else {
M.nzval[writeidx - 1] = M.nzval[writeidx - 1] + M.nzval[readidx];
M.colptr[col] -= 1;
readidx += 1;
}
}
}
M.rowval.resize(writeidx, 0);
M.nzval.resize(writeidx, T::zero());
M.colcount_to_colptr();
M
}
pub fn spalloc(size: (usize, usize), nnz: usize) -> Self {
let (m, n) = size;
let mut colptr = vec![0; n + 1];
let rowval = vec![0; nnz];
let nzval = vec![T::zero(); nnz];
colptr[n] = nnz;
CscMatrix::new(m, n, colptr, rowval, nzval)
}
pub fn zeros(size: (usize, usize)) -> Self {
Self::spalloc(size, 0)
}
pub fn identity(n: usize) -> Self {
let colptr = (0usize..=n).collect();
let rowval = (0usize..n).collect();
let nzval = vec![T::one(); n];
CscMatrix::new(n, n, colptr, rowval, nzval)
}
pub fn dropzeros(&mut self) -> usize {
let mut writeidx: usize = 0;
let mut first: usize = 0;
for col in 0..self.ncols() {
let last = self.colptr[col + 1];
for readidx in first..last {
let val = self.nzval[readidx];
let row = self.rowval[readidx];
if val != T::zero() {
if writeidx != readidx {
self.nzval[writeidx] = val;
self.rowval[writeidx] = row;
}
writeidx += 1;
}
}
first = self.colptr[col + 1];
self.colptr[col + 1] = writeidx;
}
let dropcount = self.nzval.len() - writeidx;
self.rowval.resize(writeidx, 0);
self.nzval.resize(writeidx, T::zero());
dropcount
}
#[cfg_attr(not(feature = "sdp"), allow(dead_code))]
pub(crate) fn findnz(&self) -> (Vec<usize>, Vec<usize>, Vec<T>) {
let I = self.rowval.clone();
let mut J = Vec::with_capacity(self.nnz());
let V = self.nzval.clone();
for c in 0..self.ncols() {
let times = self.colptr[c + 1] - self.colptr[c];
J.extend(repeat(c).take(times));
}
(I, J, V)
}
pub fn nnz(&self) -> usize {
self.colptr[self.n]
}
pub fn t(&self) -> Adjoint<'_, Self> {
Adjoint { src: self }
}
pub fn sym(&self, uplo: MatrixTriangle) -> Symmetric<'_, Self> {
match uplo {
MatrixTriangle::Triu => {
debug_assert!(self.is_triu());
}
MatrixTriangle::Tril => {
debug_assert!(self.is_tril());
}
}
Symmetric { src: self, uplo }
}
pub fn sym_up(&self) -> Symmetric<'_, Self> {
self.sym(MatrixTriangle::Triu)
}
pub fn sym_lo(&self) -> Symmetric<'_, Self> {
self.sym(MatrixTriangle::Tril)
}
pub fn check_format(&self) -> Result<(), SparseFormatError> {
self.check_dimensions()?;
for col in 0..self.n {
let rng = self.colptr[col]..self.colptr[col + 1];
if self.rowval[rng].windows(2).any(|c| c[0] >= c[1]) {
return Err(SparseFormatError::BadRowval);
}
}
if !self.rowval.iter().all(|r| r < &self.m) {
return Err(SparseFormatError::BadRowval);
}
Ok(())
}
pub fn canonicalize(&mut self) -> Result<(), SparseFormatError> {
self.check_dimensions()?;
self.sort_indices()?;
self.deduplicate()
}
fn sort_indices(&mut self) -> Result<(), SparseFormatError> {
let mut tempdata: Vec<(usize, T)> = Vec::new();
for col in 0..self.n {
let start = self.colptr[col];
let stop = self.colptr[col + 1];
let nzval = &mut self.nzval[start..stop];
let rowval = &mut self.rowval[start..stop];
tempdata.resize(stop - start, (0, T::zero()));
for (i, (r, v)) in zip(rowval.iter(), nzval.iter()).enumerate() {
tempdata[i] = (*r, *v);
}
tempdata.sort_by_key(|&(r, _)| r);
for (i, (r, v)) in tempdata.iter().enumerate() {
rowval[i] = *r;
nzval[i] = *v;
}
}
Ok(())
}
fn deduplicate(&mut self) -> Result<(), SparseFormatError> {
let mut nnz = 0;
let mut stop = 0;
for col in 0..self.n {
let mut ptr = stop;
stop = self.colptr[col + 1];
while ptr < stop {
let thisrow = self.rowval[ptr];
let mut accum = self.nzval[ptr];
ptr += 1;
while (ptr < stop) && (self.rowval[ptr] == thisrow) {
accum = accum + self.nzval[ptr];
ptr += 1;
}
self.rowval[nnz] = thisrow;
self.nzval[nnz] = accum;
nnz += 1;
}
self.colptr[col + 1] = nnz;
}
self.rowval.truncate(nnz);
self.nzval.truncate(nnz);
Ok(())
}
fn check_dimensions(&self) -> Result<(), SparseFormatError> {
if self.rowval.len() != self.nzval.len() {
return Err(SparseFormatError::IncompatibleDimension);
}
if self.colptr.is_empty()
|| (self.colptr.len() - 1) != self.n
|| self.colptr[self.n] != self.rowval.len()
{
return Err(SparseFormatError::IncompatibleDimension);
}
if self.colptr.windows(2).any(|c| c[0] > c[1]) {
return Err(SparseFormatError::BadColptr);
}
Ok(())
}
pub fn is_equal_sparsity(&self, other: &Self) -> bool {
self.size() == other.size() && self.colptr == other.colptr && self.rowval == other.rowval
}
pub fn check_equal_sparsity(&self, other: &Self) -> Result<(), SparseFormatError> {
if self.size() != other.size() {
Err(SparseFormatError::IncompatibleDimension)
} else if self.colptr != other.colptr || self.rowval != other.rowval {
Err(SparseFormatError::SparsityMismatch)
} else {
Ok(())
}
}
pub fn select_rows(&self, rowidx: &Vec<bool>) -> Self {
assert_eq!(rowidx.len(), self.m);
let mut rridx = vec![0; self.m];
let mut mred = 0;
for (r, is_used) in zip(&mut rridx, rowidx) {
if *is_used {
*r = mred;
mred += 1;
}
}
let nzred = self.rowval.iter().filter(|&r| rowidx[*r]).count();
let mut Ared = CscMatrix::spalloc((mred, self.n), nzred);
let mut ptrred = 0;
for col in 0..self.n {
Ared.colptr[col] = ptrred;
for ptr in self.colptr[col]..self.colptr[col + 1] {
let thisrow = self.rowval[ptr];
if rowidx[thisrow] {
Ared.rowval[ptrred] = rridx[thisrow];
Ared.nzval[ptrred] = self.nzval[ptr];
ptrred += 1;
}
}
Ared.colptr[Ared.n] = ptrred;
}
Ared
}
pub fn to_triu(&self) -> Self {
assert_eq!(self.m, self.n);
let (m, n) = (self.m, self.n);
let mut colptr = vec![0; n + 1];
let mut nnz = 0;
for col in 0..n {
let first = self.colptr[col];
let last = self.colptr[col + 1];
let rows = &self.rowval[first..last];
colptr[col + 1] = rows.iter().filter(|&row| *row <= col).count();
nnz += colptr[col + 1];
}
let mut rowval = vec![0; nnz];
let mut nzval = vec![T::zero(); nnz];
for col in 0..n {
let ntriu = colptr[col + 1];
let fdest = colptr[col];
let ldest = fdest + ntriu;
let fsrc = self.colptr[col];
let lsrc = fsrc + ntriu;
rowval[fdest..ldest].copy_from_slice(&self.rowval[fsrc..lsrc]);
nzval[fdest..ldest].copy_from_slice(&self.nzval[fsrc..lsrc]);
colptr[col + 1] = ldest;
}
CscMatrix::new(m, n, colptr, rowval, nzval)
}
pub fn get_entry(&self, idx: (usize, usize)) -> Option<T> {
let (row, col) = idx;
assert!(row < self.nrows() && col < self.ncols());
let first = self.colptr[col];
let last = self.colptr[col + 1];
let rows_in_this_column = &self.rowval[first..last];
match rows_in_this_column.binary_search(&row) {
Ok(idx) => Some(self.nzval[first + idx]),
Err(_) => None,
}
}
pub fn set_entry(&mut self, idx: (usize, usize), value: T) {
let (row, col) = idx;
assert!(row < self.nrows() && col < self.ncols());
let first = self.colptr[col];
let last = self.colptr[col + 1];
let rows_in_this_column = &self.rowval[first..last];
let i = rows_in_this_column.partition_point(|&x| x < row);
if i == rows_in_this_column.len() || rows_in_this_column[i] != row {
if value.is_zero() {
return;
}
self.rowval.insert(first + i, row);
self.nzval.insert(first + i, value);
self.colptr_to_colcount();
self.colptr[col] += 1;
self.colcount_to_colptr();
} else {
self.nzval[first + i] = value;
}
}
pub fn index_to_coord(&self, idx: usize) -> (usize, usize) {
assert!(idx < self.nnz());
let row = self.rowval[idx];
let col = self.colptr.partition_point(|&c| idx + 1 > c) - 1;
(row, col)
}
}
impl<T> TriangularMatrixChecks for CscMatrix<T> {
fn is_triu(&self) -> bool {
for col in 0..self.ncols() {
let first = self.colptr[col];
let last = self.colptr[col + 1];
let rows = &self.rowval[first..last];
if rows.iter().any(|&row| row > col) {
return false;
}
}
true
}
fn is_tril(&self) -> bool {
for col in 0..self.ncols() {
let first = self.colptr[col];
let last = self.colptr[col + 1];
let rows = &self.rowval[first..last];
if rows.iter().any(|&row| row < col) {
return false;
}
}
true
}
}
impl<T> ShapedMatrix for CscMatrix<T> {
fn nrows(&self) -> usize {
self.m
}
fn ncols(&self) -> usize {
self.n
}
fn size(&self) -> (usize, usize) {
(self.m, self.n)
}
fn shape(&self) -> MatrixShape {
MatrixShape::N
}
fn is_square(&self) -> bool {
self.m == self.n
}
}
impl<'a, T> From<Adjoint<'a, CscMatrix<T>>> for CscMatrix<T>
where
T: Num + Copy,
{
fn from(M: Adjoint<'a, CscMatrix<T>>) -> CscMatrix<T> {
let src = M.src;
let (m, n) = (src.n, src.m);
let mut A = CscMatrix::spalloc((m, n), src.nnz());
let mut amap = vec![0usize; src.nnz()];
A.colcount_block(src, 0, MatrixShape::T);
A.colcount_to_colptr();
A.fill_block(src, &mut amap, 0, 0, MatrixShape::T);
A.backshift_colptrs();
A
}
}
#[test]
#[rustfmt::skip]
fn test_matrix_istriu_istril() {
let A = CscMatrix::from(&[
[1., 2., 3.],
[0., 2., 0.],
[0., 0., 1.]]);
assert!(A.is_triu());
assert!(!A.is_tril());
assert!(A.sym_up().is_triu_src());
assert!(!A.sym_up().is_tril_src());
let A = CscMatrix::from(&[
[1., 2., 3.],
[0., 2., 0.],
[1., 0., 1.]]);
assert!(!A.is_triu());
assert!(!A.is_tril());
let A = CscMatrix::from(&[
[1., 0., 0.],
[0., 2., 0.],
[1., 1., 1.]]);
assert!(!A.is_triu());
assert!(A.is_tril());
assert!(!A.sym_lo().is_triu_src());
assert!(A.sym_lo().is_tril_src());
}
#[test]
fn test_csc_from_slice_of_arrays() {
let A = CscMatrix::new(
3, 2, vec![0, 2, 4], vec![0, 1, 0, 2], vec![1., 3., 2., 4.], );
let B = CscMatrix::from(&[
[1., 2.], [3., 0.], [0., 4.],
]);
let C: CscMatrix = (&[
[1., 2.], [3., 0.], [0., 4.],
])
.into();
assert_eq!(A, B);
assert_eq!(A, C);
}
#[test]
fn test_csc_get_entry() {
let A = CscMatrix::from(&[
[0.0, 4.0, 0.0, 0.0, 12.0],
[1.0, 5.0, 0.0, 0.0, 0.0],
[0.0, 6.0, 0.0, 0.0, 13.0],
[2.0, 7.0, 10.0, 0.0, 0.0],
[0.0, 8.0, 11.0, 0.0, 14.0],
[3.0, 9.0, 0.0, 0.0, 0.0],
]);
assert_eq!(A.get_entry((1, 0)), Some(1.));
assert_eq!(A.get_entry((5, 0)), Some(3.));
assert_eq!(A.get_entry((0, 1)), Some(4.));
assert_eq!(A.get_entry((3, 1)), Some(7.));
assert_eq!(A.get_entry((5, 1)), Some(9.));
assert_eq!(A.get_entry((3, 2)), Some(10.));
assert_eq!(A.get_entry((4, 2)), Some(11.));
assert_eq!(A.get_entry((4, 4)), Some(14.));
assert_eq!(A.get_entry((0, 0)), None);
assert_eq!(A.get_entry((4, 0)), None);
assert_eq!(A.get_entry((2, 2)), None);
assert_eq!(A.get_entry((1, 3)), None);
assert_eq!(A.get_entry((2, 3)), None);
assert_eq!(A.get_entry((4, 3)), None);
assert_eq!(A.get_entry((3, 4)), None);
}
#[test]
fn test_csc_set_entry() {
let mut A = CscMatrix::from(&[
[0.0, 3.0, 6.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 4.0, 7.0, 8.0],
[2.0, 5.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]);
let B = CscMatrix::from(&[
[0.0, 3.0, -6.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 4.0, 7.0, -8.0],
[2.0, 5.0, 10.0, 0.0],
[0.0, 0.0, 0.0, 11.0],
]);
A.set_entry((0, 2), -6.0);
A.set_entry((2, 3), -8.0);
A.set_entry((3, 2), 10.0);
A.set_entry((4, 3), 11.0);
assert_eq!(A, B);
}
#[test]
fn test_csc_index_to_coord() {
let A = CscMatrix::from(&[
[0.0, 4.0, 0.0, 0.0, 12.0],
[1.0, 5.0, 0.0, 0.0, 0.0],
[0.0, 6.0, 0.0, 0.0, 13.0],
[2.0, 7.0, 10.0, 0.0, 0.0],
[0.0, 8.0, 11.0, 0.0, 14.0],
[3.0, 9.0, 0.0, 0.0, 0.0],
]);
assert_eq!(A.index_to_coord(0), (1, 0));
assert_eq!(A.index_to_coord(1), (3, 0));
assert_eq!(A.index_to_coord(2), (5, 0));
assert_eq!(A.index_to_coord(3), (0, 1));
assert_eq!(A.index_to_coord(4), (1, 1));
assert_eq!(A.index_to_coord(5), (2, 1));
assert_eq!(A.index_to_coord(6), (3, 1));
assert_eq!(A.index_to_coord(7), (4, 1));
assert_eq!(A.index_to_coord(8), (5, 1));
assert_eq!(A.index_to_coord(9), (3, 2));
assert_eq!(A.index_to_coord(10), (4, 2));
assert_eq!(A.index_to_coord(11), (0, 4));
assert_eq!(A.index_to_coord(12), (2, 4));
assert_eq!(A.index_to_coord(13), (4, 4));
}
#[test]
fn test_adjoint_into() {
let A: CscMatrix = (&[
[1., 0., 0.], [2., 4., 0.], [3., 5., 6.],
])
.into();
let T: CscMatrix = (&[
[1., 2., 3.], [0., 4., 5.], [0., 0., 6.],
])
.into();
let B: CscMatrix = A.t().into();
assert_eq!(B, T);
}
#[test]
fn test_triplets() {
let A: CscMatrix = (&[
[1., 0., 0., 5.], [0., 0., 3., 0.], [2., 0., 4., 0.],
])
.into();
let cols = vec![0, 0, 2, 2, 3];
let rows = vec![0, 2, 1, 2, 0];
let vals = vec![1., 2., 3., 4., 5.];
let (I, J, V) = A.findnz();
assert_eq!(I, rows);
assert_eq!(J, cols);
assert_eq!(V, vals);
let B: CscMatrix = CscMatrix::new_from_triplets(3, 4, rows, cols, vals);
assert_eq!(A, B);
let cols = vec![2, 0, 2, 0, 3];
let rows = vec![2, 2, 1, 0, 0];
let vals = vec![4., 2., 3., 1., 5.];
let B: CscMatrix = CscMatrix::new_from_triplets(3, 4, rows, cols, vals);
assert_eq!(A, B);
let A: CscMatrix<isize> = (&[
[0, 0, 0], [-20, 0, 0], [-20, -20, 0],
])
.into();
let rows = vec![1, 2, 2, 1, 2, 2];
let cols = vec![0, 0, 1, 0, 0, 1];
let vals = vec![-10, -10, -10, -10, -10, -10];
let B = CscMatrix::new_from_triplets(3, 3, rows, cols, vals);
assert_eq!(A, B);
}
#[test]
fn test_drop_zeros() {
let mut A = CscMatrix::from(&[
[0.0, 3.0, 6.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 4.0, 7.0, 8.0],
[2.0, 5.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]);
let mut B = CscMatrix::from(&[
[0.0, 3.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[0.0, 4.0, 0.0, 0.0],
[0.0, 5.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
]);
let dropped = [2, 6, 7, 8];
for idx in dropped {
A.nzval[idx - 1] = 0.0;
}
let count = A.dropzeros();
assert_eq!(count, 4);
assert_eq!(A, B);
let count = B.dropzeros();
assert_eq!(count, 0);
}
#[test]
fn test_sort_indices() {
let mut A = CscMatrix {
m: 4,
n: 3,
colptr: vec![0, 2, 4, 5],
rowval: vec![3, 1, 4, 2, 2],
nzval: vec![2.0, 3.0, 1.0, 4.0, 5.0],
};
A.sort_indices().unwrap();
assert_eq!(A.rowval, vec![1, 3, 2, 4, 2]);
assert_eq!(A.nzval, vec![3.0, 2.0, 4.0, 1.0, 5.0]);
A.sort_indices().unwrap();
assert_eq!(A.rowval, vec![1, 3, 2, 4, 2]);
assert_eq!(A.nzval, vec![3.0, 2.0, 4.0, 1.0, 5.0]);
}
#[test]
fn test_sort_indices_with_duplicates() {
let mut A = CscMatrix {
m: 4,
n: 2,
colptr: vec![0, 3, 5],
rowval: vec![3, 3, 1, 2, 4],
nzval: vec![2.0, 3.0, 1.0, 1.0, 4.0],
};
A.sort_indices().unwrap();
assert_eq!(A.rowval, vec![1, 3, 3, 2, 4]);
assert_eq!(A.nzval, vec![1.0, 2.0, 3.0, 1.0, 4.0]);
}
#[test]
fn test_deduplicate() {
let mut A = CscMatrix {
m: 4,
n: 2,
colptr: vec![0, 2, 4],
rowval: vec![1, 1, 2, 4],
nzval: vec![3.0, 2.0, 1.0, 4.0],
};
A.deduplicate().unwrap();
assert_eq!(A.colptr, vec![0, 1, 3]);
assert_eq!(A.rowval, vec![1, 2, 4]);
assert_eq!(A.nzval, vec![5.0, 1.0, 4.0]);
A.deduplicate().unwrap();
assert_eq!(A.colptr, vec![0, 1, 3]);
assert_eq!(A.rowval, vec![1, 2, 4]);
assert_eq!(A.nzval, vec![5.0, 1.0, 4.0]);
}
#[test]
fn test_deduplicate_multiple_columns() {
let mut A = CscMatrix {
m: 4,
n: 3,
colptr: vec![0, 2, 4, 6],
rowval: vec![1, 1, 2, 4, 3, 3],
nzval: vec![3.0, 2.0, 1.0, 4.0, 5.0, 6.0],
};
A.deduplicate().unwrap();
assert_eq!(A.colptr, vec![0, 1, 3, 4]);
assert_eq!(A.rowval, vec![1, 2, 4, 3]);
assert_eq!(A.nzval, vec![5.0, 1.0, 4.0, 11.0]);
}
#[test]
fn test_deduplicate_1col() {
let mut A = CscMatrix {
m: 4,
n: 1,
colptr: vec![0, 3],
rowval: vec![1, 1, 4],
nzval: vec![2.0, 3.0, 4.0],
};
A.deduplicate().unwrap();
assert_eq!(A.colptr, vec![0, 2]);
assert_eq!(A.rowval, vec![1, 4]);
assert_eq!(A.nzval, vec![5.0, 4.0]);
}
#[test]
fn test_canonicalize() {
let mut A = CscMatrix {
m: 4,
n: 3,
colptr: vec![0, 3, 4, 7],
rowval: vec![2, 1, 1, 4, 3, 4, 3],
nzval: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
};
A.canonicalize().unwrap();
assert_eq!(A.colptr, vec![0, 2, 3, 5]);
assert_eq!(A.rowval, vec![1, 2, 4, 3, 4]);
assert_eq!(A.nzval, vec![5.0, 1.0, 4.0, 12.0, 6.0]);
}
#[test]
fn test_canonicalize_structural_zeros() {
let mut A = CscMatrix {
m: 4,
n: 3,
colptr: vec![0, 3, 4, 7],
rowval: vec![2, 1, 1, 4, 3, 4, 3],
nzval: vec![1.0, 2.0, 3.0, 0.0, 5.0, 6.0, -5.0],
};
A.canonicalize().unwrap();
assert_eq!(A.colptr, vec![0, 2, 3, 5]);
assert_eq!(A.rowval, vec![1, 2, 4, 3, 4]);
assert_eq!(A.nzval, vec![5.0, 1.0, 0.0, 0.0, 6.0]);
}
#[test]
fn test_canonicalize_empty() {
let mut A: CscMatrix<f64> = CscMatrix {
m: 0,
n: 0,
colptr: vec![0],
rowval: vec![],
nzval: vec![],
};
A.canonicalize().unwrap();
assert!(A.rowval.is_empty());
assert!(A.nzval.is_empty());
}
#[test]
fn test_canonicalize_singleton() {
let mut A = CscMatrix {
m: 4,
n: 1,
colptr: vec![0, 1],
rowval: vec![2],
nzval: vec![5.0],
};
A.sort_indices().unwrap();
assert_eq!(A.rowval, vec![2]);
assert_eq!(A.nzval, vec![5.0]);
}