use super::super::jacobi::LinalgElement;
use super::super::{CpuClient, CpuRuntime};
use crate::algorithm::linalg::{
SchurDecomposition, linalg_demote, linalg_promote, validate_linalg_dtype,
validate_square_matrix,
};
use crate::dtype::{DType, Element};
use crate::error::Result;
use crate::runtime::RuntimeClient;
use crate::tensor::Tensor;
pub fn schur_decompose_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
) -> Result<SchurDecomposition<CpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let (a, original_dtype) = linalg_promote(client, a)?;
let n = validate_square_matrix(a.shape())?;
let result = match a.dtype() {
DType::F32 => schur_decompose_typed::<f32>(client, &a, n),
DType::F64 => schur_decompose_typed::<f64>(client, &a, n),
_ => unreachable!(),
}?;
Ok(SchurDecomposition {
z: linalg_demote(client, result.z, original_dtype)?,
t: linalg_demote(client, result.t, original_dtype)?,
})
}
fn schur_decompose_typed<T: Element + LinalgElement>(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
n: usize,
) -> Result<SchurDecomposition<CpuRuntime>> {
let device = client.device();
if n == 0 {
return Ok(SchurDecomposition {
z: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
t: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
});
}
if n == 1 {
let data: Vec<T> = a.to_vec();
return Ok(SchurDecomposition {
z: Tensor::<CpuRuntime>::from_slice(&[T::one()], &[1, 1], device),
t: Tensor::<CpuRuntime>::from_slice(&data, &[1, 1], device),
});
}
let mut t_data: Vec<T> = a.to_vec();
let mut z_data: Vec<T> = vec![T::zero(); n * n];
for i in 0..n {
z_data[i * n + i] = T::one();
}
hessenberg_reduction::<T>(&mut t_data, &mut z_data, n);
let max_iter = 30 * n;
let eps = T::epsilon_val();
for _iter in 0..max_iter {
let mut converged = true;
for i in 0..(n - 1) {
let h_ii = t_data[i * n + i].to_f64().abs();
let h_ip1 = t_data[(i + 1) * n + (i + 1)].to_f64().abs();
let threshold = eps * (h_ii + h_ip1).max(1.0);
let subdiag = t_data[(i + 1) * n + i].to_f64().abs();
if subdiag > threshold {
converged = false;
break;
}
}
if converged {
break;
}
qr_iteration_step::<T>(&mut t_data, &mut z_data, n);
}
for i in 0..(n - 1) {
let h_ii = t_data[i * n + i].to_f64().abs();
let h_ip1 = t_data[(i + 1) * n + (i + 1)].to_f64().abs();
let threshold = eps * (h_ii + h_ip1).max(1.0);
let subdiag = t_data[(i + 1) * n + i].to_f64();
if subdiag.abs() <= threshold {
t_data[(i + 1) * n + i] = T::zero();
}
}
for i in 2..n {
for j in 0..(i - 1) {
t_data[i * n + j] = T::zero();
}
}
Ok(SchurDecomposition {
z: Tensor::<CpuRuntime>::from_slice(&z_data, &[n, n], device),
t: Tensor::<CpuRuntime>::from_slice(&t_data, &[n, n], device),
})
}
pub fn hessenberg_reduction<T: Element + LinalgElement>(h: &mut [T], q: &mut [T], n: usize) {
for k in 0..(n - 2) {
let mut v = vec![T::zero(); n - k - 1];
let mut norm_sq = 0.0;
for i in (k + 1)..n {
let val = h[i * n + k].to_f64();
v[i - k - 1] = T::from_f64(val);
norm_sq += val * val;
}
if norm_sq < T::epsilon_val() {
continue;
}
let norm = norm_sq.sqrt();
let x0 = v[0].to_f64();
let alpha = if x0 >= 0.0 { -norm } else { norm };
v[0] = T::from_f64(x0 - alpha);
let mut v_norm_sq = 0.0;
for vi in &v {
v_norm_sq += vi.to_f64() * vi.to_f64();
}
if v_norm_sq < T::epsilon_val() {
continue;
}
let v_norm = v_norm_sq.sqrt();
for vi in &mut v {
*vi = T::from_f64(vi.to_f64() / v_norm);
}
for j in 0..n {
let mut dot = 0.0;
for i in 0..v.len() {
dot += v[i].to_f64() * h[(k + 1 + i) * n + j].to_f64();
}
for i in 0..v.len() {
let old_val = h[(k + 1 + i) * n + j].to_f64();
h[(k + 1 + i) * n + j] = T::from_f64(old_val - 2.0 * v[i].to_f64() * dot);
}
}
for i in 0..n {
let mut dot = 0.0;
for j in 0..v.len() {
dot += h[i * n + (k + 1 + j)].to_f64() * v[j].to_f64();
}
for j in 0..v.len() {
let old_val = h[i * n + (k + 1 + j)].to_f64();
h[i * n + (k + 1 + j)] = T::from_f64(old_val - 2.0 * dot * v[j].to_f64());
}
}
for i in 0..n {
let mut dot = 0.0;
for j in 0..v.len() {
dot += q[i * n + (k + 1 + j)].to_f64() * v[j].to_f64();
}
for j in 0..v.len() {
let old_val = q[i * n + (k + 1 + j)].to_f64();
q[i * n + (k + 1 + j)] = T::from_f64(old_val - 2.0 * dot * v[j].to_f64());
}
}
}
}
pub fn qr_iteration_step<T: Element + LinalgElement>(h: &mut [T], q: &mut [T], n: usize) {
let a = h[(n - 2) * n + (n - 2)].to_f64();
let b = h[(n - 2) * n + (n - 1)].to_f64();
let c = h[(n - 1) * n + (n - 2)].to_f64();
let d = h[(n - 1) * n + (n - 1)].to_f64();
let trace = a + d;
let det = a * d - b * c;
let disc = trace * trace - 4.0 * det;
let shift = if disc >= 0.0 {
let sqrt_disc = disc.sqrt();
let lambda1 = (trace + sqrt_disc) / 2.0;
let lambda2 = (trace - sqrt_disc) / 2.0;
if (lambda1 - d).abs() < (lambda2 - d).abs() {
lambda1
} else {
lambda2
}
} else {
trace / 2.0
};
for i in 0..n {
h[i * n + i] = T::from_f64(h[i * n + i].to_f64() - shift);
}
for i in 0..(n - 1) {
let a_val = h[i * n + i].to_f64();
let b_val = h[(i + 1) * n + i].to_f64();
if b_val.abs() < T::epsilon_val() {
continue;
}
let r = (a_val * a_val + b_val * b_val).sqrt();
let c = a_val / r;
let s = -b_val / r;
for j in 0..n {
let t1 = h[i * n + j].to_f64();
let t2 = h[(i + 1) * n + j].to_f64();
h[i * n + j] = T::from_f64(c * t1 - s * t2);
h[(i + 1) * n + j] = T::from_f64(s * t1 + c * t2);
}
for k in 0..n {
let t1 = h[k * n + i].to_f64();
let t2 = h[k * n + (i + 1)].to_f64();
h[k * n + i] = T::from_f64(c * t1 - s * t2);
h[k * n + (i + 1)] = T::from_f64(s * t1 + c * t2);
}
for k in 0..n {
let t1 = q[k * n + i].to_f64();
let t2 = q[k * n + (i + 1)].to_f64();
q[k * n + i] = T::from_f64(c * t1 - s * t2);
q[k * n + (i + 1)] = T::from_f64(s * t1 + c * t2);
}
}
for i in 0..n {
h[i * n + i] = T::from_f64(h[i * n + i].to_f64() + shift);
}
}