use crate::array_backend::Array2;
use crate::errors::StructureError;
use crate::indexing::SpIndex;
use crate::IndPtrBase;
use std::ops::Deref;
#[cfg(feature = "serde")]
mod serde_traits;
#[cfg(feature = "serde")]
use serde_traits::{CsMatBaseShadow, CsVecBaseShadow, Deserialize, Serialize};
pub use self::csmat::CompressedStorage;
#[derive(Eq, PartialEq, Debug, Copy, Clone, Hash)]
#[cfg_attr(feature = "serde", derive(Deserialize))]
#[cfg_attr(
feature = "serde",
serde(
try_from = "CsMatBaseShadow<N, I, IptrStorage, IndStorage, DataStorage, Iptr>"
)
)]
pub struct CsMatBase<N, I, IptrStorage, IndStorage, DataStorage, Iptr = I>
where
I: SpIndex,
Iptr: SpIndex,
IptrStorage: Deref<Target = [Iptr]>,
IndStorage: Deref<Target = [I]>,
DataStorage: Deref<Target = [N]>,
{
storage: CompressedStorage,
nrows: usize,
ncols: usize,
#[cfg_attr(feature = "serde", serde(flatten))]
indptr: IndPtrBase<Iptr, IptrStorage>,
indices: IndStorage,
data: DataStorage,
}
pub type CsMatI<N, I, Iptr = I> =
CsMatBase<N, I, Vec<Iptr>, Vec<I>, Vec<N>, Iptr>;
pub type CsMatViewI<'a, N, I, Iptr = I> =
CsMatBase<N, I, &'a [Iptr], &'a [I], &'a [N], Iptr>;
pub type CsMatViewMutI<'a, N, I, Iptr = I> =
CsMatBase<N, I, &'a [Iptr], &'a [I], &'a mut [N], Iptr>;
pub type CsMatVecView_<'a, N, I, Iptr = I> =
CsMatBase<N, I, Array2<Iptr>, &'a [I], &'a [N], Iptr>;
pub type CsMat<N> = CsMatI<N, usize>;
pub type CsMatView<'a, N> = CsMatViewI<'a, N, usize>;
pub type CsMatViewMut<'a, N> = CsMatViewMutI<'a, N, usize>;
pub type CsMatVecView<'a, N> = CsMatVecView_<'a, N, usize>;
pub type CsStructureViewI<'a, I, Iptr = I> = CsMatViewI<'a, (), I, Iptr>;
pub type CsStructureView<'a> = CsStructureViewI<'a, usize>;
pub type CsStructureI<I, Iptr = I> = CsMatI<(), I, Iptr>;
pub type CsStructure = CsStructureI<usize>;
#[derive(Eq, PartialEq, Debug, Copy, Clone, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(
feature = "serde",
serde(try_from = "CsVecBaseShadow<IStorage, DStorage, N, I>")
)]
pub struct CsVecBase<IStorage, DStorage, N, I: SpIndex = usize>
where
IStorage: Deref<Target = [I]>,
DStorage: Deref<Target = [N]>,
{
dim: usize,
indices: IStorage,
data: DStorage,
}
pub type CsVecI<N, I = usize> = CsVecBase<Vec<I>, Vec<N>, N, I>;
pub type CsVecViewI<'a, N, I = usize> = CsVecBase<&'a [I], &'a [N], N, I>;
pub type CsVecViewMutI<'a, N, I = usize> =
CsVecBase<&'a [I], &'a mut [N], N, I>;
pub type CsVecView<'a, N> = CsVecViewI<'a, N>;
pub type CsVecViewMut<'a, N> = CsVecViewMutI<'a, N>;
pub type CsVec<N> = CsVecI<N>;
#[derive(PartialEq, Eq, Debug, Hash)]
pub struct TriMatBase<IStorage, DStorage> {
rows: usize,
cols: usize,
row_inds: IStorage,
col_inds: IStorage,
data: DStorage,
}
pub type TriMatI<N, I> = TriMatBase<Vec<I>, Vec<N>>;
pub type TriMatViewI<'a, N, I> = TriMatBase<&'a [I], &'a [N]>;
pub type TriMatViewMutI<'a, N, I> = TriMatBase<&'a mut [I], &'a mut [N]>;
pub type TriMat<N> = TriMatI<N, usize>;
pub type TriMatView<'a, N> = TriMatViewI<'a, N, usize>;
pub type TriMatViewMut<'a, N> = TriMatViewMutI<'a, N, usize>;
#[derive(PartialEq, Eq, Debug, Clone)]
pub struct TriMatIter<RI, CI, DI> {
rows: usize,
cols: usize,
nnz: usize,
row_inds: RI,
col_inds: CI,
data: DI,
}
mod prelude {
#[allow(unused_imports)]
pub use super::{
CsMat, CsMatBase, CsMatI, CsMatVecView, CsMatVecView_, CsMatView,
CsMatViewI, CsMatViewMut, CsMatViewMutI, CsStructure, CsStructureI,
CsStructureView, CsStructureViewI, CsVec, CsVecBase, CsVecI, CsVecView,
CsVecViewI, CsVecViewMut, CsVecViewMutI, SparseMat, TriMat, TriMatBase,
TriMatI, TriMatIter, TriMatView, TriMatViewI, TriMatViewMut,
TriMatViewMutI,
};
}
pub trait SparseMat {
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn nnz(&self) -> usize;
}
pub(crate) mod utils {
use super::*;
use ndarray::Axis;
use std::convert::TryInto;
pub(crate) fn check_compressed_structure<I: SpIndex, Iptr: SpIndex>(
inner: usize,
outer: usize,
indptr: &[Iptr],
indices: &[I],
) -> Result<(), StructureError> {
let indptr =
crate::IndPtrView::new_checked(indptr).map_err(|(_, e)| e)?;
if indptr.len() != outer + 1 {
return Err(StructureError::SizeMismatch(
"Indptr length does not match dimension",
));
}
if I::from(inner).is_none() {
return Err(StructureError::OutOfRange(
"Index type not large enough for this matrix",
));
}
if Iptr::from(outer + 1).is_none() {
return Err(StructureError::OutOfRange(
"Iptr type not large enough for this matrix",
));
}
for i in indices.iter() {
if i.try_index().is_none() {
return Err(StructureError::OutOfRange(
"Indices value out of range of usize",
));
}
}
let nnz = indices.len();
if nnz != indptr.nnz() {
return Err(StructureError::SizeMismatch(
"Indices length and inpdtr's nnz do not match",
));
}
for range in indptr.iter_outer_sz() {
let indices = &indices[range];
if !sorted_indices(indices) {
return Err(StructureError::Unsorted("Indices are not sorted"));
}
if let Some(i) = indices.last() {
if i.to_usize().unwrap() >= inner {
return Err(StructureError::OutOfRange(
"Indice is larger than inner dimension",
));
}
}
}
Ok(())
}
pub fn sorted_indices<I: SpIndex>(indices: &[I]) -> bool {
for w in indices.windows(2) {
let &[i1, i2]: &[I; 2] = w.try_into().unwrap();
if i2 <= i1 {
return false;
}
}
true
}
pub fn sort_indices_data_slices<N: Clone, I: SpIndex>(
indices: &mut [I],
data: &mut [N],
buf: &mut Vec<(I, N)>,
) {
let len = indices.len();
assert_eq!(len, data.len());
let indices = &mut indices[..len];
let data = &mut data[..len];
buf.clear();
buf.reserve_exact(len);
for (i, v) in indices.iter().zip(data.iter()) {
buf.push((*i, v.clone()));
}
buf.sort_unstable_by_key(|x| x.0);
for ((i, x), (ind, v)) in buf
.iter()
.cloned()
.zip(indices.iter_mut().zip(data.iter_mut()))
{
*ind = i;
*v = x;
}
}
pub(crate) fn fastest_axis<T>(mat: ndarray::ArrayView2<T>) -> Axis {
if mat.strides()[1] > mat.strides()[0] {
Axis(0)
} else {
Axis(1)
}
}
pub(crate) fn slowest_axis<T>(mat: ndarray::ArrayView2<T>) -> Axis {
match fastest_axis(mat) {
Axis(1) => Axis(0),
Axis(0) => Axis(1),
_ => unreachable!(),
}
}
}
pub mod binop;
pub mod compressed;
pub mod construct;
pub mod csmat;
pub mod indptr;
pub mod kronecker;
pub mod linalg;
pub mod permutation;
pub mod prod;
pub mod slicing;
pub mod smmp;
pub mod special_mats;
pub mod symmetric;
pub mod to_dense;
pub mod triplet;
pub mod triplet_iter;
pub mod vec;
pub mod visu;
#[cfg(test)]
mod test {
use super::utils;
#[test]
fn test_sort_indices() {
let mut idx: Vec<usize> = vec![4, 1, 6, 2];
let mut dat: Vec<i32> = vec![4, -1, 2, -3];
let mut buf: Vec<(usize, i32)> = Vec::new();
utils::sort_indices_data_slices(&mut idx[..], &mut dat[..], &mut buf);
assert_eq!(idx, vec![1, 2, 4, 6]);
assert_eq!(dat, vec![-1, -3, 4, 2]);
}
#[test]
fn test_sorted_indices() {
use utils::sorted_indices;
assert!(sorted_indices(&[1, 2, 3]));
assert!(sorted_indices(&[1, 2, 8]));
assert!(!sorted_indices(&[1, 1, 3]));
assert!(!sorted_indices(&[2, 1, 3]));
assert!(sorted_indices(&[1, 2]));
assert!(sorted_indices(&[1]));
}
#[test]
fn test_fastest_axis() {
use ndarray::{arr2, s, Array2, Axis, ShapeBuilder};
use utils::fastest_axis;
let arr = arr2(&[[1, 2], [3, 4]]);
assert_eq!(fastest_axis(arr.view()), Axis(1));
let arr = Array2::<i32>::zeros((10, 9));
assert_eq!(fastest_axis(arr.view()), Axis(1));
let arrslice = arr.slice(s![..;2, ..;3]);
assert_eq!(fastest_axis(arrslice.view()), Axis(1));
let arr = Array2::<i32>::zeros((10, 9).f());
assert_eq!(fastest_axis(arr.view()), Axis(0));
let arrslice = arr.slice(s![..;2, ..;3]);
assert_eq!(fastest_axis(arrslice.view()), Axis(0));
}
}