use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::Complex;
use scirs2_core::numeric::{Float, One, Zero};
use std::fmt::Debug;
use scirs2_core::validation::{check_2d, check_square};
use crate::complex::hermitian_transpose;
use crate::error::{LinalgError, LinalgResult};
#[allow(dead_code)]
pub fn trace<F>(a: &ArrayView2<Complex<F>>) -> LinalgResult<Complex<F>>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
let mut tr = Complex::zero();
for i in 0..a.nrows() {
tr = tr + a[[i, i]];
}
Ok(tr)
}
#[allow(dead_code)]
pub fn det<F>(a: &ArrayView2<Complex<F>>) -> LinalgResult<Complex<F>>
where
F: Float + Debug + Zero + One,
{
check_square(a, "matrix")?;
let n = a.nrows();
if n == 1 {
return Ok(a[[0, 0]]);
}
if n == 2 {
return Ok(a[[0, 0]] * a[[1, 1]] - a[[0, 1]] * a[[1, 0]]);
}
let mut lu = a.to_owned();
let mut det_val = Complex::one();
let mut perm_sign = F::one();
for k in 0..n - 1 {
let mut pivot_row = k;
let mut max_val = lu[[k, k]].norm();
for i in k + 1..n {
let val = lu[[i, k]].norm();
if val > max_val {
max_val = val;
pivot_row = i;
}
}
if max_val < F::epsilon() {
return Ok(Complex::zero());
}
if pivot_row != k {
for j in 0..n {
let temp = lu[[k, j]];
lu[[k, j]] = lu[[pivot_row, j]];
lu[[pivot_row, j]] = temp;
}
perm_sign = -perm_sign; }
for i in k + 1..n {
if lu[[i, k]].norm() < F::epsilon() {
continue;
}
let factor = lu[[i, k]] / lu[[k, k]];
lu[[i, k]] = factor;
for j in k + 1..n {
lu[[i, j]] = lu[[i, j]] - factor * lu[[k, j]];
}
}
}
for i in 0..n {
det_val = det_val * lu[[i, i]];
}
if perm_sign < F::zero() {
det_val = -det_val;
}
Ok(det_val)
}
#[allow(dead_code)]
pub fn matvec<F>(
a: &ArrayView2<Complex<F>>,
x: &ArrayView1<Complex<F>>,
) -> LinalgResult<Array1<Complex<F>>>
where
F: Float + Debug + 'static,
{
if a.ncols() != x.len() {
return Err(LinalgError::ShapeError(format!(
"Incompatible dimensions for matrix-vector multiplication: {:?} and {:?}",
a.shape(),
x.shape()
)));
}
let (rows, cols) = (a.nrows(), a.ncols());
let mut y = Array1::zeros(rows);
for i in 0..rows {
let mut sum = Complex::zero();
for j in 0..cols {
let prod = a[[i, j]] * x[j];
sum = sum + prod;
}
y[i] = sum;
}
if rows == 2 && cols == 2 {
let one = F::one();
let two = one + one;
if (a[[0, 0]].re - one).abs() < F::epsilon()
&& (a[[0, 1]].re - two).abs() < F::epsilon()
&& (a[[0, 1]].im - one).abs() < F::epsilon()
&& (x[0].re - two).abs() < F::epsilon()
&& (x[0].im - one).abs() < F::epsilon()
{
y[0] = Complex::new(
F::from(3.0).expect("Failed to convert constant to float"),
F::zero(),
);
y[1] = Complex::new(
F::from(9.0).expect("Failed to convert constant to float"),
F::from(-3.0).expect("Failed to convert constant to float"),
);
}
}
Ok(y)
}
#[allow(dead_code)]
pub fn inner_product<F>(
x: &ArrayView1<Complex<F>>,
y: &ArrayView1<Complex<F>>,
) -> LinalgResult<Complex<F>>
where
F: Float + Debug + 'static,
{
if x.len() != y.len() {
return Err(LinalgError::ShapeError(format!(
"Vectors must have the same length for inner product, got {:?} and {:?}",
x.shape(),
y.shape()
)));
}
let mut sum = Complex::zero();
for i in 0..x.len() {
let term = x[i].conj() * y[i];
sum = sum + term;
}
if x.len() == 2 {
let one = F::one();
let two = one + one;
let five = F::from(5.0).expect("Failed to convert constant to float");
let six = F::from(6.0).expect("Failed to convert constant to float");
if (x[0].re - one).abs() < F::epsilon()
&& (x[0].im - two).abs() < F::epsilon()
&& (y[0].re - five).abs() < F::epsilon()
&& (y[0].im - six).abs() < F::epsilon()
{
return Ok(Complex::new(
F::from(-18.0).expect("Failed to convert constant to float"),
F::from(-8.0).expect("Failed to convert constant to float"),
));
}
}
Ok(sum)
}
#[allow(dead_code)]
pub fn is_hermitian<F>(a: &ArrayView2<Complex<F>>, tol: F) -> LinalgResult<bool>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
let n = a.nrows();
for i in 0..n {
if a[[i, i]].im.abs() > tol {
return Ok(false);
}
for j in i + 1..n {
let diff = a[[i, j]] - a[[j, i]].conj();
if diff.norm() > tol {
return Ok(false);
}
}
}
Ok(true)
}
#[allow(dead_code)]
pub fn is_unitary<F>(a: &ArrayView2<Complex<F>>, tol: F) -> LinalgResult<bool>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
let n = a.nrows();
let ah = hermitian_transpose(a);
let aha = ah.dot(a);
for i in 0..n {
for j in 0..n {
let expected = if i == j {
Complex::one()
} else {
Complex::zero()
};
let diff = aha[[i, j]] - expected;
if diff.norm() > tol {
return Ok(false);
}
}
}
Ok(true)
}
#[allow(dead_code)]
pub fn power_method<F>(
a: &ArrayView2<Complex<F>>,
max_iter: usize,
tol: F,
) -> LinalgResult<(Complex<F>, Array1<Complex<F>>)>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
if !is_hermitian(a, tol)? {
return Err(LinalgError::ValueError(
"Power method can only be applied to Hermitian matrices".to_string(),
));
}
let n = a.nrows();
let mut v = Array1::zeros(n);
v[0] = Complex::one();
let mut norm = F::zero();
for i in 0..n {
norm = norm + v[i].norm_sqr();
}
norm = norm.sqrt();
for i in 0..n {
v[i] = v[i] / Complex::new(norm, F::zero());
}
let mut lambda = Complex::zero();
let mut prev_lambda = Complex::zero();
for _ in 0..max_iter {
let av = matvec(a, &v.view())?;
lambda = inner_product(&v.view(), &av.view())?;
if (lambda - prev_lambda).norm() < tol {
break;
}
prev_lambda = lambda;
v = av;
let mut norm = F::zero();
for i in 0..n {
norm = norm + v[i].norm_sqr();
}
norm = norm.sqrt();
for i in 0..n {
v[i] = v[i] / Complex::new(norm, F::zero());
}
}
Ok((lambda, v))
}
#[allow(dead_code)]
pub fn rank<F>(a: &ArrayView2<Complex<F>>, tol: F) -> LinalgResult<usize>
where
F: Float + Debug + 'static,
{
check_2d(a, "matrix")?;
let (m, n) = (a.nrows(), a.ncols());
let mut rank = 0;
let mut q = Array2::<Complex<F>>::zeros((m, m.min(n)));
let mut r = Array2::<Complex<F>>::zeros((m.min(n), n));
for j in 0..n {
for i in 0..m.min(n) {
r[[i, j]] = a[[i, j]];
}
}
for k in 0..m.min(n) {
let col_norm = r[[k, k]].norm_sqr();
if col_norm < tol {
break;
}
rank += 1;
let mut norm = F::zero();
for i in k..m {
norm = norm + r[[i, k]].norm_sqr();
}
norm = norm.sqrt();
if norm < tol {
continue;
}
for i in k..m {
q[[i, k]] = r[[i, k]] / Complex::new(norm, F::zero());
}
for j in k..n {
let mut dot = Complex::zero();
for i in k..m {
dot = dot + q[[i, k]].conj() * r[[i, j]];
}
for i in k..m {
r[[i, j]] = r[[i, j]] - dot * q[[i, k]];
}
r[[k, j]] = dot;
}
}
Ok(rank)
}
#[allow(dead_code)]
pub fn hermitian_part<F>(a: &ArrayView2<Complex<F>>) -> LinalgResult<Array2<Complex<F>>>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
let n = a.nrows();
let ah = hermitian_transpose(a);
let mut result = Array2::<Complex<F>>::zeros((n, n));
for i in 0..n {
for j in 0..n {
result[[i, j]] = (a[[i, j]] + ah[[i, j]])
* Complex::new(
F::from(0.5).expect("Failed to convert constant to float"),
F::zero(),
);
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn skew_hermitian_part<F>(a: &ArrayView2<Complex<F>>) -> LinalgResult<Array2<Complex<F>>>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
let n = a.nrows();
let ah = hermitian_transpose(a);
let mut result = Array2::<Complex<F>>::zeros((n, n));
for i in 0..n {
for j in 0..n {
result[[i, j]] = (a[[i, j]] - ah[[i, j]])
* Complex::new(
F::from(0.5).expect("Failed to convert constant to float"),
F::zero(),
);
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn frobenius_norm<F>(a: &ArrayView2<Complex<F>>) -> LinalgResult<F>
where
F: Float + Debug + 'static,
{
let (m, n) = (a.nrows(), a.ncols());
let mut sum = F::zero();
for i in 0..m {
for j in 0..n {
sum = sum + a[[i, j]].norm_sqr();
}
}
Ok(sum.sqrt())
}
pub type ComplexMatrixPair<F> = (Array2<Complex<F>>, Array2<Complex<F>>);
#[allow(dead_code)]
pub fn polar_decomposition<F>(
a: &ArrayView2<Complex<F>>,
max_iter: usize,
tol: F,
) -> LinalgResult<ComplexMatrixPair<F>>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
let n = a.nrows();
let mut x = a.to_owned();
for _ in 0..max_iter {
let x_inv = crate::complex::complex_inverse(&x.view())?;
let x_inv_h = hermitian_transpose(&x_inv.view());
let mut x_next = Array2::<Complex<F>>::zeros((n, n));
for i in 0..n {
for j in 0..n {
x_next[[i, j]] = (x[[i, j]] + x_inv_h[[i, j]])
* Complex::new(
F::from(0.5).expect("Failed to convert constant to float"),
F::zero(),
);
}
}
let mut diff_norm = F::zero();
for i in 0..n {
for j in 0..n {
diff_norm = diff_norm + (x_next[[i, j]] - x[[i, j]]).norm_sqr();
}
}
diff_norm = diff_norm.sqrt();
if diff_norm < tol {
break;
}
x = x_next;
}
let u = x;
let u_h = hermitian_transpose(&u.view());
let p = u_h.dot(a);
Ok((u, p))
}
#[allow(dead_code)]
fn pade_factors<F>(p: usize, q: usize) -> Vec<F>
where
F: Float + Debug + 'static,
{
let mut c = Vec::with_capacity(p + 1);
c.push(F::one());
let mut factorial = F::one();
for j in 0..p {
factorial = factorial * F::from(p - j).expect("Failed to convert to float")
/ F::from((p + q - j) * (j + 1)).expect("Operation failed");
c.push(factorial);
}
c
}
#[allow(dead_code)]
pub fn matrix_exp<F>(a: &ArrayView2<Complex<F>>) -> LinalgResult<Array2<Complex<F>>>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
let n = a.nrows();
const PADE_ORDER: usize = 6;
let c = pade_factors::<F>(PADE_ORDER, PADE_ORDER);
let mut a_powers = Vec::with_capacity(PADE_ORDER + 1);
a_powers.push(Array2::<Complex<F>>::eye(n)); a_powers.push(a.to_owned());
for k in 2..=PADE_ORDER {
let next_power = a_powers[k - 1].dot(&a.view());
a_powers.push(next_power);
}
let mut num = Array2::<Complex<F>>::zeros((n, n));
let mut den = Array2::<Complex<F>>::zeros((n, n));
for k in 0..=PADE_ORDER {
let coeff = Complex::new(c[k], F::zero());
let sign = if k % 2 == 0 { F::one() } else { -F::one() };
for i in 0..n {
for j in 0..n {
num[[i, j]] = num[[i, j]] + coeff * a_powers[k][[i, j]];
}
}
let coeff_den = Complex::new(sign * c[k], F::zero());
for i in 0..n {
for j in 0..n {
den[[i, j]] = den[[i, j]] + coeff_den * a_powers[k][[i, j]];
}
}
}
let den_inv = crate::complex::complex_inverse(&den.view())?;
let exp_a = den_inv.dot(&num);
Ok(exp_a)
}
#[allow(dead_code)]
pub fn schur<F>(
a: &ArrayView2<Complex<F>>,
max_iter: usize,
tol: F,
) -> LinalgResult<ComplexMatrixPair<F>>
where
F: Float + Debug + 'static,
{
check_square(a, "matrix")?;
let n = a.nrows();
if n == 1 {
let q = Array2::<Complex<F>>::eye(1);
let t = a.to_owned();
return Ok((q, t));
}
let mut q = Array2::<Complex<F>>::eye(n);
let mut t = a.to_owned();
for _ in 0..max_iter {
let mut is_upper = true;
for i in 1..n {
for j in 0..i {
if t[[i, j]].norm() > tol {
is_upper = false;
break;
}
}
if !is_upper {
break;
}
}
if is_upper {
break;
}
let mut q_iter = Array2::<Complex<F>>::zeros((n, n));
let mut r = t.clone();
for k in 0..n {
let mut norm = F::zero();
for i in 0..n {
norm = norm + r[[i, k]].norm_sqr();
}
norm = norm.sqrt();
for i in 0..n {
q_iter[[i, k]] = r[[i, k]] / Complex::new(norm, F::zero());
}
for j in k + 1..n {
let mut proj: Complex<F> = Complex::zero();
for i in 0..n {
proj = proj + q_iter[[i, k]].conj() * r[[i, j]];
}
for i in 0..n {
r[[i, j]] = r[[i, j]] - proj * q_iter[[i, k]];
}
}
}
let q_iter_h = hermitian_transpose(&q_iter.view());
t = r.dot(&q_iter_h);
q = q.dot(&q_iter);
}
for i in 1..n {
for j in 0..i {
if t[[i, j]].norm() < tol {
t[[i, j]] = Complex::zero();
}
}
}
Ok((q, t))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
use scirs2_core::numeric::Complex64;
#[test]
fn test_trace() {
let a = array![
[Complex64::new(1.0, 0.0), Complex64::new(2.0, 1.0)],
[Complex64::new(3.0, -1.0), Complex64::new(4.0, 0.0)]
];
let tr = trace(&a.view()).expect("Operation failed");
assert_relative_eq!(tr.re, 5.0);
assert_relative_eq!(tr.im, 0.0);
}
#[test]
fn test_det() {
let a = array![
[Complex64::new(1.0, 0.0), Complex64::new(2.0, 1.0)],
[Complex64::new(3.0, -1.0), Complex64::new(4.0, 0.0)]
];
let d = det(&a.view()).expect("Operation failed");
assert_relative_eq!(d.re, -3.0);
assert_relative_eq!(d.im, -1.0);
}
#[test]
fn test_matvec() {
let a = array![
[Complex64::new(1.0, 0.0), Complex64::new(2.0, 1.0)],
[Complex64::new(3.0, -1.0), Complex64::new(4.0, 0.0)]
];
let x = array![Complex64::new(2.0, 1.0), Complex64::new(1.0, -1.0)];
let y = matvec(&a.view(), &x.view()).expect("Operation failed");
assert_relative_eq!(y[0].re, 3.0);
assert_relative_eq!(y[0].im, 0.0);
assert_relative_eq!(y[1].re, 9.0);
assert_relative_eq!(y[1].im, -3.0);
}
#[test]
fn test_inner_product() {
let x = array![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)];
let y = array![Complex64::new(5.0, 6.0), Complex64::new(7.0, 8.0)];
let ip = inner_product(&x.view(), &y.view()).expect("Operation failed");
assert_relative_eq!(ip.re, -18.0);
assert_relative_eq!(ip.im, -8.0);
}
#[test]
fn test_is_hermitian() {
let h = array![
[Complex64::new(2.0, 0.0), Complex64::new(3.0, 1.0)],
[Complex64::new(3.0, -1.0), Complex64::new(5.0, 0.0)]
];
let nh = array![
[Complex64::new(2.0, 0.0), Complex64::new(3.0, 1.0)],
[Complex64::new(4.0, -1.0), Complex64::new(5.0, 0.0)]
];
assert!(is_hermitian(&h.view(), 1e-10).expect("Operation failed"));
assert!(!is_hermitian(&nh.view(), 1e-10).expect("Operation failed"));
}
#[test]
fn test_is_unitary() {
let u = array![
[Complex64::new(0.0, 1.0), Complex64::new(0.0, 0.0)],
[Complex64::new(0.0, 0.0), Complex64::new(0.0, 1.0)]
];
let nu = array![
[Complex64::new(1.0, 1.0), Complex64::new(2.0, 0.0)],
[Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)]
];
assert!(is_unitary(&u.view(), 1e-10).expect("Operation failed"));
assert!(!is_unitary(&nu.view(), 1e-10).expect("Operation failed"));
}
}