use crate::scalar::Scalar;
#[derive(Debug)]
pub struct CscMatrix<T: Scalar> {
nrows: usize,
ncols: usize,
colptr: Vec<usize>,
rowind: Vec<usize>,
values: Vec<T>,
}
mod ops;
mod conv;
#[derive(Clone, Debug)]
pub struct Iter<'iter, T> {
iter: std::vec::IntoIter<(usize, usize, &'iter T)>,
}
#[derive(Debug)]
pub struct IterMut<'iter, T> {
iter: std::vec::IntoIter<(usize, usize, &'iter mut T)>,
}
#[derive(Debug)]
pub struct IntoIter<T> {
iter: std::vec::IntoIter<(usize, usize, T)>,
}
impl<T: Scalar> CscMatrix<T> {
pub fn new(
nrows: usize,
ncols: usize,
colptr: Vec<usize>,
rowind: Vec<usize>,
values: Vec<T>,
) -> Self {
assert!(nrows > 0);
assert!(ncols > 0);
assert!(colptr.len() == ncols + 1);
assert_eq!(colptr[0], 0);
assert_eq!(rowind.len(), colptr[ncols]);
assert_eq!(values.len(), colptr[ncols]);
assert!(colptr.windows(2).all(|ptr| ptr[0] <= ptr[1]));
assert!(rowind.iter().all(|row| (0..nrows).contains(row)));
for col in 0..ncols {
assert!(rowind[colptr[col]..colptr[col + 1]]
.windows(2)
.all(|rows| rows[0] < rows[1]));
}
Self {
nrows,
ncols,
colptr,
rowind,
values,
}
}
pub fn eye(size: usize) -> Self {
assert!(size > 0);
Self {
nrows: size,
ncols: size,
colptr: (0..=size).collect(),
rowind: (0..size).collect(),
values: vec![T::one(); size],
}
}
pub fn nrows(&self) -> usize {
self.nrows
}
pub fn ncols(&self) -> usize {
self.ncols
}
pub fn colptr(&self) -> &[usize] {
&self.colptr
}
pub fn rowind(&self) -> &[usize] {
&self.rowind
}
pub fn values(&self) -> &[T] {
&self.values
}
pub fn values_mut(&mut self) -> &mut [T] {
&mut self.values
}
pub fn nnz(&self) -> usize {
*self.colptr.last().unwrap()
}
pub fn iter(&self) -> Iter<T> {
let mut vec = Vec::with_capacity(self.nnz());
let mut values = self.values.iter();
for col in 0..self.ncols {
for ptr in self.colptr[col]..self.colptr[col + 1] {
let row = self.rowind[ptr];
let val = values.next().unwrap();
vec.push((row, col, val));
}
}
Iter {
iter: vec.into_iter(),
}
}
pub fn iter_mut(&mut self) -> IterMut<T> {
let mut vec = Vec::with_capacity(self.nnz());
let mut values = self.values.iter_mut();
for col in 0..self.ncols {
for ptr in self.colptr[col]..self.colptr[col + 1] {
let row = self.rowind[ptr];
let val = values.next().unwrap();
vec.push((row, col, val));
}
}
IterMut {
iter: vec.into_iter(),
}
}
pub fn transpose(&self) -> Self {
let nrows = self.nrows();
let ncols = self.ncols();
let nz = self.nnz();
let rowind = self.rowind();
let colptr = self.colptr();
let colval = self.values();
let mut vec = vec![0; nrows];
for col in 0..ncols {
for ptr in colptr[col]..colptr[col + 1] {
let row = rowind[ptr];
vec[row] += 1;
}
}
let mut rowptr = Vec::with_capacity(nrows + 1);
let mut sum = 0;
rowptr.push(0);
for value in vec {
sum += value;
rowptr.push(sum);
}
let mut vec = rowptr[..nrows].to_vec();
let mut colind = vec![0; nz];
let mut rowval = vec![T::zero(); nz];
for col in 0..ncols {
for ptr in colptr[col]..colptr[col + 1] {
let row = rowind[ptr];
let idx = &mut vec[row];
colind[*idx] = col;
rowval[*idx] = colval[ptr];
*idx += 1;
}
}
Self {
nrows: ncols,
ncols: nrows,
colptr: rowptr,
rowind: colind,
values: rowval,
}
}
}
impl<T: Scalar> IntoIterator for CscMatrix<T> {
type Item = (usize, usize, T);
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
let mut vec = Vec::with_capacity(self.nnz());
let mut values = self.values.into_iter();
for col in 0..self.ncols {
for ptr in self.colptr[col]..self.colptr[col + 1] {
let row = self.rowind[ptr];
let val = values.next().unwrap();
vec.push((row, col, val));
}
}
IntoIter {
iter: vec.into_iter(),
}
}
}
impl<'iter, T> Iterator for Iter<'iter, T> {
type Item = (usize, usize, &'iter T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
impl<'iter, T> Iterator for IterMut<'iter, T> {
type Item = (usize, usize, &'iter mut T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
impl<T: Scalar> Iterator for IntoIter<T> {
type Item = (usize, usize, T);
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn new_invalid_nrows() {
CscMatrix::<f64>::new(0, 2, vec![0, 1, 1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn new_invalid_ncols() {
CscMatrix::<f64>::new(1, 0, vec![0, 1, 1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn new_invalid_colptr_first_not_zero() {
CscMatrix::<f64>::new(1, 2, vec![1, 1, 1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn new_invalid_colptr_invalid_length() {
CscMatrix::<f64>::new(1, 2, vec![0, 1], vec![0], vec![1.0]);
}
#[test]
#[should_panic]
fn new_invalid_rowind() {
CscMatrix::<f64>::new(1, 2, vec![0, 1, 1], vec![1], vec![1.0]);
}
#[test]
#[should_panic]
fn new_unsorted_rowind() {
CscMatrix::<f64>::new(1, 2, vec![0, 2, 2], vec![1, 0], vec![1.0, 2.0]);
}
#[test]
#[should_panic]
fn new_invalid_rowind_values() {
CscMatrix::<f64>::new(1, 2, vec![0, 1, 1], vec![0], vec![1.0, 2.0]);
}
}