use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, Axis, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use scirs2_core::parallel_ops::*;
use std::cmp;
use std::iter::Sum;
#[derive(Debug, Clone, Copy)]
pub enum OptAlgorithm {
Standard,
Blocked,
Parallel,
Adaptive,
}
#[derive(Debug, Clone)]
pub struct OptConfig {
pub blocksize: usize,
pub parallel_threshold: usize,
pub num_threads: Option<usize>,
pub algorithm: OptAlgorithm,
}
impl Default for OptConfig {
fn default() -> Self {
OptConfig {
blocksize: 64,
parallel_threshold: 1000,
num_threads: None,
algorithm: OptAlgorithm::Blocked,
}
}
}
impl OptConfig {
pub fn with_blocksize(mut self, size: usize) -> Self {
self.blocksize = size;
self
}
pub fn with_parallel_threshold(mut self, threshold: usize) -> Self {
self.parallel_threshold = threshold;
self
}
pub fn with_num_threads(mut self, threads: usize) -> Self {
self.num_threads = Some(threads);
self
}
pub fn with_algorithm(mut self, algorithm: OptAlgorithm) -> Self {
self.algorithm = algorithm;
self
}
}
#[allow(dead_code)]
pub fn blocked_matmul<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
config: &OptConfig,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
let (m, k) = (a.nrows(), a.ncols());
let (k2, n) = (b.nrows(), b.ncols());
if k != k2 {
return Err(LinalgError::DimensionError(format!(
"Matrix dimensions don't match: ({m}, {k}) x ({k2}, {n})"
)));
}
let mut c = Array2::zeros((m, n));
let blocksize = config.blocksize;
match config.algorithm {
OptAlgorithm::Standard => Ok(a.dot(b)),
OptAlgorithm::Blocked => {
serial_blocked_matmul(a, b, &mut c, blocksize)?;
Ok(c)
}
OptAlgorithm::Parallel => {
parallel_blocked_matmul(a, b, &mut c, blocksize)?;
Ok(c)
}
OptAlgorithm::Adaptive => {
if m * n > config.parallel_threshold {
parallel_blocked_matmul(a, b, &mut c, blocksize)?;
} else if m * n > 10000 {
serial_blocked_matmul(a, b, &mut c, blocksize)?;
} else {
return Ok(a.dot(b));
}
Ok(c)
}
}
}
#[allow(dead_code)]
fn serial_blocked_matmul<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
c: &mut Array2<F>,
blocksize: usize,
) -> LinalgResult<()>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (m, k) = (a.nrows(), a.ncols());
let n = b.ncols();
for ii in (0..m).step_by(blocksize) {
for jj in (0..n).step_by(blocksize) {
for kk in (0..k).step_by(blocksize) {
let i_end = cmp::min(ii + blocksize, m);
let j_end = cmp::min(jj + blocksize, n);
let k_end = cmp::min(kk + blocksize, k);
for i in ii..i_end {
for j in jj..j_end {
let mut sum = c[[i, j]];
for ki in kk..k_end {
sum += a[[i, ki]] * b[[ki, j]];
}
c[[i, j]] = sum;
}
}
}
}
}
Ok(())
}
#[allow(dead_code)]
fn parallel_blocked_matmul<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
c: &mut Array2<F>,
blocksize: usize,
) -> LinalgResult<()>
where
F: Float + NumAssign + Sum + Send + Sync,
{
let (m, k) = (a.nrows(), a.ncols());
let n = b.ncols();
let block_indices: Vec<(usize, usize)> = (0..m)
.step_by(blocksize)
.flat_map(|i| (0..n).step_by(blocksize).map(move |j| (i, j)))
.collect();
let results: Vec<_> = parallel_map(&block_indices, |&(ii, jj)| {
let i_end = cmp::min(ii + blocksize, m);
let j_end = cmp::min(jj + blocksize, n);
let mut local_c = Array2::zeros((i_end - ii, j_end - jj));
for kk in (0..k).step_by(blocksize) {
let k_end = cmp::min(kk + blocksize, k);
for (i_local, i) in (0..(i_end - ii)).zip(ii..i_end) {
for (j_local, j) in (0..(j_end - jj)).zip(jj..j_end) {
let mut sum = local_c[[i_local, j_local]];
for ki in kk..k_end {
sum += a[[i, ki]] * b[[ki, j]];
}
local_c[[i_local, j_local]] = sum;
}
}
}
((ii, jj), local_c)
});
for ((ii, jj), local_c) in results {
let i_end = cmp::min(ii + blocksize, m);
let j_end = cmp::min(jj + blocksize, n);
for (i_local, i) in (0..(i_end - ii)).zip(ii..i_end) {
for (j_local, j) in (0..(j_end - jj)).zip(jj..j_end) {
c[[i, j]] = local_c[[i_local, j_local]];
}
}
}
Ok(())
}
#[allow(dead_code)]
pub fn blocked_matmul_with_workers<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
workers: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
use crate::parallel;
parallel::configure_workers(workers);
let config = OptConfig {
num_threads: workers,
..OptConfig::default()
};
blocked_matmul(a, b, &config)
}
#[allow(dead_code)]
pub fn optimized_transpose<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + Send + Sync,
{
optimized_transpose_with_workers(a, None)
}
#[allow(dead_code)]
pub fn optimized_transpose_with_workers<F>(
a: &ArrayView2<F>,
workers: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + Send + Sync,
{
use crate::parallel;
parallel::configure_workers(workers);
let (m, n) = (a.nrows(), a.ncols());
let mut result = Array2::zeros((n, m));
let blocksize = 32;
for i in (0..m).step_by(blocksize) {
for j in (0..n).step_by(blocksize) {
let i_end = cmp::min(i + blocksize, m);
let j_end = cmp::min(j + blocksize, n);
for ii in i..i_end {
for jj in j..j_end {
result[[jj, ii]] = a[[ii, jj]];
}
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn parallel_matvec<F>(
a: &ArrayView2<F>,
x: &ArrayView2<F>,
config: &OptConfig,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync,
{
if a.ncols() != x.nrows() {
return Err(LinalgError::DimensionError(format!(
"Matrix and vector dimensions don't match: ({}, {}) x ({}, {})",
a.nrows(),
a.ncols(),
x.nrows(),
x.ncols()
)));
}
let m = a.nrows();
let n = x.ncols();
let mut result = Array2::zeros((m, n));
if m > config.parallel_threshold {
let rows: Vec<_> = result.axis_iter_mut(Axis(0)).enumerate().collect();
rows.into_par_iter().for_each(|(i, mut row)| {
for j in 0..n {
let mut sum = F::zero();
for k in 0..a.ncols() {
sum += a[[i, k]] * x[[k, j]];
}
row[j] = sum;
}
});
} else {
for i in 0..m {
for j in 0..n {
let mut sum = F::zero();
for k in 0..a.ncols() {
sum += a[[i, k]] * x[[k, j]];
}
result[[i, j]] = sum;
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn parallel_matvec_with_workers<F>(
a: &ArrayView2<F>,
x: &ArrayView2<F>,
workers: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync,
{
use crate::parallel;
parallel::configure_workers(workers);
let config = OptConfig {
num_threads: workers,
..OptConfig::default()
};
parallel_matvec(a, x, &config)
}
pub mod inplace {
use super::*;
pub fn add_assign<F>(a: &mut Array2<F>, b: &ArrayView2<F>) -> LinalgResult<()>
where
F: Float + NumAssign,
{
if a.shape() != b.shape() {
return Err(LinalgError::DimensionError(format!(
"Matrix dimensions don't match: {:?} != {:?}",
a.shape(),
b.shape()
)));
}
for (a_elem, b_elem) in a.iter_mut().zip(b.iter()) {
*a_elem += *b_elem;
}
Ok(())
}
pub fn scalar_mul_assign<F>(a: &mut Array2<F>, scalar: F) -> LinalgResult<()>
where
F: Float + NumAssign + Send + Sync,
{
for elem in a.iter_mut() {
*elem *= scalar;
}
Ok(())
}
pub fn transpose_square<F>(a: &mut Array2<F>) -> LinalgResult<()>
where
F: Float,
{
let n = a.nrows();
if n != a.ncols() {
return Err(LinalgError::DimensionError(
"In-place transpose requires square matrix".to_string(),
));
}
for i in 0..n {
for j in (i + 1)..n {
let temp = a[[i, j]];
a[[i, j]] = a[[j, i]];
a[[j, i]] = temp;
}
}
Ok(())
}
}
#[allow(dead_code)]
pub fn adaptive_matmul<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
adaptive_matmul_with_workers(a, b, None)
}
#[allow(dead_code)]
pub fn adaptive_matmul_with_workers<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
workers: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
use crate::parallel;
parallel::configure_workers(workers);
let size = a.nrows() * a.ncols() + b.nrows() * b.ncols();
if size < 10000 {
Ok(a.dot(b))
} else if size < 1000000 {
let config = OptConfig {
num_threads: workers,
..OptConfig::default()
};
blocked_matmul(a, b, &config)
} else {
let config = OptConfig {
parallel_threshold: 50000,
num_threads: workers,
..OptConfig::default()
};
blocked_matmul(a, b, &config)
}
}
#[allow(dead_code)]
pub fn inplace_add<F>(
a: &mut scirs2_core::ndarray::ArrayViewMut2<F>,
b: &ArrayView2<F>,
) -> LinalgResult<()>
where
F: Float + NumAssign + Send + Sync,
{
let mut a_owned = a.to_owned();
inplace::add_assign(&mut a_owned, b)?;
a.assign(&a_owned);
Ok(())
}
#[allow(dead_code)]
pub fn inplace_scale<F>(
a: &mut scirs2_core::ndarray::ArrayViewMut2<F>,
scalar: F,
) -> LinalgResult<()>
where
F: Float + NumAssign + Send + Sync,
{
for elem in a.iter_mut() {
*elem *= scalar;
}
Ok(())
}
#[allow(dead_code)]
pub fn matmul_benchmark<F>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
config: &OptConfig,
) -> LinalgResult<String>
where
F: Float + NumAssign + Sum + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
use std::time::Instant;
let start = Instant::now();
let _c1 = a.dot(b);
let time_standard = start.elapsed();
let start = Instant::now();
let _c2 = blocked_matmul(a, b, config)?;
let time_optimized = start.elapsed();
Ok(format!(
"Standard: {:?}, Optimized: {:?}, Speedup: {:.2}x",
time_standard,
time_optimized,
time_standard.as_secs_f64() / time_optimized.as_secs_f64()
))
}
pub mod decomposition_opt {
use super::*;
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
pub struct QRWorkspace<F: Float> {
pub tau: Array1<F>,
pub work: Array1<F>,
pub tempmatrix: Array2<F>,
}
impl<F: Float> QRWorkspace<F> {
pub fn new(_max_rows: usize, maxcols: usize) -> Self {
let min_dim = _max_rows.min(maxcols);
Self {
tau: Array1::zeros(min_dim),
work: Array1::zeros(maxcols * 64), tempmatrix: Array2::zeros((_max_rows, maxcols)),
}
}
pub fn resize_if_needed(&mut self, rows: usize, cols: usize) {
let min_dim = rows.min(cols);
if self.tau.len() < min_dim {
self.tau = Array1::zeros(min_dim);
}
let worksize = cols * 64;
if self.work.len() < worksize {
self.work = Array1::zeros(worksize);
}
if self.tempmatrix.nrows() < rows || self.tempmatrix.ncols() < cols {
self.tempmatrix = Array2::zeros((rows, cols));
}
}
}
pub fn qr_with_workspace<F>(
a: &ArrayView2<F>,
workspace: &mut QRWorkspace<F>,
) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Clone,
{
let (m, n) = a.dim();
workspace.resize_if_needed(m, n);
let mut a_work = workspace.tempmatrix.slice_mut(s![..m, ..n]);
a_work.assign(a);
let min_dim = m.min(n);
for j in 0..min_dim {
let column = a_work.column(j);
let norm = column
.slice(s![j..])
.fold(F::zero(), |acc, &x| acc + x * x)
.sqrt();
if norm > F::epsilon() {
workspace.tau[j] = norm;
let alpha = column[j];
let sign = if alpha >= F::zero() {
F::one()
} else {
-F::one()
};
let u1 = alpha + sign * norm;
if u1.abs() > F::epsilon() {
let scale = F::one() / u1;
for i in (j + 1)..m {
a_work[[i, j]] *= scale;
}
}
}
}
let q = Array2::eye(m);
let mut r = Array2::zeros((m, n));
for i in 0..m {
for j in i..n {
if i < min_dim && j < n {
r[[i, j]] = a_work[[i, j]];
}
}
}
Ok((q, r))
}
pub struct DecompositionMemoryPool<F: Float> {
pub arrays: Vec<Array2<F>>,
pub vectors: Vec<Array1<F>>,
pub max_poolsize: usize,
}
impl<F: Float> DecompositionMemoryPool<F> {
pub fn new(_max_poolsize: usize) -> Self {
Self {
arrays: Vec::new(),
vectors: Vec::new(),
max_poolsize: _max_poolsize,
}
}
pub fn getarray(&mut self, rows: usize, cols: usize) -> Array2<F> {
for (i, array) in self.arrays.iter().enumerate() {
if array.nrows() >= rows && array.ncols() >= cols {
let mut result = self.arrays.swap_remove(i);
result = result.slice(s![..rows, ..cols]).to_owned();
result.fill(F::zero()); return result;
}
}
Array2::zeros((rows, cols))
}
pub fn returnarray(&mut self, array: Array2<F>) {
if self.arrays.len() < self.max_poolsize {
self.arrays.push(array);
}
}
pub fn get_vector(&mut self, len: usize) -> Array1<F> {
for (i, vector) in self.vectors.iter().enumerate() {
if vector.len() >= len {
let mut result = self.vectors.swap_remove(i);
result = result.slice(s![..len]).to_owned();
result.fill(F::zero()); return result;
}
}
Array1::zeros(len)
}
pub fn return_vector(&mut self, vector: Array1<F>) {
if self.vectors.len() < self.max_poolsize {
self.vectors.push(vector);
}
}
pub fn clear(&mut self) {
self.arrays.clear();
self.vectors.clear();
}
}
pub fn blocked_qr<F>(
a: &ArrayView2<F>,
blocksize: usize,
) -> LinalgResult<(Array2<F>, Array2<F>)>
where
F: Float + NumAssign + Sum + Clone,
{
let (m, n) = a.dim();
let mut a_copy = a.to_owned();
let q = Array2::eye(m);
let min_dim = m.min(n);
for start_col in (0..min_dim).step_by(blocksize) {
let end_col = (start_col + blocksize).min(min_dim);
let _panel_width = end_col - start_col;
for j in start_col..end_col {
let col_norm = a_copy
.slice(s![j.., j])
.fold(F::zero(), |acc, &x| acc + x * x)
.sqrt();
if col_norm > F::epsilon() {
let alpha = a_copy[[j, j]];
let sign = if alpha >= F::zero() {
F::one()
} else {
-F::one()
};
let beta = alpha + sign * col_norm;
if beta.abs() > F::epsilon() {
let scale = F::one() / beta;
for i in (j + 1)..m {
a_copy[[i, j]] *= scale;
}
a_copy[[j, j]] = -sign * col_norm;
for k in (j + 1)..n {
let mut dot_product = a_copy[[j, k]];
for i in (j + 1)..m {
dot_product += a_copy[[i, j]] * a_copy[[i, k]];
}
let tau = dot_product * F::from(2.0).expect("Operation failed");
a_copy[[j, k]] -= tau;
for i in (j + 1)..m {
let householder_val = a_copy[[i, j]];
a_copy[[i, k]] -= tau * householder_val;
}
}
}
}
}
}
let mut r = Array2::zeros((m, n));
for i in 0..m {
for j in i..n {
if i < min_dim {
r[[i, j]] = a_copy[[i, j]];
}
}
}
Ok((q, r))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_blocked_matmul() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let config = OptConfig {
blocksize: 1,
parallel_threshold: 1000,
num_threads: None,
algorithm: OptAlgorithm::Blocked,
};
let c = blocked_matmul(&a.view(), &b.view(), &config).expect("Operation failed");
assert_relative_eq!(c[[0, 0]], 19.0);
assert_relative_eq!(c[[0, 1]], 22.0);
assert_relative_eq!(c[[1, 0]], 43.0);
assert_relative_eq!(c[[1, 1]], 50.0);
}
#[test]
fn test_optimized_transpose() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let a_t = optimized_transpose(&a.view()).expect("Operation failed");
assert_eq!(a_t.shape(), &[3, 2]);
assert_relative_eq!(a_t[[0, 0]], 1.0);
assert_relative_eq!(a_t[[1, 0]], 2.0);
assert_relative_eq!(a_t[[2, 0]], 3.0);
assert_relative_eq!(a_t[[0, 1]], 4.0);
assert_relative_eq!(a_t[[1, 1]], 5.0);
assert_relative_eq!(a_t[[2, 1]], 6.0);
}
#[test]
fn test_parallel_matvec() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let x = array![[5.0], [6.0]];
let config = OptConfig::default();
let y = parallel_matvec(&a.view(), &x.view(), &config).expect("Operation failed");
assert_relative_eq!(y[[0, 0]], 17.0);
assert_relative_eq!(y[[1, 0]], 39.0);
}
#[test]
fn test_inplace_operations() {
let mut a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
inplace::add_assign(&mut a, &b.view()).expect("Operation failed");
assert_relative_eq!(a[[0, 0]], 6.0);
assert_relative_eq!(a[[0, 1]], 8.0);
assert_relative_eq!(a[[1, 0]], 10.0);
assert_relative_eq!(a[[1, 1]], 12.0);
inplace::scalar_mul_assign(&mut a, 2.0).expect("Operation failed");
assert_relative_eq!(a[[0, 0]], 12.0);
assert_relative_eq!(a[[0, 1]], 16.0);
assert_relative_eq!(a[[1, 0]], 20.0);
assert_relative_eq!(a[[1, 1]], 24.0);
}
#[test]
fn test_inplace_transpose() {
let mut a = array![[1.0, 2.0], [3.0, 4.0]];
inplace::transpose_square(&mut a).expect("Operation failed");
assert_relative_eq!(a[[0, 0]], 1.0);
assert_relative_eq!(a[[0, 1]], 3.0);
assert_relative_eq!(a[[1, 0]], 2.0);
assert_relative_eq!(a[[1, 1]], 4.0);
}
#[test]
fn test_adaptive_matmul() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let c = adaptive_matmul(&a.view(), &b.view()).expect("Operation failed");
assert_relative_eq!(c[[0, 0]], 19.0);
assert_relative_eq!(c[[0, 1]], 22.0);
assert_relative_eq!(c[[1, 0]], 43.0);
assert_relative_eq!(c[[1, 1]], 50.0);
}
#[test]
fn test_largematrix_blocked() {
let n = 100;
let a = Array2::from_shape_fn((n, n), |(i, j)| (i + j) as f64);
let b = Array2::eye(n);
let config = OptConfig {
blocksize: 16,
parallel_threshold: 10000,
num_threads: None,
algorithm: OptAlgorithm::Blocked,
};
let c = blocked_matmul(&a.view(), &b.view(), &config).expect("Operation failed");
for i in 0..n {
for j in 0..n {
assert_relative_eq!(c[[i, j]], a[[i, j]], epsilon = 1e-10);
}
}
}
}