use super::super::super::jacobi::LinalgElement;
use super::super::super::{CpuClient, CpuRuntime};
use crate::algorithm::linalg::{
GeneralizedSchurDecomposition, linalg_demote, linalg_promote, validate_linalg_dtype,
validate_square_matrix,
};
use crate::dtype::{DType, Element};
use crate::error::{Error, Result};
use crate::runtime::RuntimeClient;
use crate::tensor::Tensor;
pub fn qz_decompose_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
) -> Result<GeneralizedSchurDecomposition<CpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
if a.dtype() != b.dtype() {
return Err(Error::DTypeMismatch {
lhs: a.dtype(),
rhs: b.dtype(),
});
}
let (a, original_dtype) = linalg_promote(client, a)?;
let (b, _) = linalg_promote(client, b)?;
let n = validate_square_matrix(a.shape())?;
let n_b = validate_square_matrix(b.shape())?;
if n != n_b {
return Err(Error::ShapeMismatch {
expected: vec![n, n],
got: vec![n_b, n_b],
});
}
let result = match a.dtype() {
DType::F32 => qz_decompose_typed::<f32>(client, &a, &b, n),
DType::F64 => qz_decompose_typed::<f64>(client, &a, &b, n),
_ => unreachable!(),
}?;
Ok(GeneralizedSchurDecomposition {
q: linalg_demote(client, result.q, original_dtype)?,
z: linalg_demote(client, result.z, original_dtype)?,
s: linalg_demote(client, result.s, original_dtype)?,
t: linalg_demote(client, result.t, original_dtype)?,
eigenvalues_real: linalg_demote(client, result.eigenvalues_real, original_dtype)?,
eigenvalues_imag: linalg_demote(client, result.eigenvalues_imag, original_dtype)?,
})
}
fn qz_decompose_typed<T: Element + LinalgElement>(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
b: &Tensor<CpuRuntime>,
n: usize,
) -> Result<GeneralizedSchurDecomposition<CpuRuntime>> {
let device = client.device();
if n == 0 {
return Ok(GeneralizedSchurDecomposition {
q: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
z: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
s: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
t: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
eigenvalues_real: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0], device),
eigenvalues_imag: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0], device),
});
}
if n == 1 {
let s_val: T = a.to_vec::<T>()[0];
let t_val: T = b.to_vec::<T>()[0];
let lambda = if t_val.to_f64().abs() > T::epsilon_val() {
s_val.to_f64() / t_val.to_f64()
} else if s_val.to_f64() >= 0.0 {
f64::INFINITY
} else {
f64::NEG_INFINITY
};
return Ok(GeneralizedSchurDecomposition {
q: Tensor::<CpuRuntime>::from_slice(&[T::one()], &[1, 1], device),
z: Tensor::<CpuRuntime>::from_slice(&[T::one()], &[1, 1], device),
s: Tensor::<CpuRuntime>::from_slice(&[s_val], &[1, 1], device),
t: Tensor::<CpuRuntime>::from_slice(&[t_val], &[1, 1], device),
eigenvalues_real: Tensor::<CpuRuntime>::from_slice(
&[T::from_f64(lambda)],
&[1],
device,
),
eigenvalues_imag: Tensor::<CpuRuntime>::from_slice(&[T::zero()], &[1], device),
});
}
let mut s_data: Vec<T> = a.to_vec();
let mut t_data: Vec<T> = b.to_vec();
let mut q_data: Vec<T> = vec![T::zero(); n * n];
let mut z_data: Vec<T> = vec![T::zero(); n * n];
for i in 0..n {
q_data[i * n + i] = T::one();
z_data[i * n + i] = T::one();
}
hessenberg_triangular_reduction::<T>(&mut s_data, &mut t_data, &mut q_data, &mut z_data, n);
qz_iteration::<T>(&mut s_data, &mut t_data, &mut q_data, &mut z_data, n);
let eps = T::epsilon_val();
for i in 1..n {
let h_ii = s_data[(i - 1) * n + (i - 1)].to_f64().abs();
let h_ip1 = s_data[i * n + i].to_f64().abs();
let threshold = eps * (h_ii + h_ip1).max(1.0);
if s_data[i * n + (i - 1)].to_f64().abs() <= threshold {
s_data[i * n + (i - 1)] = T::zero();
}
for j in 0..i {
if t_data[i * n + j].to_f64().abs() <= eps * t_data[j * n + j].to_f64().abs().max(1.0) {
t_data[i * n + j] = T::zero();
}
}
}
let (eigenvalues_real, eigenvalues_imag) =
extract_generalized_eigenvalues::<T>(&s_data, &t_data, n);
Ok(GeneralizedSchurDecomposition {
q: Tensor::<CpuRuntime>::from_slice(&q_data, &[n, n], device),
z: Tensor::<CpuRuntime>::from_slice(&z_data, &[n, n], device),
s: Tensor::<CpuRuntime>::from_slice(&s_data, &[n, n], device),
t: Tensor::<CpuRuntime>::from_slice(&t_data, &[n, n], device),
eigenvalues_real: Tensor::<CpuRuntime>::from_slice(&eigenvalues_real, &[n], device),
eigenvalues_imag: Tensor::<CpuRuntime>::from_slice(&eigenvalues_imag, &[n], device),
})
}
#[inline]
fn givens_params(a: f64, b: f64) -> (f64, f64) {
if b.abs() == 0.0 {
return (1.0, 0.0);
}
if a.abs() == 0.0 {
return (0.0, if b >= 0.0 { 1.0 } else { -1.0 });
}
let r = a.hypot(b);
(a / r, b / r)
}
#[inline]
fn left_givens<T: Element + LinalgElement>(
mat: &mut [T],
n: usize,
i1: usize,
i2: usize,
c: f64,
s: f64,
col_lo: usize,
col_hi: usize,
) {
for j in col_lo..col_hi {
let a = mat[i1 * n + j].to_f64();
let b = mat[i2 * n + j].to_f64();
mat[i1 * n + j] = T::from_f64(c * a + s * b);
mat[i2 * n + j] = T::from_f64(-s * a + c * b);
}
}
#[inline]
fn right_givens<T: Element + LinalgElement>(
mat: &mut [T],
n: usize,
j1: usize,
j2: usize,
c: f64,
s: f64,
row_lo: usize,
row_hi: usize,
) {
for i in row_lo..row_hi {
let a = mat[i * n + j1].to_f64();
let b = mat[i * n + j2].to_f64();
mat[i * n + j1] = T::from_f64(c * a + s * b);
mat[i * n + j2] = T::from_f64(-s * a + c * b);
}
}
#[allow(clippy::too_many_arguments)]
fn left_givens_with_r_restore<T: Element + LinalgElement>(
h: &mut [T],
r: &mut [T],
q: &mut [T],
z: &mut [T],
n: usize,
i1: usize,
i2: usize,
c: f64,
s: f64,
ihi: usize,
) {
left_givens(h, n, i1, i2, c, s, 0, n);
left_givens(r, n, i1, i2, c, s, 0, n);
right_givens(q, n, i1, i2, c, s, 0, n);
for col in (0..i2).rev() {
let ri2_col = r[i2 * n + col].to_f64();
if ri2_col.abs() > f64::MIN_POSITIVE {
let ri2_diag = r[i2 * n + i2].to_f64();
let rr = ri2_diag.hypot(ri2_col);
let cr = ri2_diag / rr;
let sr = -ri2_col / rr;
right_givens(h, n, col, i2, cr, sr, 0, ihi);
right_givens(r, n, col, i2, cr, sr, 0, ihi);
right_givens(z, n, col, i2, cr, sr, 0, n);
}
}
}
fn hessenberg_triangular_reduction<T: Element + LinalgElement>(
h: &mut [T],
r: &mut [T],
q: &mut [T],
z: &mut [T],
n: usize,
) {
let eps = T::epsilon_val();
for k in 0..(n - 1) {
for i in (k + 1)..n {
let a_val = r[k * n + k].to_f64();
let b_val = r[i * n + k].to_f64();
if b_val.abs() < eps {
continue;
}
let (c, s) = givens_params(a_val, b_val);
left_givens(r, n, k, i, c, s, 0, n);
left_givens(h, n, k, i, c, s, 0, n);
right_givens(q, n, k, i, c, s, 0, n);
}
}
for k in 0..(n - 2) {
for i in (k + 2)..n {
let a_val = h[(k + 1) * n + k].to_f64();
let b_val = h[i * n + k].to_f64();
if b_val.abs() < eps {
continue;
}
let (c, s) = givens_params(a_val, b_val);
left_givens(h, n, k + 1, i, c, s, 0, n);
left_givens(r, n, k + 1, i, c, s, 0, n);
right_givens(q, n, k + 1, i, c, s, 0, n);
let b1 = r[i * n + i].to_f64();
let b2 = r[i * n + (k + 1)].to_f64();
if b2.abs() > eps {
let rr2 = b1.hypot(b2);
let c2 = b1 / rr2;
let s2 = -b2 / rr2;
right_givens(r, n, k + 1, i, c2, s2, 0, n);
right_givens(h, n, k + 1, i, c2, s2, 0, n);
right_givens(z, n, k + 1, i, c2, s2, 0, n);
}
}
}
}
fn qz_iteration<T: Element + LinalgElement>(
h: &mut [T],
r: &mut [T],
q: &mut [T],
z: &mut [T],
n: usize,
) {
if n < 2 {
return;
}
let max_iter = 60 * n;
let eps = T::epsilon_val();
let mut ihi = n;
for _iter in 0..max_iter {
while ihi > 1 {
let i = ihi - 1;
let h_ii = h[(i - 1) * n + (i - 1)].to_f64().abs();
let h_ip1 = h[i * n + i].to_f64().abs();
let threshold = eps * (h_ii + h_ip1).max(1.0);
if h[i * n + (i - 1)].to_f64().abs() <= threshold {
h[i * n + (i - 1)] = T::zero();
ihi -= 1;
} else {
break;
}
}
if ihi <= 1 {
break;
}
let mut ilo = 0;
for i in (1..ihi).rev() {
let h_ii = h[(i - 1) * n + (i - 1)].to_f64().abs();
let h_ip1 = h[i * n + i].to_f64().abs();
let threshold = eps * (h_ii + h_ip1).max(1.0);
if h[i * n + (i - 1)].to_f64().abs() <= threshold {
h[i * n + (i - 1)] = T::zero();
ilo = i;
break;
}
}
let block_size = ihi - ilo;
if block_size <= 1 {
ihi = ilo;
continue;
}
if block_size == 2 {
qz_step_2x2::<T>(h, r, q, z, n, ilo);
ihi = ilo;
continue;
}
implicit_double_shift_qz_step::<T>(h, r, q, z, n, ilo, ihi);
for ii in (ilo + 1)..ihi {
for jj in ilo..ii {
if r[ii * n + jj].to_f64().abs() <= eps * r[jj * n + jj].to_f64().abs().max(1.0) {
r[ii * n + jj] = T::zero();
}
}
}
}
}
fn qz_step_2x2<T: Element + LinalgElement>(
h: &mut [T],
r: &mut [T],
q: &mut [T],
z: &mut [T],
n: usize,
ilo: usize,
) {
let eps = T::epsilon_val();
let i = ilo;
let j = ilo + 1;
if r[j * n + i].to_f64().abs() > eps * r[i * n + i].to_f64().abs().max(1.0) {
let a = r[j * n + j].to_f64();
let b = r[j * n + i].to_f64();
let rr = a.hypot(b);
let c = a / rr;
let s = -b / rr;
right_givens(r, n, i, j, c, s, 0, n);
right_givens(h, n, i, j, c, s, 0, n);
right_givens(z, n, i, j, c, s, 0, n);
}
let a00 = h[i * n + i].to_f64();
let a10 = h[j * n + i].to_f64();
let a11 = h[j * n + j].to_f64();
let trace = a00 + a11;
let det = a00 * a11 - h[i * n + j].to_f64() * a10;
let disc = trace * trace - 4.0 * det;
if disc >= 0.0 && a10.abs() > eps * (a00.abs() + a11.abs()).max(1.0) {
let (c, s) = givens_params(h[i * n + i].to_f64(), h[j * n + i].to_f64());
left_givens_with_r_restore(h, r, q, z, n, i, j, c, s, n);
}
}
fn implicit_double_shift_qz_step<T: Element + LinalgElement>(
h: &mut [T],
r: &mut [T],
q: &mut [T],
z: &mut [T],
n: usize,
ilo: usize,
ihi: usize,
) {
let eps = T::epsilon_val();
let m = ihi - 1;
let h_pp = h[(m - 1) * n + (m - 1)].to_f64();
let h_qp = h[m * n + (m - 1)].to_f64();
let h_pq = h[(m - 1) * n + m].to_f64();
let h_qq = h[m * n + m].to_f64();
let r_pp = r[(m - 1) * n + (m - 1)].to_f64();
let r_pq = r[(m - 1) * n + m].to_f64();
let r_qq = r[m * n + m].to_f64();
let (s_tr, p_det) = if r_pp.abs() > eps && r_qq.abs() > eps {
let inv_rpp = 1.0 / r_pp;
let inv_rqq = 1.0 / r_qq;
let m00 = h_pp * inv_rpp;
let m01 = (h_pq - h_pp * r_pq * inv_rpp) * inv_rqq;
let m10 = h_qp * inv_rpp;
let m11 = (h_qq - h_qp * r_pq * inv_rpp) * inv_rqq;
(m00 + m11, m00 * m11 - m01 * m10)
} else {
(h_pp + h_qq, h_pp * h_qq - h_pq * h_qp)
};
let a00 = h[ilo * n + ilo].to_f64();
let a10 = h[(ilo + 1) * n + ilo].to_f64();
let a01 = h[ilo * n + (ilo + 1)].to_f64();
let a11 = h[(ilo + 1) * n + (ilo + 1)].to_f64();
let a21 = if ilo + 2 < ihi {
h[(ilo + 2) * n + (ilo + 1)].to_f64()
} else {
0.0
};
let rr00 = r[ilo * n + ilo].to_f64();
let rr01 = r[ilo * n + (ilo + 1)].to_f64();
let rr11 = r[(ilo + 1) * n + (ilo + 1)].to_f64();
if rr00.abs() < eps || rr11.abs() < eps {
return;
}
let inv_r00 = 1.0 / rr00;
let inv_r11 = 1.0 / rr11;
let mm00 = a00 * inv_r00;
let mm10 = a10 * inv_r00;
let mm01 = (a01 - a00 * rr01 * inv_r00) * inv_r11;
let mm11 = (a11 - a10 * rr01 * inv_r00) * inv_r11;
let mm21 = a21 * inv_r11;
let v0 = mm00 * mm00 + mm01 * mm10 - s_tr * mm00 + p_det;
let v1 = mm10 * (mm00 + mm11 - s_tr);
let v2 = mm21 * mm10;
if v0.abs() < eps && v1.abs() < eps && v2.abs() < eps {
return;
}
let (c1, s1) = givens_params(v1, v2);
left_givens_with_r_restore(h, r, q, z, n, ilo + 1, ilo + 2, c1, s1, ihi);
let v1_new = c1 * v1 + s1 * v2;
let (c2, s2) = givens_params(v0, v1_new);
left_givens_with_r_restore(h, r, q, z, n, ilo, ilo + 1, c2, s2, ihi);
for k in ilo..(ihi - 2) {
if h[(k + 2) * n + k].to_f64().abs()
> eps
* (h[(k + 1) * n + k].to_f64().abs() + h[(k + 2) * n + (k + 2)].to_f64().abs())
.max(1.0)
{
let (c, s) = givens_params(h[(k + 1) * n + k].to_f64(), h[(k + 2) * n + k].to_f64());
left_givens_with_r_restore(h, r, q, z, n, k + 1, k + 2, c, s, ihi);
}
if k + 3 < ihi
&& h[(k + 3) * n + (k + 1)].to_f64().abs()
> eps
* (h[(k + 2) * n + (k + 1)].to_f64().abs()
+ h[(k + 3) * n + (k + 3)].to_f64().abs())
.max(1.0)
{
let (c, s) = givens_params(
h[(k + 2) * n + (k + 1)].to_f64(),
h[(k + 3) * n + (k + 1)].to_f64(),
);
left_givens_with_r_restore(h, r, q, z, n, k + 2, k + 3, c, s, ihi);
}
}
}
fn push_eigenvalues_from_2x2<T: Element + LinalgElement>(
trace: f64,
det: f64,
real_parts: &mut Vec<T>,
imag_parts: &mut Vec<T>,
) {
let disc = trace * trace - 4.0 * det;
if disc < 0.0 {
let re = trace / 2.0;
let im = (-disc).sqrt() / 2.0;
real_parts.push(T::from_f64(re));
imag_parts.push(T::from_f64(im));
real_parts.push(T::from_f64(re));
imag_parts.push(T::from_f64(-im));
} else {
let sqrt_disc = disc.sqrt();
real_parts.push(T::from_f64((trace + sqrt_disc) / 2.0));
imag_parts.push(T::zero());
real_parts.push(T::from_f64((trace - sqrt_disc) / 2.0));
imag_parts.push(T::zero());
}
}
fn extract_generalized_eigenvalues<T: Element + LinalgElement>(
s: &[T],
t: &[T],
n: usize,
) -> (Vec<T>, Vec<T>) {
let mut real_parts = Vec::with_capacity(n);
let mut imag_parts = Vec::with_capacity(n);
let eps = T::epsilon_val();
let mut i = 0;
while i < n {
if i + 1 < n && s[(i + 1) * n + i].to_f64().abs() > eps {
let a = s[i * n + i].to_f64();
let b = s[i * n + (i + 1)].to_f64();
let c = s[(i + 1) * n + i].to_f64();
let d = s[(i + 1) * n + (i + 1)].to_f64();
let t_ii = t[i * n + i].to_f64();
let t_i1i1 = t[(i + 1) * n + (i + 1)].to_f64();
let t_ii1 = t[i * n + (i + 1)].to_f64();
let (trace, det) = if t_ii.abs() > eps && t_i1i1.abs() > eps {
let inv_tii = 1.0 / t_ii;
let inv_ti1 = 1.0 / t_i1i1;
let m00 = a * inv_tii;
let m01 = (b - a * t_ii1 * inv_tii) * inv_ti1;
let m10 = c * inv_tii;
let m11 = (d - c * t_ii1 * inv_tii) * inv_ti1;
(m00 + m11, m00 * m11 - m01 * m10)
} else {
(a + d, a * d - b * c)
};
push_eigenvalues_from_2x2(trace, det, &mut real_parts, &mut imag_parts);
i += 2;
} else {
let s_ii = s[i * n + i].to_f64();
let t_ii = t[i * n + i].to_f64();
let lambda = if t_ii.abs() > eps {
s_ii / t_ii
} else if s_ii >= 0.0 {
f64::INFINITY
} else {
f64::NEG_INFINITY
};
real_parts.push(T::from_f64(lambda));
imag_parts.push(T::zero());
i += 1;
}
}
(real_parts, imag_parts)
}