use std::iter;
use std::fmt;
use super::coo::CooMatrix;
#[derive(Clone, PartialEq, Eq)]
pub struct CscMatrix<T> {
shape: (usize, usize),
data: Vec<T>, indices: Vec<usize>, indptr: Vec<usize> }
impl<T: Copy> CscMatrix<T> {
pub fn empty(shape: (usize,usize)) -> Self {
let n = shape.1;
CscMatrix {
shape: shape,
data: vec![],
indices: vec![],
indptr: iter::repeat(0).take(n).collect()
}
}
#[inline]
pub fn nnz(&self) -> usize {
self.data.len()
}
#[inline]
pub fn shape(&self) -> (usize,usize) {
self.shape
}
pub fn get(&self, index: (usize, usize)) -> Option<&T> {
let (row, col) = index;
assert!(row < self.shape.0);
assert!(col < self.shape.1);
let col_start_idx = self.indptr[col];
let col_end_idx = self.indptr[col+1];
for i in col_start_idx .. col_end_idx {
if self.indices[i] == row {
return self.data.get(i);
}
}
None
}
pub fn get_mut(&mut self, index: (usize, usize)) -> Option<&mut T> {
let (row, col) = index;
assert!(row < self.shape.0);
assert!(col < self.shape.1);
let col_start_idx = self.indptr[col];
let col_end_idx = self.indptr[col+1];
for i in col_start_idx .. col_end_idx {
if self.indices[i] == row {
return self.data.get_mut(i);
}
}
None
}
pub fn set(&mut self, index: (usize, usize), it: T) {
let (row, col) = index;
assert!(row < self.shape.0);
assert!(col < self.shape.1);
let col_start_idx = self.indptr[col];
let col_end_idx = self.indptr[col+1];
let mut data_insert_pos = col_start_idx;
for i in col_start_idx .. col_end_idx {
if self.indices[i] == row {
self.data[i] = it;
return;
} else if self.indices[i] > row {
data_insert_pos = i;
break;
}
}
println!("WARNING: Changing the sparsity structure of a CSC matrix is expensive.");
self.data.insert(data_insert_pos, it);
for (i, idx) in self.indptr.iter_mut().enumerate() {
if i > col {
*idx += 1;
}
}
self.indices.insert(data_insert_pos, row);
}
pub fn to_bsr(&self) -> () {
unimplemented!()
}
pub fn to_coo(&self) -> CooMatrix<T> {
let mut data = Vec::new();
let mut row = vec![];
let mut col = vec![];
for (j, &ptr) in self.indptr.iter().take(self.shape.1).enumerate() {
for (off, val) in self.data[ptr .. self.indptr[j+1]].iter().enumerate() {
let i = self.indices[ptr + off];
data.push(*val);
row.push(i);
col.push(j);
}
}
CooMatrix {
shape: self.shape(),
data: data,
row: row,
col: col
}
}
pub fn to_csc(&self) -> CscMatrix<T> {
self.clone()
}
}
impl<T: Copy + fmt::Display> fmt::Display for CscMatrix<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for (j, &ptr) in self.indptr.iter().take(self.shape.1).enumerate() {
for (off, val) in self.data[ptr .. self.indptr[j+1]].iter().enumerate() {
let i = self.indices[ptr + off];
try!(writeln!(f, " {:?}\t{}", (i, j), val));
}
}
Ok(())
}
}
#[test]
fn test_csc() {
let mut mat = CscMatrix {
shape: (6, 6),
data: vec![10, 45, 40, 2, 4, 3, 3, 9, 19, 7],
indptr: vec![0, 3, 5, 8, 8, 8, 10],
indices: vec![0, 1, 3, 0, 2, 0, 1, 2, 0, 5]
};
assert_eq!(format!("{}", mat).lines().count(), 10);
mat.set((3, 4), 20);
assert_eq!(format!("{}", mat).lines().count(), 11);
}