use super::compress::build_compressed_format;
use super::csc::CscMatrix;
use crate::error::SolverError;
#[derive(Debug, Clone)]
pub struct CsrMatrix {
pub row_ptr: Vec<usize>,
pub col_ind: Vec<usize>,
pub values: Vec<f64>,
pub nrows: usize,
pub ncols: usize,
}
impl CsrMatrix {
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn from_triplets(
rows: &[usize],
cols: &[usize],
vals: &[f64],
nrows: usize,
ncols: usize,
) -> Result<Self, SolverError> {
if rows.len() != cols.len() || rows.len() != vals.len() {
return Err(SolverError::DimensionMismatch { field: "triplet_arrays", expected: rows.len(), got: vals.len() });
}
let (row_ptr, col_ind, values) =
build_compressed_format(nrows, ncols, rows, cols, vals)?;
Ok(Self { row_ptr, col_ind, values, nrows, ncols })
}
pub fn get_row(&self, i: usize) -> Result<(&[usize], &[f64]), SolverError> {
if i >= self.nrows {
return Err(SolverError::IndexOutOfBounds { context: "row", index: i, bound: self.nrows });
}
let start = self.row_ptr[i];
let end = self.row_ptr[i + 1];
Ok((&self.col_ind[start..end], &self.values[start..end]))
}
pub fn from_csc(csc: &CscMatrix) -> Self {
let nnz = csc.nnz();
let nrows = csc.nrows;
let ncols = csc.ncols;
let mut row_ptr = vec![0usize; nrows + 1];
for &r in &csc.row_ind {
row_ptr[r + 1] += 1;
}
for i in 0..nrows {
row_ptr[i + 1] += row_ptr[i];
}
let mut col_ind = vec![0usize; nnz];
let mut values = vec![0.0f64; nnz];
let mut cur = row_ptr[..nrows].to_vec();
for j in 0..ncols {
let start = csc.col_ptr[j];
let end = csc.col_ptr[j + 1];
for k in start..end {
let r = csc.row_ind[k];
let pos = cur[r];
col_ind[pos] = j;
values[pos] = csc.values[k];
cur[r] += 1;
}
}
Self {
row_ptr,
col_ind,
values,
nrows,
ncols,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_csr_from_triplets() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 2];
let vals = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mat = CsrMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
assert_eq!(mat.nrows, 3);
assert_eq!(mat.ncols, 3);
assert_eq!(mat.nnz(), 5);
let (ci, v) = mat.get_row(0).unwrap();
assert_eq!(ci, &[0, 2]);
assert_eq!(v, &[1.0, 2.0]);
let (ci, v) = mat.get_row(1).unwrap();
assert_eq!(ci, &[1]);
assert_eq!(v, &[3.0]);
let (ci, v) = mat.get_row(2).unwrap();
assert_eq!(ci, &[0, 2]);
assert_eq!(v, &[4.0, 5.0]);
}
#[test]
fn test_csr_from_csc() {
let rows = vec![0, 2, 1, 0, 2];
let cols = vec![0, 0, 1, 2, 2];
let vals = vec![1.0, 4.0, 3.0, 2.0, 5.0];
let csc = CscMatrix::from_triplets(&rows, &cols, &vals, 3, 3).unwrap();
let csr = CsrMatrix::from_csc(&csc);
assert_eq!(csr.nrows, 3);
assert_eq!(csr.ncols, 3);
assert_eq!(csr.nnz(), 5);
let (ci, v) = csr.get_row(0).unwrap();
assert_eq!(ci, &[0, 2]);
assert_eq!(v, &[1.0, 2.0]);
let (ci, v) = csr.get_row(1).unwrap();
assert_eq!(ci, &[1]);
assert_eq!(v, &[3.0]);
let (ci, v) = csr.get_row(2).unwrap();
assert_eq!(ci, &[0, 2]);
assert_eq!(v, &[4.0, 5.0]);
}
}