use rayon::prelude::*;
const CHUNK_SIZE: usize = 512;
pub fn parallel_csr_spmv_add(
indptr: &[i32],
indices: &[i32],
data: &[f64],
x: &[f64],
y: &mut [f64],
) {
y.par_chunks_mut(CHUNK_SIZE)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let row_start = chunk_idx * CHUNK_SIZE;
for (i, yi) in chunk.iter_mut().enumerate() {
let r = row_start + i;
let start = indptr[r] as usize;
let end = indptr[r + 1] as usize;
let mut sum: f64 = 0.0;
for k in start..end {
let col = indices[k] as usize;
sum += data[k] * x[col];
}
*yi += sum;
}
});
}
#[allow(clippy::too_many_arguments)]
pub fn parallel_csr_multi_spmv_add(
indptr_blocks: &[&[i32]],
indices_blocks: &[&[i32]],
data_blocks: &[&[f64]],
x_blocks: &[&[f64]],
y: &mut [f64],
) {
let n_blocks = indptr_blocks.len();
debug_assert_eq!(n_blocks, indices_blocks.len());
debug_assert_eq!(n_blocks, data_blocks.len());
debug_assert_eq!(n_blocks, x_blocks.len());
y.par_chunks_mut(CHUNK_SIZE)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let row_start = chunk_idx * CHUNK_SIZE;
for (i, yi) in chunk.iter_mut().enumerate() {
let r = row_start + i;
let mut sum: f64 = 0.0;
for b in 0..n_blocks {
let indptr = indptr_blocks[b];
let indices = indices_blocks[b];
let data = data_blocks[b];
let x = x_blocks[b];
let start = indptr[r] as usize;
let end = indptr[r + 1] as usize;
for k in start..end {
let col = indices[k] as usize;
sum += data[k] * x[col];
}
}
*yi += sum;
}
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_csr_spmv() {
let indptr: Vec<i32> = vec![0, 2, 3, 5];
let indices: Vec<i32> = vec![0, 2, 1, 0, 2];
let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let x: Vec<f64> = vec![1.0, 1.0, 1.0];
let mut y: Vec<f64> = vec![0.0, 0.0, 0.0];
parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
assert_eq!(y, vec![3.0, 3.0, 9.0]);
}
#[test]
fn test_empty_row() {
let indptr: Vec<i32> = vec![0, 1, 1, 2];
let indices: Vec<i32> = vec![0, 1];
let data: Vec<f64> = vec![10.0, 20.0];
let x: Vec<f64> = vec![1.0, 2.0];
let mut y: Vec<f64> = vec![100.0, 100.0, 100.0];
parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
assert_eq!(y, vec![110.0, 100.0, 140.0]);
}
#[test]
fn test_accumulates_into_y() {
let indptr: Vec<i32> = vec![0, 1, 2];
let indices: Vec<i32> = vec![0, 0];
let data: Vec<f64> = vec![3.0, 5.0];
let x: Vec<f64> = vec![2.0];
let mut y: Vec<f64> = vec![0.0, 0.0];
parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
assert_eq!(y, vec![6.0, 10.0]);
parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
assert_eq!(y, vec![12.0, 20.0]);
}
#[test]
fn test_large_dense_diagonal() {
let n = 1024;
let indptr: Vec<i32> = (0..=n).map(|i| i as i32).collect();
let indices: Vec<i32> = (0..n).map(|i| i as i32).collect();
let data: Vec<f64> = (0..n).map(|i| (i as f64) + 1.0).collect();
let x: Vec<f64> = vec![1.0; n];
let mut y: Vec<f64> = vec![0.0; n];
parallel_csr_spmv_add(&indptr, &indices, &data, &x, &mut y);
for i in 0..n {
assert_eq!(y[i], (i as f64) + 1.0);
}
}
#[test]
fn test_multi_spmv_matches_sequential() {
let indptr0: Vec<i32> = vec![0, 1, 2, 3, 4];
let indices0: Vec<i32> = vec![0, 1, 2, 0];
let data0: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
let x0: Vec<f64> = vec![10.0, 20.0, 30.0];
let indptr1: Vec<i32> = vec![0, 0, 1, 1, 2];
let indices1: Vec<i32> = vec![1, 2];
let data1: Vec<f64> = vec![5.0, 6.0];
let x1: Vec<f64> = vec![100.0, 200.0, 300.0];
let indptr2: Vec<i32> = vec![0, 1, 1, 2, 3];
let indices2: Vec<i32> = vec![2, 0, 1];
let data2: Vec<f64> = vec![7.0, 8.0, 9.0];
let x2: Vec<f64> = vec![1000.0, 2000.0, 3000.0];
let mut y_seq = vec![0.0_f64; 4];
parallel_csr_spmv_add(&indptr0, &indices0, &data0, &x0, &mut y_seq);
parallel_csr_spmv_add(&indptr1, &indices1, &data1, &x1, &mut y_seq);
parallel_csr_spmv_add(&indptr2, &indices2, &data2, &x2, &mut y_seq);
let mut y_batched = vec![0.0_f64; 4];
let indptrs: Vec<&[i32]> = vec![&indptr0, &indptr1, &indptr2];
let indices_b: Vec<&[i32]> = vec![&indices0, &indices1, &indices2];
let data_b: Vec<&[f64]> = vec![&data0, &data1, &data2];
let xs: Vec<&[f64]> = vec![&x0, &x1, &x2];
parallel_csr_multi_spmv_add(&indptrs, &indices_b, &data_b, &xs, &mut y_batched);
assert_eq!(y_seq, y_batched);
}
}