use crate::scalar::Scalar;
#[derive(Debug)]
pub struct CsrMatrix<T: Scalar> {
nrows: usize,
ncols: usize,
rowptr: Vec<usize>,
colind: Vec<usize>,
values: Vec<T>,
}
mod ops;
mod conv;
#[derive(Clone, Debug)]
pub struct Iter<'iter, T> {
iter: std::vec::IntoIter<(usize, usize, &'iter T)>,
}
#[derive(Debug)]
pub struct IterMut<'iter, T> {
iter: std::vec::IntoIter<(usize, usize, &'iter mut T)>,
}
#[derive(Debug)]
pub struct IntoIter<T> {
iter: std::vec::IntoIter<(usize, usize, T)>,
}
impl<T: Scalar> CsrMatrix<T> {
pub fn new(
nrows: usize,
ncols: usize,
rowptr: Vec<usize>,
colind: Vec<usize>,
values: Vec<T>,
) -> Self {
assert!(nrows > 0);
assert!(ncols > 0);
assert!(rowptr.len() == nrows + 1);
assert_eq!(rowptr[0], 0);
assert_eq!(colind.len(), rowptr[nrows]);
assert_eq!(values.len(), rowptr[nrows]);
assert!(rowptr.windows(2).all(|ptr| ptr[0] <= ptr[1]));
assert!(colind.iter().all(|col| (0..ncols).contains(col)));
for row in 0..nrows {
assert!(colind[rowptr[row]..rowptr[row + 1]]
.windows(2)
.all(|cols| cols[0] < cols[1]));
}
Self {
nrows,
ncols,
rowptr,
colind,
values,
}
}
pub fn eye(size: usize) -> Self {
assert!(size > 0);
Self {
nrows: size,
ncols: size,
rowptr: (0..=size).collect(),
colind: (0..size).collect(),
values: vec![T::one(); size],
}
}
pub fn nrows(&self) -> usize {
self.nrows
}
pub fn ncols(&self) -> usize {
self.ncols
}
pub fn rowptr(&self) -> &[usize] {
&self.rowptr
}
pub fn colind(&self) -> &[usize] {
&self.colind
}
pub fn values(&self) -> &[T] {
&self.values
}
pub fn values_mut(&mut self) -> &mut [T] {
&mut self.values
}
pub fn nnz(&self) -> usize {
*self.rowptr.last().unwrap()
}
pub fn iter(&self) -> Iter<T> {
let mut vec = Vec::with_capacity(self.nnz());
let mut values = self.values.iter();
for row in 0..self.nrows {
for ptr in self.rowptr[row]..self.rowptr[row + 1] {
let col = self.colind[ptr];
let val = values.next().unwrap();
vec.push((row, col, val));
}
}
Iter {
iter: vec.into_iter(),
}
}
pub fn iter_mut(&mut self) -> IterMut<T> {
let mut vec = Vec::with_capacity(self.nnz());
let mut values = self.values.iter_mut();
for row in 0..self.ncols {
for ptr in self.rowptr[row]..self.rowptr[row + 1] {
let col = self.colind[ptr];
let val = values.next().unwrap();
vec.push((row, col, val));
}
}
IterMut {
iter: vec.into_iter(),
}
}
pub fn transpose(&self) -> Self {
let nrows = self.nrows();
let ncols = self.ncols();
let nz = self.nnz();
let rowptr = self.rowptr();
let colind = self.colind();
let rowval = self.values();
let mut vec = vec![0; ncols];
for row in 0..nrows {
for ptr in rowptr[row]..rowptr[row + 1] {
let col = colind[ptr];
vec[col] += 1;
}
}
let mut colptr = Vec::with_capacity(ncols + 1);
let mut sum = 0;
colptr.push(0);
for value in vec {
sum += value;
colptr.push(sum);
}
let mut vec = colptr[..ncols].to_vec();
let mut rowind = vec![0; nz];
let mut colval = vec![T::zero(); nz];
for row in 0..nrows {
for ptr in rowptr[row]..rowptr[row + 1] {
let col = colind[ptr];
let idx = &mut vec[col];
rowind[*idx] = row;
colval[*idx] = rowval[ptr];
*idx += 1;
}
}
Self {
nrows: ncols,
ncols: nrows,
rowptr: colptr,
colind: rowind,
values: colval,
}
}
}
impl<T: Scalar> IntoIterator for CsrMatrix<T> {
type Item = (usize, usize, T);
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
let mut vec = Vec::with_capacity(self.nnz());
let mut values = self.values.into_iter();
for row in 0..self.nrows {
for ptr in self.rowptr[row]..self.rowptr[row + 1] {
let col = self.colind[ptr];
let val = values.next().unwrap();
vec.push((row, col, val));
}
}
IntoIter {
iter: vec.into_iter(),
}
}
}
impl<'iter, T> Iterator for Iter<'iter, T> {
type Item = (usize, usize, &'iter T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
impl<'iter, T> Iterator for IterMut<'iter, T> {
type Item = (usize, usize, &'iter mut T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
impl<T: Scalar> Iterator for IntoIter<T> {
type Item = (usize, usize, T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn new_invalid_nrows() {
CsrMatrix::<f64>::new(0, 1, vec![0, 1, 1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn new_invalid_ncols() {
CsrMatrix::<f64>::new(2, 0, vec![0, 1, 1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn new_invalid_colptr_first_not_zero() {
CsrMatrix::<f64>::new(2, 1, vec![1, 1, 1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn new_invalid_colptr_invalid_length() {
CsrMatrix::<f64>::new(2, 1, vec![0, 1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn new_invalid_rowind() {
CsrMatrix::<f64>::new(2, 1, vec![0, 1, 1], vec![1], vec![1.0]);
}
#[test]
#[should_panic]
fn new_unsorted_colind() {
CsrMatrix::<f64>::new(2, 1, vec![0, 2, 2], vec![1, 0], vec![1.0, 2.0]);
}
#[test]
#[should_panic]
fn new_invalid_rowind_values() {
CsrMatrix::<f64>::new(2, 1, vec![0, 1, 1], vec![0], vec![1.0, 2.0]);
}
}