use numr::algorithm::linalg::LinearAlgebraAlgorithms;
use numr::dtype::DType;
use numr::error::{Error, Result};
use numr::ops::{MatmulOps, ScalarOps, ShapeOps, TensorOps, UnaryOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn sylvester_impl<R, C>(
client: &C,
a: &Tensor<R>,
b: &Tensor<R>,
c: &Tensor<R>,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ MatmulOps<R>
+ UtilityOps<R>
+ LinearAlgebraAlgorithms<R>
+ RuntimeClient<R>,
{
let a_shape = a.shape();
let b_shape = b.shape();
let c_shape = c.shape();
if a_shape.len() != 2 || a_shape[0] != a_shape[1] {
return Err(Error::InvalidArgument {
arg: "a",
reason: "Sylvester: A must be square".into(),
});
}
if b_shape.len() != 2 || b_shape[0] != b_shape[1] {
return Err(Error::InvalidArgument {
arg: "b",
reason: "Sylvester: B must be square".into(),
});
}
let na = a_shape[0];
let mb = b_shape[0];
if c_shape != [na, mb] {
return Err(Error::InvalidArgument {
arg: "c",
reason: format!("Sylvester: C must be {}×{}, got {:?}", na, mb, c_shape),
});
}
let dtype = a.dtype();
if dtype != DType::F32 && dtype != DType::F64 {
return Err(Error::UnsupportedDType {
dtype,
op: "solve_sylvester",
});
}
let schur_a = client.schur_decompose(a)?;
let schur_b = client.schur_decompose(b)?;
let u = &schur_a.z; let ta = &schur_a.t; let v = &schur_b.z; let tb = &schur_b.t;
let ut = u.transpose(0, 1)?;
let f = client.matmul(&client.matmul(&ut, c)?, v)?;
let ta_data = ta.to_vec::<f64>();
let tb_data = tb.to_vec::<f64>();
let f_data = f.to_vec::<f64>();
let mut y_data = vec![0.0f64; na * mb];
let mut j = 0;
while j < mb {
let block_size = if j + 1 < mb && tb_data[(j + 1) * mb + j].abs() > 1e-10 {
2
} else {
1
};
if block_size == 1 {
let tbjj = tb_data[j * mb + j];
let mut rhs = vec![0.0f64; na];
for i in 0..na {
rhs[i] = f_data[i * mb + j];
for k in 0..j {
rhs[i] -= tb_data[k * mb + j] * y_data[i * mb + k];
}
}
solve_quasi_upper_shifted(&ta_data, na, tbjj, &rhs, &mut y_data, mb, j);
} else {
let tb_jj = tb_data[j * mb + j];
let tb_jj1 = tb_data[j * mb + j + 1];
let tb_j1j = tb_data[(j + 1) * mb + j];
let tb_j1j1 = tb_data[(j + 1) * mb + j + 1];
let mut rhs0 = vec![0.0f64; na];
let mut rhs1 = vec![0.0f64; na];
for i in 0..na {
rhs0[i] = f_data[i * mb + j];
rhs1[i] = f_data[i * mb + j + 1];
for k in 0..j {
rhs0[i] -= tb_data[k * mb + j] * y_data[i * mb + k];
rhs1[i] -= tb_data[k * mb + j + 1] * y_data[i * mb + k];
}
}
solve_quasi_upper_coupled(
&ta_data,
na,
tb_jj,
tb_jj1,
tb_j1j,
tb_j1j1,
&rhs0,
&rhs1,
&mut y_data,
mb,
j,
);
j += 1; }
j += 1;
}
let device = a.device();
let y = Tensor::from_slice(&y_data, &[na, mb], device);
let vt = v.transpose(0, 1)?;
client.matmul(&client.matmul(u, &y)?, &vt)
}
fn solve_quasi_upper_shifted(
ta: &[f64],
n: usize,
shift: f64,
rhs: &[f64],
y: &mut [f64],
y_cols: usize,
col: usize,
) {
let mut x = vec![0.0f64; n];
let mut i = n;
while i > 0 {
i -= 1;
if i > 0 && ta[i * n + i - 1].abs() > 1e-10 {
let i0 = i - 1;
let a00 = ta[i0 * n + i0] + shift;
let a01 = ta[i0 * n + i];
let a10 = ta[i * n + i0];
let a11 = ta[i * n + i] + shift;
let mut r0 = rhs[i0];
let mut r1 = rhs[i];
for k in i + 1..n {
r0 -= ta[i0 * n + k] * x[k];
r1 -= ta[i * n + k] * x[k];
}
let det = a00 * a11 - a01 * a10;
if det.abs() > 1e-15 {
x[i0] = (a11 * r0 - a01 * r1) / det;
x[i] = (-a10 * r0 + a00 * r1) / det;
}
i -= 1; } else {
let diag = ta[i * n + i] + shift;
let mut r = rhs[i];
for k in i + 1..n {
r -= ta[i * n + k] * x[k];
}
if diag.abs() > 1e-15 {
x[i] = r / diag;
}
}
}
for i in 0..n {
y[i * y_cols + col] = x[i];
}
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::needless_range_loop)]
fn solve_quasi_upper_coupled(
ta: &[f64],
n: usize,
tb00: f64,
tb01: f64,
tb10: f64,
tb11: f64,
rhs0: &[f64],
rhs1: &[f64],
y: &mut [f64],
y_cols: usize,
col: usize,
) {
let mut x0 = vec![0.0f64; n];
let mut x1 = vec![0.0f64; n];
let mut i = n;
while i > 0 {
i -= 1;
if i > 0 && ta[i * n + i - 1].abs() > 1e-10 {
let i0 = i - 1;
let a00 = ta[i0 * n + i0];
let a01 = ta[i0 * n + i];
let a10 = ta[i * n + i0];
let a11 = ta[i * n + i];
let mut r00 = rhs0[i0];
let mut r10 = rhs0[i];
let mut r01 = rhs1[i0];
let mut r11 = rhs1[i];
for k in i + 1..n {
r00 -= ta[i0 * n + k] * x0[k];
r10 -= ta[i * n + k] * x0[k];
r01 -= ta[i0 * n + k] * x1[k];
r11 -= ta[i * n + k] * x1[k];
}
let mut m = [
[a00 + tb00, a01, tb01, 0.0],
[a10, a11 + tb00, 0.0, tb01],
[tb10, 0.0, a00 + tb11, a01],
[0.0, tb10, a10, a11 + tb11],
];
let mut rhs_4 = [r00, r10, r01, r11];
for col_idx in 0..4 {
let mut max_row = col_idx;
let mut max_val = m[col_idx][col_idx].abs();
for row in col_idx + 1..4 {
if m[row][col_idx].abs() > max_val {
max_val = m[row][col_idx].abs();
max_row = row;
}
}
if max_val < 1e-15 {
continue;
}
if max_row != col_idx {
m.swap(col_idx, max_row);
rhs_4.swap(col_idx, max_row);
}
let pivot = m[col_idx][col_idx];
for row in col_idx + 1..4 {
let factor = m[row][col_idx] / pivot;
for jj in col_idx..4 {
m[row][jj] -= factor * m[col_idx][jj];
}
rhs_4[row] -= factor * rhs_4[col_idx];
}
}
let mut sol = [0.0f64; 4];
for ii in (0..4).rev() {
let mut s = rhs_4[ii];
for jj in ii + 1..4 {
s -= m[ii][jj] * sol[jj];
}
if m[ii][ii].abs() > 1e-15 {
sol[ii] = s / m[ii][ii];
}
}
x0[i0] = sol[0];
x0[i] = sol[1];
x1[i0] = sol[2];
x1[i] = sol[3];
i -= 1;
} else {
let aii = ta[i * n + i];
let mut r0 = rhs0[i];
let mut r1 = rhs1[i];
for k in i + 1..n {
r0 -= ta[i * n + k] * x0[k];
r1 -= ta[i * n + k] * x1[k];
}
let m00 = aii + tb00;
let m01 = tb01;
let m10 = tb10;
let m11 = aii + tb11;
let det = m00 * m11 - m01 * m10;
if det.abs() > 1e-15 {
x0[i] = (m11 * r0 - m01 * r1) / det;
x1[i] = (-m10 * r0 + m00 * r1) / det;
}
}
}
for i in 0..n {
y[i * y_cols + col] = x0[i];
y[i * y_cols + col + 1] = x1[i];
}
}