use crate::error::FeralError;
#[cfg(test)]
thread_local! {
pub(crate) static CSC_MATRIX_CLONES: std::cell::Cell<usize> =
const { std::cell::Cell::new(0) };
}
#[cfg(test)]
pub(crate) fn reset_csc_matrix_clones() {
CSC_MATRIX_CLONES.with(|c| c.set(0));
}
#[cfg(test)]
pub(crate) fn csc_matrix_clones() -> usize {
CSC_MATRIX_CLONES.with(|c| c.get())
}
#[derive(Debug)]
pub struct CscMatrix {
pub n: usize,
pub col_ptr: Vec<usize>,
pub row_idx: Vec<usize>,
pub values: Vec<f64>,
}
impl Clone for CscMatrix {
fn clone(&self) -> Self {
#[cfg(test)]
CSC_MATRIX_CLONES.with(|c| c.set(c.get() + 1));
Self {
n: self.n,
col_ptr: self.col_ptr.clone(),
row_idx: self.row_idx.clone(),
values: self.values.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct CscPattern {
pub n: usize,
pub col_ptr: Vec<usize>,
pub row_idx: Vec<usize>,
}
impl CscMatrix {
pub fn nnz(&self) -> usize {
self.values.len()
}
pub fn from_triplets(
n: usize,
rows: &[usize],
cols: &[usize],
vals: &[f64],
) -> Result<Self, FeralError> {
if rows.len() != cols.len() || cols.len() != vals.len() {
return Err(FeralError::InvalidInput(
"triplet arrays must have equal length".to_string(),
));
}
let mut col_counts = vec![0usize; n];
for &c in cols {
if c >= n {
return Err(FeralError::InvalidInput(format!(
"column index {} out of bounds for n={}",
c, n
)));
}
col_counts[c] += 1;
}
let mut col_ptr = vec![0usize; n + 1];
for j in 0..n {
col_ptr[j + 1] = col_ptr[j] + col_counts[j];
}
let nnz = col_ptr[n];
let mut row_idx = vec![0usize; nnz];
let mut values = vec![0.0f64; nnz];
let mut offsets = col_ptr[..n].to_vec();
for k in 0..rows.len() {
let (r, c) = (rows[k], cols[k]);
if r >= n {
return Err(FeralError::InvalidInput(format!(
"row index {} out of bounds for n={}",
r, n
)));
}
if r < c {
return Err(FeralError::InvalidInput(format!(
"triplet {} ({}, {}) is upper-triangle; \
CscMatrix stores only the lower triangle (row >= col)",
k, r, c
)));
}
let pos = offsets[c];
row_idx[pos] = r;
values[pos] = vals[k];
offsets[c] += 1;
}
let mut result = CscMatrix {
n,
col_ptr,
row_idx,
values,
};
result.sort_and_sum_duplicates();
Ok(result)
}
fn sort_and_sum_duplicates(&mut self) {
let mut new_row_idx = Vec::with_capacity(self.row_idx.len());
let mut new_values = Vec::with_capacity(self.values.len());
let mut new_col_ptr = vec![0usize; self.n + 1];
for j in 0..self.n {
let start = self.col_ptr[j];
let end = self.col_ptr[j + 1];
let col_start = new_row_idx.len();
if start == end {
new_col_ptr[j + 1] = col_start;
continue;
}
let mut pairs: Vec<(usize, f64)> = (start..end)
.map(|k| (self.row_idx[k], self.values[k]))
.collect();
pairs.sort_unstable_by_key(|&(r, _)| r);
let mut prev_row = pairs[0].0;
let mut prev_val = pairs[0].1;
for &(r, v) in &pairs[1..] {
if r == prev_row {
prev_val += v;
} else {
new_row_idx.push(prev_row);
new_values.push(prev_val);
prev_row = r;
prev_val = v;
}
}
new_row_idx.push(prev_row);
new_values.push(prev_val);
new_col_ptr[j + 1] = new_row_idx.len();
}
self.col_ptr = new_col_ptr;
self.row_idx = new_row_idx;
self.values = new_values;
}
pub fn validate(&self) -> Result<(), FeralError> {
if self.col_ptr.len() != self.n + 1 {
return Err(FeralError::InvalidInput(format!(
"col_ptr length {} != n+1={}",
self.col_ptr.len(),
self.n + 1
)));
}
if self.row_idx.len() != self.values.len() {
return Err(FeralError::InvalidInput(
"row_idx and values length mismatch".to_string(),
));
}
if self.col_ptr[0] != 0 {
return Err(FeralError::InvalidInput(format!(
"col_ptr[0] must be 0, got {}",
self.col_ptr[0]
)));
}
if self.col_ptr[self.n] != self.row_idx.len() {
return Err(FeralError::InvalidInput("col_ptr[n] != nnz".to_string()));
}
for j in 0..self.n {
if self.col_ptr[j + 1] < self.col_ptr[j] {
return Err(FeralError::InvalidInput(format!(
"col_ptr not monotonically non-decreasing at column {} ({} > {})",
j,
self.col_ptr[j],
self.col_ptr[j + 1]
)));
}
}
for j in 0..self.n {
let start = self.col_ptr[j];
let end = self.col_ptr[j + 1];
for k in start..end {
if self.row_idx[k] >= self.n {
return Err(FeralError::InvalidInput(format!(
"row index {} out of bounds in column {}",
self.row_idx[k], j
)));
}
if self.row_idx[k] < j {
return Err(FeralError::InvalidInput(format!(
"row index {} in column {} is upper-triangle; \
CscMatrix stores only the lower triangle (row >= col)",
self.row_idx[k], j
)));
}
}
for k in (start + 1)..end {
if self.row_idx[k] <= self.row_idx[k - 1] {
return Err(FeralError::InvalidInput(format!(
"row indices not sorted in column {} ({}>={})",
j,
self.row_idx[k - 1],
self.row_idx[k]
)));
}
}
}
Ok(())
}
pub fn symmetric_pattern(&self) -> CscPattern {
let mut col_counts = vec![0usize; self.n];
for j in 0..self.n {
for k in self.col_ptr[j]..self.col_ptr[j + 1] {
let i = self.row_idx[k];
col_counts[j] += 1; if i != j {
col_counts[i] += 1; }
}
}
let mut pat_col_ptr = vec![0usize; self.n + 1];
for j in 0..self.n {
pat_col_ptr[j + 1] = pat_col_ptr[j] + col_counts[j];
}
let pat_nnz = pat_col_ptr[self.n];
let mut pat_row_idx = vec![0usize; pat_nnz];
let mut offsets = pat_col_ptr[..self.n].to_vec();
for j in 0..self.n {
for k in self.col_ptr[j]..self.col_ptr[j + 1] {
let i = self.row_idx[k];
pat_row_idx[offsets[j]] = i;
offsets[j] += 1;
if i != j {
pat_row_idx[offsets[i]] = j;
offsets[i] += 1;
}
}
}
for j in 0..self.n {
let start = pat_col_ptr[j];
let end = pat_col_ptr[j + 1];
pat_row_idx[start..end].sort_unstable();
}
CscPattern {
n: self.n,
col_ptr: pat_col_ptr,
row_idx: pat_row_idx,
}
}
pub fn symv(&self, x: &[f64], y: &mut [f64]) {
for yi in y.iter_mut().take(self.n) {
*yi = 0.0;
}
for j in 0..self.n {
for k in self.col_ptr[j]..self.col_ptr[j + 1] {
let i = self.row_idx[k];
let v = self.values[k];
y[i] += v * x[j];
if i != j {
y[j] += v * x[i];
}
}
}
}
pub fn to_dense(&self) -> crate::dense::matrix::SymmetricMatrix {
self.to_dense_into(Vec::new())
}
pub fn to_dense_into(&self, mut buf: Vec<f64>) -> crate::dense::matrix::SymmetricMatrix {
let nn = self.n * self.n;
buf.clear();
buf.resize(nn, 0.0);
for j in 0..self.n {
let col = j * self.n;
for k in self.col_ptr[j]..self.col_ptr[j + 1] {
let i = self.row_idx[k];
buf[col + i] = self.values[k];
}
}
crate::dense::matrix::SymmetricMatrix {
n: self.n,
data: buf,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_3x3() -> CscMatrix {
CscMatrix::from_triplets(
3,
&[0, 1, 1, 2, 2],
&[0, 0, 1, 1, 2],
&[2.0, -1.0, 3.0, -1.0, 4.0],
)
.unwrap()
}
#[test]
fn test_from_triplets_basic() {
let m = sample_3x3();
assert_eq!(m.n, 3);
assert_eq!(m.nnz(), 5);
m.validate().unwrap();
}
#[test]
fn test_from_triplets_duplicate_summing() {
let m = CscMatrix::from_triplets(2, &[0, 0, 1], &[0, 0, 1], &[1.0, 2.0, 3.0]).unwrap();
assert_eq!(m.nnz(), 2);
assert_eq!(m.values[0], 3.0); assert_eq!(m.values[1], 3.0);
}
#[test]
fn test_symmetric_pattern() {
let m = sample_3x3();
let pat = m.symmetric_pattern();
assert_eq!(pat.n, 3);
assert_eq!(pat.col_ptr[3], 7);
assert_eq!(&pat.row_idx[pat.col_ptr[0]..pat.col_ptr[1]], &[0, 1]);
assert_eq!(&pat.row_idx[pat.col_ptr[1]..pat.col_ptr[2]], &[0, 1, 2]);
assert_eq!(&pat.row_idx[pat.col_ptr[2]..pat.col_ptr[3]], &[1, 2]);
}
#[test]
fn test_symv() {
let m = sample_3x3();
let x = [1.0, 2.0, 3.0];
let mut y = [0.0; 3];
m.symv(&x, &mut y);
assert!((y[0] - 0.0).abs() < 1e-14);
assert!((y[1] - 2.0).abs() < 1e-14);
assert!((y[2] - 10.0).abs() < 1e-14);
}
#[test]
fn test_to_dense_roundtrip() {
let m = sample_3x3();
let dense = m.to_dense();
assert_eq!(dense.get(0, 0), 2.0);
assert_eq!(dense.get(1, 0), -1.0);
assert_eq!(dense.get(0, 1), -1.0);
assert_eq!(dense.get(1, 1), 3.0);
assert_eq!(dense.get(2, 1), -1.0);
assert_eq!(dense.get(1, 2), -1.0);
assert_eq!(dense.get(2, 2), 4.0);
assert_eq!(dense.get(2, 0), 0.0);
}
#[test]
fn test_validate_rejects_bad_input() {
let mut m = sample_3x3();
m.row_idx[0] = 5; assert!(m.validate().is_err());
}
#[test]
fn test_from_triplets_rejects_upper_triangle() {
let lower = CscMatrix::from_triplets(2, &[0, 1, 1], &[0, 0, 1], &[2.0, 1.0, 2.0]).unwrap();
lower.validate().unwrap();
let err = CscMatrix::from_triplets(2, &[0, 0, 1], &[0, 1, 1], &[2.0, 1.0, 2.0])
.expect_err("upper-triangle triplet must be rejected");
let msg = format!("{}", err);
assert!(
msg.contains("upper-triangle"),
"error should mention upper-triangle, got: {}",
msg
);
}
#[test]
fn test_validate_rejects_upper_triangle_row() {
let mut m = sample_3x3();
m.row_idx[2] = 0;
let err = m
.validate()
.expect_err("validate must reject upper-triangle row");
let msg = format!("{}", err);
assert!(
msg.contains("upper-triangle"),
"error should mention upper-triangle, got: {}",
msg
);
}
#[test]
fn test_diagonal_matrix() {
let m = CscMatrix::from_triplets(3, &[0, 1, 2], &[0, 1, 2], &[1.0, 2.0, 3.0]).unwrap();
assert_eq!(m.nnz(), 3);
let pat = m.symmetric_pattern();
assert_eq!(pat.col_ptr[3], 3); }
#[test]
fn test_empty_matrix() {
let m = CscMatrix::from_triplets(3, &[], &[], &[]).unwrap();
assert_eq!(m.nnz(), 0);
m.validate().unwrap();
let pat = m.symmetric_pattern();
assert_eq!(pat.col_ptr[3], 0);
}
#[test]
fn test_kkt_structure() {
let m = CscMatrix::from_triplets(
3,
&[0, 1, 2, 2, 2],
&[0, 1, 0, 1, 2],
&[2.0, 3.0, 1.0, 1.0, -1e-8],
)
.unwrap();
assert_eq!(m.nnz(), 5);
m.validate().unwrap();
let x = [1.0, 1.0, 1.0];
let mut y = [0.0; 3];
m.symv(&x, &mut y);
assert!((y[0] - 3.0).abs() < 1e-14); assert!((y[1] - 4.0).abs() < 1e-14); assert!((y[2] - (2.0 - 1e-8)).abs() < 1e-14); }
#[test]
fn validate_rejects_non_monotone_col_ptr() {
let m = CscMatrix {
n: 3,
col_ptr: vec![0, 2, 1, 2],
row_idx: vec![0, 2],
values: vec![1.0, 1.0],
};
let err = m
.validate()
.expect_err("non-monotone col_ptr must be rejected (X6)");
let msg = format!("{}", err);
assert!(
msg.contains("col_ptr") && msg.contains("monoton"),
"error should mention non-monotone col_ptr, got: {}",
msg
);
}
#[test]
fn validate_rejects_nonzero_col_ptr_start() {
let m = CscMatrix {
n: 2,
col_ptr: vec![1, 1, 2],
row_idx: vec![0, 1],
values: vec![1.0, 1.0],
};
let err = m
.validate()
.expect_err("a col_ptr that does not start at 0 must be rejected (X6)");
let msg = format!("{}", err);
assert!(
msg.contains("col_ptr[0]"),
"error should mention col_ptr[0], got: {}",
msg
);
}
}