use std::ops::{Add, AddAssign, Mul, MulAssign};
use ferrolearn_core::FerroError;
use ndarray::{Array1, Array2, ArrayView2};
use num_traits::Zero;
use sprs::CsMat;
use crate::coo::CooMatrix;
use crate::csr::CsrMatrix;
#[derive(Debug, Clone)]
pub struct CscMatrix<T> {
inner: CsMat<T>,
}
impl<T> CscMatrix<T>
where
T: Clone,
{
pub fn new(
n_rows: usize,
n_cols: usize,
indptr: Vec<usize>,
indices: Vec<usize>,
data: Vec<T>,
) -> Result<Self, FerroError> {
CsMat::try_new_csc((n_rows, n_cols), indptr, indices, data)
.map(|inner| Self { inner })
.map_err(|(_, _, _, err)| FerroError::InvalidParameter {
name: "CscMatrix raw components".into(),
reason: err.to_string(),
})
}
pub(crate) fn from_inner(inner: CsMat<T>) -> Self {
debug_assert!(inner.is_csc(), "inner matrix must be in CSC storage");
Self { inner }
}
pub fn n_rows(&self) -> usize {
self.inner.rows()
}
pub fn n_cols(&self) -> usize {
self.inner.cols()
}
pub fn nnz(&self) -> usize {
self.inner.nnz()
}
pub fn inner(&self) -> &CsMat<T> {
&self.inner
}
pub fn into_inner(self) -> CsMat<T> {
self.inner
}
pub fn from_coo(coo: &CooMatrix<T>) -> Result<Self, FerroError>
where
T: Clone + Add<Output = T> + 'static,
{
let inner: CsMat<T> = coo.inner().to_csc();
Ok(Self { inner })
}
pub fn from_csr(csr: &CsrMatrix<T>) -> Result<Self, FerroError>
where
T: Clone + Default + 'static,
{
Ok(csr.to_csc())
}
pub fn to_csr(&self) -> CsrMatrix<T>
where
T: Clone + Default + 'static,
{
CsrMatrix::from_csc(self).unwrap()
}
pub fn to_coo(&self) -> CooMatrix<T> {
let mut coo = CooMatrix::with_capacity(self.n_rows(), self.n_cols(), self.nnz());
for (val, (r, c)) in &self.inner {
let _ = coo.push(r, c, val.clone());
}
coo
}
pub fn to_dense(&self) -> Array2<T>
where
T: Clone + Zero + 'static,
{
self.inner.to_dense()
}
pub fn from_dense(dense: &ArrayView2<'_, T>, epsilon: T) -> Self
where
T: Copy + Zero + PartialOrd + num_traits::Signed + 'static,
{
let inner = CsMat::csc_from_dense(dense.view(), epsilon);
Self { inner }
}
pub fn col_slice(&self, start: usize, end: usize) -> Result<CscMatrix<T>, FerroError>
where
T: Clone + Default + 'static,
{
if start > end {
return Err(FerroError::InvalidParameter {
name: "col_slice range".into(),
reason: format!("start ({start}) must be <= end ({end})"),
});
}
if end > self.n_cols() {
return Err(FerroError::InvalidParameter {
name: "col_slice range".into(),
reason: format!("end ({end}) exceeds n_cols ({})", self.n_cols()),
});
}
let view = self.inner.slice_outer(start..end);
Ok(Self {
inner: view.to_owned(),
})
}
pub fn scale(&mut self, scalar: T)
where
for<'r> T: MulAssign<&'r T>,
{
self.inner.scale(scalar);
}
pub fn mul_scalar(&self, scalar: T) -> CscMatrix<T>
where
T: Copy + Mul<Output = T> + Zero + 'static,
{
let new_inner = self.inner.map(|&v| v * scalar);
Self { inner: new_inner }
}
pub fn add(&self, rhs: &CscMatrix<T>) -> Result<CscMatrix<T>, FerroError>
where
T: Zero + Default + Clone + 'static,
for<'r> &'r T: Add<&'r T, Output = T>,
{
if self.n_rows() != rhs.n_rows() || self.n_cols() != rhs.n_cols() {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_rows(), self.n_cols()],
actual: vec![rhs.n_rows(), rhs.n_cols()],
context: "CscMatrix::add".into(),
});
}
let result = &self.inner + &rhs.inner;
Ok(Self { inner: result })
}
pub fn mul_vec(&self, rhs: &Array1<T>) -> Result<Array1<T>, FerroError>
where
T: Clone + Zero + 'static,
for<'r> &'r T: Mul<Output = T>,
T: AddAssign,
{
if rhs.len() != self.n_cols() {
return Err(FerroError::ShapeMismatch {
expected: vec![self.n_cols()],
actual: vec![rhs.len()],
context: "CscMatrix::mul_vec".into(),
});
}
let result = &self.inner * rhs;
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use ndarray::array;
fn sample_csc() -> CscMatrix<f64> {
CscMatrix::new(
3,
3,
vec![0, 2, 3, 5],
vec![0, 2, 1, 0, 2],
vec![1.0, 4.0, 3.0, 2.0, 5.0],
)
.unwrap()
}
#[test]
fn test_new_valid() {
let m = sample_csc();
assert_eq!(m.n_rows(), 3);
assert_eq!(m.n_cols(), 3);
assert_eq!(m.nnz(), 5);
}
#[test]
fn test_to_dense() {
let m = sample_csc();
let d = m.to_dense();
assert_abs_diff_eq!(d[[0, 0]], 1.0);
assert_abs_diff_eq!(d[[0, 2]], 2.0);
assert_abs_diff_eq!(d[[1, 1]], 3.0);
assert_abs_diff_eq!(d[[2, 0]], 4.0);
assert_abs_diff_eq!(d[[2, 2]], 5.0);
}
#[test]
fn test_from_dense() {
let dense = array![[1.0_f64, 0.0], [0.0, 2.0]];
let m = CscMatrix::from_dense(&dense.view(), 0.0);
assert_eq!(m.nnz(), 2);
let back = m.to_dense();
assert_abs_diff_eq!(back[[0, 0]], 1.0);
assert_abs_diff_eq!(back[[1, 1]], 2.0);
}
#[test]
fn test_from_coo_roundtrip() {
let mut coo: CooMatrix<f64> = CooMatrix::new(3, 3);
coo.push(0, 0, 1.0).unwrap();
coo.push(1, 2, 4.0).unwrap();
coo.push(2, 1, 7.0).unwrap();
let csc = CscMatrix::from_coo(&coo).unwrap();
let dense = csc.to_dense();
assert_abs_diff_eq!(dense[[0, 0]], 1.0);
assert_abs_diff_eq!(dense[[1, 2]], 4.0);
assert_abs_diff_eq!(dense[[2, 1]], 7.0);
}
#[test]
fn test_csc_csr_roundtrip() {
let csc = sample_csc();
let csr = csc.to_csr();
let back = CscMatrix::from_csr(&csr).unwrap();
assert_eq!(back.to_dense(), csc.to_dense());
}
#[test]
fn test_col_slice() {
let m = sample_csc();
let sliced = m.col_slice(0, 2).unwrap();
assert_eq!(sliced.n_rows(), 3);
assert_eq!(sliced.n_cols(), 2);
let d = sliced.to_dense();
assert_abs_diff_eq!(d[[0, 0]], 1.0);
assert_abs_diff_eq!(d[[1, 1]], 3.0);
}
#[test]
fn test_col_slice_empty() {
let m = sample_csc();
let sliced = m.col_slice(1, 1).unwrap();
assert_eq!(sliced.n_cols(), 0);
}
#[test]
fn test_col_slice_invalid() {
let m = sample_csc();
assert!(m.col_slice(2, 1).is_err());
assert!(m.col_slice(0, 4).is_err());
}
#[test]
fn test_mul_scalar() {
let m = sample_csc();
let m2 = m.mul_scalar(2.0);
let d = m2.to_dense();
assert_abs_diff_eq!(d[[0, 0]], 2.0);
assert_abs_diff_eq!(d[[1, 1]], 6.0);
}
#[test]
fn test_scale_in_place() {
let mut m = sample_csc();
m.scale(3.0);
let d = m.to_dense();
assert_abs_diff_eq!(d[[0, 0]], 3.0);
assert_abs_diff_eq!(d[[2, 2]], 15.0);
}
#[test]
fn test_add() {
let m = sample_csc();
let sum = m.add(&m).unwrap();
let d = sum.to_dense();
assert_abs_diff_eq!(d[[0, 0]], 2.0);
assert_abs_diff_eq!(d[[1, 1]], 6.0);
}
#[test]
fn test_add_shape_mismatch() {
let m1 = sample_csc();
let m2 = CscMatrix::new(2, 3, vec![0, 0, 0, 0], vec![], vec![]).unwrap();
assert!(m1.add(&m2).is_err());
}
#[test]
fn test_mul_vec() {
let m = sample_csc();
let v = Array1::from(vec![1.0_f64, 2.0, 3.0]);
let result = m.mul_vec(&v).unwrap();
assert_abs_diff_eq!(result[0], 7.0);
assert_abs_diff_eq!(result[1], 6.0);
assert_abs_diff_eq!(result[2], 19.0);
}
#[test]
fn test_mul_vec_shape_mismatch() {
let m = sample_csc();
let v = Array1::from(vec![1.0_f64, 2.0]);
assert!(m.mul_vec(&v).is_err());
}
}
#[cfg(kani)]
mod kani_proofs {
use super::*;
use crate::coo::CooMatrix;
const MAX_DIM: usize = 3;
fn assert_csc_invariants<T>(m: &CscMatrix<T>) {
let inner = m.inner();
let indptr = inner.indptr();
let indptr_raw = indptr.raw_storage();
assert!(indptr_raw.len() == m.n_cols() + 1);
for i in 0..m.n_cols() {
assert!(indptr_raw[i] <= indptr_raw[i + 1]);
}
let indices = inner.indices();
for &row_idx in indices {
assert!(row_idx < m.n_rows());
}
assert!(inner.indices().len() == inner.data().len());
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_new_indptr_length() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let indptr = vec![0usize; n_cols + 1];
let indices: Vec<usize> = vec![];
let data: Vec<i32> = vec![];
if let Ok(m) = CscMatrix::new(n_rows, n_cols, indptr, indices, data) {
let inner_indptr = m.inner().indptr();
assert!(inner_indptr.raw_storage().len() == n_cols + 1);
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_new_indptr_monotonic() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let row: usize = kani::any();
let col: usize = kani::any();
kani::assume(row < n_rows);
kani::assume(col < n_cols);
let mut indptr = vec![0usize; n_cols + 1];
for i in (col + 1)..=n_cols {
indptr[i] = 1;
}
let indices = vec![row];
let data = vec![42i32];
if let Ok(m) = CscMatrix::new(n_rows, n_cols, indptr, indices, data) {
let inner_indptr = m.inner().indptr().raw_storage().to_vec();
for i in 0..m.n_cols() {
assert!(inner_indptr[i] <= inner_indptr[i + 1]);
}
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_new_row_indices_in_bounds() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let row: usize = kani::any();
let col: usize = kani::any();
kani::assume(row < n_rows);
kani::assume(col < n_cols);
let mut indptr = vec![0usize; n_cols + 1];
for i in (col + 1)..=n_cols {
indptr[i] = 1;
}
let indices = vec![row];
let data = vec![1i32];
if let Ok(m) = CscMatrix::new(n_rows, n_cols, indptr, indices, data) {
for &r in m.inner().indices() {
assert!(r < m.n_rows());
}
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_new_indices_data_same_length() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let indptr = vec![0usize; n_cols + 1];
let indices: Vec<usize> = vec![];
let data: Vec<i32> = vec![];
if let Ok(m) = CscMatrix::new(n_rows, n_cols, indptr, indices, data) {
assert!(m.inner().indices().len() == m.inner().data().len());
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_new_rejects_mismatched_lengths() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let indptr = vec![0usize; n_cols + 1];
let indices = vec![0usize];
let data: Vec<i32> = vec![];
let result = CscMatrix::new(n_rows, n_cols, indptr, indices, data);
assert!(result.is_err());
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_from_coo_invariants() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let mut coo = CooMatrix::<i32>::new(n_rows, n_cols);
let do_insert: bool = kani::any();
if do_insert {
let row: usize = kani::any();
let col: usize = kani::any();
kani::assume(row < n_rows);
kani::assume(col < n_cols);
let _ = coo.push(row, col, 1i32);
}
if let Ok(csc) = CscMatrix::from_coo(&coo) {
assert_csc_invariants(&csc);
assert!(csc.n_rows() == n_rows);
assert!(csc.n_cols() == n_cols);
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_add_preserves_invariants() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let indptr = vec![0usize; n_cols + 1];
let a = CscMatrix::<i32>::new(n_rows, n_cols, indptr.clone(), vec![], vec![]);
let b = CscMatrix::<i32>::new(n_rows, n_cols, indptr, vec![], vec![]);
if let (Ok(a), Ok(b)) = (a, b) {
if let Ok(sum) = a.add(&b) {
assert!(sum.n_rows() == n_rows);
assert!(sum.n_cols() == n_cols);
assert_csc_invariants(&sum);
}
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_add_nonempty_preserves_invariants() {
let a = CscMatrix::<i32>::new(2, 2, vec![0, 1, 1], vec![0], vec![1]);
let b = CscMatrix::<i32>::new(2, 2, vec![0, 0, 1], vec![1], vec![2]);
if let (Ok(a), Ok(b)) = (a, b) {
if let Ok(sum) = a.add(&b) {
assert!(sum.n_rows() == 2);
assert!(sum.n_cols() == 2);
assert_csc_invariants(&sum);
}
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_mul_vec_output_dimension() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let indptr = vec![0usize; n_cols + 1];
let m = CscMatrix::<f64>::new(n_rows, n_cols, indptr, vec![], vec![]);
if let Ok(m) = m {
let v = Array1::<f64>::zeros(n_cols);
if let Ok(result) = m.mul_vec(&v) {
assert!(result.len() == n_rows);
}
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_mul_vec_rejects_wrong_dimension() {
let n_rows: usize = kani::any();
let n_cols: usize = kani::any();
kani::assume(n_rows > 0 && n_rows <= MAX_DIM);
kani::assume(n_cols > 0 && n_cols <= MAX_DIM);
let indptr = vec![0usize; n_cols + 1];
let m = CscMatrix::<f64>::new(n_rows, n_cols, indptr, vec![], vec![]);
if let Ok(m) = m {
let wrong_len: usize = kani::any();
kani::assume(wrong_len <= MAX_DIM);
kani::assume(wrong_len != n_cols);
let v = Array1::<f64>::zeros(wrong_len);
let result = m.mul_vec(&v);
assert!(result.is_err());
}
}
#[kani::proof]
#[kani::unwind(5)]
fn csc_mul_vec_nonempty_no_oob() {
let m = CscMatrix::<f64>::new(2, 3, vec![0, 0, 1, 2], vec![0, 1], vec![3.0, 4.0]);
if let Ok(m) = m {
let v = Array1::from(vec![1.0, 2.0, 3.0]);
if let Ok(result) = m.mul_vec(&v) {
assert!(result.len() == 2);
}
}
}
}