mtl 0.1.5

Matrix template library. Dense2D matrix, SparseMatrix.
use std::iter;
use std::fmt;

use super::coo::CooMatrix;

/// Compressed Sparse Column matrix
#[derive(Clone, PartialEq, Eq)]
pub struct CscMatrix<T> {
    shape: (usize, usize),
    data: Vec<T>,           // Data array of the matrix
    indices: Vec<usize>,    // CSC format index array
    indptr: Vec<usize>      // CSC format index pointer array
}


impl<T: Copy> CscMatrix<T> {
    pub fn empty(shape: (usize,usize)) -> Self {
        let n = shape.1;
        CscMatrix {
            shape: shape,
            data: vec![],
            indices: vec![],
            indptr: iter::repeat(0).take(n).collect()
        }
    }

    #[inline]
    pub fn nnz(&self) -> usize {
        self.data.len()
    }

    #[inline]
    pub fn shape(&self) -> (usize,usize) {
        self.shape
    }

    /// Returns the element of the given index, or None if not exists
    pub fn get(&self, index: (usize, usize)) -> Option<&T> {
        let (row, col) = index;
        assert!(row < self.shape.0);
        assert!(col < self.shape.1);

        let col_start_idx = self.indptr[col];
        let col_end_idx = self.indptr[col+1];

        for i in col_start_idx .. col_end_idx {
            if self.indices[i] == row {
                return self.data.get(i);
            }
        }
        None
    }

    /// Returns mutable ref of the element of the given index, None if not exists
    pub fn get_mut(&mut self, index: (usize, usize)) -> Option<&mut T> {
        let (row, col) = index;
        assert!(row < self.shape.0);
        assert!(col < self.shape.1);

        let col_start_idx = self.indptr[col];
        let col_end_idx = self.indptr[col+1];

        for i in col_start_idx .. col_end_idx {
            if self.indices[i] == row {
                return self.data.get_mut(i);
            }
        }
        None
    }

    pub fn set(&mut self, index: (usize, usize), it: T) {
        let (row, col) = index;
        assert!(row < self.shape.0);
        assert!(col < self.shape.1);

        let col_start_idx = self.indptr[col];
        let col_end_idx = self.indptr[col+1];

        let mut data_insert_pos = col_start_idx;
        for i in col_start_idx .. col_end_idx {
            if self.indices[i] == row {
                self.data[i] = it;
                return;
            } else if self.indices[i] > row {
                data_insert_pos = i;
                break;
            }
        }

        println!("WARNING: Changing the sparsity structure of a CSC matrix is expensive.");
        self.data.insert(data_insert_pos, it);
        for (i, idx) in self.indptr.iter_mut().enumerate() {
            if i > col {
                *idx += 1;
            }
        }
        self.indices.insert(data_insert_pos, row);
    }

    // convertion
    pub fn to_bsr(&self) -> () {
        unimplemented!()
    }

    pub fn to_coo(&self) -> CooMatrix<T> {
        let mut data = Vec::new();
        let mut row = vec![];
        let mut col = vec![];
        for (j, &ptr) in self.indptr.iter().take(self.shape.1).enumerate() {
            for (off, val) in self.data[ptr .. self.indptr[j+1]].iter().enumerate() {
                let i = self.indices[ptr + off];
                data.push(*val);
                row.push(i);
                col.push(j);
            }
        }
        CooMatrix {
            shape: self.shape(),
            data: data,
            row: row,
            col: col
        }
    }

    pub fn to_csc(&self) -> CscMatrix<T> {
        self.clone()
    }
}

impl<T: Copy + fmt::Display> fmt::Display for CscMatrix<T> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        for (j, &ptr) in self.indptr.iter().take(self.shape.1).enumerate() {
            for (off, val) in self.data[ptr .. self.indptr[j+1]].iter().enumerate() {
                let i = self.indices[ptr + off];
                try!(writeln!(f, "  {:?}\t{}", (i, j), val));
            }
        }
        Ok(())
    }
}


#[test]
fn test_csc() {
    let mut mat = CscMatrix {
        shape: (6, 6),
        data: vec![10, 45, 40, 2, 4, 3, 3, 9, 19, 7],
        indptr: vec![0, 3, 5, 8, 8, 8, 10],
        indices: vec![0, 1, 3, 0, 2, 0, 1, 2, 0, 5]
    };

    assert_eq!(format!("{}", mat).lines().count(), 10);
    mat.set((3, 4), 20);
    assert_eq!(format!("{}", mat).lines().count(), 11);

}