use crate::csc_array::CscArray;
use crate::csr_array::CsrArray;
use crate::error::{SparseError, SparseResult};
use crate::sparray::SparseArray;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, SparseElement};
use std::fmt::Debug;
use scirs2_core::parallel_ops::*;
use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
#[derive(Debug, Clone)]
pub struct SimdOptions {
pub min_simd_size: usize,
pub chunk_size: usize,
pub use_parallel: bool,
pub parallel_threshold: usize,
}
impl Default for SimdOptions {
fn default() -> Self {
let _capabilities = PlatformCapabilities::detect();
let optimal_chunk_size = 8;
Self {
min_simd_size: optimal_chunk_size,
chunk_size: optimal_chunk_size,
use_parallel: true, parallel_threshold: 8000, }
}
}
#[allow(dead_code)]
pub fn simd_csr_matvec<T>(
matrix: &CsrArray<T>,
x: &ArrayView1<T>,
options: SimdOptions,
) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
{
let (rows, cols) = matrix.shape();
if x.len() != cols {
return Err(SparseError::DimensionMismatch {
expected: cols,
found: x.len(),
});
}
let mut y = Array1::zeros(rows);
let (_row_indices, col_indices, values) = matrix.find();
let row_ptr = matrix.get_indptr();
if options.use_parallel && rows >= options.parallel_threshold {
let chunk_size = rows.div_ceil(4); let row_chunks: Vec<_> = (0..rows)
.collect::<Vec<_>>()
.chunks(chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let results: Vec<_> = parallel_map(&row_chunks, |row_chunk| {
let mut local_y = vec![T::sparse_zero(); row_chunk.len()];
for (local_idx, &i) in row_chunk.iter().enumerate() {
let start = row_ptr[i];
let end = row_ptr[i + 1];
let row_length = end - start;
if row_length >= options.min_simd_size {
let mut sum = T::sparse_zero();
let mut j = start;
while j + options.chunk_size <= end {
let mut values_chunk = vec![T::sparse_zero(); options.chunk_size];
let mut x_vals_chunk = vec![T::sparse_zero(); options.chunk_size];
for (idx, k) in (j..j + options.chunk_size).enumerate() {
values_chunk[idx] = values[k];
x_vals_chunk[idx] = x[col_indices[k]];
}
let values_view = ArrayView1::from(&values_chunk);
let x_vals_view = ArrayView1::from(&x_vals_chunk);
let dot_product = T::simd_dot(&values_view, &x_vals_view);
sum = sum + dot_product;
j += options.chunk_size;
}
for k in j..end {
sum = sum + values[k] * x[col_indices[k]];
}
local_y[local_idx] = sum;
} else {
let mut sum = T::sparse_zero();
for k in start..end {
sum = sum + values[k] * x[col_indices[k]];
}
local_y[local_idx] = sum;
}
}
(row_chunk.clone(), local_y)
});
for (row_chunk, local_y) in results {
for (local_idx, &global_idx) in row_chunk.iter().enumerate() {
y[global_idx] = local_y[local_idx];
}
}
} else {
for i in 0..rows {
let start = row_ptr[i];
let end = row_ptr[i + 1];
let row_length = end - start;
if row_length >= options.min_simd_size {
let mut sum = T::sparse_zero();
let mut j = start;
while j + options.chunk_size <= end {
let mut values_chunk = vec![T::sparse_zero(); options.chunk_size];
let mut x_vals_chunk = vec![T::sparse_zero(); options.chunk_size];
for (idx, k) in (j..j + options.chunk_size).enumerate() {
values_chunk[idx] = values[k];
x_vals_chunk[idx] = x[col_indices[k]];
}
let values_view = ArrayView1::from(&values_chunk);
let x_vals_view = ArrayView1::from(&x_vals_chunk);
let chunk_sum = T::simd_dot(&values_view, &x_vals_view);
sum = sum + chunk_sum;
j += options.chunk_size;
}
for k in j..end {
sum = sum + values[k] * x[col_indices[k]];
}
y[i] = sum;
} else {
let mut sum = T::sparse_zero();
for k in start..end {
sum = sum + values[k] * x[col_indices[k]];
}
y[i] = sum;
}
}
}
Ok(y)
}
#[derive(Debug, Clone, Copy)]
pub enum ElementwiseOp {
Add,
Sub,
Mul,
Div,
}
#[allow(dead_code)]
pub fn simd_sparse_elementwise<T, S1, S2>(
a: &S1,
b: &S2,
op: ElementwiseOp,
options: Option<SimdOptions>,
) -> SparseResult<CsrArray<T>>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
S1: SparseArray<T>,
S2: SparseArray<T>,
{
if a.shape() != b.shape() {
return Err(SparseError::DimensionMismatch {
expected: a.shape().0 * a.shape().1,
found: b.shape().0 * b.shape().1,
});
}
let opts = options.unwrap_or_default();
let a_csr = a.to_csr()?;
let b_csr = b.to_csr()?;
let (_, _, a_values) = a_csr.find();
let (_, _, b_values) = b_csr.find();
if a_values.len() >= opts.min_simd_size && b_values.len() >= opts.min_simd_size {
let result = match op {
ElementwiseOp::Add => {
if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
a_csr.as_any().downcast_ref::<CsrArray<T>>(),
b_csr.as_any().downcast_ref::<CsrArray<T>>(),
) {
simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x + y)?
} else {
return a_csr.add(&*b_csr).and_then(|boxed| {
boxed
.as_any()
.downcast_ref::<CsrArray<T>>()
.cloned()
.ok_or_else(|| {
SparseError::ValueError(
"Failed to convert result to CsrArray".to_string(),
)
})
});
}
}
ElementwiseOp::Sub => {
if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
a_csr.as_any().downcast_ref::<CsrArray<T>>(),
b_csr.as_any().downcast_ref::<CsrArray<T>>(),
) {
simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x - y)?
} else {
return a_csr.sub(&*b_csr).and_then(|boxed| {
boxed
.as_any()
.downcast_ref::<CsrArray<T>>()
.cloned()
.ok_or_else(|| {
SparseError::ValueError(
"Failed to convert result to CsrArray".to_string(),
)
})
});
}
}
ElementwiseOp::Mul => {
if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
a_csr.as_any().downcast_ref::<CsrArray<T>>(),
b_csr.as_any().downcast_ref::<CsrArray<T>>(),
) {
simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x * y)?
} else {
return a_csr.mul(&*b_csr).and_then(|boxed| {
boxed
.as_any()
.downcast_ref::<CsrArray<T>>()
.cloned()
.ok_or_else(|| {
SparseError::ValueError(
"Failed to convert result to CsrArray".to_string(),
)
})
});
}
}
ElementwiseOp::Div => {
if let (Some(a_csr_concrete), Some(b_csr_concrete)) = (
a_csr.as_any().downcast_ref::<CsrArray<T>>(),
b_csr.as_any().downcast_ref::<CsrArray<T>>(),
) {
simd_sparse_binary_op(a_csr_concrete, b_csr_concrete, &opts, |x, y| x / y)?
} else {
return a_csr.div(&*b_csr).and_then(|boxed| {
boxed
.as_any()
.downcast_ref::<CsrArray<T>>()
.cloned()
.ok_or_else(|| {
SparseError::ValueError(
"Failed to convert result to CsrArray".to_string(),
)
})
});
}
}
};
Ok(result)
} else {
let result_box = match op {
ElementwiseOp::Add => a_csr.add(&*b_csr)?,
ElementwiseOp::Sub => a_csr.sub(&*b_csr)?,
ElementwiseOp::Mul => a_csr.mul(&*b_csr)?,
ElementwiseOp::Div => a_csr.div(&*b_csr)?,
};
result_box
.as_any()
.downcast_ref::<CsrArray<T>>()
.cloned()
.ok_or_else(|| {
SparseError::ValueError("Failed to convert result to CsrArray".to_string())
})
}
}
#[allow(dead_code)]
fn simd_sparse_binary_op<T, F>(
a: &CsrArray<T>,
b: &CsrArray<T>,
options: &SimdOptions,
op: F,
) -> SparseResult<CsrArray<T>>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
F: Fn(T, T) -> T + Send + Sync + Copy,
{
let (rows, cols) = a.shape();
let mut result_rows = Vec::new();
let mut result_cols = Vec::new();
let mut result_values = Vec::new();
let (a_row_indices, a_col_indices, a_values) = a.find();
let (b_row_indices, b_col_indices, b_values) = b.find();
use std::collections::HashMap;
let mut a_map = HashMap::new();
let mut b_map = HashMap::new();
for (i, (&row, &col)) in a_row_indices.iter().zip(a_col_indices.iter()).enumerate() {
a_map.insert((row, col), a_values[i]);
}
for (i, (&row, &col)) in b_row_indices.iter().zip(b_col_indices.iter()).enumerate() {
b_map.insert((row, col), b_values[i]);
}
let mut all_positions = std::collections::BTreeSet::new();
for &pos in a_map.keys() {
all_positions.insert(pos);
}
for &pos in b_map.keys() {
all_positions.insert(pos);
}
let positions: Vec<_> = all_positions.into_iter().collect();
if options.use_parallel && positions.len() >= options.parallel_threshold {
let chunks: Vec<_> = positions.chunks(options.chunk_size).collect();
let results: Vec<_> = parallel_map(&chunks, |chunk| {
let mut local_rows = Vec::new();
let mut local_cols = Vec::new();
let mut local_values = Vec::new();
for &(row, col) in *chunk {
let a_val = a_map.get(&(row, col)).copied().unwrap_or(T::sparse_zero());
let b_val = b_map.get(&(row, col)).copied().unwrap_or(T::sparse_zero());
let result_val = op(a_val, b_val);
if !SparseElement::is_zero(&result_val) {
local_rows.push(row);
local_cols.push(col);
local_values.push(result_val);
}
}
(local_rows, local_cols, local_values)
});
for (mut local_rows, mut local_cols, mut local_values) in results {
result_rows.append(&mut local_rows);
result_cols.append(&mut local_cols);
result_values.append(&mut local_values);
}
} else {
for (row, col) in positions {
let a_val = a_map.get(&(row, col)).copied().unwrap_or(T::sparse_zero());
let b_val = b_map.get(&(row, col)).copied().unwrap_or(T::sparse_zero());
let result_val = op(a_val, b_val);
if !SparseElement::is_zero(&result_val) {
result_rows.push(row);
result_cols.push(col);
result_values.push(result_val);
}
}
}
CsrArray::from_triplets(
&result_rows,
&result_cols,
&result_values,
(rows, cols),
false,
)
}
#[allow(dead_code)]
pub fn simd_sparse_transpose<T, S>(
matrix: &S,
options: Option<SimdOptions>,
) -> SparseResult<CsrArray<T>>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
S: SparseArray<T>,
{
let opts = options.unwrap_or_default();
let (rows, cols) = matrix.shape();
let (row_indices, col_indices, values) = matrix.find();
if opts.use_parallel && values.len() >= opts.parallel_threshold {
let chunks: Vec<_> = (0..values.len())
.collect::<Vec<_>>()
.chunks(opts.chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let transposed_triplets: Vec<_> = parallel_map(&chunks, |chunk| {
let mut local_rows = Vec::new();
let mut local_cols = Vec::new();
let mut local_values = Vec::new();
for &idx in chunk {
local_rows.push(col_indices[idx]);
local_cols.push(row_indices[idx]);
local_values.push(values[idx]);
}
(local_rows, local_cols, local_values)
});
let mut result_rows = Vec::new();
let mut result_cols = Vec::new();
let mut result_values = Vec::new();
for (mut local_rows, mut local_cols, mut local_values) in transposed_triplets {
result_rows.append(&mut local_rows);
result_cols.append(&mut local_cols);
result_values.append(&mut local_values);
}
CsrArray::from_triplets(
&result_rows,
&result_cols,
&result_values,
(cols, rows),
false,
)
} else {
CsrArray::from_triplets(
col_indices.as_slice().expect("Array should be contiguous"),
row_indices.as_slice().expect("Array should be contiguous"),
values.as_slice().expect("Array should be contiguous"),
(cols, rows),
false,
)
}
}
#[allow(dead_code)]
pub fn simd_sparse_matmul<T, S1, S2>(
a: &S1,
b: &S2,
options: Option<SimdOptions>,
) -> SparseResult<CsrArray<T>>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
S1: SparseArray<T>,
S2: SparseArray<T>,
{
if a.shape().1 != b.shape().0 {
return Err(SparseError::DimensionMismatch {
expected: a.shape().1,
found: b.shape().0,
});
}
let opts = options.unwrap_or_default();
let a_csr = a.to_csr()?;
let b_csc = b.to_csc()?;
let (a_rows, a_cols) = a_csr.shape();
let (_b_rows, b_cols) = b_csc.shape();
let mut result_rows = Vec::new();
let mut result_cols = Vec::new();
let mut result_values = Vec::new();
let a_indptr = if let Some(a_concrete) = a_csr.as_any().downcast_ref::<CsrArray<T>>() {
a_concrete.get_indptr() } else {
return Err(SparseError::ValueError(
"Matrix A must be CSR format".to_string(),
));
};
let (_, a_col_indices, a_values) = a_csr.find();
let b_indptr = if let Some(b_concrete) = b_csc.as_any().downcast_ref::<CscArray<T>>() {
b_concrete.get_indptr() } else if let Some(b_concrete) = b_csc.as_any().downcast_ref::<CsrArray<T>>() {
b_concrete.get_indptr()
} else {
return Err(SparseError::ValueError(
"Matrix B must be CSC or CSR format".to_string(),
));
};
let (_, b_row_indices, b_values) = b_csc.find();
if opts.use_parallel && a_rows >= opts.parallel_threshold {
let chunks: Vec<_> = (0..a_rows)
.collect::<Vec<_>>()
.chunks(opts.chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let results: Vec<_> = parallel_map(&chunks, |row_chunk| {
let mut local_rows = Vec::new();
let mut local_cols = Vec::new();
let mut local_values = Vec::new();
for &i in row_chunk {
let a_start = a_indptr[i];
let a_end = a_indptr[i + 1];
for j in 0..b_cols {
let b_start = b_indptr[j];
let b_end = b_indptr[j + 1];
let mut sum = T::sparse_zero();
let mut a_idx = a_start;
let mut b_idx = b_start;
if (a_end - a_start) >= opts.min_simd_size
&& (b_end - b_start) >= opts.min_simd_size
{
while a_idx < a_end && b_idx < b_end {
let a_col = a_col_indices[a_idx];
let b_row = b_row_indices[b_idx];
match a_col.cmp(&b_row) {
std::cmp::Ordering::Equal => {
sum = sum + a_values[a_idx] * b_values[b_idx];
a_idx += 1;
b_idx += 1;
}
std::cmp::Ordering::Less => {
a_idx += 1;
}
std::cmp::Ordering::Greater => {
b_idx += 1;
}
}
}
} else {
while a_idx < a_end && b_idx < b_end {
let a_col = a_col_indices[a_idx];
let b_row = b_row_indices[b_idx];
match a_col.cmp(&b_row) {
std::cmp::Ordering::Equal => {
sum = sum + a_values[a_idx] * b_values[b_idx];
a_idx += 1;
b_idx += 1;
}
std::cmp::Ordering::Less => {
a_idx += 1;
}
std::cmp::Ordering::Greater => {
b_idx += 1;
}
}
}
}
if !SparseElement::is_zero(&sum) {
local_rows.push(i);
local_cols.push(j);
local_values.push(sum);
}
}
}
(local_rows, local_cols, local_values)
});
for (mut local_rows, mut local_cols, mut local_values) in results {
result_rows.append(&mut local_rows);
result_cols.append(&mut local_cols);
result_values.append(&mut local_values);
}
} else {
for i in 0..a_rows {
let a_start = a_indptr[i];
let a_end = a_indptr[i + 1];
for j in 0..b_cols {
let b_start = b_indptr[j];
let b_end = b_indptr[j + 1];
let mut sum = T::sparse_zero();
let mut a_idx = a_start;
let mut b_idx = b_start;
while a_idx < a_end && b_idx < b_end {
let a_col = a_col_indices[a_idx];
let b_row = b_row_indices[b_idx];
match a_col.cmp(&b_row) {
std::cmp::Ordering::Equal => {
sum = sum + a_values[a_idx] * b_values[b_idx];
a_idx += 1;
b_idx += 1;
}
std::cmp::Ordering::Less => {
a_idx += 1;
}
std::cmp::Ordering::Greater => {
b_idx += 1;
}
}
}
if !SparseElement::is_zero(&sum) {
result_rows.push(i);
result_cols.push(j);
result_values.push(sum);
}
}
}
}
CsrArray::from_triplets(
&result_rows,
&result_cols,
&result_values,
(a_rows, b_cols),
false,
)
}
#[allow(dead_code)]
pub fn simd_sparse_norm<T, S>(
matrix: &S,
norm_type: &str,
options: Option<SimdOptions>,
) -> SparseResult<T>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
S: SparseArray<T>,
{
let opts = options.unwrap_or_default();
let (_, _, values) = matrix.find();
match norm_type {
"fro" | "frobenius" => {
if opts.use_parallel && values.len() >= opts.parallel_threshold {
let chunks: Vec<_> = values
.as_slice()
.expect("Array should be contiguous")
.chunks(opts.chunk_size)
.collect();
let partial_sums: Vec<T> = parallel_map(&chunks, |chunk| {
let chunk_view = ArrayView1::from(chunk);
T::simd_dot(&chunk_view, &chunk_view)
});
Ok(partial_sums
.iter()
.copied()
.fold(T::sparse_zero(), |acc, x| acc + x)
.sqrt())
} else {
let values_view = values.view();
let sum_squares = T::simd_dot(&values_view, &values_view);
Ok(sum_squares.sqrt())
}
}
"1" => {
let (_rows, cols) = matrix.shape();
let (_row_indices, col_indices, values) = matrix.find();
let mut column_sums = vec![T::sparse_zero(); cols];
if opts.use_parallel && values.len() >= opts.parallel_threshold {
let chunks: Vec<_> = (0..values.len())
.collect::<Vec<_>>()
.chunks(opts.chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let partial_sums: Vec<Vec<T>> = parallel_map(&chunks, |chunk| {
let mut local_sums = vec![T::sparse_zero(); cols];
for &idx in chunk {
let col = col_indices[idx];
let val = values[idx].abs();
local_sums[col] = local_sums[col] + val;
}
local_sums
});
for partial_sum in partial_sums {
for j in 0..cols {
column_sums[j] = column_sums[j] + partial_sum[j];
}
}
} else {
for (i, &col) in col_indices.iter().enumerate() {
column_sums[col] = column_sums[col] + values[i].abs();
}
}
Ok(column_sums
.iter()
.copied()
.fold(T::sparse_zero(), |acc, x| if x > acc { x } else { acc }))
}
"inf" | "infinity" => {
let (rows, cols) = matrix.shape();
let (row_indices, col_indices, values) = matrix.find();
let mut row_sums = vec![T::sparse_zero(); rows];
if opts.use_parallel && values.len() >= opts.parallel_threshold {
let chunks: Vec<_> = (0..values.len())
.collect::<Vec<_>>()
.chunks(opts.chunk_size)
.map(|chunk| chunk.to_vec())
.collect();
let partial_sums: Vec<Vec<T>> = parallel_map(&chunks, |chunk| {
let mut local_sums = vec![T::sparse_zero(); rows];
for &idx in chunk {
let row = row_indices[idx];
let val = values[idx].abs();
local_sums[row] = local_sums[row] + val;
}
local_sums
});
for partial_sum in partial_sums {
for i in 0..rows {
row_sums[i] = row_sums[i] + partial_sum[i];
}
}
} else {
for (i, &row) in row_indices.iter().enumerate() {
row_sums[row] = row_sums[row] + values[i].abs();
}
}
Ok(row_sums
.iter()
.copied()
.fold(T::sparse_zero(), |acc, x| if x > acc { x } else { acc }))
}
_ => Err(SparseError::ValueError(format!(
"Unknown norm _type: {norm_type}"
))),
}
}
#[allow(dead_code)]
pub fn simd_sparse_scale<T, S>(
matrix: &S,
scalar: T,
options: Option<SimdOptions>,
) -> SparseResult<CsrArray<T>>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
S: SparseArray<T>,
{
let opts = options.unwrap_or_default();
let (rows, cols) = matrix.shape();
let (row_indices, col_indices, values) = matrix.find();
let scaled_values = if opts.use_parallel && values.len() >= opts.parallel_threshold {
let chunks: Vec<_> = values
.as_slice()
.expect("Array should be contiguous")
.chunks(opts.chunk_size)
.collect();
let scaled_chunks: Vec<Vec<T>> = parallel_map(&chunks, |chunk: &&[T]| {
let _scalar_vec = vec![scalar; chunk.len()];
let mut result = vec![T::sparse_zero(); chunk.len()];
for i in 0..chunk.len() {
result[i] = chunk[i] * scalar;
}
result
});
scaled_chunks.into_iter().flatten().collect()
} else {
values.iter().map(|&val| val * scalar).collect::<Vec<T>>()
};
CsrArray::from_triplets(
row_indices.as_slice().expect("Array should be contiguous"),
col_indices.as_slice().expect("Array should be contiguous"),
&scaled_values,
(rows, cols),
false,
)
}
#[allow(dead_code)]
pub fn simd_sparse_linear_combination<T, S>(
matrices: &[&S],
coefficients: &[T],
options: Option<SimdOptions>,
) -> SparseResult<CsrArray<T>>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
S: SparseArray<T> + Sync,
{
if matrices.is_empty() {
return Err(SparseError::ValueError("No matrices provided".to_string()));
}
if matrices.len() != coefficients.len() {
return Err(SparseError::DimensionMismatch {
expected: matrices.len(),
found: coefficients.len(),
});
}
let opts = options.unwrap_or_default();
let (rows, cols) = matrices[0].shape();
for matrix in matrices.iter() {
if matrix.shape() != (rows, cols) {
return Err(SparseError::DimensionMismatch {
expected: rows * cols,
found: matrix.shape().0 * matrix.shape().1,
});
}
}
use std::collections::HashMap;
let mut accumulator = HashMap::new();
if opts.use_parallel && matrices.len() >= 4 {
let results: Vec<HashMap<(usize, usize), T>> = parallel_map(matrices, |matrix| {
let mut local_accumulator = HashMap::new();
let (row_indices, col_indices, values) = matrix.find();
for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
let entry = local_accumulator.entry((i, j)).or_insert(T::sparse_zero());
*entry = *entry + values[k];
}
local_accumulator
});
for (idx, local_acc) in results.into_iter().enumerate() {
let coeff = coefficients[idx];
for ((i, j), val) in local_acc {
let entry = accumulator.entry((i, j)).or_insert(T::sparse_zero());
*entry = *entry + coeff * val;
}
}
} else {
for (idx, matrix) in matrices.iter().enumerate() {
let coeff = coefficients[idx];
let (row_indices, col_indices, values) = matrix.find();
for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
let entry = accumulator.entry((i, j)).or_insert(T::sparse_zero());
*entry = *entry + coeff * values[k];
}
}
}
let mut result_rows = Vec::new();
let mut result_cols = Vec::new();
let mut result_values = Vec::new();
for ((i, j), val) in accumulator {
if !SparseElement::is_zero(&val) {
result_rows.push(i);
result_cols.push(j);
result_values.push(val);
}
}
CsrArray::from_triplets(
&result_rows,
&result_cols,
&result_values,
(rows, cols),
false,
)
}
#[allow(dead_code)]
pub fn simd_sparse_matmul_default<T, S1, S2>(a: &S1, b: &S2) -> SparseResult<CsrArray<T>>
where
T: Float + SparseElement + Debug + Copy + 'static + SimdUnifiedOps + Send + Sync,
S1: SparseArray<T>,
S2: SparseArray<T>,
{
simd_sparse_matmul(a, b, None)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::csr_array::CsrArray;
use approx::assert_relative_eq;
#[test]
fn test_simd_csr_matvec() {
let rows = vec![0, 0, 1, 2, 2];
let cols = vec![0, 2, 1, 0, 2];
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let matrix =
CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let y =
simd_csr_matvec(&matrix, &x.view(), SimdOptions::default()).expect("Operation failed");
assert_eq!(y.len(), 3);
assert_relative_eq!(y[0], 7.0);
assert_relative_eq!(y[1], 6.0);
assert_relative_eq!(y[2], 19.0);
}
#[test]
fn test_simd_sparse_elementwise() {
let rows = vec![0, 1, 2];
let cols = vec![0, 1, 2];
let data1 = vec![1.0, 2.0, 3.0];
let data2 = vec![4.0, 5.0, 6.0];
let a =
CsrArray::from_triplets(&rows, &cols, &data1, (3, 3), false).expect("Operation failed");
let b =
CsrArray::from_triplets(&rows, &cols, &data2, (3, 3), false).expect("Operation failed");
let result =
simd_sparse_elementwise(&a, &b, ElementwiseOp::Add, None).expect("Operation failed");
assert_relative_eq!(result.get(0, 0), 5.0);
assert_relative_eq!(result.get(1, 1), 7.0);
assert_relative_eq!(result.get(2, 2), 9.0);
}
#[test]
fn test_simd_sparse_matmul() {
let rows = vec![0, 1];
let cols = vec![0, 1];
let data1 = vec![2.0, 3.0];
let data2 = vec![4.0, 5.0];
let a =
CsrArray::from_triplets(&rows, &cols, &data1, (2, 2), false).expect("Operation failed");
let b =
CsrArray::from_triplets(&rows, &cols, &data2, (2, 2), false).expect("Operation failed");
let result = simd_sparse_matmul_default(&a, &b).expect("Operation failed");
assert_relative_eq!(result.get(0, 0), 8.0);
assert_relative_eq!(result.get(1, 1), 15.0);
assert_relative_eq!(result.get(0, 1), 0.0);
assert_relative_eq!(result.get(1, 0), 0.0);
}
}