use ndarray::ArrayView;
use num_traits::{Float, Num, Signed, Zero};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::cmp;
use std::default::Default;
use std::iter::{Enumerate, Zip};
use std::mem;
use std::ops::{Add, Deref, DerefMut, Index, IndexMut, Mul, MulAssign};
use std::slice::Iter;
use crate::{Ix1, Ix2, Shape};
use ndarray::linalg::Dot;
use ndarray::{self, Array, ArrayBase, ShapeBuilder};
use crate::indexing::SpIndex;
use crate::errors::StructureError;
use crate::sparse::binop;
use crate::sparse::permutation::PermViewI;
use crate::sparse::prelude::*;
use crate::sparse::prod;
use crate::sparse::smmp;
use crate::sparse::to_dense::assign_to_dense;
use crate::sparse::utils;
use crate::sparse::vec;
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[allow(clippy::upper_case_acronyms)]
pub enum CompressedStorage {
CSR,
CSC,
}
impl CompressedStorage {
pub fn other_storage(self) -> Self {
match self {
CSR => CSC,
CSC => CSR,
}
}
}
pub fn outer_dimension(
storage: CompressedStorage,
rows: usize,
cols: usize,
) -> usize {
match storage {
CSR => rows,
CSC => cols,
}
}
pub fn inner_dimension(
storage: CompressedStorage,
rows: usize,
cols: usize,
) -> usize {
match storage {
CSR => cols,
CSC => rows,
}
}
pub use self::CompressedStorage::{CSC, CSR};
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub struct NnzIndex(pub usize);
pub struct CsIter<'a, N: 'a, I: 'a, Iptr: 'a = I>
where
I: SpIndex,
Iptr: SpIndex,
{
storage: CompressedStorage,
cur_outer: I,
indptr: crate::IndPtrView<'a, Iptr>,
inner_iter: Enumerate<Zip<Iter<'a, I>, Iter<'a, N>>>,
}
impl<'a, N, I, Iptr> Iterator for CsIter<'a, N, I, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
N: 'a,
{
type Item = (&'a N, (I, I));
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
match self.inner_iter.next() {
None => None,
Some((nnz_index, (&inner_ind, val))) => {
loop {
let nnz_end = self
.indptr
.outer_inds_sz(self.cur_outer.index_unchecked())
.end;
if nnz_index == nnz_end.index_unchecked() {
self.cur_outer += I::one();
} else {
break;
}
}
let (row, col) = match self.storage {
CSR => (self.cur_outer, inner_ind),
CSC => (inner_ind, self.cur_outer),
};
Some((val, (row, col)))
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner_iter.size_hint()
}
}
impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
where
IptrStorage: Deref<Target = [Iptr]>,
IStorage: Deref<Target = [I]>,
DStorage: Deref<Target = [N]>,
{
pub(crate) fn new_checked(
storage: CompressedStorage,
shape: (usize, usize),
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
let (nrows, ncols) = shape;
let (inner, outer) = match storage {
CSR => (ncols, nrows),
CSC => (nrows, ncols),
};
if data.len() != indices.len() {
return Err((
indptr,
indices,
data,
StructureError::SizeMismatch(
"data and indices have different sizes",
),
));
}
match crate::sparse::utils::check_compressed_structure(
inner,
outer,
indptr.as_ref(),
indices.as_ref(),
) {
Err(e) => Err((indptr, indices, data, e)),
Ok(_) => Ok(Self {
storage,
nrows,
ncols,
indptr: crate::IndPtrBase::new_trusted(indptr),
indices,
data,
}),
}
}
pub fn new(
shape: (usize, usize),
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Self {
Self::new_checked(CompressedStorage::CSR, shape, indptr, indices, data)
.map_err(|(_, _, _, e)| e)
.unwrap()
}
pub fn new_csc(
shape: (usize, usize),
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Self {
Self::new_checked(CompressedStorage::CSC, shape, indptr, indices, data)
.map_err(|(_, _, _, e)| e)
.unwrap()
}
pub fn try_new(
shape: (usize, usize),
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
Self::new_checked(CompressedStorage::CSR, shape, indptr, indices, data)
}
pub fn try_new_csc(
shape: (usize, usize),
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)> {
Self::new_checked(CompressedStorage::CSC, shape, indptr, indices, data)
}
pub unsafe fn new_unchecked(
storage: CompressedStorage,
shape: Shape,
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Self {
let (nrows, ncols) = shape;
Self {
storage,
nrows,
ncols,
indptr: crate::IndPtrBase::new_trusted(indptr),
indices,
data,
}
}
pub(crate) fn new_trusted(
storage: CompressedStorage,
shape: Shape,
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Self {
let (nrows, ncols) = shape;
Self {
storage,
nrows,
ncols,
indptr: crate::IndPtrBase::new_trusted(indptr),
indices,
data,
}
}
}
impl<N, I: SpIndex, Iptr: SpIndex, IptrStorage, IStorage, DStorage>
CsMatBase<N, I, IptrStorage, IStorage, DStorage, Iptr>
where
IptrStorage: Deref<Target = [Iptr]>,
IStorage: DerefMut<Target = [I]>,
DStorage: DerefMut<Target = [N]>,
{
fn new_from_unsorted_checked(
storage: CompressedStorage,
shape: (usize, usize),
indptr: IptrStorage,
mut indices: IStorage,
mut data: DStorage,
) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
where
N: Clone,
{
let (nrows, ncols) = shape;
let (inner, outer) = match storage {
CSR => (ncols, nrows),
CSC => (nrows, ncols),
};
if data.len() != indices.len() {
return Err((
indptr,
indices,
data,
StructureError::SizeMismatch(
"data and indices have different sizes",
),
));
}
let mut buf = Vec::new();
for start_stop in indptr.windows(2) {
let start = start_stop[0].to_usize().unwrap();
let stop = start_stop[1].to_usize().unwrap();
let indices = &mut indices[start..stop];
if utils::sorted_indices(indices) {
continue;
}
let data = &mut data[start..stop];
let len = stop - start;
let indices = &mut indices[..len];
let data = &mut data[..len];
utils::sort_indices_data_slices(indices, data, &mut buf);
}
match crate::sparse::utils::check_compressed_structure(
inner,
outer,
indptr.as_ref(),
indices.as_ref(),
) {
Err(e) => Err((indptr, indices, data, e)),
Ok(_) => Ok(Self {
storage,
nrows,
ncols,
indptr: crate::IndPtrBase::new_trusted(indptr),
indices,
data,
}),
}
}
pub fn new_from_unsorted(
shape: Shape,
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
where
N: Clone,
{
Self::new_from_unsorted_checked(CSR, shape, indptr, indices, data)
}
pub fn new_from_unsorted_csc(
shape: Shape,
indptr: IptrStorage,
indices: IStorage,
data: DStorage,
) -> Result<Self, (IptrStorage, IStorage, DStorage, StructureError)>
where
N: Clone,
{
Self::new_from_unsorted_checked(CSC, shape, indptr, indices, data)
}
}
impl<N, I: SpIndex, Iptr: SpIndex> CsMatI<N, I, Iptr> {
pub fn eye(dim: usize) -> Self
where
N: Num + Clone,
{
let _ = (I::from_usize(dim), Iptr::from_usize(dim)); let n = dim;
let indptr = (0..=n).map(Iptr::from_usize_unchecked).collect();
let indices = (0..n).map(I::from_usize_unchecked).collect();
let data = vec![N::one(); n];
Self::new_trusted(CSR, (n, n), indptr, indices, data)
}
pub fn eye_csc(dim: usize) -> Self
where
N: Num + Clone,
{
let _ = (I::from_usize(dim), Iptr::from_usize(dim)); let n = dim;
let indptr = (0..=n).map(Iptr::from_usize_unchecked).collect();
let indices = (0..n).map(I::from_usize_unchecked).collect();
let data = vec![N::one(); n];
Self::new_trusted(CSC, (n, n), indptr, indices, data)
}
pub fn empty(storage: CompressedStorage, inner_size: usize) -> Self {
let shape = match storage {
CSR => (0, inner_size),
CSC => (inner_size, 0),
};
Self::new_trusted(
storage,
shape,
vec![Iptr::zero(); 1],
Vec::new(),
Vec::new(),
)
}
pub fn zero(shape: Shape) -> Self {
let (nrows, _ncols) = shape;
Self::new_trusted(
CSR,
shape,
vec![Iptr::zero(); nrows + 1],
Vec::new(),
Vec::new(),
)
}
pub fn reserve_outer_dim(&mut self, outer_dim_additional: usize) {
self.indptr.reserve(outer_dim_additional);
}
pub fn reserve_nnz(&mut self, nnz_additional: usize) {
self.indices.reserve(nnz_additional);
self.data.reserve(nnz_additional);
}
pub fn reserve_outer_dim_exact(&mut self, outer_dim_lim: usize) {
self.indptr.reserve_exact(outer_dim_lim + 1);
}
pub fn reserve_nnz_exact(&mut self, nnz_lim: usize) {
self.indices.reserve_exact(nnz_lim);
self.data.reserve_exact(nnz_lim);
}
pub fn csr_from_dense(m: ArrayView<N, Ix2>, epsilon: N) -> Self
where
N: Num + Clone + cmp::PartialOrd + Signed,
{
let epsilon = if epsilon > N::zero() {
epsilon
} else {
N::zero()
};
let nrows = m.shape()[0];
let ncols = m.shape()[1];
let mut indptr = vec![Iptr::zero(); nrows + 1];
let mut nnz = 0;
for (row, row_count) in m.outer_iter().zip(&mut indptr[1..]) {
nnz += row.iter().filter(|&x| x.abs() > epsilon).count();
*row_count = Iptr::from_usize(nnz);
}
let mut indices = Vec::with_capacity(nnz);
let mut data = Vec::with_capacity(nnz);
for row in m.outer_iter() {
for (col_ind, x) in row.iter().enumerate() {
if x.abs() > epsilon {
indices.push(I::from_usize(col_ind));
data.push(x.clone());
}
}
}
Self {
storage: CompressedStorage::CSR,
nrows,
ncols,
indptr: crate::IndPtr::new_trusted(indptr),
indices,
data,
}
}
pub fn csc_from_dense(m: ArrayView<N, Ix2>, epsilon: N) -> Self
where
N: Num + Clone + cmp::PartialOrd + Signed,
{
Self::csr_from_dense(m.reversed_axes(), epsilon).transpose_into()
}
pub fn append_outer(mut self, data: &[N]) -> Self
where
N: Clone + Num,
{
let mut nnz = self.nnz();
for (inner_ind, val) in data.iter().enumerate() {
if *val != N::zero() {
self.indices.push(I::from_usize(inner_ind));
self.data.push(val.clone());
nnz += 1;
}
}
match self.storage {
CSR => self.nrows += 1,
CSC => self.ncols += 1,
}
self.indptr.push(Iptr::from_usize(nnz));
self
}
pub fn append_outer_csvec(mut self, vec: CsVecViewI<N, I>) -> Self
where
N: Clone,
{
assert_eq!(self.inner_dims(), vec.dim());
for (ind, val) in vec.indices().iter().zip(vec.data()) {
self.indices.push(*ind);
self.data.push(val.clone());
}
match self.storage {
CSR => self.nrows += 1,
CSC => self.ncols += 1,
}
let nnz = Iptr::from_usize(self.indptr.nnz() + vec.nnz());
self.indptr.push(nnz);
self
}
pub fn insert(&mut self, row: usize, col: usize, val: N) {
match self.storage() {
CSR => self.insert_outer_inner(row, col, val),
CSC => self.insert_outer_inner(col, row, val),
}
}
fn insert_outer_inner(
&mut self,
outer_ind: usize,
inner_ind: usize,
val: N,
) {
let outer_dims = self.outer_dims();
let inner_ind_idx = I::from_usize(inner_ind);
if outer_ind >= outer_dims {
let last_nnz = self.indptr.nnz_i();
self.indptr.resize(outer_ind + 1, last_nnz);
self.set_outer_dims(outer_ind + 1);
self.indptr.push(last_nnz + Iptr::one());
self.indices.push(inner_ind_idx);
self.data.push(val);
} else {
let range = self.indptr.outer_inds_sz(outer_ind);
let location =
self.indices[range.clone()].binary_search(&inner_ind_idx);
match location {
Ok(ind) => {
let ind = range.start + ind.index_unchecked();
self.data[ind] = val;
return;
}
Err(ind) => {
let ind = range.start + ind.index_unchecked();
self.indices.insert(ind, inner_ind_idx);
self.data.insert(ind, val);
self.indptr.record_new_element(outer_ind);
}
}
}
if inner_ind >= self.inner_dims() {
self.set_inner_dims(inner_ind + 1);
}
}
fn set_outer_dims(&mut self, outer_dims: usize) {
match self.storage() {
CSR => self.nrows = outer_dims,
CSC => self.ncols = outer_dims,
}
}
fn set_inner_dims(&mut self, inner_dims: usize) {
match self.storage() {
CSR => self.ncols = inner_dims,
CSC => self.nrows = inner_dims,
}
}
}
impl<'a, N: 'a, I: 'a + SpIndex, Iptr: 'a + SpIndex>
CsMatViewI<'a, N, I, Iptr>
{
#[deprecated(
since = "0.10.0",
note = "Please use the `slice_outer` method instead"
)]
pub fn middle_outer_views(
&self,
i: usize,
count: usize,
) -> CsMatViewI<'a, N, I, Iptr> {
let iend = i.checked_add(count).unwrap();
let (nrows, ncols) = match self.storage {
CSR => (count, self.cols()),
CSC => (self.rows(), count),
};
let data_range = self.indptr.outer_inds_slice(i, iend);
CsMatViewI {
storage: self.storage,
nrows,
ncols,
indptr: self.indptr.middle_slice_rbr(i..iend),
indices: &self.indices[data_range.clone()],
data: &self.data[data_range],
}
}
pub fn iter_rbr(&self) -> CsIter<'a, N, I, Iptr> {
CsIter {
storage: self.storage,
cur_outer: I::zero(),
indptr: self.indptr.reborrow(),
inner_iter: self.indices.iter().zip(self.data.iter()).enumerate(),
}
}
}
impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
IptrStorage: Deref<Target = [Iptr]>,
IndStorage: Deref<Target = [I]>,
DataStorage: Deref<Target = [N]>,
{
pub fn storage(&self) -> CompressedStorage {
self.storage
}
pub fn rows(&self) -> usize {
self.nrows
}
pub fn cols(&self) -> usize {
self.ncols
}
pub fn shape(&self) -> Shape {
(self.nrows, self.ncols)
}
pub fn nnz(&self) -> usize {
self.indptr.nnz()
}
pub fn density(&self) -> f64 {
let rows = self.nrows as f64;
let cols = self.ncols as f64;
let nnz = self.nnz() as f64;
nnz / (rows * cols)
}
pub fn outer_dims(&self) -> usize {
outer_dimension(self.storage, self.nrows, self.ncols)
}
pub fn inner_dims(&self) -> usize {
match self.storage {
CSC => self.nrows,
CSR => self.ncols,
}
}
pub fn get(&self, i: usize, j: usize) -> Option<&N> {
match self.storage {
CSR => self.get_outer_inner(i, j),
CSC => self.get_outer_inner(j, i),
}
}
pub fn indptr(&self) -> crate::IndPtrView<Iptr> {
crate::IndPtrView::new_trusted(self.indptr.raw_storage())
}
pub fn proper_indptr(&self) -> std::borrow::Cow<[Iptr]> {
self.indptr.to_proper()
}
pub fn indices(&self) -> &[I] {
&self.indices[..]
}
pub fn data(&self) -> &[N] {
&self.data[..]
}
pub fn into_raw_storage(self) -> (IptrStorage, IndStorage, DataStorage) {
let Self {
indptr,
indices,
data,
..
} = self;
(indptr.into_raw_storage(), indices, data)
}
pub fn is_csc(&self) -> bool {
self.storage == CSC
}
pub fn is_csr(&self) -> bool {
self.storage == CSR
}
pub fn transpose_mut(&mut self) {
mem::swap(&mut self.nrows, &mut self.ncols);
self.storage = self.storage.other_storage();
}
pub fn transpose_into(mut self) -> Self {
self.transpose_mut();
self
}
pub fn transpose_view(&self) -> CsMatViewI<N, I, Iptr> {
CsMatViewI {
storage: self.storage.other_storage(),
nrows: self.ncols,
ncols: self.nrows,
indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
indices: &self.indices[..],
data: &self.data[..],
}
}
pub fn to_owned(&self) -> CsMatI<N, I, Iptr>
where
N: Clone,
{
CsMatI {
storage: self.storage,
nrows: self.nrows,
ncols: self.ncols,
indptr: self.indptr.to_owned(),
indices: self.indices.to_vec(),
data: self.data.to_vec(),
}
}
pub fn to_inner_onehot(&self) -> CsMatI<N, I, Iptr>
where
N: Clone + Float + PartialOrd,
{
let mut indptr_counter = 0_usize;
let mut indptr: Vec<Iptr> = Vec::with_capacity(self.indptr.len());
let max_data_len = self.indptr.len().min(self.data.len());
let mut indices: Vec<I> = Vec::with_capacity(max_data_len);
let mut data = Vec::with_capacity(max_data_len);
for (_, inner_vec) in self.outer_iterator().enumerate() {
let hot_element = inner_vec
.iter()
.filter(|e| !e.1.is_nan())
.max_by(|a, b| {
a.1.partial_cmp(b.1)
.expect("Unexpected NaN value was found")
})
.map(|a| a.0);
indptr.push(Iptr::from_usize(indptr_counter));
if let Some(inner_id) = hot_element {
indices.push(I::from_usize(inner_id));
data.push(N::one());
indptr_counter += 1;
}
}
indptr.push(Iptr::from_usize(indptr_counter));
CsMatI {
storage: self.storage,
nrows: self.rows(),
ncols: self.cols(),
indptr: crate::IndPtr::new_trusted(indptr),
indices,
data,
}
}
pub fn to_other_types<I2, N2, Iptr2>(&self) -> CsMatI<N2, I2, Iptr2>
where
N: Clone + Into<N2>,
I2: SpIndex,
Iptr2: SpIndex,
{
let indptr = crate::IndPtr::new_trusted(
self.indptr
.raw_storage()
.iter()
.map(|i| Iptr2::from_usize(i.index_unchecked()))
.collect(),
);
let indices = self
.indices
.iter()
.map(|i| I2::from_usize(i.index_unchecked()))
.collect();
let data = self.data.iter().map(|x| x.clone().into()).collect();
CsMatI {
storage: self.storage,
nrows: self.nrows,
ncols: self.ncols,
indptr,
indices,
data,
}
}
pub fn view(&self) -> CsMatViewI<N, I, Iptr> {
CsMatViewI {
storage: self.storage,
nrows: self.nrows,
ncols: self.ncols,
indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
indices: &self.indices[..],
data: &self.data[..],
}
}
pub fn structure_view(&self) -> CsStructureViewI<I, Iptr> {
let zst_data = unsafe {
std::slice::from_raw_parts(
self.data.as_ptr() as *const (),
self.data.len(),
)
};
CsStructureViewI {
storage: self.storage,
nrows: self.nrows,
ncols: self.ncols,
indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
indices: &self.indices[..],
data: zst_data,
}
}
pub fn to_dense(&self) -> Array<N, Ix2>
where
N: Clone + Zero,
{
let mut res = Array::zeros((self.rows(), self.cols()));
assign_to_dense(res.view_mut(), self.view());
res
}
pub fn outer_iterator(
&self,
) -> impl std::iter::DoubleEndedIterator<Item = CsVecViewI<N, I>>
+ std::iter::ExactSizeIterator<Item = CsVecViewI<N, I>>
+ '_ {
self.indptr.iter_outer_sz().map(move |range| {
CsVecViewI::new_trusted(
self.inner_dims(),
&self.indices[range.clone()],
&self.data[range],
)
})
}
#[doc(hidden)]
pub fn outer_iterator_papt<'a, 'perm: 'a>(
&'a self,
perm: PermViewI<'perm, I>,
) -> impl std::iter::DoubleEndedIterator<Item = (usize, CsVecViewI<N, I>)>
+ std::iter::ExactSizeIterator<Item = (usize, CsVecViewI<N, I>)>
+ '_ {
(0..self.outer_dims()).into_iter().map(move |outer_ind| {
let outer_ind_perm = perm.at(outer_ind);
let range = self.indptr.outer_inds_sz(outer_ind_perm);
let indices = &self.indices[range.clone()];
let data = &self.data[range];
let vec = CsVecBase::new_trusted(self.inner_dims(), indices, data);
(outer_ind_perm, vec)
})
}
pub fn max_outer_nnz(&self) -> usize {
self.outer_iterator()
.map(|outer| outer.indices().len())
.max()
.unwrap_or(0)
}
pub fn degrees(&self) -> Vec<usize> {
self.outer_iterator()
.enumerate()
.map(|(outer_dim, outer)| {
outer
.indices()
.iter()
.filter(|ind| ind.index() != outer_dim)
.count()
})
.collect()
}
pub fn outer_view(&self, i: usize) -> Option<CsVecViewI<N, I>> {
if i >= self.outer_dims() {
return None;
}
let range = self.indptr.outer_inds_sz(i);
Some(CsVecViewI::new_trusted(
self.inner_dims(),
&self.indices[range.clone()],
&self.data[range],
))
}
pub fn diag(&self) -> CsVecI<N, I>
where
N: Clone,
{
let shape = self.shape();
let smallest_dim: usize = cmp::min(shape.0, shape.1);
let heuristic = smallest_dim / 2;
let mut index_vec = Vec::with_capacity(heuristic);
let mut data_vec = Vec::with_capacity(heuristic);
for i in 0..smallest_dim {
let optional_index = self.nnz_index(i, i);
if let Some(idx) = optional_index {
data_vec.push(self[idx].clone());
index_vec.push(I::from_usize(i));
}
}
data_vec.shrink_to_fit();
index_vec.shrink_to_fit();
CsVecI::new_trusted(smallest_dim, index_vec, data_vec)
}
pub fn diag_iter(
&self,
) -> impl ExactSizeIterator<Item = Option<&N>>
+ DoubleEndedIterator<Item = Option<&N>> {
let smallest_dim = cmp::min(self.ncols, self.nrows);
(0..smallest_dim).map(move |i| self.get_outer_inner(i, i))
}
pub fn outer_block_iter(
&self,
block_size: usize,
) -> impl std::iter::DoubleEndedIterator<Item = CsMatViewI<N, I, Iptr>>
+ std::iter::ExactSizeIterator<Item = CsMatViewI<N, I, Iptr>>
+ '_ {
(0..self.outer_dims()).step_by(block_size).map(move |i| {
let count = if i + block_size > self.outer_dims() {
self.outer_dims() - i
} else {
block_size
};
self.view().slice_outer_rbr(i..i + count)
})
}
pub fn map<F, N2>(&self, f: F) -> CsMatI<N2, I, Iptr>
where
F: FnMut(&N) -> N2,
{
let data: Vec<N2> = self.data.iter().map(f).collect();
CsMatI {
storage: self.storage,
nrows: self.nrows,
ncols: self.ncols,
indptr: self.indptr.to_owned(),
indices: self.indices.to_vec(),
data,
}
}
pub fn get_outer_inner(
&self,
outer_ind: usize,
inner_ind: usize,
) -> Option<&N> {
self.outer_view(outer_ind)
.and_then(|vec| vec.get_rbr(inner_ind))
}
pub fn nnz_index(&self, row: usize, col: usize) -> Option<NnzIndex> {
match self.storage() {
CSR => self.nnz_index_outer_inner(row, col),
CSC => self.nnz_index_outer_inner(col, row),
}
}
pub fn nnz_index_outer_inner(
&self,
outer_ind: usize,
inner_ind: usize,
) -> Option<NnzIndex> {
if outer_ind >= self.outer_dims() {
return None;
}
let offset = self.indptr.outer_inds_sz(outer_ind).start;
self.outer_view(outer_ind)
.and_then(|vec| vec.nnz_index(inner_ind))
.map(|vec::NnzIndex(ind)| NnzIndex(ind + offset))
}
pub fn check_compressed_structure(&self) -> Result<(), StructureError> {
let inner = self.inner_dims();
let outer = self.outer_dims();
if self.indices.len() != self.data.len() {
return Err(StructureError::SizeMismatch(
"Indices and data lengths do not match",
));
}
utils::check_compressed_structure(
inner,
outer,
self.indptr.raw_storage(),
&self.indices,
)
}
pub fn iter(&self) -> CsIter<N, I, Iptr> {
CsIter {
storage: self.storage,
cur_outer: I::zero(),
indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
inner_iter: self.indices.iter().zip(self.data.iter()).enumerate(),
}
}
}
impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
where
N: Default,
I: SpIndex,
Iptr: SpIndex,
IptrStorage: Deref<Target = [Iptr]>,
IndStorage: Deref<Target = [I]>,
DataStorage: Deref<Target = [N]>,
{
pub fn to_other_storage(&self) -> CsMatI<N, I, Iptr>
where
N: Clone,
{
let mut indptr = vec![Iptr::zero(); self.inner_dims() + 1];
let mut indices = vec![I::zero(); self.nnz()];
let mut data = vec![N::default(); self.nnz()];
raw::convert_mat_storage(
self.view(),
&mut indptr,
&mut indices,
&mut data,
);
CsMatI {
storage: self.storage().other_storage(),
nrows: self.nrows,
ncols: self.ncols,
indptr: crate::IndPtr::new_trusted(indptr),
indices,
data,
}
}
pub fn to_csc(&self) -> CsMatI<N, I, Iptr>
where
N: Clone,
{
match self.storage {
CSR => self.to_other_storage(),
CSC => self.to_owned(),
}
}
pub fn to_csr(&self) -> CsMatI<N, I, Iptr>
where
N: Clone,
{
match self.storage {
CSR => self.to_owned(),
CSC => self.to_other_storage(),
}
}
}
impl<N, I, Iptr> CsMatI<N, I, Iptr>
where
N: Default,
I: SpIndex,
Iptr: SpIndex,
{
pub fn into_csc(self) -> Self
where
N: Clone,
{
match self.storage {
CSR => self.to_other_storage(),
CSC => self,
}
}
pub fn into_csr(self) -> Self
where
N: Clone,
{
match self.storage {
CSR => self,
CSC => self.to_other_storage(),
}
}
}
impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
IptrStorage: Deref<Target = [Iptr]>,
IndStorage: Deref<Target = [I]>,
DataStorage: DerefMut<Target = [N]>,
{
pub fn data_mut(&mut self) -> &mut [N] {
&mut self.data[..]
}
pub fn scale(&mut self, val: N)
where
for<'r> N: MulAssign<&'r N>,
{
for data in self.data_mut() {
*data *= &val;
}
}
pub fn outer_view_mut(&mut self, i: usize) -> Option<CsVecViewMutI<N, I>> {
if i >= self.outer_dims() {
return None;
}
let range = self.indptr.outer_inds_sz(i);
Some(CsVecBase::new_trusted(
self.inner_dims(),
&self.indices[range.clone()],
&mut self.data[range],
))
}
pub fn get_mut(&mut self, i: usize, j: usize) -> Option<&mut N> {
match self.storage {
CSR => self.get_outer_inner_mut(i, j),
CSC => self.get_outer_inner_mut(j, i),
}
}
pub fn get_outer_inner_mut(
&mut self,
outer_ind: usize,
inner_ind: usize,
) -> Option<&mut N> {
if let Some(NnzIndex(index)) =
self.nnz_index_outer_inner(outer_ind, inner_ind)
{
Some(&mut self.data[index])
} else {
None
}
}
pub fn set(&mut self, row: usize, col: usize, val: N) {
let outer = outer_dimension(self.storage(), row, col);
let inner = inner_dimension(self.storage(), row, col);
let vec::NnzIndex(index) = self
.outer_view(outer)
.and_then(|vec| vec.nnz_index(inner))
.unwrap();
self.data[index] = val;
}
pub fn map_inplace<F>(&mut self, mut f: F)
where
F: FnMut(&N) -> N,
{
for val in &mut self.data[..] {
*val = f(val);
}
}
pub fn outer_iterator_mut(
&mut self,
) -> impl std::iter::DoubleEndedIterator<Item = CsVecViewMutI<N, I>>
+ std::iter::ExactSizeIterator<Item = CsVecViewMutI<N, I>>
+ '_ {
let inner_dim = self.inner_dims();
let indices = &self.indices[..];
let data_ptr: *mut N = self.data.as_mut_ptr();
self.indptr.iter_outer_sz().map(move |range| {
let data: &mut [N] = unsafe {
std::slice::from_raw_parts_mut(
data_ptr.add(range.start),
range.end - range.start,
)
};
CsVecViewMutI::new_trusted(inner_dim, &indices[range], data)
})
}
pub fn view_mut(&mut self) -> CsMatViewMutI<N, I, Iptr> {
CsMatViewMutI {
storage: self.storage,
nrows: self.nrows,
ncols: self.ncols,
indptr: crate::IndPtrView::new_trusted(self.indptr.raw_storage()),
indices: &self.indices[..],
data: &mut self.data[..],
}
}
pub fn diag_iter_mut(
&mut self,
) -> impl ExactSizeIterator<Item = Option<&mut N>>
+ DoubleEndedIterator<Item = Option<&mut N>>
+ '_ {
let data_ptr: *mut N = self.data[..].as_mut_ptr();
let smallest_dim = cmp::min(self.ncols, self.nrows);
(0..smallest_dim).map(move |i| {
let idx = self.nnz_index_outer_inner(i, i);
if let Some(NnzIndex(idx)) = idx {
Some(unsafe { &mut *data_ptr.add(idx) })
} else {
None
}
})
}
}
impl<N, I, Iptr, IptrStorage, IndStorage, DataStorage>
CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
IptrStorage: DerefMut<Target = [Iptr]>,
IndStorage: DerefMut<Target = [I]>,
DataStorage: DerefMut<Target = [N]>,
{
pub fn modify<F>(&mut self, mut f: F)
where
F: FnMut(&mut [Iptr], &mut [I], &mut [N]),
{
f(
self.indptr.raw_storage_mut(),
&mut self.indices[..],
&mut self.data[..],
);
self.check_compressed_structure().unwrap();
}
}
pub mod raw {
use crate::indexing::SpIndex;
use crate::sparse::prelude::*;
use std::mem::swap;
pub fn convert_mat_storage<N: Clone, I: SpIndex, Iptr: SpIndex>(
mat: CsMatViewI<N, I, Iptr>,
indptr: &mut [Iptr],
indices: &mut [I],
data: &mut [N],
) {
assert_eq!(indptr.len(), mat.inner_dims() + 1);
assert_eq!(indices.len(), mat.indices().len());
assert_eq!(data.len(), mat.data().len());
assert!(indptr.iter().all(|x| x.is_zero()));
for vec in mat.outer_iterator() {
for (inner_dim, _) in vec.iter() {
indptr[inner_dim] += Iptr::one();
}
}
let mut cumsum = Iptr::zero();
for iptr in indptr.iter_mut() {
let tmp = *iptr;
*iptr = cumsum;
cumsum += tmp;
}
if let Some(last_iptr) = indptr.last() {
assert_eq!(last_iptr.index(), mat.nnz());
}
for (outer_dim, vec) in mat.outer_iterator().enumerate() {
for (inner_dim, val) in vec.iter() {
let dest = indptr[inner_dim].index();
data[dest] = val.clone();
indices[dest] = I::from_usize_unchecked(outer_dim);
indptr[inner_dim] += Iptr::one();
}
}
let mut last = Iptr::zero();
for iptr in indptr.iter_mut() {
swap(iptr, &mut last);
}
}
}
impl<'a, I, Iptr, IpStorage, IStorage, DStorage, T> std::ops::MulAssign<T>
for CsMatBase<T, I, IpStorage, IStorage, DStorage, Iptr>
where
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpStorage: 'a + Deref<Target = [Iptr]>,
IStorage: 'a + Deref<Target = [I]>,
DStorage: 'a + DerefMut<Target = [T]>,
T: std::ops::MulAssign<T> + Clone,
{
fn mul_assign(&mut self, rhs: T) {
self.data_mut()
.iter_mut()
.for_each(|v| v.mul_assign(rhs.clone()));
}
}
impl<'a, I, Iptr, IpStorage, IStorage, DStorage, T> std::ops::DivAssign<T>
for CsMatBase<T, I, IpStorage, IStorage, DStorage, Iptr>
where
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpStorage: 'a + Deref<Target = [Iptr]>,
IStorage: 'a + Deref<Target = [I]>,
DStorage: 'a + DerefMut<Target = [T]>,
T: std::ops::DivAssign<T> + Clone,
{
fn div_assign(&mut self, rhs: T) {
self.data_mut()
.iter_mut()
.for_each(|v| v.div_assign(rhs.clone()));
}
}
impl<'a, 'b, N, I, Iptr, IpS1, IS1, DS1, IpS2, IS2, DS2>
Mul<&'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>>
for &'a CsMatBase<N, I, IpS1, IS1, DS1, Iptr>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero + Default + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS1: 'a + Deref<Target = [Iptr]>,
IS1: 'a + Deref<Target = [I]>,
DS1: 'a + Deref<Target = [N]>,
IpS2: 'b + Deref<Target = [Iptr]>,
IS2: 'b + Deref<Target = [I]>,
DS2: 'b + Deref<Target = [N]>,
{
type Output = CsMatI<N, I, Iptr>;
fn mul(
self,
rhs: &'b CsMatBase<N, I, IpS2, IS2, DS2, Iptr>,
) -> CsMatI<N, I, Iptr> {
csmat_mul_csmat(self, rhs)
}
}
pub fn csmat_mul_csmat<
'a,
'b,
N,
A,
B,
I,
Iptr,
IpS1,
IS1,
DS1,
IpS2,
IS2,
DS2,
>(
lhs: &'a CsMatBase<A, I, IpS1, IS1, DS1, Iptr>,
rhs: &'b CsMatBase<B, I, IpS2, IS2, DS2, Iptr>,
) -> CsMatI<N, I, Iptr>
where
N: 'a
+ Clone
+ crate::MulAcc<A, B>
+ crate::MulAcc<B, A>
+ num_traits::Zero
+ Default
+ Send
+ Sync,
A: 'a + Clone + num_traits::Zero + Default + Send + Sync,
B: 'a + Clone + num_traits::Zero + Default + Send + Sync,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS1: 'a + Deref<Target = [Iptr]>,
IS1: 'a + Deref<Target = [I]>,
DS1: 'a + Deref<Target = [A]>,
IpS2: 'b + Deref<Target = [Iptr]>,
IS2: 'b + Deref<Target = [I]>,
DS2: 'b + Deref<Target = [B]>,
{
match (lhs.storage(), rhs.storage()) {
(CSR, CSR) => smmp::mul_csr_csr(lhs.view(), rhs.view()),
(CSR, CSC) => {
let rhs_csr = rhs.to_other_storage();
smmp::mul_csr_csr(lhs.view(), rhs_csr.view())
}
(CSC, CSR) => {
let rhs_csc = rhs.to_other_storage();
smmp::mul_csr_csr(rhs_csc.transpose_view(), lhs.transpose_view())
.transpose_into()
}
(CSC, CSC) => {
smmp::mul_csr_csr(rhs.transpose_view(), lhs.transpose_view())
.transpose_into()
}
}
}
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Add<&'b ArrayBase<DS2, Ix2>>
for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + Copy + Num + Default,
for<'r> &'r N: Mul<Output = N>,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix2>;
fn add(self, rhs: &'b ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
let is_standard_layout =
utils::fastest_axis(rhs.view()) == ndarray::Axis(1);
let neuter_element = N::one();
match (self.storage(), is_standard_layout) {
(CSR, true) | (CSC, false) => binop::add_dense_mat_same_ordering(
self,
rhs,
neuter_element,
neuter_element,
),
(CSR, false) | (CSC, true) => {
let lhs = self.to_other_storage();
binop::add_dense_mat_same_ordering(
&lhs,
rhs,
neuter_element,
neuter_element,
)
}
}
}
}
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix2>>
for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + crate::MulAcc + num_traits::Zero + Clone,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix2>;
fn mul(self, rhs: &'b ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
let rows = self.rows();
let cols = rhs.shape()[1];
match (self.storage(), cols >= 8) {
(CSR, true) => {
let mut res = Array::zeros((rows, cols));
prod::csr_mulacc_dense_rowmaj(
self.view(),
rhs.view(),
res.view_mut(),
);
res
}
(CSR, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csr_mulacc_dense_colmaj(
self.view(),
rhs.view(),
res.view_mut(),
);
res
}
(CSC, true) => {
let mut res = Array::zeros((rows, cols));
prod::csc_mulacc_dense_rowmaj(
self.view(),
rhs.view(),
res.view_mut(),
);
res
}
(CSC, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csc_mulacc_dense_colmaj(
self.view(),
rhs.view(),
res.view_mut(),
);
res
}
}
}
}
impl<'a, 'b, N, I, IpS, IS, DS, DS2> Dot<CsMatBase<N, I, IpS, IS, DS>>
for ArrayBase<DS2, Ix2>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero + std::fmt::Debug,
I: 'a + SpIndex,
IpS: 'a + Deref<Target = [I]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix2>;
fn dot(&self, rhs: &CsMatBase<N, I, IpS, IS, DS>) -> Array<N, Ix2> {
let rhs_t = rhs.transpose_view();
let lhs_t = self.t();
let rows = rhs_t.rows();
let cols = lhs_t.ncols();
let rres = match (rhs_t.storage(), cols >= 8) {
(CSR, true) => {
let mut res = Array::zeros((rows, cols));
prod::csr_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSR, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csr_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSC, true) => {
let mut res = Array::zeros((rows, cols));
prod::csc_mulacc_dense_rowmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
(CSC, false) => {
let mut res = Array::zeros((rows, cols).f());
prod::csc_mulacc_dense_colmaj(rhs_t, lhs_t, res.view_mut());
res.reversed_axes()
}
};
assert_eq!(self.shape()[0], rres.shape()[0]);
assert_eq!(rhs.cols(), rres.shape()[1]);
rres
}
}
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix2>>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix2>;
fn dot(&self, rhs: &ArrayBase<DS2, Ix2>) -> Array<N, Ix2> {
Mul::mul(self, rhs)
}
}
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Mul<&'b ArrayBase<DS2, Ix1>>
for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix1>;
fn mul(self, rhs: &'b ArrayBase<DS2, Ix1>) -> Array<N, Ix1> {
let rows = self.rows();
let cols = rhs.shape()[0];
let rhs_reshape = rhs.view().into_shape((cols, 1)).unwrap();
let mut res = Array::zeros(rows);
{
let res_reshape = res.view_mut().into_shape((rows, 1)).unwrap();
match self.storage() {
CSR => {
prod::csr_mulacc_dense_colmaj(
self.view(),
rhs_reshape,
res_reshape,
);
}
CSC => {
prod::csc_mulacc_dense_colmaj(
self.view(),
rhs_reshape,
res_reshape,
);
}
}
}
res
}
}
impl<'a, 'b, N, I, Iptr, IpS, IS, DS, DS2> Dot<ArrayBase<DS2, Ix1>>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
N: 'a + Clone + crate::MulAcc + num_traits::Zero,
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
IpS: 'a + Deref<Target = [Iptr]>,
IS: 'a + Deref<Target = [I]>,
DS: 'a + Deref<Target = [N]>,
DS2: 'b + ndarray::Data<Elem = N>,
{
type Output = Array<N, Ix1>;
fn dot(&self, rhs: &ArrayBase<DS2, Ix1>) -> Array<N, Ix1> {
Mul::mul(self, rhs)
}
}
impl<N, I, Iptr, IpS, IS, DS> Index<[usize; 2]>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
IpS: Deref<Target = [Iptr]>,
IS: Deref<Target = [I]>,
DS: Deref<Target = [N]>,
{
type Output = N;
fn index(&self, index: [usize; 2]) -> &N {
let i = index[0];
let j = index[1];
self.get(i, j).unwrap()
}
}
impl<N, I, Iptr, IpS, IS, DS> IndexMut<[usize; 2]>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
IpS: Deref<Target = [Iptr]>,
IS: Deref<Target = [I]>,
DS: DerefMut<Target = [N]>,
{
fn index_mut(&mut self, index: [usize; 2]) -> &mut N {
let i = index[0];
let j = index[1];
self.get_mut(i, j).unwrap()
}
}
impl<N, I, Iptr, IpS, IS, DS> Index<NnzIndex>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
IpS: Deref<Target = [Iptr]>,
IS: Deref<Target = [I]>,
DS: Deref<Target = [N]>,
{
type Output = N;
fn index(&self, index: NnzIndex) -> &N {
let NnzIndex(i) = index;
self.data().get(i).unwrap()
}
}
impl<N, I, Iptr, IpS, IS, DS> IndexMut<NnzIndex>
for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
IpS: Deref<Target = [Iptr]>,
IS: Deref<Target = [I]>,
DS: DerefMut<Target = [N]>,
{
fn index_mut(&mut self, index: NnzIndex) -> &mut N {
let NnzIndex(i) = index;
self.data_mut().get_mut(i).unwrap()
}
}
impl<N, I, Iptr, IpS, IS, DS> SparseMat for CsMatBase<N, I, IpS, IS, DS, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
IpS: Deref<Target = [Iptr]>,
IS: Deref<Target = [I]>,
DS: Deref<Target = [N]>,
{
fn rows(&self) -> usize {
self.rows()
}
fn cols(&self) -> usize {
self.cols()
}
fn nnz(&self) -> usize {
self.nnz()
}
}
impl<'a, N, I, Iptr, IpS, IS, DS> SparseMat
for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
where
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
N: 'a,
IpS: Deref<Target = [Iptr]>,
IS: Deref<Target = [I]>,
DS: Deref<Target = [N]>,
{
fn rows(&self) -> usize {
(*self).rows()
}
fn cols(&self) -> usize {
(*self).cols()
}
fn nnz(&self) -> usize {
(*self).nnz()
}
}
impl<'a, N, I, IpS, IS, DS, Iptr> IntoIterator
for &'a CsMatBase<N, I, IpS, IS, DS, Iptr>
where
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
N: 'a,
IpS: Deref<Target = [Iptr]>,
IS: Deref<Target = [I]>,
DS: Deref<Target = [N]>,
{
type Item = (&'a N, (I, I));
type IntoIter = CsIter<'a, N, I, Iptr>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, N, I, Iptr> IntoIterator for CsMatViewI<'a, N, I, Iptr>
where
I: 'a + SpIndex,
Iptr: 'a + SpIndex,
N: 'a,
{
type Item = (&'a N, (I, I));
type IntoIter = CsIter<'a, N, I, Iptr>;
fn into_iter(self) -> Self::IntoIter {
self.iter_rbr()
}
}
#[cfg(test)]
mod test {
use super::CompressedStorage::CSR;
use crate::errors::StructureErrorKind;
use crate::sparse::{CsMat, CsMatI, CsMatView, CsVec};
use crate::test_data::{mat1, mat1_csc, mat1_times_2};
use ndarray::{arr2, Array};
#[test]
fn test_copy() {
let m = mat1();
let view1 = m.view();
let view2 = view1; assert_eq!(view1, view2);
}
#[test]
fn test_new_csr_success() {
let indptr_ok: &[usize] = &[0, 1, 2, 3];
let indices_ok: &[usize] = &[0, 1, 2];
let data_ok: &[f64] = &[1., 1., 1.];
let m = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_ok);
assert!(m.is_ok());
}
#[test]
#[should_panic]
fn test_new_csr_bad_indptr_length() {
let indptr_fail1: &[usize] = &[0, 1, 2];
let indices_ok: &[usize] = &[0, 1, 2];
let data_ok: &[f64] = &[1., 1., 1.];
let res = CsMatView::try_new((3, 3), indptr_fail1, indices_ok, data_ok);
res.unwrap(); }
#[test]
#[should_panic]
fn test_new_csr_out_of_bounds_index() {
let indptr_ok: &[usize] = &[0, 1, 2, 3];
let data_ok: &[f64] = &[1., 1., 1.];
let indices_fail2: &[usize] = &[0, 1, 4];
let res = CsMatView::try_new((3, 3), indptr_ok, indices_fail2, data_ok);
res.unwrap(); }
#[test]
#[should_panic]
fn test_new_csr_bad_nnz_count() {
let indices_ok: &[usize] = &[0, 1, 2];
let data_ok: &[f64] = &[1., 1., 1.];
let indptr_fail2: &[usize] = &[0, 1, 2, 4];
let res = CsMatView::try_new((3, 3), indptr_fail2, indices_ok, data_ok);
res.unwrap(); }
#[test]
#[should_panic]
fn test_new_csr_data_indices_mismatch1() {
let indptr_ok: &[usize] = &[0, 1, 2, 3];
let data_ok: &[f64] = &[1., 1., 1.];
let indices_fail1: &[usize] = &[0, 1];
let res = CsMatView::try_new((3, 3), indptr_ok, indices_fail1, data_ok);
res.unwrap(); }
#[test]
#[should_panic]
fn test_new_csr_data_indices_mismatch2() {
let indptr_ok: &[usize] = &[0, 1, 2, 3];
let indices_ok: &[usize] = &[0, 1, 2];
let data_fail1: &[f64] = &[1., 1., 1., 1.];
let res = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_fail1);
res.unwrap(); }
#[test]
#[should_panic]
fn test_new_csr_data_indices_mismatch3() {
let indptr_ok: &[usize] = &[0, 1, 2, 3];
let indices_ok: &[usize] = &[0, 1, 2];
let data_fail2: &[f64] = &[1., 1.];
let res = CsMatView::try_new((3, 3), indptr_ok, indices_ok, data_fail2);
res.unwrap(); }
#[test]
fn test_new_csr_fails() {
let indices_ok: &[usize] = &[0, 1, 2];
let data_ok: &[f64] = &[1., 1., 1.];
let indptr_fail3: &[usize] = &[0, 2, 1, 3];
assert_eq!(
CsMatView::try_new((3, 3), indptr_fail3, indices_ok, data_ok)
.unwrap_err()
.3
.kind(),
StructureErrorKind::Unsorted
);
}
#[test]
fn test_new_csr_fail_indices_ordering() {
let indptr: &[usize] = &[0, 2, 4, 5, 6, 7];
let indices: &[usize] = &[3, 2, 3, 4, 2, 1, 3];
let data: &[f64] = &[
0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123,
0.88132896, 0.72527863,
];
assert_eq!(
CsMatView::try_new((5, 5), indptr, indices, data)
.unwrap_err()
.3
.kind(),
StructureErrorKind::Unsorted
);
}
#[test]
fn test_new_csr_csc_success() {
let indptr_ok: &[usize] = &[0, 2, 5, 6];
let indices_ok: &[usize] = &[2, 3, 1, 2, 3, 3];
let data_ok: &[f64] = &[
0.05734571, 0.15543348, 0.75628258, 0.83054515, 0.71851547,
0.46202352,
];
assert!(
CsMatView::try_new((3, 4), indptr_ok, indices_ok, data_ok).is_ok()
);
assert!(
CsMatView::try_new_csc((4, 3), indptr_ok, indices_ok, data_ok)
.is_ok()
);
}
#[test]
#[should_panic]
fn test_new_csc_bad_indptr_length() {
let indptr_ok: &[usize] = &[0, 2, 5, 6];
let indices_ok: &[usize] = &[2, 3, 1, 2, 3, 3];
let data_ok: &[f64] = &[
0.05734571, 0.15543348, 0.75628258, 0.83054515, 0.71851547,
0.46202352,
];
let res =
CsMatView::try_new_csc((3, 4), indptr_ok, indices_ok, data_ok);
res.unwrap(); }
#[test]
fn test_new_csr_vec_borrowed() {
let indptr_ok = vec![0, 1, 2, 3];
let indices_ok = vec![0, 1, 2];
let data_ok: Vec<f64> = vec![1., 1., 1.];
assert!(
CsMatView::try_new((3, 3), &indptr_ok, &indices_ok, &data_ok)
.is_ok()
);
}
#[test]
fn test_new_csr_vec_owned() {
let indptr_ok = vec![0, 1, 2, 3];
let indices_ok = vec![0, 1, 2];
let data_ok: Vec<f64> = vec![1., 1., 1.];
assert!(CsMat::new_from_unsorted(
(3, 3),
indptr_ok,
indices_ok,
data_ok
)
.is_ok());
}
#[test]
fn test_csr_from_dense() {
let m = Array::eye(3);
let m_sparse = CsMat::csr_from_dense(m.view(), 0.);
assert_eq!(m_sparse, CsMat::eye(3));
let m = arr2(&[
[1., 0., 2., 1e-7, 1.],
[0., 0., 0., 1., 0.],
[3., 0., 1., 0., 0.],
]);
let m_sparse = CsMat::csr_from_dense(m.view(), 1e-5);
let expected_output = CsMat::new(
(3, 5),
vec![0, 3, 4, 6],
vec![0, 2, 4, 3, 0, 2],
vec![1., 2., 1., 1., 3., 1.],
);
assert_eq!(m_sparse, expected_output);
}
#[test]
fn test_csc_from_dense() {
let m = Array::eye(3);
let m_sparse = CsMat::csc_from_dense(m.view(), 0.);
assert_eq!(m_sparse, CsMat::eye_csc(3));
let m = arr2(&[
[1., 0., 2., 1e-7, 1.],
[0., 0., 0., 1., 0.],
[3., 0., 1., 0., 0.],
]);
let m_sparse = CsMat::csc_from_dense(m.view(), 1e-5);
let expected_output = CsMat::new_csc(
(3, 5),
vec![0, 2, 2, 4, 5, 6],
vec![0, 2, 0, 2, 1, 0],
vec![1., 3., 2., 1., 1., 1.],
);
assert_eq!(m_sparse, expected_output);
}
#[test]
fn owned_csr_unsorted_indices() {
let indptr = vec![0, 3, 3, 5, 6, 7];
let indices_sorted = &[1, 2, 3, 2, 3, 4, 4];
let indices_shuffled = vec![1, 3, 2, 2, 3, 4, 4];
let mut data: Vec<i32> = (0..7).collect();
let m = CsMat::new_from_unsorted(
(5, 5),
indptr,
indices_shuffled,
data.clone(),
)
.unwrap();
assert_eq!(m.indices(), indices_sorted);
data.swap(1, 2);
assert_eq!(m.data(), &data[..]);
}
#[test]
fn new_csr_with_empty_row() {
let indptr: &[usize] = &[0, 3, 3, 5, 6, 7];
let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4];
let data: &[f64] = &[
0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315,
0.39244208, 0.57202407,
];
assert!(CsMatView::try_new((5, 5), indptr, indices, data).is_ok());
}
#[test]
fn csr_to_csc() {
let a = mat1();
let a_csc_ground_truth = mat1_csc();
let a_csc = a.to_other_storage();
assert_eq!(a_csc, a_csc_ground_truth);
}
#[test]
fn test_self_smul() {
let mut a = mat1();
a.scale(2.);
let c_true = mat1_times_2();
assert_eq!(a.indptr(), c_true.indptr());
assert_eq!(a.indices(), c_true.indices());
assert_eq!(a.data(), c_true.data());
}
#[test]
fn outer_block_iter() {
let mat: CsMat<f64> = CsMat::eye(11);
let mut block_iter = mat.outer_block_iter(3);
assert_eq!(block_iter.next().unwrap().rows(), 3);
assert_eq!(block_iter.next().unwrap().rows(), 3);
assert_eq!(block_iter.next().unwrap().rows(), 3);
assert_eq!(block_iter.next().unwrap().rows(), 2);
assert_eq!(block_iter.next(), None);
let mut block_iter = mat.outer_block_iter(4);
assert_eq!(block_iter.next().unwrap().cols(), 11);
block_iter.next().unwrap();
block_iter.next().unwrap();
assert_eq!(block_iter.next(), None);
}
#[test]
fn middle_outer_views() {
let size = 11;
let csr: CsMat<f64> = CsMat::eye(size);
#[allow(deprecated)]
let v = csr.view().middle_outer_views(1, 3);
assert_eq!(v.shape(), (3, size));
assert_eq!(v.nnz(), 3);
let csc = csr.to_other_storage();
#[allow(deprecated)]
let v = csc.view().middle_outer_views(1, 3);
assert_eq!(v.shape(), (size, 3));
assert_eq!(v.nnz(), 3);
}
#[test]
fn nnz_index() {
let mat: CsMat<f64> = CsMat::eye(11);
assert_eq!(mat.nnz_index(2, 3), None);
assert_eq!(mat.nnz_index(5, 7), None);
assert_eq!(mat.nnz_index(0, 11), None);
assert_eq!(mat.nnz_index(0, 0), Some(super::NnzIndex(0)));
assert_eq!(mat.nnz_index(7, 7), Some(super::NnzIndex(7)));
assert_eq!(mat.nnz_index(10, 10), Some(super::NnzIndex(10)));
let index = mat.nnz_index(8, 8).unwrap();
assert_eq!(mat[index], 1.);
let mut mat = mat;
mat[index] = 2.;
assert_eq!(mat[index], 2.);
}
#[test]
fn index() {
let mat = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1., 2., 3., 4.],
);
assert_eq!(mat[[1, 0]], 1.);
assert_eq!(mat[[0, 1]], 2.);
assert_eq!(mat[[2, 1]], 3.);
assert_eq!(mat[[2, 2]], 4.);
assert_eq!(mat.get(0, 0), None);
assert_eq!(mat.get(4, 4), None);
}
#[test]
fn get_mut() {
let mut mat = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1.; 4],
);
*mat.get_mut(2, 1).unwrap() = 3.;
let exp = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1., 1., 3., 1.],
);
assert_eq!(mat, exp);
mat[[2, 2]] = 5.;
let exp = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1., 1., 3., 5.],
);
assert_eq!(mat, exp);
}
#[test]
fn map() {
let mat = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1.; 4],
);
let mut res = mat.map(|&x| x + 2.);
let expected = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![3.; 4],
);
assert_eq!(res, expected);
res.map_inplace(|&x| x / 3.);
assert_eq!(res, mat);
}
#[test]
fn insert() {
let mut mat = CsMat::empty(CSR, 0);
mat.reserve_outer_dim(3);
mat.reserve_nnz(4);
mat.insert(0, 1, 1.);
mat.insert(1, 0, 1.);
mat.insert(2, 1, 1.);
mat.insert(2, 2, 1.);
let expected =
CsMat::new((3, 3), vec![0, 1, 2, 4], vec![1, 0, 1, 2], vec![1.; 4]);
assert_eq!(mat, expected);
mat.insert(0, 0, 2.);
let expected = CsMat::new(
(3, 3),
vec![0, 2, 3, 5],
vec![0, 1, 0, 1, 2],
vec![2., 1., 1., 1., 1.],
);
assert_eq!(mat, expected);
mat.insert(1, 0, 3.);
let expected = CsMat::new(
(3, 3),
vec![0, 2, 3, 5],
vec![0, 1, 0, 1, 2],
vec![2., 1., 3., 1., 1.],
);
assert_eq!(mat, expected);
}
#[test]
fn bug_129() {
let mut mat = CsMat::zero((3, 100));
mat.insert(2, 3, 42);
let mut iter = mat.iter();
assert_eq!(iter.next(), Some((&42, (2, 3))));
assert_eq!(iter.next(), None);
}
#[test]
fn iter_mut() {
let mut mat = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1.; 4],
);
for mut col_vec in mat.outer_iterator_mut() {
for (row_ind, val) in col_vec.iter_mut() {
*val = row_ind as f64 + 1.;
}
}
let expected = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![2., 1., 3., 3.],
);
assert_eq!(mat, expected);
}
#[test]
#[should_panic]
fn modify_fail() {
let mut mat = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1.; 4],
);
mat.modify(|indptr, indices, data| {
indptr[1] = 2;
indptr[2] = 4;
indices[0] = 0;
indices[1] = 1;
data[2] = 2.;
});
}
#[test]
fn convert_types() {
let mat: CsMat<f32> = CsMat::eye(3);
let mat_: CsMatI<f64, u32> = mat.to_other_types();
assert_eq!(mat_.indptr(), &[0, 1, 2, 3][..]);
let mat = CsMatI::new_csc(
(3, 3),
vec![0u32, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1.; 4],
);
let mat_: CsMatI<f32, usize, u32> = mat.to_other_types();
assert_eq!(mat_.indptr(), &[0, 1, 3, 4][..]);
assert_eq!(mat_.data(), &[1.0f32, 1., 1., 1.]);
}
#[test]
fn iter() {
let mat = CsMat::new_csc(
(3, 3),
vec![0, 1, 3, 4],
vec![1, 0, 2, 2],
vec![1.; 4],
);
let mut iter = mat.iter();
assert_eq!(iter.next(), Some((&1., (1, 0))));
assert_eq!(iter.next(), Some((&1., (0, 1))));
assert_eq!(iter.next(), Some((&1., (2, 1))));
assert_eq!(iter.next(), Some((&1., (2, 2))));
assert_eq!(iter.next(), None);
}
#[test]
fn degrees() {
let mat = CsMat::new_csc(
(5, 5),
vec![0, 3, 4, 5, 8, 10],
vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
);
let degrees = mat.degrees();
assert_eq!(°rees, &[2, 0, 1, 2, 1],);
}
#[test]
fn diag() {
let mat = CsMat::new_csc(
(5, 5),
vec![0, 3, 4, 5, 8, 10],
vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
);
let diag = mat.diag();
let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, 1, 1]);
assert_eq!(diag, expected);
let mut iter = mat.diag_iter();
assert_eq!(iter.next().unwrap(), Some(&1));
assert_eq!(iter.next().unwrap(), Some(&2));
assert_eq!(iter.next().unwrap(), None);
assert_eq!(iter.next().unwrap(), Some(&1));
assert_eq!(iter.next().unwrap(), Some(&1));
assert_eq!(iter.next(), None);
}
#[test]
fn diag_mut() {
let mut mat = CsMat::new_csc(
(5, 5),
vec![0, 3, 4, 5, 8, 10],
vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4],
vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1],
);
let mut diags = mat.diag_iter_mut().collect::<Vec<_>>();
diags[4].as_mut().map(|x| **x *= 3);
diags[3].as_mut().map(|x| **x -= 4);
let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, -3, 3]);
assert_eq!(mat.diag(), expected);
}
#[test]
fn diag_rectangular() {
let mat = CsMat::new_csc(
(5, 6),
vec![0, 3, 4, 5, 8, 10, 12],
vec![0, 3, 4, 1, 3, 0, 2, 3, 0, 4, 0, 2],
vec![1, 3, 1, 2, 1, 3, 1, 1, 1, 1, 3, 1],
);
let diag = mat.diag();
let expected = CsVec::new(5, vec![0, 1, 3, 4], vec![1, 2, 1, 1]);
assert_eq!(diag, expected);
let mut iter = mat.diag_iter();
assert_eq!(iter.next().unwrap(), Some(&1));
assert_eq!(iter.next().unwrap(), Some(&2));
assert_eq!(iter.next().unwrap(), None);
assert_eq!(iter.next().unwrap(), Some(&1));
assert_eq!(iter.next().unwrap(), Some(&1));
assert_eq!(iter.next(), None);
}
#[test]
fn onehot_zero() {
let onehot: CsMat<f32> = CsMat::zero((3, 3)).to_inner_onehot();
assert!(onehot.is_csr());
assert_eq!(CsMat::zero((3, 3)), onehot);
}
#[test]
fn onehot_eye() {
let mat = CsMat::new(
(2, 2),
vec![0, 2, 4],
vec![0, 1, 0, 1],
vec![2.0, 0.0, 0.0, 2.0],
);
let onehot = mat.to_inner_onehot();
assert!(onehot.is_csr());
assert_eq!(CsMat::eye(2), onehot);
}
#[test]
fn onehot_sparse_csc() {
let mat = CsMat::new_csc((2, 3), vec![0, 0, 1, 1], vec![1], vec![2.0]);
let onehot = mat.to_inner_onehot();
let expected =
CsMat::new_csc((2, 3), vec![0, 0, 1, 1], vec![1], vec![1.0]);
assert!(onehot.is_csc());
assert_eq!(expected, onehot);
}
#[test]
fn onehot_ignores_nan() {
let mat = CsMat::new(
(2, 2),
vec![0, 2, 3],
vec![0, 1, 1],
vec![2.0, std::f64::NAN, 2.0],
);
let onehot = mat.to_inner_onehot();
assert!(onehot.is_csr());
assert_eq!(CsMat::eye(2), onehot);
}
#[test]
fn mul_assign() {
let mut m1 = crate::TriMat::new((6, 9));
m1.add_triplet(1, 1, 8_i32);
m1.add_triplet(1, 2, 7);
m1.add_triplet(0, 1, 6);
m1.add_triplet(0, 8, 5);
m1.add_triplet(4, 2, 4);
let mut m1: CsMat<_> = m1.to_csr();
m1 *= 2;
for (&v, (j, i)) in m1.iter() {
match (j, i) {
(1, 1) => assert_eq!(v, 16),
(1, 2) => assert_eq!(v, 14),
(0, 1) => assert_eq!(v, 12),
(0, 8) => assert_eq!(v, 10),
(4, 2) => assert_eq!(v, 8),
_ => panic!(),
}
}
}
#[test]
fn div_assign() {
let mut m1 = crate::TriMat::new((6, 9));
m1.add_triplet(1, 1, 8_i32);
m1.add_triplet(1, 2, 7);
m1.add_triplet(0, 1, 6);
m1.add_triplet(0, 8, 5);
m1.add_triplet(4, 2, 4);
let mut m1: CsMat<_> = m1.to_csr();
m1 /= 2;
for (&v, (j, i)) in m1.iter() {
match (j, i) {
(1, 1) => assert_eq!(v, 4),
(1, 2) => assert_eq!(v, 3),
(0, 1) => assert_eq!(v, 3),
(0, 8) => assert_eq!(v, 2),
(4, 2) => assert_eq!(v, 2),
_ => panic!(),
}
}
}
#[test]
fn issue_99() {
let a = crate::TriMat::<i32>::new((10, 1)).to_csc::<usize>();
let b = crate::TriMat::<i32>::new((1, 9)).to_csr();
let _c = &a * &b;
}
}
#[cfg(feature = "approx")]
mod approx_impls {
use super::*;
use approx::*;
impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
AbsDiffEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
IS1: Deref<Target = [I]>,
IS2: Deref<Target = [I]>,
ISptr1: Deref<Target = [Iptr]>,
ISptr2: Deref<Target = [Iptr]>,
DS1: Deref<Target = [N]>,
DS2: Deref<Target = [N]>,
N: AbsDiffEq,
N::Epsilon: Clone,
N: num_traits::Zero,
{
type Epsilon = N::Epsilon;
fn default_epsilon() -> N::Epsilon {
N::default_epsilon()
}
fn abs_diff_eq(
&self,
other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
epsilon: N::Epsilon,
) -> bool {
if self.shape() != other.shape() {
return false;
}
if self.storage() == other.storage() {
self.outer_iterator()
.zip(other.outer_iterator())
.all(|(r1, r2)| r1.abs_diff_eq(&r2, epsilon.clone()))
} else {
let all_matching = self.iter().all(|(n, (i, j))| {
n.abs_diff_eq(
other
.get(i.to_usize().unwrap(), j.to_usize().unwrap())
.unwrap_or(&N::zero()),
epsilon.clone(),
)
});
if !all_matching {
return false;
}
other.iter().all(|(n, (i, j))| {
n.abs_diff_eq(
self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
.unwrap_or(&N::zero()),
epsilon.clone(),
)
})
}
}
}
impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
UlpsEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
IS1: Deref<Target = [I]>,
IS2: Deref<Target = [I]>,
ISptr1: Deref<Target = [Iptr]>,
ISptr2: Deref<Target = [Iptr]>,
DS1: Deref<Target = [N]>,
DS2: Deref<Target = [N]>,
N: UlpsEq,
N::Epsilon: Clone,
N: num_traits::Zero,
{
fn default_max_ulps() -> u32 {
N::default_max_ulps()
}
fn ulps_eq(
&self,
other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
epsilon: N::Epsilon,
max_ulps: u32,
) -> bool {
if self.shape() != other.shape() {
return false;
}
if self.storage() == other.storage() {
self.outer_iterator()
.zip(other.outer_iterator())
.all(|(r1, r2)| r1.ulps_eq(&r2, epsilon.clone(), max_ulps))
} else {
let all_matches = self.iter().all(|(n, (i, j))| {
n.ulps_eq(
other
.get(i.to_usize().unwrap(), j.to_usize().unwrap())
.unwrap_or(&N::zero()),
epsilon.clone(),
max_ulps,
)
});
if !all_matches {
return false;
}
other.iter().all(|(n, (i, j))| {
n.ulps_eq(
self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
.unwrap_or(&N::zero()),
epsilon.clone(),
max_ulps,
)
})
}
}
}
impl<N, I, Iptr, IS1, DS1, ISptr1, IS2, ISptr2, DS2>
RelativeEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>
for CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>
where
I: SpIndex,
Iptr: SpIndex,
CsMatBase<N, I, ISptr1, IS1, DS1, Iptr>:
std::cmp::PartialEq<CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>>,
IS1: Deref<Target = [I]>,
IS2: Deref<Target = [I]>,
ISptr1: Deref<Target = [Iptr]>,
ISptr2: Deref<Target = [Iptr]>,
DS1: Deref<Target = [N]>,
DS2: Deref<Target = [N]>,
N: RelativeEq,
N::Epsilon: Clone,
N: num_traits::Zero,
{
fn default_max_relative() -> N::Epsilon {
N::default_max_relative()
}
fn relative_eq(
&self,
other: &CsMatBase<N, I, ISptr2, IS2, DS2, Iptr>,
epsilon: N::Epsilon,
max_relative: Self::Epsilon,
) -> bool {
if self.shape() != other.shape() {
return false;
}
if self.storage() == other.storage() {
self.outer_iterator().zip(other.outer_iterator()).all(
|(r1, r2)| {
r1.relative_eq(
&r2,
epsilon.clone(),
max_relative.clone(),
)
},
)
} else {
let all_matches = self.iter().all(|(n, (i, j))| {
n.relative_eq(
other
.get(i.to_usize().unwrap(), j.to_usize().unwrap())
.unwrap_or(&N::zero()),
epsilon.clone(),
max_relative.clone(),
)
});
if !all_matches {
return false;
}
other.iter().all(|(n, (i, j))| {
n.relative_eq(
self.get(i.to_usize().unwrap(), j.to_usize().unwrap())
.unwrap_or(&N::zero()),
epsilon.clone(),
max_relative.clone(),
)
})
}
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn different_shapes() {
let mut m1 = TriMat::new((3, 2));
m1.add_triplet(1, 1, 8_u8);
let m1: CsMat<_> = m1.to_csr();
let mut m2 = TriMat::new((2, 3));
m2.add_triplet(1, 1, 8_u8);
let m2 = m2.to_csr();
::approx::assert_abs_diff_ne!(m1, m2);
::approx::assert_abs_diff_ne!(m1, m2.to_csc());
::approx::assert_abs_diff_ne!(m1.to_csc(), m2);
::approx::assert_abs_diff_ne!(m1.to_csc(), m2.to_csc());
}
#[test]
fn equal_elements() {
let mut m1 = TriMat::new((6, 9));
m1.add_triplet(1, 1, 8_u8);
m1.add_triplet(1, 2, 7_u8);
m1.add_triplet(0, 1, 6_u8);
m1.add_triplet(0, 8, 5_u8);
m1.add_triplet(4, 2, 4_u8);
let m1: CsMat<_> = m1.to_csr();
let m2 = m1.clone();
::approx::assert_abs_diff_eq!(m1, m2, epsilon = 0);
::approx::assert_abs_diff_eq!(m1.to_csc(), m2, epsilon = 0);
::approx::assert_abs_diff_eq!(m1, m2.to_csc(), epsilon = 0);
::approx::assert_abs_diff_eq!(
m1.to_csc(),
m2.to_csc(),
epsilon = 0
);
let mut m1 = TriMat::new((6, 9));
m1.add_triplet(1, 1, 8.0_f32);
m1.add_triplet(1, 2, 7.0);
m1.add_triplet(0, 1, 6.0);
m1.add_triplet(0, 8, 5.0);
m1.add_triplet(4, 2, 4.0);
let m1: CsMat<_> = m1.to_csr();
let m2 = m1.clone();
::approx::assert_abs_diff_eq!(m1, m2);
::approx::assert_abs_diff_eq!(m1.to_csc(), m2);
::approx::assert_abs_diff_eq!(m1, m2.to_csc());
::approx::assert_abs_diff_eq!(m1.to_csc(), m2.to_csc());
::approx::assert_relative_eq!(m1, m2);
::approx::assert_relative_eq!(m1.to_csc(), m2);
::approx::assert_relative_eq!(m1, m2.to_csc());
::approx::assert_relative_eq!(m1.to_csc(), m2.to_csc());
::approx::assert_ulps_eq!(m1, m2);
::approx::assert_ulps_eq!(m1.to_csc(), m2);
::approx::assert_ulps_eq!(m1, m2.to_csc());
::approx::assert_ulps_eq!(m1.to_csc(), m2.to_csc());
}
#[test]
fn almost_equal_elements() {
let mut m1 = TriMat::new((6, 9));
m1.add_triplet(1, 1, 8.0_f32);
m1.add_triplet(1, 2, 7.0);
m1.add_triplet(0, 1, 6.0);
m1.add_triplet(0, 8, 5.0);
m1.add_triplet(4, 2, 4.0);
let m1: CsMat<_> = m1.to_csr();
let mut m2 = TriMat::new((6, 9));
m2.add_triplet(1, 1, 8.0_f32);
m2.add_triplet(1, 2, 7.0 - 0.5); m2.add_triplet(0, 1, 6.0);
m2.add_triplet(0, 8, 5.0);
m2.add_triplet(4, 2, 4.0);
m2.add_triplet(4, 3, 0.2); let m2 = m2.to_csr();
::approx::assert_abs_diff_eq!(m1, m2, epsilon = 0.6);
::approx::assert_abs_diff_eq!(m1.to_csc(), m2, epsilon = 0.6);
::approx::assert_abs_diff_eq!(m1, m2.to_csc(), epsilon = 0.6);
::approx::assert_abs_diff_eq!(
m1.to_csc(),
m2.to_csc(),
epsilon = 0.6
);
::approx::assert_abs_diff_ne!(m1, m2, epsilon = 0.4);
::approx::assert_abs_diff_ne!(m1.to_csc(), m2, epsilon = 0.4);
::approx::assert_abs_diff_ne!(m1, m2.to_csc(), epsilon = 0.4);
::approx::assert_abs_diff_ne!(
m1.to_csc(),
m2.to_csc(),
epsilon = 0.4
);
}
}
}