use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, One};
use std::iter::Sum;
use crate::eigen::eig;
use crate::error::{LinalgError, LinalgResult};
use crate::norm::matrix_norm;
use crate::solve::solve_multiple;
use crate::validation::validate_decomposition;
#[allow(dead_code)]
pub fn expm<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + scirs2_core::ndarray::ScalarOperand + Send + Sync + 'static,
{
use crate::parallel;
parallel::configure_workers(workers);
validate_decomposition(a, "Matrix exponential computation", true)?;
let n = a.nrows();
if n == 1 {
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = a[[0, 0]].exp();
return Ok(result);
}
let mut is_diagonal = true;
for i in 0..n {
for j in 0..n {
if i != j && a[[i, j]].abs() > F::epsilon() {
is_diagonal = false;
break;
}
}
if !is_diagonal {
break;
}
}
if is_diagonal {
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
result[[i, i]] = a[[i, i]].exp();
}
return Ok(result);
}
let norm_a = matrix_norm(a, "1", None)?;
let scaling_f = norm_a.log2().ceil().max(F::zero());
let scaling = scaling_f.to_i32().unwrap_or(0);
let s = F::from(2.0_f64.powi(-scaling)).unwrap_or(F::one());
let mut a_scaled = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
a_scaled[[i, j]] = a[[i, j]] * s;
}
}
let c = [
F::from(1.0).expect("Operation failed"),
F::from(1.0 / 2.0).expect("Operation failed"),
F::from(1.0 / 6.0).expect("Operation failed"),
F::from(1.0 / 24.0).expect("Operation failed"),
F::from(1.0 / 120.0).expect("Operation failed"),
F::from(1.0 / 720.0).expect("Operation failed"),
];
let mut a2 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
a2[[i, j]] += a_scaled[[i, k]] * a_scaled[[k, j]];
}
}
}
let mut a4 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
a4[[i, j]] += a2[[i, k]] * a2[[k, j]];
}
}
}
let mut n_pade = Array2::<F>::zeros((n, n));
for i in 0..n {
n_pade[[i, i]] = c[0]; }
for i in 0..n {
for j in 0..n {
n_pade[[i, j]] += c[1] * a_scaled[[i, j]];
}
}
for i in 0..n {
for j in 0..n {
n_pade[[i, j]] += c[2] * a2[[i, j]];
}
}
let mut a3 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
a3[[i, j]] += a_scaled[[i, k]] * a2[[k, j]];
}
}
}
for i in 0..n {
for j in 0..n {
n_pade[[i, j]] += c[3] * a3[[i, j]];
}
}
for i in 0..n {
for j in 0..n {
n_pade[[i, j]] += c[4] * a4[[i, j]];
}
}
let mut a5 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
a5[[i, j]] += a_scaled[[i, k]] * a4[[k, j]];
}
}
}
for i in 0..n {
for j in 0..n {
n_pade[[i, j]] += c[5] * a5[[i, j]];
}
}
let mut d_pade = Array2::<F>::zeros((n, n));
for i in 0..n {
d_pade[[i, i]] = c[0]; }
for i in 0..n {
for j in 0..n {
d_pade[[i, j]] -= c[1] * a_scaled[[i, j]];
}
}
for i in 0..n {
for j in 0..n {
d_pade[[i, j]] += c[2] * a2[[i, j]];
}
}
for i in 0..n {
for j in 0..n {
d_pade[[i, j]] -= c[3] * a3[[i, j]];
}
}
for i in 0..n {
for j in 0..n {
d_pade[[i, j]] += c[4] * a4[[i, j]];
}
}
for i in 0..n {
for j in 0..n {
d_pade[[i, j]] -= c[5] * a5[[i, j]];
}
}
let result = solve_multiple(&d_pade.view(), &n_pade.view(), None)?;
let mut exp_a = result;
for _ in 0..scaling as usize {
let mut temp = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
temp[[i, j]] += exp_a[[i, k]] * exp_a[[k, j]];
}
}
}
exp_a = temp;
}
Ok(exp_a)
}
#[allow(dead_code)]
pub fn logm<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
logm_impl(a)
}
#[allow(dead_code)]
fn logm_impl<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Matrix must be square to compute logarithm, got shape {:?}",
a.shape()
)));
}
let n = a.nrows();
if n == 1 {
let val = a[[0, 0]];
if val <= F::zero() {
return Err(LinalgError::InvalidInputError(
"Cannot compute real logarithm of non-positive scalar".to_string(),
));
}
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = val.ln();
return Ok(result);
}
let mut is_diagonal = true;
for i in 0..n {
for j in 0..n {
if i != j && a[[i, j]].abs() > F::epsilon() {
is_diagonal = false;
break;
}
}
if !is_diagonal {
break;
}
}
if is_diagonal {
for i in 0..n {
if a[[i, i]] <= F::zero() {
return Err(LinalgError::InvalidInputError(
"Cannot compute real logarithm of matrix with non-positive eigenvalues"
.to_string(),
));
}
}
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
result[[i, i]] = a[[i, i]].ln();
}
return Ok(result);
}
let mut is_identity = true;
for i in 0..n {
for j in 0..n {
let expected = if i == j { F::one() } else { F::zero() };
if (a[[i, j]] - expected).abs() > F::epsilon() {
is_identity = false;
break;
}
}
if !is_identity {
break;
}
}
if is_identity {
return Ok(Array2::<F>::zeros((n, n)));
}
if n == 2 && a[[0, 1]].abs() < F::epsilon() && a[[1, 0]].abs() < F::epsilon() {
let a00 = a[[0, 0]];
let a11 = a[[1, 1]];
if a00 <= F::zero() || a11 <= F::zero() {
return Err(LinalgError::InvalidInputError(
"Cannot compute real logarithm of matrix with non-positive eigenvalues".to_string(),
));
}
let mut result = Array2::<F>::zeros((2, 2));
result[[0, 0]] = a00.ln();
result[[1, 1]] = a11.ln();
return Ok(result);
}
let identity = Array2::eye(n);
let mut max_diff = F::zero();
for i in 0..n {
for j in 0..n {
let diff = (a[[i, j]] - identity[[i, j]]).abs();
if diff > max_diff {
max_diff = diff;
}
}
}
if max_diff > F::from(0.5).expect("Operation failed") {
let mut scaling_k = 0;
let mut a_scaled = a.to_owned();
while scaling_k < 10 {
let mut max_scaled_diff = F::zero();
for i in 0..n {
for j in 0..n {
let expected = if i == j { F::one() } else { F::zero() };
let diff = (a_scaled[[i, j]] - expected).abs();
if diff > max_scaled_diff {
max_scaled_diff = diff;
}
}
}
if max_scaled_diff <= F::from(0.2).expect("Operation failed") {
break;
}
match sqrtm(
&a_scaled.view(),
20,
F::from(1e-12).expect("Operation failed"),
) {
Ok(sqrt_result) => {
a_scaled = sqrt_result;
scaling_k += 1;
}
Err(_) => {
return Err(LinalgError::ImplementationError(
"Matrix logarithm: Could not compute matrix square root for scaling"
.to_string(),
));
}
}
}
if scaling_k >= 10 {
return Err(LinalgError::ImplementationError(
"Matrix logarithm: Matrix could not be scaled close enough to identity".to_string(),
));
}
let mut x_scaled = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
let expected = if i == j { F::one() } else { F::zero() };
x_scaled[[i, j]] = a_scaled[[i, j]] - expected;
}
}
let mut x2 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x2[[i, j]] += x_scaled[[i, k]] * x_scaled[[k, j]];
}
}
}
let mut x3 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x3[[i, j]] += x2[[i, k]] * x_scaled[[k, j]];
}
}
}
let mut x4 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x4[[i, j]] += x3[[i, k]] * x_scaled[[k, j]];
}
}
}
let mut x5 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x5[[i, j]] += x4[[i, k]] * x_scaled[[k, j]];
}
}
}
let mut x6 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x6[[i, j]] += x5[[i, k]] * x_scaled[[k, j]];
}
}
}
let mut log_scaled = Array2::<F>::zeros((n, n));
let half = F::from(0.5).expect("Operation failed");
let third = F::from(1.0 / 3.0).expect("Operation failed");
let fourth = F::from(0.25).expect("Operation failed");
let fifth = F::from(0.2).expect("Operation failed");
let sixth = F::from(1.0 / 6.0).expect("Operation failed");
for i in 0..n {
for j in 0..n {
log_scaled[[i, j]] = x_scaled[[i, j]] - half * x2[[i, j]] + third * x3[[i, j]]
- fourth * x4[[i, j]]
+ fifth * x5[[i, j]]
- sixth * x6[[i, j]];
}
}
let scale_factor = F::from(2.0_f64.powi(scaling_k)).expect("Operation failed");
for i in 0..n {
for j in 0..n {
log_scaled[[i, j]] *= scale_factor;
}
}
return Ok(log_scaled);
}
let mut x = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
x[[i, j]] = a[[i, j]] - identity[[i, j]];
}
}
let mut x2 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x2[[i, j]] += x[[i, k]] * x[[k, j]];
}
}
}
let mut x3 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x3[[i, j]] += x2[[i, k]] * x[[k, j]];
}
}
}
let mut x4 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x4[[i, j]] += x3[[i, k]] * x[[k, j]];
}
}
}
let mut x5 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x5[[i, j]] += x4[[i, k]] * x[[k, j]];
}
}
}
let mut x6 = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
x6[[i, j]] += x5[[i, k]] * x[[k, j]];
}
}
}
let mut result = Array2::<F>::zeros((n, n));
let half = F::from(0.5).expect("Operation failed");
let third = F::from(1.0 / 3.0).expect("Operation failed");
let fourth = F::from(0.25).expect("Operation failed");
let fifth = F::from(0.2).expect("Operation failed");
let sixth = F::from(1.0 / 6.0).expect("Operation failed");
for i in 0..n {
for j in 0..n {
result[[i, j]] = x[[i, j]] - half * x2[[i, j]] + third * x3[[i, j]]
- fourth * x4[[i, j]]
+ fifth * x5[[i, j]]
- sixth * x6[[i, j]];
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn logm_parallel<F>(a: &ArrayView2<F>, workers: Option<usize>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
use crate::parallel;
parallel::configure_workers(workers);
const PARALLEL_THRESHOLD: usize = 50;
if a.nrows() < PARALLEL_THRESHOLD || a.ncols() < PARALLEL_THRESHOLD {
return logm(a);
}
logm_impl_parallel(a)
}
#[allow(dead_code)]
fn logm_impl_parallel<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
logm_impl(a)
}
#[allow(dead_code)]
pub fn sqrtm<F>(a: &ArrayView2<F>, maxiter: usize, tol: F) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
sqrtm_impl(a, maxiter, tol)
}
#[allow(dead_code)]
fn sqrtm_impl<F>(a: &ArrayView2<F>, maxiter: usize, tol: F) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
validate_decomposition(a, "Matrix square root computation", true)?;
let n = a.nrows();
if n == 1 {
let val = a[[0, 0]];
if val < F::zero() {
return Err(LinalgError::InvalidInputError(
"Cannot compute real square root of negative number".to_string(),
));
}
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = val.sqrt();
return Ok(result);
}
let mut is_diagonal = true;
for i in 0..n {
for j in 0..n {
if i != j && a[[i, j]].abs() > F::epsilon() {
is_diagonal = false;
break;
}
}
if !is_diagonal {
break;
}
}
if is_diagonal {
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
if a[[i, i]] < F::zero() {
return Err(LinalgError::InvalidInputError(
"Cannot compute real square root of matrix with negative eigenvalues"
.to_string(),
));
}
result[[i, i]] = a[[i, i]].sqrt();
}
return Ok(result);
}
let mut x = a.to_owned();
let mut y = Array2::eye(n);
for _ in 0..maxiter {
let x_prev = x.clone();
let y_inv = solve_multiple(&y.view(), &Array2::eye(n).view(), None)?;
let x_inv = solve_multiple(&x.view(), &Array2::eye(n).view(), None)?;
let mut x_new = Array2::<F>::zeros((n, n));
let mut y_new = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
x_new[[i, j]] =
(x[[i, j]] + y_inv[[i, j]]) * F::from(0.5).expect("Operation failed");
y_new[[i, j]] =
(y[[i, j]] + x_inv[[i, j]]) * F::from(0.5).expect("Operation failed");
}
}
x = x_new;
y = y_new;
let mut max_diff = F::zero();
for i in 0..n {
for j in 0..n {
let diff = (x[[i, j]] - x_prev[[i, j]]).abs();
if diff > max_diff {
max_diff = diff;
}
}
}
if max_diff < tol {
break;
}
}
Ok(x)
}
#[allow(dead_code)]
pub fn sqrtm_parallel<F>(
a: &ArrayView2<F>,
maxiter: usize,
tol: F,
workers: Option<usize>,
) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
{
use crate::parallel;
parallel::configure_workers(workers);
const PARALLEL_THRESHOLD: usize = 50;
if a.nrows() < PARALLEL_THRESHOLD {
return sqrtm(a, maxiter, tol);
}
sqrtm_impl(a, maxiter, tol)
}
#[allow(dead_code)]
pub fn matrix_power<F>(a: &ArrayView2<F>, p: F) -> LinalgResult<Array2<F>>
where
F: Float
+ NumAssign
+ Sum
+ One
+ 'static
+ Send
+ Sync
+ scirs2_core::ndarray::ScalarOperand
+ std::fmt::Display,
{
validate_decomposition(a, "Matrix power computation", true)?;
let n = a.nrows();
if p.abs() < F::epsilon() {
return Ok(Array2::eye(n));
}
if (p - F::one()).abs() < F::epsilon() {
return Ok(a.to_owned());
}
let mut is_diagonal = true;
for i in 0..n {
for j in 0..n {
if i != j && a[[i, j]].abs() > F::epsilon() {
is_diagonal = false;
break;
}
}
if !is_diagonal {
break;
}
}
if is_diagonal {
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
let val = a[[i, i]];
if val < F::zero() && !is_integer(p) {
return Err(LinalgError::InvalidInputError(
"Cannot compute real fractional power of negative number".to_string(),
));
}
result[[i, i]] = val.powf(p);
}
return Ok(result);
}
if is_integer(p) {
let int_p = p.to_i32().unwrap_or(0);
if int_p >= 0 {
let mut result = Array2::eye(n);
let mut base = a.to_owned();
let mut exp = int_p as u32;
while exp > 0 {
if exp % 2 == 1 {
let mut temp = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
temp[[i, j]] += result[[i, k]] * base[[k, j]];
}
}
}
result = temp;
}
let mut temp = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
temp[[i, j]] += base[[i, k]] * base[[k, j]];
}
}
}
base = temp;
exp /= 2;
}
return Ok(result);
}
}
super::fractional::fractionalmatrix_power(a, p, "eigen")
}
fn is_integer<F: Float>(x: F) -> bool {
(x - x.round()).abs() < F::from(1e-10).unwrap_or(F::epsilon())
}