use super::super::super::jacobi::LinalgElement;
use super::super::super::{CpuClient, CpuRuntime};
use crate::algorithm::linalg::{
ComplexSchurDecomposition, SchurDecomposition, linalg_demote, linalg_promote,
validate_linalg_dtype,
};
use crate::dtype::{DType, Element};
use crate::error::{Error, Result};
use crate::runtime::RuntimeClient;
use crate::tensor::Tensor;
pub fn rsf2csf_impl(
client: &CpuClient,
schur: &SchurDecomposition<CpuRuntime>,
) -> Result<ComplexSchurDecomposition<CpuRuntime>> {
validate_linalg_dtype(schur.t.dtype())?;
let (t, original_dtype) = linalg_promote(client, &schur.t)?;
let (z, _) = linalg_promote(client, &schur.z)?;
let schur = SchurDecomposition {
t: t.into_owned(),
z: z.into_owned(),
};
let shape = schur.t.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(Error::Internal(
"rsf2csf: Schur form T must be square".to_string(),
));
}
let n = shape[0];
let result = match schur.t.dtype() {
DType::F32 => rsf2csf_typed::<f32>(client, &schur, n),
DType::F64 => rsf2csf_typed::<f64>(client, &schur, n),
_ => unreachable!(),
}?;
Ok(ComplexSchurDecomposition {
z_real: linalg_demote(client, result.z_real, original_dtype)?,
z_imag: linalg_demote(client, result.z_imag, original_dtype)?,
t_real: linalg_demote(client, result.t_real, original_dtype)?,
t_imag: linalg_demote(client, result.t_imag, original_dtype)?,
})
}
fn rsf2csf_typed<T: Element + LinalgElement>(
client: &CpuClient,
schur: &SchurDecomposition<CpuRuntime>,
n: usize,
) -> Result<ComplexSchurDecomposition<CpuRuntime>> {
let device = client.device();
if n == 0 {
return Ok(ComplexSchurDecomposition {
z_real: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
z_imag: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
t_real: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
t_imag: Tensor::<CpuRuntime>::from_slice(&[] as &[T], &[0, 0], device),
});
}
let mut t_real: Vec<T> = schur.t.to_vec();
let mut t_imag: Vec<T> = vec![T::zero(); n * n];
let mut z_real: Vec<T> = schur.z.to_vec();
let mut z_imag: Vec<T> = vec![T::zero(); n * n];
let mut i = 0;
while i < n {
if i + 1 < n {
let subdiag = t_real[(i + 1) * n + i].to_f64();
if subdiag.abs() > T::epsilon_val() {
convert_2x2_block::<T>(&mut t_real, &mut t_imag, &mut z_real, &mut z_imag, n, i);
i += 2;
continue;
}
}
i += 1;
}
Ok(ComplexSchurDecomposition {
z_real: Tensor::<CpuRuntime>::from_slice(&z_real, &[n, n], device),
z_imag: Tensor::<CpuRuntime>::from_slice(&z_imag, &[n, n], device),
t_real: Tensor::<CpuRuntime>::from_slice(&t_real, &[n, n], device),
t_imag: Tensor::<CpuRuntime>::from_slice(&t_imag, &[n, n], device),
})
}
fn convert_2x2_block<T: Element + LinalgElement>(
t_real: &mut [T],
t_imag: &mut [T],
z_real: &mut [T],
z_imag: &mut [T],
n: usize,
i: usize,
) {
let a = t_real[i * n + i].to_f64();
let b = t_real[i * n + (i + 1)].to_f64();
let c = t_real[(i + 1) * n + i].to_f64();
let d = t_real[(i + 1) * n + (i + 1)].to_f64();
let mu = (a + d) / 2.0; let _det = a * d - b * c; let disc = (a - d) * (a - d) / 4.0 + b * c;
let omega = if disc < 0.0 { (-disc).sqrt() } else { 0.0 };
if omega.abs() < T::epsilon_val() {
t_imag[i * n + i] = T::zero();
t_imag[(i + 1) * n + (i + 1)] = T::zero();
t_imag[i * n + (i + 1)] = T::zero();
t_imag[(i + 1) * n + i] = T::zero();
return;
}
let v_re_0 = b;
let v_re_1 = mu - a;
let v_im_1 = omega;
let v_norm_sq = v_re_0 * v_re_0 + v_re_1 * v_re_1 + v_im_1 * v_im_1;
let v_norm = v_norm_sq.sqrt();
if v_norm < T::epsilon_val() {
t_real[i * n + i] = T::from_f64(mu);
t_imag[i * n + i] = T::from_f64(omega);
t_real[(i + 1) * n + (i + 1)] = T::from_f64(mu);
t_imag[(i + 1) * n + (i + 1)] = T::from_f64(-omega);
t_real[(i + 1) * n + i] = T::zero();
t_imag[(i + 1) * n + i] = T::zero();
return;
}
let u0_re = v_re_0 / v_norm;
let u1_re = v_re_1 / v_norm;
let u1_im = v_im_1 / v_norm;
let t12_new = b * (u0_re * u0_re - u1_re * u1_re - u1_im * u1_im)
+ (a - d) * u0_re * u1_re
+ 2.0 * u0_re * u1_im * omega;
t_real[i * n + i] = T::from_f64(mu);
t_imag[i * n + i] = T::from_f64(omega);
t_real[(i + 1) * n + (i + 1)] = T::from_f64(mu);
t_imag[(i + 1) * n + (i + 1)] = T::from_f64(-omega);
t_real[i * n + (i + 1)] = T::from_f64(t12_new.abs()); t_imag[i * n + (i + 1)] = T::zero();
t_real[(i + 1) * n + i] = T::zero();
t_imag[(i + 1) * n + i] = T::zero();
for row in 0..n {
let z1_re = z_real[row * n + i].to_f64();
let z2_re = z_real[row * n + (i + 1)].to_f64();
let z1_im = z_imag[row * n + i].to_f64();
let z2_im = z_imag[row * n + (i + 1)].to_f64();
z_real[row * n + i] = T::from_f64(z1_re * u0_re + z2_re * u1_re - z2_im * u1_im);
z_imag[row * n + i] = T::from_f64(z1_im * u0_re + z2_im * u1_re + z2_re * u1_im);
z_real[row * n + (i + 1)] = T::from_f64(z1_re * u0_re + z2_re * u1_re + z2_im * u1_im);
z_imag[row * n + (i + 1)] = T::from_f64(z1_im * u0_re + z2_im * u1_re - z2_re * u1_im);
}
for row in 0..i {
let t1_re = t_real[row * n + i].to_f64();
let t2_re = t_real[row * n + (i + 1)].to_f64();
let t1_im = t_imag[row * n + i].to_f64();
let t2_im = t_imag[row * n + (i + 1)].to_f64();
t_real[row * n + i] = T::from_f64(t1_re * u0_re + t2_re * u1_re - t2_im * u1_im);
t_imag[row * n + i] = T::from_f64(t1_im * u0_re + t2_im * u1_re + t2_re * u1_im);
t_real[row * n + (i + 1)] = T::from_f64(t1_re * u0_re + t2_re * u1_re + t2_im * u1_im);
t_imag[row * n + (i + 1)] = T::from_f64(t1_im * u0_re + t2_im * u1_re - t2_re * u1_im);
}
}