#![cfg(feature = "lapack")]
use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use crate::new_modules::matrix_decomp::qr::{identity_matrix, qr};
use num_traits::Float;
use scirs2_core::ndarray::ArrayView2;
use std::fmt::Debug;
pub fn schur<T>(a: &Array<T>) -> Result<(Array<T>, Array<T>)>
where
T: Float
+ Clone
+ Debug
+ std::ops::AddAssign
+ std::ops::MulAssign
+ std::ops::DivAssign
+ std::ops::SubAssign
+ std::fmt::Display,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"Schur decomposition requires a square matrix".to_string(),
));
}
let _a_view: ArrayView2<T> = a.view_2d()?;
let n = shape[0];
let mut balanced = a.clone();
let mut d = vec![num_traits::One::one(); n];
let max_iterations = 5; let tol = num_traits::NumCast::from(0.95).expect("0.95 should convert to float type");
for _ in 0..max_iterations {
let mut converged = true;
for i in 0..n {
let mut row_sum: T = <T as num_traits::Zero>::zero();
let mut col_sum: T = <T as num_traits::Zero>::zero();
for j in 0..n {
if i != j {
row_sum += num_traits::Float::abs(balanced.get(&[i, j])?);
col_sum += num_traits::Float::abs(balanced.get(&[j, i])?);
}
}
if row_sum > <T as num_traits::Zero>::zero()
&& col_sum > <T as num_traits::Zero>::zero()
{
let f = num_traits::Float::sqrt(col_sum / row_sum);
let s = if f > tol && f < <T as num_traits::One>::one() / tol {
f
} else {
<T as num_traits::One>::one()
};
if s != <T as num_traits::One>::one() {
converged = false;
d[i] *= s;
for j in 0..n {
let val = balanced.get(&[i, j])?;
balanced.set(&[i, j], val * s)?;
}
for j in 0..n {
let val = balanced.get(&[j, i])?;
balanced.set(&[j, i], val / s)?;
}
}
}
}
if converged {
break;
}
}
let mut h = balanced.clone();
let mut q = identity_matrix(n);
for k in 0..(n - 2) {
let mut x = Vec::with_capacity(n - k - 1);
for i in (k + 1)..n {
x.push(h.get(&[i, k])?);
}
let mut sum_xx: T = <T as num_traits::Zero>::zero();
for &val in &x {
sum_xx += val * val;
}
let x_norm = num_traits::Float::sqrt(sum_xx);
let eps = num_traits::Float::epsilon();
if x_norm > eps {
let alpha = if x[0] >= num_traits::Zero::zero() {
-x_norm
} else {
x_norm
};
let mut v = x.clone();
v[0] -= alpha;
let mut sum_vv: T = <T as num_traits::Zero>::zero();
for &val in &v {
sum_vv += val * val;
}
let v_norm = num_traits::Float::sqrt(sum_vv);
if v_norm > eps {
for val in &mut v {
*val /= v_norm;
}
for _j in k..n {
let mut w: Vec<T> = vec![<T as num_traits::Zero>::zero(); n - k - 1];
for i in 0..(n - k - 1) {
for l in 0..(n - k - 1) {
w[i] += h.get(&[i + k + 1, l + k + 1])? * v[l];
}
}
for i in 0..(n - k - 1) {
for l in 0..(n - k - 1) {
let h_val = h.get(&[i + k + 1, l + k + 1])?;
h.set(
&[i + k + 1, l + k + 1],
h_val
- <T as num_traits::NumCast>::from(2.0)
.expect("2.0 should convert to float type")
* v[i]
* w[l],
)?;
}
}
}
for i in 0..n {
let mut q_row_dot_v: T = num_traits::Zero::zero();
for l in 0..(n - k - 1) {
q_row_dot_v += q.get(&[i, l + k + 1])? * v[l];
}
for j in (k + 1)..n {
let q_val = q.get(&[i, j])?;
q.set(
&[i, j],
q_val
- <T as num_traits::NumCast>::from(2.0)
.expect("2.0 should convert to float type")
* q_row_dot_v
* v[j - k - 1],
)?;
}
}
}
}
}
let max_iterations = 50 * n; let mut iterations = 0;
let tol = T::epsilon()
* num_traits::NumCast::from(n * 10).expect("n * 10 should convert to float type");
while iterations < max_iterations {
let mut done = true;
for i in 0..(n - 1) {
if num_traits::Float::abs(h.get(&[i + 1, i])?) > tol {
done = false;
break;
}
}
if done {
break;
}
let shift = h.get(&[n - 1, n - 1])?;
for i in 0..n {
let diag = h.get(&[i, i])?;
h.set(&[i, i], diag - shift)?;
}
let (q_i, r_i) = qr(&h)?;
h = r_i.matmul(&q_i)?;
for i in 0..n {
let diag = h.get(&[i, i])?;
h.set(&[i, i], diag + shift)?;
}
q = q.matmul(&q_i)?;
iterations += 1;
}
if d.iter().any(|&s| s != num_traits::One::one()) {
for i in 0..n {
for j in 0..n {
let val = q.get(&[i, j])?;
q.set(&[i, j], val * d[j] / d[i])?;
}
}
}
for i in 1..n {
for j in 0..(i - 1) {
let val = h.get(&[i, j])?;
if num_traits::Float::abs(val) < tol {
h.set(&[i, j], num_traits::Zero::zero())?;
}
}
}
Ok((q, h))
}