use crate::error::{LinalgError, LinalgResult};
use crate::schur_enhanced::{real_schur_decompose, SchurFloat};
use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
use scirs2_core::numeric::{Complex, Float, NumAssign, One, Zero};
use std::fmt::{Debug, Display};
use std::iter::Sum;
pub fn solve_sylvester<T: SchurFloat>(
a: &ArrayView2<T>,
b: &ArrayView2<T>,
c: &ArrayView2<T>,
) -> LinalgResult<Array2<T>> {
let m = a.nrows();
let n = b.nrows();
if a.ncols() != m {
return Err(LinalgError::ShapeError("A must be square".into()));
}
if b.ncols() != n {
return Err(LinalgError::ShapeError("B must be square".into()));
}
if c.nrows() != m || c.ncols() != n {
return Err(LinalgError::DimensionError(format!(
"C must be {m}×{n}, got {}×{}",
c.nrows(),
c.ncols()
)));
}
let schur_a = real_schur_decompose(a, 300, T::epsilon() * T::from(100.0).unwrap_or(T::one()))?;
let schur_b = real_schur_decompose(b, 300, T::epsilon() * T::from(100.0).unwrap_or(T::one()))?;
let sa = &schur_a.t; let ua = &schur_a.q; let sb = &schur_b.t;
let ub = &schur_b.q;
let c_tilde: Array2<T> = ua.t().dot(c).dot(ub);
let x_tilde = bartels_stewart_solve(sa, sb, &c_tilde.view())?;
Ok(ua.dot(&x_tilde).dot(&ub.t()))
}
fn bartels_stewart_solve<T: SchurFloat>(
sa: &Array2<T>,
sb: &Array2<T>,
c: &ArrayView2<T>,
) -> LinalgResult<Array2<T>> {
let m = sa.nrows();
let n = sb.nrows();
let tol = T::epsilon() * T::from(100.0).unwrap_or(T::one());
let mut x = Array2::<T>::zeros((m, n));
let b_blocks = identify_quasi_blocks(sb, tol);
let mut rhs_x = c.to_owned();
let mut col = 0usize;
for &(blk_start, blk_size) in &b_blocks {
if blk_size == 1 {
let mu = sb[[blk_start, blk_start]];
let rhs_col = rhs_x.column(col).to_owned();
let x_col = solve_shifted_upper_quasi(sa, mu, &rhs_col.view(), tol)?;
x.column_mut(col).assign(&x_col);
for j in (blk_start + 1)..n {
let sb_val = sb[[blk_start, j]];
if sb_val.abs() > T::epsilon() {
x.column(col)
.iter()
.zip(rhs_x.column_mut(j).iter_mut())
.for_each(|(&xi, rj)| *rj -= xi * sb_val);
}
}
} else {
let sb00 = sb[[blk_start, blk_start]];
let sb01 = sb[[blk_start, blk_start + 1]];
let sb10 = sb[[blk_start + 1, blk_start]];
let sb11 = sb[[blk_start + 1, blk_start + 1]];
let rhs0 = rhs_x.column(col).to_owned();
let rhs1 = rhs_x.column(col + 1).to_owned();
let (x0, x1) = solve_shifted_upper_quasi_2x2(
sa,
sb00,
sb01,
sb10,
sb11,
&rhs0.view(),
&rhs1.view(),
tol,
)?;
x.column_mut(col).assign(&x0);
x.column_mut(col + 1).assign(&x1);
for j in (blk_start + 2)..n {
let sb0 = sb[[blk_start, j]];
let sb1 = sb[[blk_start + 1, j]];
for i in 0..m {
rhs_x[[i, j]] -= x0[i] * sb0 + x1[i] * sb1;
}
}
}
col += blk_size;
}
Ok(x)
}
fn identify_quasi_blocks<T: SchurFloat>(t: &Array2<T>, tol: T) -> Vec<(usize, usize)> {
let n = t.nrows();
let mut blocks = Vec::new();
let mut k = 0;
while k < n {
if k + 1 < n && t[[k + 1, k]].abs() > tol {
blocks.push((k, 2));
k += 2;
} else {
blocks.push((k, 1));
k += 1;
}
}
blocks
}
fn solve_shifted_upper_quasi<T: SchurFloat>(
sa: &Array2<T>,
mu: T,
rhs: &scirs2_core::ndarray::ArrayView1<T>,
tol: T,
) -> LinalgResult<Array1<T>> {
let n = sa.nrows();
let blocks = identify_quasi_blocks(sa, tol);
let nb = blocks.len();
let mut x = Array1::<T>::zeros(n);
let mut rhs_copy: Array1<T> = rhs.to_owned();
let mut bidx = nb;
loop {
if bidx == 0 {
break;
}
bidx -= 1;
let (blk_start, blk_size) = blocks[bidx];
if blk_size == 1 {
let mut s = rhs_copy[blk_start];
for j in (blk_start + 1)..n {
s -= sa[[blk_start, j]] * x[j];
}
let diag = sa[[blk_start, blk_start]] + mu;
if diag.abs() < tol {
return Err(LinalgError::SingularMatrixError(format!(
"solve_shifted_upper_quasi: near-zero diagonal at {blk_start} (shift {mu:?})"
)));
}
x[blk_start] = s / diag;
} else {
let r0 = blk_start;
let r1 = blk_start + 1;
let mut s0 = rhs_copy[r0];
let mut s1 = rhs_copy[r1];
for j in (r1 + 1)..n {
s0 -= sa[[r0, j]] * x[j];
s1 -= sa[[r1, j]] * x[j];
}
let a = sa[[r0, r0]] + mu;
let b = sa[[r0, r1]];
let c = sa[[r1, r0]];
let d = sa[[r1, r1]] + mu;
let det = a * d - b * c;
if det.abs() < tol {
return Err(LinalgError::SingularMatrixError(format!(
"solve_shifted_upper_quasi: singular 2x2 block at {blk_start}"
)));
}
x[r0] = (d * s0 - b * s1) / det;
x[r1] = (-c * s0 + a * s1) / det;
}
}
Ok(x)
}
fn solve_shifted_upper_quasi_2x2<T: SchurFloat>(
sa: &Array2<T>,
sb00: T,
sb01: T,
sb10: T,
sb11: T,
rhs0: &scirs2_core::ndarray::ArrayView1<T>,
rhs1: &scirs2_core::ndarray::ArrayView1<T>,
tol: T,
) -> LinalgResult<(Array1<T>, Array1<T>)> {
let n = sa.nrows();
let blocks = identify_quasi_blocks(sa, tol);
let nb = blocks.len();
let mut x0 = Array1::<T>::zeros(n);
let mut x1 = Array1::<T>::zeros(n);
let mut bidx = nb;
loop {
if bidx == 0 {
break;
}
bidx -= 1;
let (blk_start, blk_size) = blocks[bidx];
if blk_size == 1 {
let mut s0 = rhs0[blk_start];
let mut s1 = rhs1[blk_start];
for j in (blk_start + 1)..n {
s0 -= sa[[blk_start, j]] * x0[j];
s1 -= sa[[blk_start, j]] * x1[j];
}
let d0 = sa[[blk_start, blk_start]] + sb00;
let d1 = sa[[blk_start, blk_start]] + sb11;
let det = d0 * d1 - sb01 * sb10;
if det.abs() < tol {
return Err(LinalgError::SingularMatrixError(format!(
"solve_shifted_upper_quasi_2x2: near-zero det at {blk_start}"
)));
}
x0[blk_start] = (d1 * s0 - sb10 * s1) / det;
x1[blk_start] = (-sb01 * s0 + d0 * s1) / det;
} else {
let r0 = blk_start;
let r1 = blk_start + 1;
let mut s00 = rhs0[r0];
let mut s01 = rhs0[r1];
let mut s10 = rhs1[r0];
let mut s11_rhs = rhs1[r1];
for j in (r1 + 1)..n {
s00 -= sa[[r0, j]] * x0[j];
s01 -= sa[[r1, j]] * x0[j];
s10 -= sa[[r0, j]] * x1[j];
s11_rhs -= sa[[r1, j]] * x1[j];
}
let sa00 = sa[[r0, r0]];
let sa01 = sa[[r0, r1]];
let sa10 = sa[[r1, r0]];
let sa11_val = sa[[r1, r1]];
let mut mat = Array2::<T>::zeros((4, 4));
mat[[0, 0]] = sa00 + sb00;
mat[[0, 1]] = sa01;
mat[[0, 2]] = sb10;
mat[[0, 3]] = T::zero();
mat[[1, 0]] = sa10;
mat[[1, 1]] = sa11_val + sb00;
mat[[1, 2]] = T::zero();
mat[[1, 3]] = sb10;
mat[[2, 0]] = sb01;
mat[[2, 1]] = T::zero();
mat[[2, 2]] = sa00 + sb11;
mat[[2, 3]] = sa01;
mat[[3, 0]] = T::zero();
mat[[3, 1]] = sb01;
mat[[3, 2]] = sa10;
mat[[3, 3]] = sa11_val + sb11;
let rhs4 = Array1::from_vec(vec![s00, s01, s10, s11_rhs]);
let sol4 = crate::solve::solve(&mat.view(), &rhs4.view(), None)?;
x0[r0] = sol4[0];
x0[r1] = sol4[1];
x1[r0] = sol4[2];
x1[r1] = sol4[3];
}
}
Ok((x0, x1))
}
pub fn solve_continuous_lyapunov<T: SchurFloat>(
a: &ArrayView2<T>,
q: &ArrayView2<T>,
) -> LinalgResult<Array2<T>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError("A must be square".into()));
}
if q.nrows() != n || q.ncols() != n {
return Err(LinalgError::DimensionError("Q must be n×n".into()));
}
let at = a.t().to_owned();
let neg_q = q.mapv(|v| -v);
let x = solve_sylvester(a, &at.view(), &neg_q.view())?;
let half = T::from(0.5).unwrap_or_else(|| T::one() / (T::one() + T::one()));
Ok((&x + &x.t()) * half)
}
pub fn solve_discrete_lyapunov<T: SchurFloat>(
a: &ArrayView2<T>,
q: &ArrayView2<T>,
) -> LinalgResult<Array2<T>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError("A must be square".into()));
}
if q.nrows() != n || q.ncols() != n {
return Err(LinalgError::DimensionError("Q must be n×n".into()));
}
let tol = T::epsilon() * T::from(100.0).unwrap_or(T::one());
let schur = real_schur_decompose(a, 300, tol)?;
let u = &schur.q;
let t = &schur.t;
let c = u.t().dot(&q.to_owned()).dot(u);
let y = solve_discrete_lyapunov_triangular(t, &c, tol)?;
let x = u.dot(&y).dot(&u.t());
let half = T::from(0.5).unwrap_or_else(|| T::one() / (T::one() + T::one()));
Ok((&x + &x.t()) * half)
}
fn solve_discrete_lyapunov_triangular<T: SchurFloat>(
t: &Array2<T>,
c: &Array2<T>,
tol: T,
) -> LinalgResult<Array2<T>> {
let n = t.nrows();
let mut y = Array2::<T>::zeros((n, n));
let blocks = identify_quasi_blocks(t, tol);
let nb = blocks.len();
for bj_idx in (0..nb).rev() {
let (j_start, j_size) = blocks[bj_idx];
for bi_idx in (0..nb).rev() {
let (i_start, i_size) = blocks[bi_idx];
if i_size == 1 && j_size == 1 {
let i = i_start;
let j = j_start;
let mut rhs = c[[i, j]];
for k in 0..n {
for l in 0..n {
if k == i && l == j {
continue;
}
let k_block = find_block_index(&blocks, k);
let l_block = find_block_index(&blocks, l);
let already_computed =
l_block > bj_idx || (l_block == bj_idx && k_block > bi_idx);
if already_computed {
rhs += t[[i, k]] * y[[k, l]] * t[[j, l]];
}
}
}
let denom = t[[i, i]] * t[[j, j]] - T::one();
if denom.abs() < tol {
y[[i, j]] = T::zero();
} else {
y[[i, j]] = -rhs / denom;
}
} else {
let rows: Vec<usize> = (i_start..i_start + i_size).collect();
let cols: Vec<usize> = (j_start..j_start + j_size).collect();
let sys_size = i_size * j_size;
let mut rhs_vec = vec![T::zero(); sys_size];
for (ri, &i) in rows.iter().enumerate() {
for (ci, &j) in cols.iter().enumerate() {
let idx = ri * j_size + ci;
rhs_vec[idx] = c[[i, j]];
for k in 0..n {
for l in 0..n {
let is_current = rows.contains(&k) && cols.contains(&l);
if is_current {
continue;
}
let k_block = find_block_index(&blocks, k);
let l_block = find_block_index(&blocks, l);
let already_computed =
l_block > bj_idx || (l_block == bj_idx && k_block > bi_idx);
if already_computed {
rhs_vec[idx] += t[[i, k]] * y[[k, l]] * t[[j, l]];
}
}
}
}
}
let mut mat = vec![T::zero(); sys_size * sys_size];
for (ri, &i) in rows.iter().enumerate() {
for (ci, &j) in cols.iter().enumerate() {
let row_idx = ri * j_size + ci;
for (rk, &k) in rows.iter().enumerate() {
for (cl, &l) in cols.iter().enumerate() {
let col_idx = rk * j_size + cl;
mat[row_idx * sys_size + col_idx] += t[[i, k]] * t[[j, l]];
}
}
mat[row_idx * sys_size + row_idx] -= T::one();
}
}
let neg_rhs: Vec<T> = rhs_vec.iter().map(|&v| -v).collect();
let y_vec = solve_small_dense(&mat, &neg_rhs, sys_size)?;
for (ri, &i) in rows.iter().enumerate() {
for (ci, &j) in cols.iter().enumerate() {
y[[i, j]] = y_vec[ri * j_size + ci];
}
}
}
}
}
Ok(y)
}
fn find_block_index(blocks: &[(usize, usize)], idx: usize) -> usize {
for (bi, &(start, size)) in blocks.iter().enumerate() {
if idx >= start && idx < start + size {
return bi;
}
}
blocks.len() }
fn solve_small_dense<T: SchurFloat>(mat: &[T], rhs: &[T], n: usize) -> LinalgResult<Vec<T>> {
let mut a = vec![T::zero(); n * n];
a.copy_from_slice(&mat[..n * n]);
let mut b = rhs.to_vec();
for col in 0..n {
let mut max_val = a[col * n + col].abs();
let mut max_row = col;
for row in (col + 1)..n {
let val = a[row * n + col].abs();
if val > max_val {
max_val = val;
max_row = row;
}
}
if max_val < T::epsilon() {
return Ok(vec![T::zero(); n]);
}
if max_row != col {
for j in 0..n {
a.swap(col * n + j, max_row * n + j);
}
b.swap(col, max_row);
}
let pivot = a[col * n + col];
for row in (col + 1)..n {
let factor = a[row * n + col] / pivot;
for j in col..n {
let val = a[col * n + j];
a[row * n + j] -= factor * val;
}
let b_col = b[col];
b[row] -= factor * b_col;
}
}
let mut x = vec![T::zero(); n];
for i in (0..n).rev() {
let mut s = b[i];
for j in (i + 1)..n {
s -= a[i * n + j] * x[j];
}
let diag = a[i * n + i];
if diag.abs() < T::epsilon() {
x[i] = T::zero();
} else {
x[i] = s / diag;
}
}
Ok(x)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RiccatiType {
Continuous,
Discrete,
}
pub fn solve_algebraic_riccati<T: SchurFloat>(
a: &ArrayView2<T>,
b: &ArrayView2<T>,
q: &ArrayView2<T>,
r: &ArrayView2<T>,
riccati_type: RiccatiType,
) -> LinalgResult<Array2<T>> {
match riccati_type {
RiccatiType::Continuous => solve_care(a, b, q, r),
RiccatiType::Discrete => solve_dare(a, b, q, r),
}
}
fn solve_care<T: SchurFloat>(
a: &ArrayView2<T>,
b: &ArrayView2<T>,
q: &ArrayView2<T>,
r: &ArrayView2<T>,
) -> LinalgResult<Array2<T>> {
let n = a.nrows();
validate_riccati_inputs(a, b, q, r)?;
let r_inv = crate::basic::inv(r, None)?;
let b_r_inv_bt: Array2<T> = b.dot(&r_inv).dot(&b.t());
let max_iter = 200usize;
let conv_tol = T::from(1e-10).unwrap_or(T::epsilon());
let mut x = find_stabilizing_initial_x(a, b, &b_r_inv_bt, q, n)?;
for _iter in 0..max_iter {
let a_k = a.to_owned() - b_r_inv_bt.dot(&x);
let xbr_bt_x = x.dot(&b_r_inv_bt).dot(&x);
let rhs = q.to_owned() + xbr_bt_x;
let a_k_t = a_k.t().to_owned();
let x_new = match solve_continuous_lyapunov(&a_k_t.view(), &rhs.view()) {
Ok(xn) => xn,
Err(_) => break, };
let diff: T = (&x_new - &x).iter().map(|&v| v * v).sum::<T>().sqrt();
let x_norm: T = x_new.iter().map(|&v| v * v).sum::<T>().sqrt();
x = x_new;
if diff <= conv_tol * (x_norm + T::one()) {
break;
}
}
let half = T::from(0.5).unwrap_or_else(|| T::one() / (T::one() + T::one()));
Ok((&x + &x.t()) * half)
}
fn find_stabilizing_initial_x<T: SchurFloat>(
a: &ArrayView2<T>,
_b: &ArrayView2<T>,
b_r_inv_bt: &Array2<T>,
_q: &ArrayView2<T>,
n: usize,
) -> LinalgResult<Array2<T>> {
let a_owned = a.to_owned();
let eye = Array2::<T>::eye(n);
let ones = Array2::<T>::ones((n, n));
for k in 1..200 {
let alpha = T::from(k as f64 * 0.5).unwrap_or(T::one());
let x0 = (&eye + &ones) * alpha;
let a_cl = &a_owned - &b_r_inv_bt.dot(&x0);
let mut all_stable = true;
for i in 0..n {
let diag = a_cl[[i, i]];
let mut off_diag_sum = T::zero();
for j in 0..n {
if j != i {
off_diag_sum += a_cl[[i, j]].abs();
}
}
if diag + off_diag_sum >= T::zero() {
all_stable = false;
break;
}
}
if all_stable {
return Ok(x0);
}
}
let alpha = T::from(100.0).unwrap_or(T::one());
Ok((&eye + &ones) * alpha)
}
fn solve_dare<T: SchurFloat>(
a: &ArrayView2<T>,
b: &ArrayView2<T>,
q: &ArrayView2<T>,
r: &ArrayView2<T>,
) -> LinalgResult<Array2<T>> {
let n = a.nrows();
let m = b.ncols();
validate_riccati_inputs(a, b, q, r)?;
let max_iter = 100usize;
let conv_tol = T::epsilon() * T::from(1e6).unwrap_or(T::one());
let mut k_gain = Array2::<T>::zeros((m, n));
let a_owned = a.to_owned();
let b_owned = b.to_owned();
let q_owned = q.to_owned();
let r_owned = r.to_owned();
let mut x = q.to_owned();
for _iter in 0..max_iter {
let a_cl = &a_owned - b_owned.dot(&k_gain);
let rhs = &q_owned + k_gain.t().dot(&r_owned).dot(&k_gain);
let x_new = solve_discrete_lyapunov(&a_cl.view(), &rhs.view())?;
let bt_x = b_owned.t().dot(&x_new);
let bt_x_b = bt_x.dot(&b_owned);
let r_plus = &r_owned + bt_x_b;
let r_plus_inv = crate::basic::inv(&r_plus.view(), None)?;
let k_new = r_plus_inv.dot(&bt_x).dot(&a_owned);
let diff: T = (&x_new - &x).iter().map(|&v| v * v).sum::<T>().sqrt();
let x_norm: T = x_new.iter().map(|&v| v * v).sum::<T>().sqrt();
x = x_new;
k_gain = k_new;
if diff <= conv_tol * (x_norm + T::one()) {
break;
}
}
let half = T::from(0.5).unwrap_or_else(|| T::one() / (T::one() + T::one()));
Ok((&x + &x.t()) * half)
}
fn validate_riccati_inputs<T: SchurFloat>(
a: &ArrayView2<T>,
b: &ArrayView2<T>,
q: &ArrayView2<T>,
r: &ArrayView2<T>,
) -> LinalgResult<()> {
let n = a.nrows();
let m = b.ncols();
if a.ncols() != n {
return Err(LinalgError::ShapeError("A must be square".into()));
}
if b.nrows() != n {
return Err(LinalgError::DimensionError(format!(
"B must have {n} rows, got {}",
b.nrows()
)));
}
if q.nrows() != n || q.ncols() != n {
return Err(LinalgError::DimensionError("Q must be n×n".into()));
}
if r.nrows() != m || r.ncols() != m {
return Err(LinalgError::DimensionError("R must be m×m".into()));
}
Ok(())
}
fn extract_schur_eigenvalues<T: SchurFloat>(t: &Array2<T>, tol: T) -> Vec<Complex<T>> {
let n = t.nrows();
let mut eigenvalues = Vec::with_capacity(n);
let mut k = 0usize;
while k < n {
if k + 1 < n && t[[k + 1, k]].abs() > tol {
let tr = t[[k, k]] + t[[k + 1, k + 1]];
let two = T::one() + T::one();
let det = t[[k, k]] * t[[k + 1, k + 1]] - t[[k, k + 1]] * t[[k + 1, k]];
let disc = tr * tr - two * two * det;
if disc >= T::zero() {
let sq = disc.sqrt();
eigenvalues.push(Complex::new((tr + sq) / two, T::zero()));
eigenvalues.push(Complex::new((tr - sq) / two, T::zero()));
} else {
let sq = (-disc).sqrt() / two;
eigenvalues.push(Complex::new(tr / two, sq));
eigenvalues.push(Complex::new(tr / two, -sq));
}
k += 2;
} else {
eigenvalues.push(Complex::new(t[[k, k]], T::zero()));
k += 1;
}
}
eigenvalues
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn max_abs_err(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs())
.fold(0.0_f64, f64::max)
}
#[test]
fn test_sylvester_diagonal() {
let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
let b = array![[-3.0_f64, 0.0], [0.0, -4.0]];
let c = array![[1.0_f64, 2.0], [3.0, 4.0]];
let x = solve_sylvester(&a.view(), &b.view(), &c.view()).expect("ok");
let resid = a.dot(&x) + x.dot(&b) - &c;
assert!(max_abs_err(&resid, &Array2::zeros((2, 2))) < 1e-9);
}
#[test]
fn test_sylvester_general_2x2() {
let a = array![[2.0_f64, 1.0], [0.0, 3.0]];
let b = array![[-5.0_f64, 0.0], [0.0, -4.0]];
let c = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = solve_sylvester(&a.view(), &b.view(), &c.view()).expect("ok");
let resid = a.dot(&x) + x.dot(&b) - &c;
let err = resid.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
assert!(err < 1e-8, "Sylvester residual {err}");
}
#[test]
fn test_continuous_lyapunov_stable() {
let a = array![[-1.0_f64, 0.5], [0.0, -2.0]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = solve_continuous_lyapunov(&a.view(), &q.view()).expect("ok");
let resid = a.dot(&x) + x.dot(&a.t()) + &q;
let err = resid.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
assert!(err < 1e-7, "CLyapunov residual {err}");
}
#[test]
fn test_discrete_lyapunov() {
let a = array![[0.5_f64, 0.1], [0.0, 0.6]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = solve_discrete_lyapunov(&a.view(), &q.view()).expect("ok");
let resid = a.dot(&x).dot(&a.t()) - &x + &q;
let err = resid.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
assert!(err < 1e-7, "DLyapunov residual {err}");
}
#[test]
fn test_care_double_integrator() {
let a = array![[0.0_f64, 1.0], [0.0, 0.0]];
let b = array![[0.0_f64], [1.0]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let r = array![[1.0_f64]];
let x = solve_algebraic_riccati(
&a.view(),
&b.view(),
&q.view(),
&r.view(),
RiccatiType::Continuous,
)
.expect("CARE ok");
assert!(x[[0, 0]] > 0.0, "CARE solution should be PD");
assert!(x[[1, 1]] > 0.0, "CARE solution should be PD");
let r_inv = array![[1.0_f64]];
let xbrbt = x.dot(&b).dot(&r_inv).dot(&b.t()).dot(&x);
let resid = a.t().dot(&x) + x.dot(&a) - xbrbt + &q;
let err = resid.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
assert!(err < 1e-5, "CARE residual {err}");
}
}