use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Complex, Float, NumAssign, One, Zero};
use std::fmt::{Debug, Display};
use std::iter::Sum;
pub trait SchurFloat:
Float
+ NumAssign
+ Debug
+ Display
+ scirs2_core::ndarray::ScalarOperand
+ Sum
+ 'static
+ Send
+ Sync
{
}
impl<T> SchurFloat for T where
T: Float
+ NumAssign
+ Debug
+ Display
+ scirs2_core::ndarray::ScalarOperand
+ Sum
+ 'static
+ Send
+ Sync
{
}
#[derive(Debug, Clone)]
pub struct RealSchurResult<T> {
pub q: Array2<T>,
pub t: Array2<T>,
}
#[derive(Debug, Clone)]
pub struct ComplexSchurResult<T> {
pub q: Array2<Complex<T>>,
pub t: Array2<Complex<T>>,
}
#[derive(Debug, Clone)]
pub struct SchurEigenResult<T> {
pub eigenvalues: Array1<Complex<T>>,
pub eigenvectors: Array2<Complex<T>>,
}
pub fn real_schur_decompose<T: SchurFloat>(
a: &ArrayView2<T>,
max_iter: usize,
tol: T,
) -> LinalgResult<RealSchurResult<T>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"real_schur_decompose: A must be square".into(),
));
}
if n == 0 {
return Ok(RealSchurResult {
q: Array2::zeros((0, 0)),
t: Array2::zeros((0, 0)),
});
}
if n == 1 {
return Ok(RealSchurResult {
q: Array2::eye(1),
t: a.to_owned(),
});
}
let (mut h, mut q) = hessenberg_decompose(a)?;
let max_total = max_iter * n;
let mut total_iters = 0usize;
let mut ihi = n; let mut _stall_count = 0usize;
while ihi > 1 {
if total_iters >= max_total {
return Err(LinalgError::ConvergenceError(format!(
"real_schur_decompose: failed to converge after {total_iters} iterations"
)));
}
total_iters += 1;
let mut new_ihi = ihi;
while new_ihi > 1 {
let k = new_ihi - 1;
let off_diag = h[[k, k - 1]].abs();
let diag_sum = h[[k - 1, k - 1]].abs() + h[[k, k]].abs();
let threshold = if diag_sum > T::epsilon() {
tol * diag_sum
} else {
tol
};
if off_diag <= threshold {
h[[k, k - 1]] = T::zero();
new_ihi = k;
} else {
break;
}
}
if new_ihi < ihi {
if ihi - new_ihi >= 2 {
if new_ihi + 1 < ihi && h[[new_ihi + 1, new_ihi]].abs() > tol {
schur_split_2x2(&mut h, &mut q, new_ihi, n, tol);
}
}
ihi = new_ihi;
_stall_count = 0;
continue;
}
let mut ilo = 0usize;
for k in (1..ihi - 1).rev() {
let off_diag = h[[k, k - 1]].abs();
let diag_sum = h[[k - 1, k - 1]].abs() + h[[k, k]].abs();
let threshold = if diag_sum > T::epsilon() {
tol * diag_sum
} else {
tol
};
if off_diag <= threshold {
h[[k, k - 1]] = T::zero();
ilo = k;
break;
}
}
let active_size = ihi - ilo;
if active_size <= 1 {
ihi = ilo;
_stall_count = 0;
continue;
}
if active_size == 2 {
schur_split_2x2(&mut h, &mut q, ilo, n, tol);
ihi = ilo;
_stall_count = 0;
continue;
}
_stall_count += 1;
francis_qr_step(&mut h, &mut q, ilo, ihi, n, tol)?;
for k in (ilo + 1..ihi).rev() {
let off_diag = h[[k, k - 1]].abs();
let diag_sum = h[[k - 1, k - 1]].abs() + h[[k, k]].abs();
let threshold = if diag_sum > T::epsilon() {
tol * diag_sum
} else {
tol
};
if off_diag <= threshold {
h[[k, k - 1]] = T::zero();
}
}
}
Ok(RealSchurResult { q, t: h })
}
fn schur_split_2x2<T: SchurFloat>(
h: &mut Array2<T>,
q: &mut Array2<T>,
pos: usize,
n: usize,
_tol: T,
) {
let a11 = h[[pos, pos]];
let a12 = h[[pos, pos + 1]];
let a21 = h[[pos + 1, pos]];
let a22 = h[[pos + 1, pos + 1]];
let tr = a11 + a22;
let det = a11 * a22 - a12 * a21;
let four = T::from(4.0).unwrap_or(T::one());
let disc = tr * tr - four * det;
if disc < T::zero() {
return;
}
if a21.abs() < T::epsilon() {
return; }
let half = T::from(0.5).unwrap_or(T::one() / (T::one() + T::one()));
let sq = disc.sqrt();
let lam1 = (tr + sq) * half;
let lam2 = (tr - sq) * half;
let shift = if (lam1 - a22).abs() < (lam2 - a22).abs() {
lam1
} else {
lam2
};
let x = a11 - shift;
let y = a21;
let r = (x * x + y * y).sqrt();
if r < T::epsilon() {
return;
}
let c = x / r;
let s = y / r;
for col in 0..n {
let t1 = h[[pos, col]];
let t2 = h[[pos + 1, col]];
h[[pos, col]] = c * t1 + s * t2;
h[[pos + 1, col]] = -s * t1 + c * t2;
}
for row in 0..n {
let t1 = h[[row, pos]];
let t2 = h[[row, pos + 1]];
h[[row, pos]] = c * t1 + s * t2;
h[[row, pos + 1]] = -s * t1 + c * t2;
}
for row in 0..n {
let t1 = q[[row, pos]];
let t2 = q[[row, pos + 1]];
q[[row, pos]] = c * t1 + s * t2;
q[[row, pos + 1]] = -s * t1 + c * t2;
}
}
fn francis_qr_step<T: SchurFloat>(
h: &mut Array2<T>,
q: &mut Array2<T>,
ilo: usize,
ihi: usize,
n: usize,
_tol: T,
) -> LinalgResult<()> {
let p = ihi; let two = T::one() + T::one();
let s = h[[p - 2, p - 2]] + h[[p - 1, p - 1]]; let t_val = h[[p - 2, p - 2]] * h[[p - 1, p - 1]] - h[[p - 2, p - 1]] * h[[p - 1, p - 2]];
let h00 = h[[ilo, ilo]];
let h01 = h[[ilo, ilo + 1]];
let h10 = h[[ilo + 1, ilo]];
let h11 = h[[ilo + 1, ilo + 1]];
let h21 = if ilo + 2 < p {
h[[ilo + 2, ilo + 1]]
} else {
T::zero()
};
let mut x = h00 * h00 + h01 * h10 - s * h00 + t_val;
let mut y = h10 * (h00 + h11 - s);
let mut z = if ilo + 2 < p { h10 * h21 } else { T::zero() };
for k in ilo..(p.saturating_sub(2)) {
let nr = if k + 3 <= p { 3 } else { 2 };
if nr == 3 {
let (v, beta) = householder_vector_3(x, y, z);
if beta.abs() > T::epsilon() {
let col_start = if k > ilo { k - 1 } else { ilo };
for col in col_start..n {
let dot = v[0] * h[[k, col]] + v[1] * h[[k + 1, col]] + v[2] * h[[k + 2, col]];
h[[k, col]] -= beta * v[0] * dot;
h[[k + 1, col]] -= beta * v[1] * dot;
h[[k + 2, col]] -= beta * v[2] * dot;
}
let row_end = (k + 4).min(ihi);
for row in 0..row_end {
let dot = h[[row, k]] * v[0] + h[[row, k + 1]] * v[1] + h[[row, k + 2]] * v[2];
h[[row, k]] -= beta * dot * v[0];
h[[row, k + 1]] -= beta * dot * v[1];
h[[row, k + 2]] -= beta * dot * v[2];
}
for row in 0..n {
let dot = q[[row, k]] * v[0] + q[[row, k + 1]] * v[1] + q[[row, k + 2]] * v[2];
q[[row, k]] -= beta * dot * v[0];
q[[row, k + 1]] -= beta * dot * v[1];
q[[row, k + 2]] -= beta * dot * v[2];
}
}
} else {
let r = (x * x + y * y).sqrt();
if r > T::epsilon() {
let c = x / r;
let s_val = y / r;
let col_start = if k > ilo { k - 1 } else { ilo };
for col in col_start..n {
let t1 = h[[k, col]];
let t2 = h[[k + 1, col]];
h[[k, col]] = c * t1 + s_val * t2;
h[[k + 1, col]] = -s_val * t1 + c * t2;
}
let row_end = (k + 3).min(p);
for row in 0..row_end {
let t1 = h[[row, k]];
let t2 = h[[row, k + 1]];
h[[row, k]] = c * t1 + s_val * t2;
h[[row, k + 1]] = -s_val * t1 + c * t2;
}
for row in 0..n {
let t1 = q[[row, k]];
let t2 = q[[row, k + 1]];
q[[row, k]] = c * t1 + s_val * t2;
q[[row, k + 1]] = -s_val * t1 + c * t2;
}
}
}
if k + 3 < p {
x = h[[k + 1, k]];
y = h[[k + 2, k]];
z = if k + 3 < p { h[[k + 3, k]] } else { T::zero() };
} else if k + 2 < p {
x = h[[k + 1, k]];
y = h[[k + 2, k]];
z = T::zero();
}
}
let _ = two;
Ok(())
}
fn householder_vector_3<T: SchurFloat>(x: T, y: T, z: T) -> ([T; 3], T) {
let norm = (x * x + y * y + z * z).sqrt();
if norm < T::epsilon() {
return ([T::one(), T::zero(), T::zero()], T::zero());
}
let sign = if x >= T::zero() { T::one() } else { -T::one() };
let v0 = x + sign * norm;
let v_norm_sq = v0 * v0 + y * y + z * z;
let two = T::one() + T::one();
let beta = two / v_norm_sq;
([v0, y, z], beta)
}
fn apply_givens_cols<T: SchurFloat>(
h: &mut Array2<T>,
r1: usize,
r2: usize,
c0: usize,
c1: usize,
c: T,
s: T,
) -> LinalgResult<()> {
for col in c0..c1 {
let a = h[[r1, col]];
let b = h[[r2, col]];
h[[r1, col]] = c * a + s * b;
h[[r2, col]] = -s * a + c * b;
}
Ok(())
}
fn apply_givens_rows<T: SchurFloat>(
h: &mut Array2<T>,
c1_idx: usize,
c2_idx: usize,
r0: usize,
r1: usize,
c: T,
s: T,
) -> LinalgResult<()> {
for row in r0..r1 {
let a = h[[row, c1_idx]];
let b = h[[row, c2_idx]];
h[[row, c1_idx]] = c * a - s * b;
h[[row, c2_idx]] = s * a + c * b;
}
Ok(())
}
fn hessenberg_decompose<T: SchurFloat>(a: &ArrayView2<T>) -> LinalgResult<(Array2<T>, Array2<T>)> {
let n = a.nrows();
let mut h = a.to_owned();
let mut q = Array2::<T>::eye(n);
for k in 0..n.saturating_sub(2) {
let col_len = n - k - 1;
if col_len == 0 {
break;
}
let x: Vec<T> = (k + 1..n).map(|i| h[[i, k]]).collect();
let (v, beta) = householder_vector(&x);
if beta.abs() < T::epsilon() {
continue;
}
for col in k..n {
let dot: T = (0..col_len).map(|i| v[i] * h[[k + 1 + i, col]]).sum();
for i in 0..col_len {
h[[k + 1 + i, col]] -= beta * v[i] * dot;
}
}
for row in 0..n {
let dot: T = (0..col_len).map(|i| h[[row, k + 1 + i]] * v[i]).sum();
for i in 0..col_len {
h[[row, k + 1 + i]] -= beta * dot * v[i];
}
}
for row in 0..n {
let dot: T = (0..col_len).map(|i| q[[row, k + 1 + i]] * v[i]).sum();
for i in 0..col_len {
q[[row, k + 1 + i]] -= beta * dot * v[i];
}
}
}
Ok((h, q))
}
fn householder_vector<T: SchurFloat>(x: &[T]) -> (Vec<T>, T) {
let n = x.len();
if n == 0 {
return (vec![], T::zero());
}
let norm: T = x.iter().map(|&xi| xi * xi).sum::<T>().sqrt();
if norm < T::epsilon() {
let mut v = vec![T::zero(); n];
v[0] = T::one();
return (v, T::zero());
}
let mut v: Vec<T> = x.to_vec();
let sign = if x[0] >= T::zero() {
T::one()
} else {
-T::one()
};
v[0] += sign * norm;
let v_norm_sq: T = v.iter().map(|&vi| vi * vi).sum();
let two = T::one() + T::one();
let beta = two / v_norm_sq;
(v, beta)
}
pub fn complex_schur_decompose<T: SchurFloat>(
a: &ArrayView2<T>,
max_iter: usize,
tol: T,
) -> LinalgResult<ComplexSchurResult<T>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"complex_schur_decompose: A must be square".into(),
));
}
let real_res = real_schur_decompose(a, max_iter, tol)?;
let q_c: Array2<Complex<T>> = real_to_complex_matrix(&real_res.q);
let mut t_c: Array2<Complex<T>> = real_to_complex_matrix(&real_res.t);
let mut q_out = q_c;
let mut k = 0usize;
while k < n {
if k + 1 < n && real_res.t[[k + 1, k]].abs() > tol {
let (lam1, lam2) = eigen2x2_complex(
real_res.t[[k, k]],
real_res.t[[k, k + 1]],
real_res.t[[k + 1, k]],
real_res.t[[k + 1, k + 1]],
);
let diff = Complex::new(real_res.t[[k, k]], T::zero()) - lam1;
let off = Complex::new(real_res.t[[k + 1, k]], T::zero());
let (gc, gs) = complex_givens(diff, off);
for col in k..n {
let a_val = t_c[[k, col]];
let b_val = t_c[[k + 1, col]];
t_c[[k, col]] = gc * a_val + gs * b_val;
t_c[[k + 1, col]] = -gs.conj() * a_val + gc.conj() * b_val;
}
for row in 0..n {
let a_val = t_c[[row, k]];
let b_val = t_c[[row, k + 1]];
t_c[[row, k]] = gc.conj() * a_val - gs * b_val;
t_c[[row, k + 1]] = gs.conj() * a_val + gc * b_val;
}
for row in 0..n {
let a_val = q_out[[row, k]];
let b_val = q_out[[row, k + 1]];
q_out[[row, k]] = gc.conj() * a_val - gs * b_val;
q_out[[row, k + 1]] = gs.conj() * a_val + gc * b_val;
}
t_c[[k, k]] = lam1;
t_c[[k + 1, k + 1]] = lam2;
t_c[[k + 1, k]] = Complex::new(T::zero(), T::zero());
k += 2;
} else {
k += 1;
}
}
Ok(ComplexSchurResult { q: q_out, t: t_c })
}
fn eigen2x2_complex<T: SchurFloat>(a: T, b: T, c: T, d: T) -> (Complex<T>, Complex<T>) {
let two = T::one() + T::one();
let tr = a + d;
let det = a * d - b * c;
let disc = tr * tr - two * two * det;
if disc >= T::zero() {
let sq = disc.sqrt();
let lam1 = Complex::new((tr + sq) / two, T::zero());
let lam2 = Complex::new((tr - sq) / two, T::zero());
(lam1, lam2)
} else {
let sq = (-disc).sqrt() / two;
let lam1 = Complex::new(tr / two, sq);
let lam2 = Complex::new(tr / two, -sq);
(lam1, lam2)
}
}
fn complex_givens<T: SchurFloat>(a: Complex<T>, b: Complex<T>) -> (Complex<T>, Complex<T>) {
let a_abs = (a.re * a.re + a.im * a.im).sqrt();
let b_abs = (b.re * b.re + b.im * b.im).sqrt();
let r_abs = (a_abs * a_abs + b_abs * b_abs).sqrt();
if r_abs < T::epsilon() {
return (
Complex::new(T::one(), T::zero()),
Complex::new(T::zero(), T::zero()),
);
}
let c = Complex::new(a_abs / r_abs, T::zero());
let s = if a_abs < T::epsilon() {
Complex::new(T::one(), T::zero())
} else {
Complex::new(a.re * b.re + a.im * b.im, a.im * b.re - a.re * b.im)
* Complex::new(T::one() / (a_abs * r_abs), T::zero())
};
(c, s)
}
fn real_to_complex_matrix<T: SchurFloat>(a: &Array2<T>) -> Array2<Complex<T>> {
let (m, n) = (a.nrows(), a.ncols());
Array2::from_shape_fn((m, n), |(i, j)| Complex::new(a[[i, j]], T::zero()))
}
pub fn schur_reorder<T, F>(
q: &Array2<T>,
t: &Array2<T>,
select: F,
tol: T,
) -> LinalgResult<(Array2<T>, Array2<T>, usize)>
where
T: SchurFloat,
F: Fn(T, T) -> bool,
{
let n = t.nrows();
if t.ncols() != n || q.nrows() != n || q.ncols() != n {
return Err(LinalgError::ShapeError(
"schur_reorder: Q and T must be square and conformant".into(),
));
}
let mut q_out = q.to_owned();
let mut t_out = t.to_owned();
let blocks = identify_schur_blocks(&t_out, tol);
let nb = blocks.len();
let block_selected: Vec<bool> = blocks
.iter()
.map(|&(start, size)| {
if size == 1 {
select(t_out[[start, start]], T::zero())
} else {
let (lam1, _) = eigen2x2_complex(
t_out[[start, start]],
t_out[[start, start + 1]],
t_out[[start + 1, start]],
t_out[[start + 1, start + 1]],
);
select(lam1.re, lam1.im)
}
})
.collect();
let mut sel_copy = block_selected.clone();
let mut blocks_cur = blocks.clone();
let mut n_selected = 0usize;
for target in 0..nb {
if sel_copy[target] {
n_selected += blocks_cur[target].1;
continue;
}
let mut found = None;
for (src, &selected) in sel_copy.iter().enumerate().take(nb).skip(target + 1) {
if selected {
found = Some(src);
break;
}
}
let src = match found {
Some(s) => s,
None => break,
};
for pos in (target..src).rev() {
swap_schur_blocks(
&mut q_out,
&mut t_out,
blocks_cur[pos].0,
blocks_cur[pos].1,
blocks_cur[pos + 1].1,
tol,
)?;
let new_start = blocks_cur[pos].0;
let size_a = blocks_cur[pos + 1].1;
let size_b = blocks_cur[pos].1;
blocks_cur[pos] = (new_start, size_a);
blocks_cur[pos + 1] = (new_start + size_a, size_b);
sel_copy.swap(pos, pos + 1);
}
n_selected += blocks_cur[target].1;
}
Ok((q_out, t_out, n_selected))
}
fn identify_schur_blocks<T: SchurFloat>(t: &Array2<T>, tol: T) -> Vec<(usize, usize)> {
let n = t.nrows();
let mut blocks = Vec::new();
let mut k = 0usize;
while k < n {
if k + 1 < n && t[[k + 1, k]].abs() > tol {
blocks.push((k, 2));
k += 2;
} else {
blocks.push((k, 1));
k += 1;
}
}
blocks
}
fn swap_schur_blocks<T: SchurFloat>(
q: &mut Array2<T>,
t: &mut Array2<T>,
start: usize,
size_a: usize,
size_b: usize,
_tol: T,
) -> LinalgResult<()> {
let n = t.nrows();
let sa = size_a;
let sb = size_b;
let p = start;
let blk_size = sa + sb;
if p + blk_size > n {
return Err(LinalgError::IndexError(
"swap_schur_blocks: block exceeds matrix dimensions".into(),
));
}
if sa == 1 && sb == 1 {
let t11 = t[[p, p]];
let t12 = t[[p, p + 1]];
let t22 = t[[p + 1, p + 1]];
let diff = t22 - t11;
let (c, s) = if diff.abs() < T::epsilon() && t12.abs() < T::epsilon() {
(T::one(), T::zero())
} else {
let val = diff / (t12 + T::epsilon());
let theta = T::one() / (val + (T::one() + val * val).sqrt());
let c = T::one() / (T::one() + theta * theta).sqrt();
let s = c * theta;
(c, s)
};
apply_givens_cols(t, p, p + 1, p, n, c, s)?;
apply_givens_rows(t, p, p + 1, 0, p + 2, c, s)?;
apply_givens_rows(q, p, p + 1, 0, n, c, s)?;
return Ok(());
}
let t_aa = t.slice(s![p..p + sa, p..p + sa]).to_owned();
let t_bb = t
.slice(s![p + sa..p + blk_size, p + sa..p + blk_size])
.to_owned();
let t_ab = t.slice(s![p..p + sa, p + sa..p + blk_size]).to_owned();
let neg_tbb = t_bb.mapv(|v| -v);
let neg_tab_t = t_ab.t().mapv(|v| -v).to_owned();
let x_t =
crate::matrix_equations::solve_sylvester(&t_aa.view(), &neg_tbb.view(), &neg_tab_t.view())?;
let x = x_t.t().to_owned();
let mut e = Array2::<T>::eye(blk_size);
for i in 0..sa {
for j in 0..sb {
e[[i, sa + j]] = x[[i, j]];
}
}
let e_t = e.t().to_owned();
let (u_t, _r) = crate::decomposition::qr(&e_t.view(), None)?;
let u = u_t.t().to_owned();
let block_rows = t.slice(s![p..p + blk_size, ..]).to_owned();
let new_block_rows = u.t().dot(&block_rows);
t.slice_mut(s![p..p + blk_size, ..]).assign(&new_block_rows);
let block_cols = t.slice(s![.., p..p + blk_size]).to_owned();
let new_block_cols = block_cols.dot(&u);
t.slice_mut(s![.., p..p + blk_size]).assign(&new_block_cols);
let q_block = q.slice(s![.., p..p + blk_size]).to_owned();
let new_q_block = q_block.dot(&u);
q.slice_mut(s![.., p..p + blk_size]).assign(&new_q_block);
Ok(())
}
pub fn schur_to_eigen<T: SchurFloat>(
q: &Array2<T>,
t: &Array2<T>,
tol: T,
) -> LinalgResult<SchurEigenResult<T>> {
let n = t.nrows();
let mut eigenvalues: Vec<Complex<T>> = Vec::with_capacity(n);
let mut k = 0usize;
while k < n {
if k + 1 < n && t[[k + 1, k]].abs() > tol {
let (lam1, lam2) =
eigen2x2_complex(t[[k, k]], t[[k, k + 1]], t[[k + 1, k]], t[[k + 1, k + 1]]);
eigenvalues.push(lam1);
eigenvalues.push(lam2);
k += 2;
} else {
eigenvalues.push(Complex::new(t[[k, k]], T::zero()));
k += 1;
}
}
let mut evec_schur: Array2<Complex<T>> = Array2::zeros((n, n));
let blocks = identify_schur_blocks(t, tol);
for (blk_idx, &(blk_start, blk_size)) in blocks.iter().enumerate() {
if blk_size == 1 {
let lam = eigenvalues[blk_start];
let mut y: Vec<Complex<T>> = vec![Complex::new(T::zero(), T::zero()); n];
y[blk_start] = Complex::new(T::one(), T::zero());
if blk_start > 0 {
for row in (0..blk_start).rev() {
let mut sum = Complex::new(T::zero(), T::zero());
for col in (row + 1)..=blk_start {
sum += Complex::new(t[[row, col]], T::zero()) * y[col];
}
let diag = Complex::new(t[[row, row]], T::zero()) - lam;
if diag.re.abs() + diag.im.abs()
< T::epsilon() * T::from(100.0).unwrap_or(T::one())
{
y[row] = Complex::new(T::zero(), T::zero());
} else {
y[row] = -sum / diag;
}
}
}
for i in 0..n {
evec_schur[[i, blk_start]] = y[i];
}
} else {
let lam1 = eigenvalues[blk_start];
let lam2 = eigenvalues[blk_start + 1];
let a11 = Complex::new(t[[blk_start, blk_start]], T::zero()) - lam1;
let a12 = Complex::new(t[[blk_start, blk_start + 1]], T::zero());
let (v0, v1) = if a12.re.abs() + a12.im.abs() > T::epsilon() {
(
a12,
lam1 - Complex::new(t[[blk_start, blk_start]], T::zero()),
)
} else {
(
Complex::new(T::one(), T::zero()),
Complex::new(T::zero(), T::zero()),
)
};
let mut y1: Vec<Complex<T>> = vec![Complex::new(T::zero(), T::zero()); n];
y1[blk_start] = v0;
y1[blk_start + 1] = v1;
let mut y2: Vec<Complex<T>> = vec![Complex::new(T::zero(), T::zero()); n];
y2[blk_start] = v0.conj();
y2[blk_start + 1] = v1.conj();
if blk_start > 0 {
for row in (0..blk_start).rev() {
let mut s1 = Complex::new(T::zero(), T::zero());
let mut s2 = Complex::new(T::zero(), T::zero());
for col in (row + 1)..blk_start + 2 {
let t_val = Complex::new(t[[row, col]], T::zero());
s1 += t_val * y1[col];
s2 += t_val * y2[col];
}
let diag1 = Complex::new(t[[row, row]], T::zero()) - lam1;
let diag2 = Complex::new(t[[row, row]], T::zero()) - lam2;
y1[row] = if diag1.re.abs() + diag1.im.abs()
< T::epsilon() * T::from(100.0).unwrap_or(T::one())
{
Complex::new(T::zero(), T::zero())
} else {
-s1 / diag1
};
y2[row] = if diag2.re.abs() + diag2.im.abs()
< T::epsilon() * T::from(100.0).unwrap_or(T::one())
{
Complex::new(T::zero(), T::zero())
} else {
-s2 / diag2
};
}
}
for i in 0..n {
evec_schur[[i, blk_start]] = y1[i];
evec_schur[[i, blk_start + 1]] = y2[i];
}
let _ = (blk_idx, lam2);
}
}
let q_c: Array2<Complex<T>> = real_to_complex_matrix(q);
let eigenvectors: Array2<Complex<T>> = q_c.dot(&evec_schur);
let mut evec_norm = eigenvectors.clone();
for j in 0..n {
let col = evec_norm.column(j).to_owned();
let norm: T = col
.iter()
.map(|c| c.re * c.re + c.im * c.im)
.sum::<T>()
.sqrt();
if norm > T::epsilon() {
let scale = Complex::new(T::one() / norm, T::zero());
evec_norm.column_mut(j).mapv_inplace(|c| c * scale);
}
}
Ok(SchurEigenResult {
eigenvalues: Array1::from_vec(eigenvalues),
eigenvectors: evec_norm,
})
}
pub fn invariant_subspace<T, F>(
a: &ArrayView2<T>,
select: F,
max_iter: usize,
tol: T,
) -> LinalgResult<Array2<T>>
where
T: SchurFloat,
F: Fn(T, T) -> bool,
{
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"invariant_subspace: A must be square".into(),
));
}
let res = real_schur_decompose(a, max_iter, tol)?;
let (q_reordered, _t_reordered, n_selected) = schur_reorder(&res.q, &res.t, select, tol)?;
if n_selected == 0 {
return Ok(Array2::zeros((n, 0)));
}
Ok(q_reordered.slice(s![.., 0..n_selected]).to_owned())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
fn frobenius_err(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f64>()
.sqrt()
}
#[test]
fn test_real_schur_2x2_diagonal() {
let a = array![[2.0_f64, 0.0], [0.0, 3.0]];
let res = real_schur_decompose(&a.view(), 200, 1e-12).expect("ok");
let qt = res.q.t().to_owned();
let reconstructed = res.q.dot(&res.t).dot(&qt);
assert!(
frobenius_err(&a, &reconstructed) < 1e-8,
"Frobenius err too large"
);
}
#[test]
fn test_real_schur_2x2_non_symmetric() {
let a = array![[1.0_f64, 2.0], [3.0, 4.0]];
let res = real_schur_decompose(&a.view(), 300, 1e-12).expect("ok");
let qt = res.q.t().to_owned();
let reconstructed = res.q.dot(&res.t).dot(&qt);
assert!(frobenius_err(&a, &reconstructed) < 1e-7);
}
#[test]
fn test_real_schur_3x3() {
let a = array![[1.0_f64, 2.0, 0.0], [0.0, 3.0, 1.0], [0.0, 0.0, 2.0]];
let res = real_schur_decompose(&a.view(), 300, 1e-12).expect("ok");
let qt = res.q.t().to_owned();
let reconstructed = res.q.dot(&res.t).dot(&qt);
assert!(frobenius_err(&a, &reconstructed) < 1e-7);
}
#[test]
fn test_schur_to_eigen_diagonal() {
let a = array![[2.0_f64, 0.0], [0.0, 3.0]];
let res = real_schur_decompose(&a.view(), 200, 1e-12).expect("ok");
let eig = schur_to_eigen(&res.q, &res.t, 1e-10).expect("ok");
let mut re_parts: Vec<f64> = eig.eigenvalues.iter().map(|c| c.re).collect();
re_parts.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
assert!((re_parts[0] - 2.0).abs() < 1e-8);
assert!((re_parts[1] - 3.0).abs() < 1e-8);
}
#[test]
fn test_complex_schur_rotation() {
let a = array![[0.0_f64, -1.0], [1.0, 0.0]];
let res = complex_schur_decompose(&a.view(), 300, 1e-12).expect("ok");
assert!(res.t[[1, 0]].re.abs() < 1e-8, "sub-diagonal should be ~0");
assert!(
res.t[[1, 0]].im.abs() < 1e-8,
"sub-diagonal im should be ~0"
);
}
#[test]
fn test_invariant_subspace_3x3() {
let a = array![[3.0_f64, 1.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 1.0]];
let u = invariant_subspace(&a.view(), |re, _im| re < 2.5, 300, 1e-10).expect("ok");
assert_eq!(u.nrows(), 3);
assert!(u.ncols() >= 1 && u.ncols() <= 3);
let au = a.dot(&u);
let utu = u.t().dot(&au);
let proj = u.dot(&utu);
let diff_arr: Array2<f64> = &au - &proj;
let frob: f64 = diff_arr.iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!(frob < 1e-6, "invariant subspace error: {frob}");
}
#[test]
fn test_hessenberg_decompose() {
let a = array![[4.0_f64, 3.0, 2.0], [1.0, 5.0, 4.0], [2.0, 1.0, 3.0]];
let (h, q) = hessenberg_decompose(&a.view()).expect("ok");
let qt = q.t().to_owned();
let reconstructed = q.dot(&h).dot(&qt);
assert!(frobenius_err(&a, &reconstructed) < 1e-9);
for i in 2..3 {
for j in 0..i - 1 {
assert!(
h[[i, j]].abs() < 1e-9,
"H[{i},{j}] = {} not zero",
h[[i, j]]
);
}
}
}
}