use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, SparseElement};
use std::fmt::Debug;
use std::ops::{Add, Mul};
use crate::error::SparseResult;
use crate::sym_coo::SymCooMatrix;
use crate::sym_csr::SymCsrMatrix;
use scirs2_core::parallel_ops::*;
#[allow(dead_code)]
pub fn sym_csr_matvec<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + Add<Output = T> + Send + Sync,
{
let (n, _) = matrix.shape();
if x.len() != n {
return Err(crate::error::SparseError::DimensionMismatch {
expected: n,
found: x.len(),
});
}
let nnz = matrix.nnz();
if nnz >= 1000 {
sym_csr_matvec_parallel(matrix, x)
} else {
sym_csr_matvec_scalar(matrix, x)
}
}
#[allow(dead_code)]
fn sym_csr_matvec_parallel<T>(
matrix: &SymCsrMatrix<T>,
x: &ArrayView1<T>,
) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + Add<Output = T> + Send + Sync,
{
let (n, _) = matrix.shape();
let mut y = Array1::zeros(n);
let chunk_size = std::cmp::max(1, n / scirs2_core::parallel_ops::get_num_threads()).min(256);
let chunks: Vec<_> = (0..n)
.collect::<Vec<_>>()
.chunks(chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let results: Vec<_> = parallel_map(&chunks, |row_chunk| {
let mut local_y = Array1::zeros(n);
for &row_i in row_chunk {
let row_start = matrix.indptr[row_i];
let row_end = matrix.indptr[row_i + 1];
let mut sum = T::sparse_zero();
for j in row_start..row_end {
let col = matrix.indices[j];
let val = matrix.data[j];
sum = sum + val * x[col];
if row_i != col {
local_y[col] = local_y[col] + val * x[row_i];
}
}
local_y[row_i] = local_y[row_i] + sum;
}
local_y
});
for local_y in results {
for i in 0..n {
y[i] = y[i] + local_y[i];
}
}
Ok(y)
}
#[allow(dead_code)]
fn sym_csr_matvec_scalar<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + Add<Output = T>,
{
let (n, _) = matrix.shape();
let mut y = Array1::zeros(n);
for i in 0..n {
for j in matrix.indptr[i]..matrix.indptr[i + 1] {
let col = matrix.indices[j];
let val = matrix.data[j];
y[i] = y[i] + val * x[col];
if i != col {
y[col] = y[col] + val * x[i];
}
}
}
Ok(y)
}
#[allow(dead_code)]
pub fn sym_coo_matvec<T>(matrix: &SymCooMatrix<T>, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + Add<Output = T>,
{
let (n, _) = matrix.shape();
if x.len() != n {
return Err(crate::error::SparseError::DimensionMismatch {
expected: n,
found: x.len(),
});
}
let mut y = Array1::zeros(n);
for i in 0..matrix.data.len() {
let row = matrix.rows[i];
let col = matrix.cols[i];
let val = matrix.data[i];
y[row] = y[row] + val * x[col];
if row != col {
y[col] = y[col] + val * x[row];
}
}
Ok(y)
}
#[allow(dead_code)]
pub fn sym_csr_rank1_update<T>(
matrix: &mut SymCsrMatrix<T>,
x: &ArrayView1<T>,
alpha: T,
) -> SparseResult<()>
where
T: Float
+ SparseElement
+ Debug
+ Copy
+ Add<Output = T>
+ Mul<Output = T>
+ std::ops::AddAssign,
{
let (n, _) = matrix.shape();
if x.len() != n {
return Err(crate::error::SparseError::DimensionMismatch {
expected: n,
found: x.len(),
});
}
let mut dense = matrix.to_dense();
for i in 0..n {
for j in 0..=i {
let update = alpha * x[i] * x[j];
dense[i][j] += update;
}
}
let mut data = Vec::new();
let mut indices = Vec::new();
let mut indptr = vec![0];
for (i, row) in dense.iter().enumerate().take(n) {
for (j, &val) in row.iter().enumerate().take(i + 1) {
if val != T::sparse_zero() {
data.push(val);
indices.push(j);
}
}
indptr.push(data.len());
}
matrix.data = data;
matrix.indices = indices;
matrix.indptr = indptr;
Ok(())
}
#[allow(dead_code)]
pub fn sym_csr_quadratic_form<T>(matrix: &SymCsrMatrix<T>, x: &ArrayView1<T>) -> SparseResult<T>
where
T: Float + SparseElement + Debug + Copy + Add<Output = T> + Mul<Output = T> + Send + Sync,
{
let ax = sym_csr_matvec(matrix, x)?;
let mut result = T::sparse_zero();
for i in 0..ax.len() {
result = result + x[i] * ax[i];
}
Ok(result)
}
#[allow(dead_code)]
pub fn sym_csr_trace<T>(matrix: &SymCsrMatrix<T>) -> T
where
T: Float + SparseElement + Debug + Copy + Add<Output = T>,
{
let (n, _) = matrix.shape();
let mut trace = T::sparse_zero();
for i in 0..n {
for j in matrix.indptr[i]..matrix.indptr[i + 1] {
let col = matrix.indices[j];
if col == i {
trace = trace + matrix.data[j];
break;
}
}
}
trace
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sym_coo::SymCooMatrix;
use crate::sym_csr::SymCsrMatrix;
use crate::AsLinearOperator; use approx::assert_relative_eq;
use scirs2_core::ndarray::Array1;
fn create_test_sym_csr() -> SymCsrMatrix<f64> {
let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
let indices = vec![0, 0, 1, 1, 2];
let indptr = vec![0, 1, 3, 5];
SymCsrMatrix::new(data, indptr, indices, (3, 3)).expect("Operation failed")
}
fn create_test_sym_coo() -> SymCooMatrix<f64> {
let data = vec![2.0, 1.0, 2.0, 3.0, 1.0];
let rows = vec![0, 1, 1, 2, 2];
let cols = vec![0, 0, 1, 1, 2];
SymCooMatrix::new(data, rows, cols, (3, 3)).expect("Operation failed")
}
#[test]
fn test_sym_csr_matvec() {
let matrix = create_test_sym_csr();
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let y = sym_csr_matvec(&matrix, &x.view()).expect("Operation failed");
assert_eq!(y.len(), 3);
assert_relative_eq!(y[0], 4.0);
assert_relative_eq!(y[1], 14.0);
assert_relative_eq!(y[2], 9.0);
}
#[test]
fn test_sym_coo_matvec() {
let matrix = create_test_sym_coo();
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let y = sym_coo_matvec(&matrix, &x.view()).expect("Operation failed");
assert_eq!(y.len(), 3);
assert_relative_eq!(y[0], 4.0);
assert_relative_eq!(y[1], 14.0);
assert_relative_eq!(y[2], 9.0);
}
#[test]
fn test_sym_csr_rank1_update() {
let mut matrix = create_test_sym_csr();
let x = Array1::from_vec(vec![1.0, 0.0, 0.0]);
let alpha = 3.0;
sym_csr_rank1_update(&mut matrix, &x.view(), alpha).expect("Operation failed");
assert_relative_eq!(matrix.get(0, 0), 5.0);
assert_relative_eq!(matrix.get(0, 1), 1.0);
assert_relative_eq!(matrix.get(1, 1), 2.0);
assert_relative_eq!(matrix.get(1, 2), 3.0);
assert_relative_eq!(matrix.get(2, 2), 1.0);
}
#[test]
fn test_sym_csr_quadratic_form() {
let matrix = create_test_sym_csr();
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = sym_csr_quadratic_form(&matrix, &x.view()).expect("Operation failed");
assert_relative_eq!(result, 59.0);
}
#[test]
fn test_sym_csr_trace() {
let matrix = create_test_sym_csr();
let trace = sym_csr_trace(&matrix);
assert_relative_eq!(trace, 5.0);
}
#[test]
fn test_compare_with_standard_matvec() {
let sym_csr = create_test_sym_csr();
let full_csr = sym_csr.to_csr().expect("Operation failed");
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let y_optimized = sym_csr_matvec(&sym_csr, &x.view()).expect("Operation failed");
let linear_op = full_csr.as_linear_operator();
let y_standard = linear_op
.matvec(x.as_slice().expect("Operation failed"))
.expect("Operation failed");
for i in 0..y_optimized.len() {
assert_relative_eq!(y_optimized[i], y_standard[i]);
}
}
}