use super::compress::build_compressed_format;
use crate::error::SolverError;
#[derive(Debug, Clone)]
pub struct CscMatrix {
pub(crate) col_ptr: Vec<usize>,
pub(crate) row_ind: Vec<usize>,
pub(crate) values: Vec<f64>,
pub(crate) nrows: usize,
pub(crate) ncols: usize,
}
impl CscMatrix {
pub fn new(nrows: usize, ncols: usize) -> Self {
Self {
col_ptr: vec![0; ncols + 1],
row_ind: Vec::new(),
values: Vec::new(),
nrows,
ncols,
}
}
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn col_ptr(&self) -> &[usize] {
&self.col_ptr
}
pub fn row_ind(&self) -> &[usize] {
&self.row_ind
}
pub fn values(&self) -> &[f64] {
&self.values
}
pub fn scale_values(&self, factor: f64) -> Self {
Self {
col_ptr: self.col_ptr.clone(),
row_ind: self.row_ind.clone(),
values: self.values.iter().map(|&v| v * factor).collect(),
nrows: self.nrows,
ncols: self.ncols,
}
}
pub fn nrows(&self) -> usize {
self.nrows
}
pub fn ncols(&self) -> usize {
self.ncols
}
pub fn row_infinity_norms(&self) -> Vec<f64> {
let mut norms = vec![0.0_f64; self.nrows];
for (&val, &row) in self.values.iter().zip(self.row_ind.iter()) {
let abs_val = val.abs();
if abs_val > norms[row] {
norms[row] = abs_val;
}
}
norms
}
pub fn from_triplets(
rows: &[usize],
cols: &[usize],
vals: &[f64],
nrows: usize,
ncols: usize,
) -> Result<Self, SolverError> {
if rows.len() != cols.len() || rows.len() != vals.len() {
return Err(SolverError::DimensionMismatch { field: "triplet_arrays", expected: rows.len(), got: vals.len() });
}
let (col_ptr, row_ind, values) =
build_compressed_format(ncols, nrows, cols, rows, vals)?;
Ok(Self { col_ptr, row_ind, values, nrows, ncols })
}
pub fn transpose(&self) -> Self {
let nnz = self.nnz();
let mut row_count = vec![0usize; self.nrows];
for &r in &self.row_ind {
row_count[r] += 1;
}
let mut col_ptr = vec![0usize; self.nrows + 1];
for r in 0..self.nrows {
col_ptr[r + 1] = col_ptr[r] + row_count[r];
}
let mut row_ind = vec![0usize; nnz];
let mut values = vec![0.0f64; nnz];
let mut pos = col_ptr[..self.nrows].to_vec();
for col in 0..self.ncols {
let start = self.col_ptr[col];
let end = self.col_ptr[col + 1];
for k in start..end {
let row = self.row_ind[k];
let p = pos[row];
row_ind[p] = col;
values[p] = self.values[k];
pos[row] += 1;
}
}
Self {
col_ptr,
row_ind,
values,
nrows: self.ncols,
ncols: self.nrows,
}
}
pub fn mat_vec_mul(&self, x: &[f64]) -> Result<Vec<f64>, SolverError> {
if x.len() != self.ncols {
return Err(SolverError::DimensionMismatch { field: "vector", expected: self.ncols, got: x.len() });
}
let mut y = vec![0.0; self.nrows];
for (col, &x_val) in x.iter().enumerate() {
let start = self.col_ptr[col];
let end = self.col_ptr[col + 1];
for idx in start..end {
let row = self.row_ind[idx];
let a_val = self.values[idx];
y[row] += a_val * x_val;
}
}
Ok(y)
}
pub fn get_column(&self, j: usize) -> Result<(&[usize], &[f64]), SolverError> {
if j >= self.ncols {
return Err(SolverError::IndexOutOfBounds { context: "column", index: j, bound: self.ncols });
}
let start = self.col_ptr[j];
let end = self.col_ptr[j + 1];
Ok((&self.row_ind[start..end], &self.values[start..end]))
}
pub fn identity(n: usize) -> Self {
let col_ptr: Vec<usize> = (0..=n).collect();
let row_ind: Vec<usize> = (0..n).collect();
let values = vec![1.0; n];
Self {
col_ptr,
row_ind,
values,
nrows: n,
ncols: n,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_triplets_basic() {
let rows = vec![0, 2, 1, 0, 2];
let cols = vec![0, 0, 1, 2, 2];
let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
assert_eq!(mat.nrows, 3);
assert_eq!(mat.ncols, 3);
assert_eq!(mat.nnz(), 5);
let (row_idx, values) = mat.get_column(0).unwrap();
assert_eq!(row_idx, &[0, 2]);
assert_eq!(values, &[1.0, 4.0]);
let (row_idx, values) = mat.get_column(1).unwrap();
assert_eq!(row_idx, &[1]);
assert_eq!(values, &[3.0]);
let (row_idx, values) = mat.get_column(2).unwrap();
assert_eq!(row_idx, &[0, 2]);
assert_eq!(values, &[2.0, 5.0]);
}
#[test]
fn test_from_triplets_duplicate_entries() {
let rows = vec![0, 0, 1];
let cols = vec![0, 0, 1];
let vals = vec![1.0, 2.0, 3.0];
let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 2).unwrap();
let (row_idx, values) = mat.get_column(0).unwrap();
assert_eq!(row_idx, &[0]);
assert_eq!(values, &[3.0]);
let (row_idx, values) = mat.get_column(1).unwrap();
assert_eq!(row_idx, &[1]);
assert_eq!(values, &[3.0]);
}
#[test]
fn test_transpose() {
let rows = vec![0, 0, 1];
let cols = vec![0, 1, 2];
let vals = vec![1.0, 2.0, 3.0];
let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 2, 3).unwrap();
let mat_t = mat.transpose();
assert_eq!(mat_t.nrows, 3);
assert_eq!(mat_t.ncols, 2);
assert_eq!(mat_t.nnz(), 3);
let (row_idx, values) = mat_t.get_column(0).unwrap();
assert_eq!(row_idx, &[0, 1]);
assert_eq!(values, &[1.0, 2.0]);
let (row_idx, values) = mat_t.get_column(1).unwrap();
assert_eq!(row_idx, &[2]);
assert_eq!(values, &[3.0]);
let mat_tt = mat_t.transpose();
assert_eq!(mat_tt.nrows, mat.nrows);
assert_eq!(mat_tt.ncols, mat.ncols);
assert_eq!(mat_tt.row_ind, mat.row_ind);
assert_eq!(mat_tt.col_ptr, mat.col_ptr);
assert_eq!(mat_tt.values, mat.values);
}
#[test]
fn test_mat_vec_mul() {
let rows = vec![0, 2, 1, 0, 2];
let cols = vec![0, 0, 1, 2, 2];
let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
let mat = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
let x = vec![1.0, 2.0, 3.0];
let y = mat.mat_vec_mul(&x).unwrap();
assert_eq!(y.len(), 3);
assert!((y[0] - 7.0).abs() < 1e-10);
assert!((y[1] - 6.0).abs() < 1e-10);
assert!((y[2] - 19.0).abs() < 1e-10);
}
#[test]
fn test_mat_vec_mul_dimension_mismatch() {
let mat = CscMatrix::identity(3);
let x = vec![1.0, 2.0]; let result = mat.mat_vec_mul(&x);
assert!(result.is_err());
}
#[test]
fn test_identity() {
let id = CscMatrix::identity(4);
assert_eq!(id.nrows, 4);
assert_eq!(id.ncols, 4);
assert_eq!(id.nnz(), 4);
for j in 0..4 {
let (row_idx, values) = id.get_column(j).unwrap();
assert_eq!(row_idx, &[j]);
assert_eq!(values, &[1.0]);
}
let x = vec![1.0, 2.0, 3.0, 4.0];
let y = id.mat_vec_mul(&x).unwrap();
assert_eq!(y, x);
}
#[test]
fn test_empty_matrix() {
let mat = CscMatrix::from_triplets(&[], &[], &[], 2, 3).unwrap();
assert_eq!(mat.nrows, 2);
assert_eq!(mat.ncols, 3);
assert_eq!(mat.nnz(), 0);
for j in 0..3 {
let (row_idx, values) = mat.get_column(j).unwrap();
assert_eq!(row_idx.len(), 0);
assert_eq!(values.len(), 0);
}
let y = mat.mat_vec_mul(&[1.0, 2.0, 3.0]).unwrap();
assert_eq!(y, vec![0.0, 0.0]);
}
#[test]
fn test_get_column_out_of_bounds() {
let mat = CscMatrix::identity(3);
let result = mat.get_column(3);
assert!(result.is_err());
}
#[test]
fn test_from_triplets_out_of_bounds() {
let result = CscMatrix::from_triplets(&[0, 3], &[0, 0], &[1.0, 2.0], 3, 2);
assert!(result.is_err());
let result = CscMatrix::from_triplets(&[0, 0], &[0, 2], &[1.0, 2.0], 3, 2);
assert!(result.is_err());
}
#[test]
fn test_from_triplets_mismatched_lengths() {
let result = CscMatrix::from_triplets(&[0, 1], &[0], &[1.0, 2.0], 2, 2);
assert!(result.is_err());
}
}