use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use crate::linalg_optimized::OptimizedBlas;
use num_traits::Float;
use std::fmt::Debug;
pub struct ParallelLinAlg;
impl ParallelLinAlg {
pub fn parallel_gemm<T>(
a: &Array<T>,
b: &Array<T>,
c: &mut Array<T>,
alpha: T,
beta: T,
trans_a: bool,
trans_b: bool,
num_threads: Option<usize>,
) -> Result<()>
where
T: Float
+ num_traits::NumAssign
+ num_traits::NumCast
+ Clone
+ Debug
+ Send
+ Sync
+ 'static,
{
let a_shape = a.shape();
let b_shape = b.shape();
let c_shape = c.shape();
if a_shape.len() != 2 || b_shape.len() != 2 || c_shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"Parallel GEMM requires 2D matrices".to_string(),
));
}
let (m, k_a) = if trans_a {
(a_shape[1], a_shape[0])
} else {
(a_shape[0], a_shape[1])
};
let (k_b, n) = if trans_b {
(b_shape[1], b_shape[0])
} else {
(b_shape[0], b_shape[1])
};
if k_a != k_b || c_shape[0] != m || c_shape[1] != n {
return Err(NumRs2Error::DimensionMismatch(
"Matrix dimensions incompatible for multiplication".to_string(),
));
}
let work_size = m * n * k_a;
let parallel_threshold = 100_000;
if work_size < parallel_threshold {
return OptimizedBlas::gemm(a, b, c, alpha, beta, trans_a, trans_b);
}
let _threads = num_threads.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
});
OptimizedBlas::gemm(a, b, c, alpha, beta, trans_a, trans_b)
}
pub fn parallel_lu<T>(
a: &Array<T>,
num_threads: Option<usize>,
) -> Result<(Array<T>, Array<T>, Array<usize>)>
where
T: Float
+ num_traits::NumAssign
+ num_traits::NumCast
+ Clone
+ Debug
+ std::iter::Sum
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ 'static,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"Parallel LU decomposition requires a square matrix".to_string(),
));
}
let n = shape[0];
let _threads = num_threads.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
});
if n < 1000 {
return crate::linalg_optimized::lu_optimized(a);
}
crate::linalg_optimized::lu_optimized(a)
}
pub fn parallel_qr<T>(a: &Array<T>, num_threads: Option<usize>) -> Result<(Array<T>, Array<T>)>
where
T: Float + Clone + Debug + Send + Sync + 'static,
{
let shape = a.shape();
if shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"QR decomposition requires a 2D matrix".to_string(),
));
}
let m = shape[0];
let n = shape[1];
let _threads = num_threads.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
});
Self::parallel_qr_householder(a, m, n)
}
fn parallel_qr_householder<T>(a: &Array<T>, m: usize, n: usize) -> Result<(Array<T>, Array<T>)>
where
T: Float + Clone + Debug,
{
let min_mn = m.min(n);
let mut q = Array::eye(m, m, 0);
let mut r = a.clone();
for k in 0..min_mn {
let mut col_k = Vec::with_capacity(m - k);
for i in k..m {
col_k.push(r.get(&[i, k])?);
}
let (v, beta) = Self::householder_vector(&col_k)?;
for j in k..n {
let mut col_j = Vec::with_capacity(m - k);
for i in k..m {
col_j.push(r.get(&[i, j])?);
}
let reflected = Self::apply_householder(&col_j, &v, beta)?;
for (idx, &val) in reflected.iter().enumerate() {
r.set(&[k + idx, j], val)?;
}
}
for j in 0..m {
let mut col_j = Vec::with_capacity(m - k);
for i in k..m {
col_j.push(q.get(&[i, j])?);
}
let reflected = Self::apply_householder(&col_j, &v, beta)?;
for (idx, &val) in reflected.iter().enumerate() {
q.set(&[k + idx, j], val)?;
}
}
}
Ok((q, r))
}
pub fn parallel_matvec<T>(
a: &Array<T>,
x: &Array<T>,
y: &mut Array<T>,
alpha: T,
beta: T,
trans: bool,
num_threads: Option<usize>,
) -> Result<()>
where
T: Float
+ num_traits::NumAssign
+ num_traits::NumCast
+ Clone
+ Debug
+ Send
+ Sync
+ 'static,
{
let a_shape = a.shape();
let x_shape = x.shape();
let y_shape = y.shape();
if a_shape.len() != 2 || x_shape.len() != 1 || y_shape.len() != 1 {
return Err(NumRs2Error::DimensionMismatch(
"Matrix-vector multiplication requires 2D matrix and 1D vectors".to_string(),
));
}
let (m, n) = (a_shape[0], a_shape[1]);
let _threads = num_threads.unwrap_or_else(|| {
std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
});
if trans {
if n != y_shape[0] || m != x_shape[0] {
return Err(NumRs2Error::DimensionMismatch(
"Incompatible dimensions for transposed matrix-vector multiplication"
.to_string(),
));
}
} else if m != y_shape[0] || n != x_shape[0] {
return Err(NumRs2Error::DimensionMismatch(
"Incompatible dimensions for matrix-vector multiplication".to_string(),
));
}
OptimizedBlas::gemv(a, x, y, alpha, beta, trans)
}
pub fn optimal_block_size(m: usize, n: usize, num_threads: usize) -> (usize, usize) {
let total_elements = m * n;
let elements_per_thread = total_elements / num_threads;
let block_size = (elements_per_thread as f64).sqrt() as usize;
let block_size = block_size.clamp(32, 512);
(block_size.min(m), block_size.min(n))
}
pub fn compute_load_balance_strategy<T>(
a: &Array<T>,
operation: &str,
num_threads: usize,
) -> String
where
T: Float + Clone + Debug,
{
let shape = a.shape();
if shape.len() != 2 {
return "sequential".to_string();
}
let m = shape[0];
let n = shape[1];
let density = Self::estimate_density(a);
match operation {
"gemm" => {
if m >= n && m >= num_threads {
"row_wise".to_string()
} else if n >= num_threads {
"col_wise".to_string()
} else {
"block_wise".to_string()
}
}
"matvec" => {
if density > 0.5 {
"dense_optimized".to_string()
} else {
"sparse_optimized".to_string()
}
}
_ => "sequential".to_string(),
}
}
fn estimate_density<T>(a: &Array<T>) -> f64
where
T: Float + Clone + Debug,
{
let shape = a.shape();
if shape.len() != 2 {
return 1.0;
}
let m = shape[0];
let n = shape[1];
let sample_size = (m * n / 100).clamp(100, 1000);
let mut non_zero_count = 0;
for i in 0..sample_size {
let row = (i * m) / sample_size;
let col = (i * n) / sample_size;
if let Ok(val) = a.get(&[row.min(m - 1), col.min(n - 1)]) {
if val != T::zero() {
non_zero_count += 1;
}
}
}
non_zero_count as f64 / sample_size as f64
}
fn householder_vector<T>(x: &[T]) -> Result<(Vec<T>, T)>
where
T: Float + Clone,
{
let n = x.len();
if n == 0 {
return Err(NumRs2Error::InvalidOperation("Empty vector".to_string()));
}
let x_norm = x
.iter()
.map(|&xi| xi * xi)
.fold(T::zero(), |acc, xi| acc + xi)
.sqrt();
if x_norm == T::zero() {
return Ok((vec![T::zero(); n], T::zero()));
}
let alpha = if x[0] >= T::zero() { -x_norm } else { x_norm };
let mut v = vec![T::zero(); n];
v[0] = x[0] - alpha;
v[1..n].copy_from_slice(&x[1..n]);
let v_norm_sq = v
.iter()
.map(|&vi| vi * vi)
.fold(T::zero(), |acc, vi| acc + vi);
if v_norm_sq == T::zero() {
return Ok((v, T::zero()));
}
let beta = T::from(2.0).expect("Failed to convert 2.0 to type T") / v_norm_sq;
Ok((v, beta))
}
fn apply_householder<T>(x: &[T], v: &[T], beta: T) -> Result<Vec<T>>
where
T: Float + Clone,
{
if x.len() != v.len() {
return Err(NumRs2Error::DimensionMismatch(
"Vector length mismatch".to_string(),
));
}
let dot_product = x
.iter()
.zip(v.iter())
.map(|(&xi, &vi)| xi * vi)
.fold(T::zero(), |acc, prod| acc + prod);
let mut result = Vec::with_capacity(x.len());
for (&xi, &vi) in x.iter().zip(v.iter()) {
result.push(xi - beta * dot_product * vi);
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_parallel_matrix_multiplication() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
let mut c = Array::zeros(&[2, 2]);
ParallelLinAlg::parallel_gemm(&a, &b, &mut c, 1.0, 0.0, false, false, Some(2))
.expect("parallel gemm should succeed");
assert_relative_eq!(c.get(&[0, 0]).expect("valid index"), 19.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[0, 1]).expect("valid index"), 22.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[1, 0]).expect("valid index"), 43.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[1, 1]).expect("valid index"), 50.0, epsilon = 1e-10);
}
#[test]
fn test_parallel_lu_decomposition() {
let a = Array::from_vec(vec![2.0, 1.0, 1.0, 3.0]).reshape(&[2, 2]);
let (l, u, _p) =
ParallelLinAlg::parallel_lu(&a, Some(2)).expect("parallel LU should succeed");
assert_relative_eq!(l.get(&[0, 0]).expect("valid index"), 1.0, epsilon = 1e-10);
assert_relative_eq!(l.get(&[1, 1]).expect("valid index"), 1.0, epsilon = 1e-10);
assert_relative_eq!(l.get(&[0, 1]).expect("valid index"), 0.0, epsilon = 1e-10);
assert_relative_eq!(u.get(&[1, 0]).expect("valid index"), 0.0, epsilon = 1e-10);
}
#[test]
fn test_parallel_qr_decomposition() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let (q, r) = ParallelLinAlg::parallel_qr(&a, Some(2)).expect("parallel QR should succeed");
assert_eq!(q.shape(), vec![2, 2]);
assert_eq!(r.shape(), vec![2, 2]);
assert!(q.get(&[0, 0]).expect("valid index").abs() <= 1.0);
assert!(q.get(&[1, 1]).expect("valid index").abs() <= 1.0);
}
#[test]
fn test_parallel_matrix_vector_multiplication() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let x = Array::from_vec(vec![1.0, 2.0]);
let mut y = Array::zeros(&[2]);
ParallelLinAlg::parallel_matvec(&a, &x, &mut y, 1.0, 0.0, false, Some(2))
.expect("parallel matvec should succeed");
assert_relative_eq!(y.get(&[0]).expect("valid index"), 5.0, epsilon = 1e-10);
assert_relative_eq!(y.get(&[1]).expect("valid index"), 11.0, epsilon = 1e-10);
}
#[test]
fn test_optimal_block_size() {
let (block_m, block_n) = ParallelLinAlg::optimal_block_size(1000, 1000, 4);
assert!(block_m >= 32);
assert!(block_n >= 32);
assert!(block_m <= 512);
assert!(block_n <= 512);
}
#[test]
fn test_load_balance_strategy() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
let strategy = ParallelLinAlg::compute_load_balance_strategy(&a, "gemm", 4);
assert!(strategy == "row_wise" || strategy == "col_wise" || strategy == "block_wise");
let strategy = ParallelLinAlg::compute_load_balance_strategy(&a, "matvec", 4);
assert!(strategy == "dense_optimized" || strategy == "sparse_optimized");
}
#[test]
fn test_householder_vector() {
let x = vec![1.0, 2.0, 3.0];
let (v, beta) = ParallelLinAlg::householder_vector(&x).expect("householder should succeed");
assert_eq!(v.len(), 3);
assert!(beta >= 0.0);
let result = ParallelLinAlg::apply_householder(&x, &v, beta)
.expect("apply householder should succeed");
let x_norm = (1.0 + 4.0 + 9.0_f64).sqrt();
assert_relative_eq!(result[0].abs(), x_norm, epsilon = 1e-10);
assert_relative_eq!(result[1], 0.0, epsilon = 1e-10);
assert_relative_eq!(result[2], 0.0, epsilon = 1e-10);
}
}