use super::super::{adaptive, WorkerConfig};
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use scirs2_core::parallel_ops::*;
use std::iter::Sum;
pub fn parallel_matvec<F>(
matrix: &ArrayView2<F>,
vector: &ArrayView1<F>,
config: &WorkerConfig,
) -> LinalgResult<Array1<F>>
where
F: Float + Send + Sync + Zero + Sum + 'static,
{
let (m, n) = matrix.dim();
if n != vector.len() {
return Err(LinalgError::ShapeError(format!(
"Matrix-vector dimensions incompatible: {}x{} * {}",
m,
n,
vector.len()
)));
}
let datasize = m * n;
if !adaptive::should_use_parallel(datasize, config) {
return Ok(matrix.dot(vector));
}
config.apply();
let result_vec: Vec<F> = (0..m)
.into_par_iter()
.map(|i| {
matrix
.row(i)
.iter()
.zip(vector.iter())
.map(|(&aij, &xj)| aij * xj)
.sum()
})
.collect();
Ok(Array1::from_vec(result_vec))
}
pub fn parallel_power_iteration<F>(
matrix: &ArrayView2<F>,
max_iter: usize,
tolerance: F,
config: &WorkerConfig,
) -> LinalgResult<(F, Array1<F>)>
where
F: Float + Send + Sync + Zero + Sum + NumAssign + One + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (m, n) = matrix.dim();
if m != n {
return Err(LinalgError::ShapeError(
"Power iteration requires square matrix".to_string(),
));
}
let datasize = m * n;
if !adaptive::should_use_parallel(datasize, config) {
return crate::eigen::power_iteration(&matrix.view(), max_iter, tolerance);
}
config.apply();
let mut v = Array1::ones(n);
let norm = v.iter().map(|&x| x * x).sum::<F>().sqrt();
v /= norm;
let mut eigenvalue = F::zero();
for _iter in 0..max_iter {
let new_v = parallel_matvec(matrix, &v.view(), config)?;
let new_eigenvalue = new_v
.iter()
.zip(v.iter())
.map(|(&new_vi, &vi)| new_vi * vi)
.sum::<F>();
let norm = new_v.iter().map(|&x| x * x).sum::<F>().sqrt();
if norm < F::epsilon() {
return Err(LinalgError::ComputationError(
"Vector became zero during iteration".to_string(),
));
}
let normalized_v = new_v / norm;
if (new_eigenvalue - eigenvalue).abs() < tolerance {
return Ok((new_eigenvalue, normalized_v));
}
eigenvalue = new_eigenvalue;
v = normalized_v;
}
Err(LinalgError::ComputationError(
"Power iteration failed to converge".to_string(),
))
}
pub mod vector_ops {
use super::*;
pub fn parallel_dot<F>(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
config: &WorkerConfig,
) -> LinalgResult<F>
where
F: Float + Send + Sync + Zero + Sum + 'static,
{
if x.len() != y.len() {
return Err(LinalgError::ShapeError(
"Vectors must have same length for dot product".to_string(),
));
}
let datasize = x.len();
if !adaptive::should_use_parallel(datasize, config) {
return Ok(x.iter().zip(y.iter()).map(|(&xi, &yi)| xi * yi).sum());
}
config.apply();
let result = (0..x.len()).into_par_iter().map(|i| x[i] * y[i]).sum();
Ok(result)
}
pub fn parallel_norm<F>(x: &ArrayView1<F>, config: &WorkerConfig) -> LinalgResult<F>
where
F: Float + Send + Sync + Zero + Sum + 'static,
{
let datasize = x.len();
if !adaptive::should_use_parallel(datasize, config) {
return Ok(x.iter().map(|&xi| xi * xi).sum::<F>().sqrt());
}
config.apply();
let sum_squares = (0..x.len()).into_par_iter().map(|i| x[i] * x[i]).sum::<F>();
Ok(sum_squares.sqrt())
}
pub fn parallel_axpy<F>(
alpha: F,
x: &ArrayView1<F>,
y: &ArrayView1<F>,
config: &WorkerConfig,
) -> LinalgResult<Array1<F>>
where
F: Float + Send + Sync + 'static,
{
if x.len() != y.len() {
return Err(LinalgError::ShapeError(
"Vectors must have same length for AXPY".to_string(),
));
}
let datasize = x.len();
if !adaptive::should_use_parallel(datasize, config) {
let result = x
.iter()
.zip(y.iter())
.map(|(&xi, &yi)| alpha * xi + yi)
.collect();
return Ok(Array1::from_vec(result));
}
config.apply();
let result_vec: Vec<F> = (0..x.len())
.into_par_iter()
.map(|i| alpha * x[i] + y[i])
.collect();
Ok(Array1::from_vec(result_vec))
}
}
pub fn parallel_gemm<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
config: &WorkerConfig,
) -> LinalgResult<scirs2_core::ndarray::Array2<F>>
where
F: Float + Send + Sync + Zero + Sum + NumAssign + 'static,
{
let (m, k) = a.dim();
let (k2, n) = b.dim();
if k != k2 {
return Err(LinalgError::ShapeError(format!(
"Matrix dimensions incompatible for multiplication: {m}x{k} * {k2}x{n}"
)));
}
let datasize = m * k * n;
if !adaptive::should_use_parallel(datasize, config) {
return Ok(a.dot(b));
}
config.apply();
let blocksize = config.chunksize;
let mut result = scirs2_core::ndarray::Array2::zeros((m, n));
result
.outer_iter_mut()
.enumerate()
.par_bridge()
.for_each(|(i, mut row)| {
for j in 0..n {
let mut sum = F::zero();
for kb in (0..k).step_by(blocksize) {
let k_end = std::cmp::min(kb + blocksize, k);
for ki in kb..k_end {
sum += a[[i, ki]] * b[[ki, j]];
}
}
row[j] = sum;
}
});
Ok(result)
}
pub fn parallel_conjugate_gradient<F>(
matrix: &ArrayView2<F>,
b: &ArrayView1<F>,
max_iter: usize,
tolerance: F,
config: &WorkerConfig,
) -> LinalgResult<Array1<F>>
where
F: Float + Send + Sync + Zero + Sum + One + NumAssign + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (m, n) = matrix.dim();
if m != n {
return Err(LinalgError::ShapeError(
"CG requires square matrix".to_string(),
));
}
if n != b.len() {
return Err(LinalgError::ShapeError(
"Matrix and vector dimensions incompatible".to_string(),
));
}
let datasize = m * n;
if !adaptive::should_use_parallel(datasize, config) {
return crate::iterative_solvers::conjugate_gradient(
&matrix.view(),
&b.view(),
max_iter,
tolerance,
None,
);
}
config.apply();
let mut x = Array1::zeros(n);
let ax = parallel_matvec(matrix, &x.view(), config)?;
let mut r = b - &ax;
let mut p = r.clone();
let mut rsold = vector_ops::parallel_dot(&r.view(), &r.view(), config)?;
for _iter in 0..max_iter {
let ap = parallel_matvec(matrix, &p.view(), config)?;
let alpha = rsold / vector_ops::parallel_dot(&p.view(), &ap.view(), config)?;
x = vector_ops::parallel_axpy(alpha, &p.view(), &x.view(), config)?;
r = vector_ops::parallel_axpy(-alpha, &ap.view(), &r.view(), config)?;
let rsnew = vector_ops::parallel_dot(&r.view(), &r.view(), config)?;
if rsnew.sqrt() < tolerance {
return Ok(x);
}
let beta = rsnew / rsold;
p = vector_ops::parallel_axpy(beta, &p.view(), &r.view(), config)?;
rsold = rsnew;
}
Err(LinalgError::ComputationError(
"Conjugate gradient failed to converge".to_string(),
))
}
pub fn parallel_jacobi<F>(
matrix: &ArrayView2<F>,
b: &ArrayView1<F>,
max_iter: usize,
tolerance: F,
config: &WorkerConfig,
) -> LinalgResult<Array1<F>>
where
F: Float + Send + Sync + Zero + Sum + One + NumAssign + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (m, n) = matrix.dim();
if m != n {
return Err(LinalgError::ShapeError(
"Jacobi method requires square matrix".to_string(),
));
}
if n != b.len() {
return Err(LinalgError::ShapeError(
"Matrix and vector dimensions incompatible".to_string(),
));
}
let datasize = m * n;
if !adaptive::should_use_parallel(datasize, config) {
return crate::iterative_solvers::jacobi_method(
&matrix.view(),
&b.view(),
max_iter,
tolerance,
None,
);
}
config.apply();
let diag: Vec<F> = (0..n)
.into_par_iter()
.map(|i| {
if matrix[[i, i]].abs() < F::epsilon() {
F::one() } else {
matrix[[i, i]]
}
})
.collect();
let mut x = Array1::zeros(n);
for _iter in 0..max_iter {
let x_new_vec: Vec<F> = (0..n)
.into_par_iter()
.map(|i| {
let mut sum = b[i];
for j in 0..n {
if i != j {
sum -= matrix[[i, j]] * x[j];
}
}
sum / diag[i]
})
.collect();
let x_new = Array1::from_vec(x_new_vec);
let diff = &x_new - &x;
let error = vector_ops::parallel_norm(&diff.view(), config)?;
if error < tolerance {
return Ok(x_new);
}
x = x_new.clone();
}
Err(LinalgError::ComputationError(
"Jacobi method failed to converge".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{arr1, arr2};
#[test]
fn test_parallel_matvec() {
let config = WorkerConfig::default();
let matrix = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let vector = arr1(&[1.0, 2.0]);
let result = parallel_matvec(&matrix.view(), &vector.view(), &config).expect("Operation failed");
assert_eq!(result, arr1(&[5.0, 11.0]));
}
#[test]
fn test_parallel_dot() {
let config = WorkerConfig::default();
let x = arr1(&[1.0, 2.0, 3.0]);
let y = arr1(&[4.0, 5.0, 6.0]);
let result = vector_ops::parallel_dot(&x.view(), &y.view(), &config).expect("Operation failed");
assert_eq!(result, 32.0); }
#[test]
fn test_parallel_norm() {
let config = WorkerConfig::default();
let x = arr1(&[3.0, 4.0]);
let result = vector_ops::parallel_norm(&x.view(), &config).expect("Operation failed");
assert_eq!(result, 5.0); }
#[test]
fn test_parallel_axpy() {
let config = WorkerConfig::default();
let x = arr1(&[1.0, 2.0]);
let y = arr1(&[3.0, 4.0]);
let alpha = 2.0;
let result = vector_ops::parallel_axpy(alpha, &x.view(), &y.view(), &config).expect("Operation failed");
assert_eq!(result, arr1(&[5.0, 8.0])); }
#[test]
fn test_parallel_gemm() {
let config = WorkerConfig::default();
let a = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let b = arr2(&[[5.0, 6.0], [7.0, 8.0]]);
let result = parallel_gemm(&a.view(), &b.view(), &config).expect("Operation failed");
assert_eq!(result, arr2(&[[19.0, 22.0], [43.0, 50.0]]));
}
}