use crate::error::LinalgResult;
use crate::parallel::{algorithms, WorkerConfig};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::iter::Sum;
pub struct ParallelDecomposition;
impl ParallelDecomposition {
pub fn cholesky<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_cholesky(a, &config);
}
}
crate::decomposition::cholesky(a, workers)
}
pub fn lu<F>(
a: &ArrayView2<F>,
workers: Option<usize>,
) -> LinalgResult<(Array2<F>, Array2<F>, Array2<F>)>
where
F: Float
+ NumAssign
+ One
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_lu(a, &config);
}
}
crate::decomposition::lu(a, workers)
}
pub fn qr<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float
+ NumAssign
+ One
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_qr(a, &config);
}
}
crate::decomposition::qr(a, workers)
}
pub fn svd<F>(
a: &ArrayView2<F>,
full_matrices: bool,
workers: Option<usize>,
) -> LinalgResult<(Array2<F>, Array1<F>, Array2<F>)>
where
F: Float
+ NumAssign
+ One
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold && !full_matrices {
return algorithms::parallel_svd(a, &config);
}
}
crate::decomposition::svd(a, full_matrices, workers)
}
}
pub struct ParallelSolver;
impl ParallelSolver {
pub fn conjugate_gradient<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
max_iter: usize,
tolerance: F,
workers: Option<usize>,
) -> LinalgResult<Array1<F>>
where
F: Float
+ NumAssign
+ One
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_conjugate_gradient(a, b, max_iter, tolerance, &config);
}
}
crate::iterative_solvers::conjugate_gradient(a, b, max_iter, tolerance, None)
}
pub fn gmres<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
max_iter: usize,
tolerance: F,
restart: usize,
workers: Option<usize>,
) -> LinalgResult<Array1<F>>
where
F: Float
+ NumAssign
+ One
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ std::fmt::Debug
+ std::fmt::Display
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_gmres(a, b, max_iter, tolerance, restart, &config);
}
}
let options = crate::solvers::iterative::IterativeSolverOptions {
max_iterations: max_iter,
tolerance,
verbose: false,
restart: Some(restart),
};
crate::solvers::iterative::gmres(a, b, None, &options).map(|result| result.solution)
}
pub fn bicgstab<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
max_iter: usize,
tolerance: F,
workers: Option<usize>,
) -> LinalgResult<Array1<F>>
where
F: Float
+ NumAssign
+ One
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_bicgstab(a, b, max_iter, tolerance, &config);
}
}
crate::iterative_solvers::bicgstab(a, b, max_iter, tolerance, None)
}
pub fn jacobi<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
max_iter: usize,
tolerance: F,
workers: Option<usize>,
) -> LinalgResult<Array1<F>>
where
F: Float
+ NumAssign
+ One
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_jacobi(a, b, max_iter, tolerance, &config);
}
}
crate::iterative_solvers::jacobi_method(a, b, max_iter, tolerance, None)
}
pub fn sor<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
omega: F,
max_iter: usize,
tolerance: F,
workers: Option<usize>,
) -> LinalgResult<Array1<F>>
where
F: Float
+ NumAssign
+ One
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_sor(a, b, omega, max_iter, tolerance, &config);
}
}
crate::iterative_solvers::successive_over_relaxation(a, b, omega, max_iter, tolerance, None)
}
}
pub struct ParallelOperations;
impl ParallelOperations {
pub fn matmul<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
workers: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Zero + Sum + Send + Sync + 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, k) = a.dim();
let (_, n) = b.dim();
if m * k * n > config.parallel_threshold {
return algorithms::parallel_gemm(a, b, &config);
}
}
Ok(a.dot(b))
}
pub fn matvec<F>(
a: &ArrayView2<F>,
x: &ArrayView1<F>,
workers: Option<usize>,
) -> LinalgResult<Array1<F>>
where
F: Float + Zero + Sum + Send + Sync + 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_matvec(a, x, &config);
}
}
Ok(a.dot(x))
}
pub fn power_iteration<F>(
a: &ArrayView2<F>,
max_iter: usize,
tolerance: F,
workers: Option<usize>,
) -> LinalgResult<(F, Array1<F>)>
where
F: Float
+ NumAssign
+ One
+ Zero
+ Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
if let Some(num_workers) = workers {
let config = WorkerConfig::new().with_workers(num_workers);
let (m, n) = a.dim();
if m * n > config.parallel_threshold {
return algorithms::parallel_power_iteration(a, max_iter, tolerance, &config);
}
}
crate::eigen::power_iteration(a, max_iter, tolerance)
}
}
pub struct ParallelConfig {
workers: Option<usize>,
threshold_multiplier: f64,
}
impl ParallelConfig {
pub fn new() -> Self {
Self {
workers: None,
threshold_multiplier: 1.0,
}
}
pub fn with_workers(mut self, workers: usize) -> Self {
self.workers = Some(workers);
self
}
pub fn with_threshold_multiplier(mut self, multiplier: f64) -> Self {
self.threshold_multiplier = multiplier;
self
}
pub fn build(&self) -> WorkerConfig {
let mut config = WorkerConfig::new();
if let Some(workers) = self.workers {
config = config.with_workers(workers);
}
let base_threshold = config.parallel_threshold;
config =
config.with_threshold((base_threshold as f64 * self.threshold_multiplier) as usize);
config
}
}
impl Default for ParallelConfig {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_parallel_dispatch_smallmatrix() {
let a = array![[1.0, 2.0], [2.0, 5.0]];
let result = ParallelDecomposition::cholesky(&a.view(), Some(4));
assert!(result.is_ok());
}
#[test]
fn test_parallel_config_builder() {
let config = ParallelConfig::new()
.with_workers(8)
.with_threshold_multiplier(2.0)
.build();
assert_eq!(config.workers, Some(8));
assert_eq!(config.parallel_threshold, 2000); }
}