use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::{Debug, Display};
use std::iter::Sum;
pub trait CtrlFloat:
Float + NumAssign + Sum + ScalarOperand + Debug + Display + Send + Sync + 'static
{
}
impl<F> CtrlFloat for F where
F: Float + NumAssign + Sum + ScalarOperand + Debug + Display + Send + Sync + 'static
{
}
fn matmul_sq<F: CtrlFloat>(a: &Array2<F>, b: &Array2<F>, n: usize) -> Array2<F> {
let mut c = Array2::<F>::zeros((n, n));
for i in 0..n {
for k in 0..n {
let a_ik = a[[i, k]];
if a_ik == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] = c[[i, j]] + a_ik * b[[k, j]];
}
}
}
c
}
fn transpose_sq<F: CtrlFloat>(a: &Array2<F>, n: usize) -> Array2<F> {
let mut t = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
t[[j, i]] = a[[i, j]];
}
}
t
}
fn mat_add_inplace<F: CtrlFloat>(a: &mut Array2<F>, b: &Array2<F>, n: usize) {
for i in 0..n {
for j in 0..n {
a[[i, j]] = a[[i, j]] + b[[i, j]];
}
}
}
fn mat_inv_gauss<F: CtrlFloat>(a: &Array2<F>, n: usize) -> LinalgResult<Array2<F>> {
let mut aug: Vec<Vec<F>> = (0..n)
.map(|i| {
let mut row: Vec<F> = (0..n).map(|j| a[[i, j]]).collect();
for j in 0..n {
row.push(if i == j { F::one() } else { F::zero() });
}
row
})
.collect();
let eps = F::epsilon() * F::from(1_000.0).unwrap_or(F::one());
for col in 0..n {
let pivot_row = (col..n)
.max_by(|&r1, &r2| {
aug[r1][col]
.abs()
.partial_cmp(&aug[r2][col].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(col);
aug.swap(col, pivot_row);
let pivot = aug[col][col];
if pivot.abs() < eps {
return Err(LinalgError::SingularMatrixError(
"Singular matrix in Gauss-Jordan inversion (control_theory)".into(),
));
}
let inv_pivot = F::one() / pivot;
for j in 0..2 * n {
aug[col][j] = aug[col][j] * inv_pivot;
}
for i in 0..n {
if i != col {
let factor = aug[i][col];
if factor == F::zero() {
continue;
}
for j in 0..2 * n {
let v = aug[col][j];
aug[i][j] = aug[i][j] - factor * v;
}
}
}
}
let mut inv = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
inv[[i, j]] = aug[i][n + j];
}
}
Ok(inv)
}
fn solve_linear<F: CtrlFloat>(a: &[Vec<F>], b: &[F], n: usize) -> LinalgResult<Vec<F>> {
let mut aug: Vec<Vec<F>> = a
.iter()
.zip(b.iter())
.map(|(row, &bi)| {
let mut r = row.clone();
r.push(bi);
r
})
.collect();
let eps = F::epsilon() * F::from(1_000.0).unwrap_or(F::one());
for col in 0..n {
let pivot_row = (col..n)
.max_by(|&r1, &r2| {
aug[r1][col]
.abs()
.partial_cmp(&aug[r2][col].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(col);
aug.swap(col, pivot_row);
let pivot = aug[col][col];
if pivot.abs() < eps {
return Err(LinalgError::SingularMatrixError(
"Singular system in solve_linear (control_theory)".into(),
));
}
let inv_pivot = F::one() / pivot;
for j in 0..=n {
aug[col][j] = aug[col][j] * inv_pivot;
}
for i in 0..n {
if i != col {
let factor = aug[i][col];
if factor == F::zero() {
continue;
}
for j in 0..=n {
let v = aug[col][j];
aug[i][j] = aug[i][j] - factor * v;
}
}
}
}
Ok(aug.iter().map(|row| row[n]).collect())
}
pub fn lyapunov_continuous<F: CtrlFloat>(
a: &ArrayView2<F>,
q: &ArrayView2<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
check_square(a, "lyapunov_continuous: A")?;
if q.nrows() != n || q.ncols() != n {
return Err(LinalgError::ShapeError(
"lyapunov_continuous: Q must be the same size as A".into(),
));
}
let n2 = n * n;
let mut kron: Vec<Vec<F>> = vec![vec![F::zero(); n2]; n2];
for i in 0..n {
for j in 0..n {
for k in 0..n {
kron[i * n + k][i * n + j] = kron[i * n + k][i * n + j] + a[[k, j]];
}
let a_ij = a[[i, j]];
for k in 0..n {
kron[i * n + k][j * n + k] = kron[i * n + k][j * n + k] + a_ij;
}
}
}
let rhs: Vec<F> = (0..n)
.flat_map(|i| (0..n).map(move |j| -q[[i, j]]))
.collect();
let x_vec = solve_linear(&kron, &rhs, n2)?;
let mut x = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
x[[i, j]] = x_vec[i * n + j];
}
}
Ok(x)
}
pub fn lyapunov_discrete<F: CtrlFloat>(
a: &ArrayView2<F>,
q: &ArrayView2<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
check_square(a, "lyapunov_discrete: A")?;
if q.nrows() != n || q.ncols() != n {
return Err(LinalgError::ShapeError(
"lyapunov_discrete: Q must be the same size as A".into(),
));
}
let n2 = n * n;
let mut kron: Vec<Vec<F>> = vec![vec![F::zero(); n2]; n2];
for i in 0..n {
for j in 0..n {
let a_ij = a[[i, j]];
for k in 0..n {
for l in 0..n {
kron[i * n + k][j * n + l] = kron[i * n + k][j * n + l] + a_ij * a[[k, l]];
}
}
}
}
for i in 0..n2 {
kron[i][i] = kron[i][i] - F::one();
}
let rhs: Vec<F> = (0..n)
.flat_map(|i| (0..n).map(move |j| -q[[i, j]]))
.collect();
let x_vec = solve_linear(&kron, &rhs, n2)?;
let mut x = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
x[[i, j]] = x_vec[i * n + j];
}
}
Ok(x)
}
pub fn riccati_continuous<F: CtrlFloat>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
q: &ArrayView2<F>,
r: &ArrayView2<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
let m = b.ncols();
check_square(a, "riccati_continuous: A")?;
if b.nrows() != n {
return Err(LinalgError::ShapeError(
"riccati_continuous: B must have n rows".into(),
));
}
if q.nrows() != n || q.ncols() != n {
return Err(LinalgError::ShapeError(
"riccati_continuous: Q must be n×n".into(),
));
}
if r.nrows() != m || r.ncols() != m {
return Err(LinalgError::ShapeError(
"riccati_continuous: R must be m×m".into(),
));
}
let r_inv = mat_inv_gauss(&r.to_owned(), m)?;
let r_inv_bt = {
let mut s = Array2::<F>::zeros((m, n));
for i in 0..m {
for j in 0..n {
for k in 0..m {
s[[i, j]] = s[[i, j]] + r_inv[[i, k]] * b[[j, k]];
}
}
}
s
};
let mut s_mat = Array2::<F>::zeros((n, n));
for i in 0..n {
for k in 0..m {
let b_ik = b[[i, k]];
if b_ik == F::zero() {
continue;
}
for j in 0..n {
s_mat[[i, j]] = s_mat[[i, j]] + b_ik * r_inv_bt[[k, j]];
}
}
}
let tol = F::epsilon() * F::from(1e6).unwrap_or(F::one());
let max_iter = 100usize;
let mut x = Array2::<F>::zeros((n, n));
let a_t = transpose_sq(&a.to_owned(), n);
for _iter in 0..max_iter {
let sx = matmul_sq(&s_mat, &x, n);
let mut a_cl = a.to_owned();
for i in 0..n {
for j in 0..n {
a_cl[[i, j]] = a_cl[[i, j]] - sx[[i, j]];
}
}
let xs = matmul_sq(&x, &s_mat, n);
let xsx = matmul_sq(&xs, &x, n);
let mut q_k = q.to_owned();
mat_add_inplace(&mut q_k, &xsx, n);
let a_cl_t = transpose_sq(&a_cl, n);
let x_new = lyapunov_continuous(&a_cl_t.view(), &q_k.view())?;
let mut diff = F::zero();
for i in 0..n {
for j in 0..n {
let d = (x_new[[i, j]] - x[[i, j]]).abs();
if d > diff {
diff = d;
}
}
}
x = x_new;
if diff < tol {
break;
}
}
let x_t = transpose_sq(&x, n);
let two = F::from(2.0).unwrap_or(F::one());
for i in 0..n {
for j in 0..n {
x[[i, j]] = (x[[i, j]] + x_t[[i, j]]) / two;
}
}
let _ = a_t; Ok(x)
}
pub fn riccati_discrete<F: CtrlFloat>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
q: &ArrayView2<F>,
r: &ArrayView2<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
let m = b.ncols();
check_square(a, "riccati_discrete: A")?;
if b.nrows() != n {
return Err(LinalgError::ShapeError(
"riccati_discrete: B must have n rows".into(),
));
}
if q.nrows() != n || q.ncols() != n {
return Err(LinalgError::ShapeError(
"riccati_discrete: Q must be n×n".into(),
));
}
if r.nrows() != m || r.ncols() != m {
return Err(LinalgError::ShapeError(
"riccati_discrete: R must be m×m".into(),
));
}
let tol = F::epsilon() * F::from(1e6).unwrap_or(F::one());
let max_iter = 100usize;
let mut x = q.to_owned();
for _iter in 0..max_iter {
let xb = {
let mut tmp = Array2::<F>::zeros((n, m));
for i in 0..n {
for j in 0..m {
for k in 0..n {
tmp[[i, j]] = tmp[[i, j]] + x[[i, k]] * b[[k, j]];
}
}
}
tmp
};
let mut rbxb = r.to_owned();
for i in 0..m {
for j in 0..m {
let mut s = F::zero();
for k in 0..n {
s = s + b[[k, i]] * xb[[k, j]];
}
rbxb[[i, j]] = rbxb[[i, j]] + s;
}
}
let rbxb_inv = mat_inv_gauss(&rbxb, m)?;
let btx = {
let mut tmp = Array2::<F>::zeros((m, n));
for i in 0..m {
for j in 0..n {
for k in 0..n {
tmp[[i, j]] = tmp[[i, j]] + b[[k, i]] * x[[k, j]];
}
}
}
tmp
};
let btxa = {
let mut tmp = Array2::<F>::zeros((m, n));
for i in 0..m {
for j in 0..n {
for k in 0..n {
tmp[[i, j]] = tmp[[i, j]] + btx[[i, k]] * a[[k, j]];
}
}
}
tmp
};
let k_gain = {
let mut tmp = Array2::<F>::zeros((m, n));
for i in 0..m {
for j in 0..n {
for k in 0..m {
tmp[[i, j]] = tmp[[i, j]] + rbxb_inv[[i, k]] * btxa[[k, j]];
}
}
}
tmp
};
let bk = {
let mut tmp = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..m {
tmp[[i, j]] = tmp[[i, j]] + b[[i, k]] * k_gain[[k, j]];
}
}
}
tmp
};
let mut a_cl = a.to_owned();
for i in 0..n {
for j in 0..n {
a_cl[[i, j]] = a_cl[[i, j]] - bk[[i, j]];
}
}
let xa_cl = matmul_sq(&x, &a_cl, n);
let atxa_cl = {
let mut tmp = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
tmp[[i, j]] = tmp[[i, j]] + a[[k, i]] * xa_cl[[k, j]];
}
}
}
tmp
};
let mut q_k = q.to_owned();
mat_add_inplace(&mut q_k, &atxa_cl, n);
let x_new = lyapunov_discrete(&a_cl.view(), &q_k.view())?;
let mut diff = F::zero();
for i in 0..n {
for j in 0..n {
let d = (x_new[[i, j]] - x[[i, j]]).abs();
if d > diff {
diff = d;
}
}
}
x = x_new;
if diff < tol {
break;
}
}
let two = F::from(2.0).unwrap_or(F::one());
let x_t = transpose_sq(&x, n);
for i in 0..n {
for j in 0..n {
x[[i, j]] = (x[[i, j]] + x_t[[i, j]]) / two;
}
}
Ok(x)
}
pub fn controllability_matrix<F: CtrlFloat>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
) -> Array2<F> {
let n = a.nrows();
let m = b.ncols();
let total_cols = n * m;
let mut result = Array2::<F>::zeros((n, total_cols));
let mut ab = b.to_owned();
for k in 0..n {
for i in 0..n {
for j in 0..m {
result[[i, k * m + j]] = ab[[i, j]];
}
}
let mut new_ab = Array2::<F>::zeros((n, m));
for i in 0..n {
for l in 0..n {
let a_il = a[[i, l]];
if a_il == F::zero() {
continue;
}
for j in 0..m {
new_ab[[i, j]] = new_ab[[i, j]] + a_il * ab[[l, j]];
}
}
}
ab = new_ab;
}
result
}
pub fn observability_matrix<F: CtrlFloat>(
a: &ArrayView2<F>,
c: &ArrayView2<F>,
) -> Array2<F> {
let n = a.nrows();
let p = c.nrows();
let total_rows = n * p;
let mut result = Array2::<F>::zeros((total_rows, n));
let mut ca = c.to_owned();
for k in 0..n {
for i in 0..p {
for j in 0..n {
result[[k * p + i, j]] = ca[[i, j]];
}
}
let mut new_ca = Array2::<F>::zeros((p, n));
for i in 0..p {
for l in 0..n {
let ca_il = ca[[i, l]];
if ca_il == F::zero() {
continue;
}
for j in 0..n {
new_ca[[i, j]] = new_ca[[i, j]] + ca_il * a[[l, j]];
}
}
}
ca = new_ca;
}
result
}
pub fn controllability_gramian<F: CtrlFloat>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
let m = b.ncols();
check_square(a, "controllability_gramian: A")?;
if b.nrows() != n {
return Err(LinalgError::ShapeError(
"controllability_gramian: B must have n rows".into(),
));
}
let mut bbt = Array2::<F>::zeros((n, n));
for i in 0..n {
for k in 0..m {
let b_ik = b[[i, k]];
for j in 0..n {
bbt[[i, j]] = bbt[[i, j]] + b_ik * b[[j, k]];
}
}
}
lyapunov_continuous(a, &bbt.view())
}
pub fn observability_gramian<F: CtrlFloat>(
a: &ArrayView2<F>,
c: &ArrayView2<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
let p = c.nrows();
check_square(a, "observability_gramian: A")?;
if c.ncols() != n {
return Err(LinalgError::ShapeError(
"observability_gramian: C must have n columns".into(),
));
}
let mut ctc = Array2::<F>::zeros((n, n));
for k in 0..p {
for i in 0..n {
let c_ki = c[[k, i]];
for j in 0..n {
ctc[[i, j]] = ctc[[i, j]] + c_ki * c[[k, j]];
}
}
}
let at = transpose_sq(&a.to_owned(), n);
lyapunov_continuous(&at.view(), &ctc.view())
}
#[derive(Debug, Clone)]
pub struct BalancedTruncationResult<F: CtrlFloat> {
pub a_r: Array2<F>,
pub b_r: Array2<F>,
pub c_r: Array2<F>,
pub hankel_singular_values: Array1<F>,
pub transform: Array2<F>,
}
pub fn balanced_truncation<F: CtrlFloat>(
a: &ArrayView2<F>,
b: &ArrayView2<F>,
c: &ArrayView2<F>,
r: usize,
) -> LinalgResult<BalancedTruncationResult<F>> {
let n = a.nrows();
let m = b.ncols();
let p = c.nrows();
check_square(a, "balanced_truncation: A")?;
if b.nrows() != n {
return Err(LinalgError::ShapeError(
"balanced_truncation: B must have n rows".into(),
));
}
if c.ncols() != n {
return Err(LinalgError::ShapeError(
"balanced_truncation: C must have n columns".into(),
));
}
if r == 0 || r > n {
return Err(LinalgError::ValueError(format!(
"balanced_truncation: r={r} must be in 1..={n}"
)));
}
let wc = controllability_gramian(a, b)?;
let wo = observability_gramian(a, c)?;
let wc_wo = matmul_sq(&wc, &wo, n);
let wo_wc = matmul_sq(&wo, &wc, n);
let mut wc_wo_sym = Array2::<F>::zeros((n, n));
let two = F::from(2.0).unwrap_or(F::one());
for i in 0..n {
for j in 0..n {
wc_wo_sym[[i, j]] = (wc_wo[[i, j]] + wo_wc[[i, j]]) / two;
}
}
let tol = F::epsilon() * F::from(1e6).unwrap_or(F::one());
let (eig_vals, _eig_vecs) = power_iter_eig(&wc_wo_sym, n, tol, 500)?;
let mut hsv_indexed: Vec<(usize, F)> = eig_vals
.iter()
.enumerate()
.map(|(i, &e)| (i, e.abs().sqrt()))
.collect();
hsv_indexed.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
});
let hsv: Array1<F> = Array1::from_vec(hsv_indexed.iter().map(|&(_, v)| v).collect());
let (_, eig_vecs_sorted) = power_iter_eig(&wc_wo_sym, n, tol, 500)?;
let mut transform = Array2::<F>::zeros((n, r));
for col in 0..r {
let orig_idx = hsv_indexed[col].0;
for row in 0..n {
transform[[row, col]] = eig_vecs_sorted[[row, orig_idx]];
}
}
let t_t = transform.t().to_owned();
let at_t = {
let mut tmp = Array2::<F>::zeros((r, n));
for i in 0..r {
for j in 0..n {
for k in 0..n {
tmp[[i, j]] = tmp[[i, j]] + t_t[[i, k]] * a[[k, j]];
}
}
}
tmp
};
let mut a_r = Array2::<F>::zeros((r, r));
for i in 0..r {
for j in 0..r {
for k in 0..n {
a_r[[i, j]] = a_r[[i, j]] + at_t[[i, k]] * transform[[k, j]];
}
}
}
let mut b_r = Array2::<F>::zeros((r, m));
for i in 0..r {
for j in 0..m {
for k in 0..n {
b_r[[i, j]] = b_r[[i, j]] + t_t[[i, k]] * b[[k, j]];
}
}
}
let mut c_r = Array2::<F>::zeros((p, r));
for i in 0..p {
for j in 0..r {
for k in 0..n {
c_r[[i, j]] = c_r[[i, j]] + c[[i, k]] * transform[[k, j]];
}
}
}
Ok(BalancedTruncationResult {
a_r,
b_r,
c_r,
hankel_singular_values: hsv,
transform,
})
}
fn power_iter_eig<F: CtrlFloat>(
a: &Array2<F>,
n: usize,
tol: F,
max_iter: usize,
) -> LinalgResult<(Vec<F>, Array2<F>)> {
let mut eigenvalues = Vec::with_capacity(n);
let mut evecs: Vec<Array1<F>> = Vec::with_capacity(n);
let mut a_work = a.clone();
for k in 0..n {
let mut v = Array1::<F>::zeros(n);
v[k] = F::one();
for ev in &evecs {
let dot: F = v.iter().zip(ev.iter()).map(|(&vi, &ei)| vi * ei).sum();
for i in 0..n {
v[i] = v[i] - dot * ev[i];
}
}
let norm: F = v.iter().map(|&x| x * x).sum::<F>().sqrt();
if norm < tol {
v = Array1::<F>::zeros(n);
v[k % n] = F::one();
} else {
for x in v.iter_mut() {
*x = *x / norm;
}
}
let mut eigenval = F::zero();
for _ in 0..max_iter {
let mut av = Array1::<F>::zeros(n);
for i in 0..n {
for j in 0..n {
av[i] = av[i] + a_work[[i, j]] * v[j];
}
}
let new_eigenval: F = v.iter().zip(av.iter()).map(|(&vi, &avi)| vi * avi).sum();
let new_norm: F = av.iter().map(|&x| x * x).sum::<F>().sqrt();
if new_norm < tol {
eigenval = new_eigenval;
break;
}
let new_v: Array1<F> = av.mapv(|x| x / new_norm);
if (new_eigenval - eigenval).abs() < tol {
eigenval = new_eigenval;
v = new_v;
break;
}
eigenval = new_eigenval;
v = new_v;
}
eigenvalues.push(eigenval);
evecs.push(v.clone());
for i in 0..n {
for j in 0..n {
a_work[[i, j]] = a_work[[i, j]] - eigenval * v[i] * v[j];
}
}
}
let mut q = Array2::<F>::zeros((n, n));
for (col, ev) in evecs.iter().enumerate() {
for row in 0..n {
q[[row, col]] = ev[row];
}
}
Ok((eigenvalues, q))
}
fn check_square<F: CtrlFloat>(a: &ArrayView2<F>, ctx: &str) -> LinalgResult<usize> {
let n = a.nrows();
if n == 0 {
return Err(LinalgError::DimensionError(format!(
"{ctx}: matrix is empty"
)));
}
if a.ncols() != n {
return Err(LinalgError::ShapeError(format!("{ctx}: matrix must be square")));
}
Ok(n)
}
pub type BtResult<F> = BalancedTruncationResult<F>;
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_continuous_lyapunov_diagonal() {
let a = array![[-1.0_f64, 0.0], [0.0, -2.0]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = lyapunov_continuous(&a.view(), &q.view()).expect("continuous_lyapunov failed");
let expected = array![[0.5_f64, 0.0], [0.0, 0.25]];
for i in 0..2 {
for j in 0..2 {
let diff = (x[[i, j]] - expected[[i, j]]).abs();
assert!(
diff < 1e-10,
"Lyapunov solution mismatch at ({i},{j}): got {}, expected {}",
x[[i, j]],
expected[[i, j]]
);
}
}
let res = a.dot(&x) + x.dot(&a.t()) + &q;
for &v in res.iter() {
assert!(v.abs() < 1e-10, "Residual = {v}");
}
}
#[test]
fn test_discrete_lyapunov_residual() {
let a = array![[0.5_f64, 0.1], [0.0, 0.3]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = lyapunov_discrete(&a.view(), &q.view()).expect("discrete_lyapunov failed");
let res = a.dot(&x).dot(&a.t()) - &x + &q;
for &v in res.iter() {
assert!(v.abs() < 1e-8, "Discrete Lyapunov residual = {v}");
}
}
#[test]
fn test_controllability_matrix_rank() {
let a = array![[0.0_f64, 1.0], [-2.0, -3.0]];
let b = array![[0.0_f64], [1.0]];
let ctrl = controllability_matrix(&a.view(), &b.view());
assert_eq!(ctrl.nrows(), 2);
assert_eq!(ctrl.ncols(), 2);
let det = ctrl[[0, 0]] * ctrl[[1, 1]] - ctrl[[0, 1]] * ctrl[[1, 0]];
assert!(
det.abs() > 1e-10,
"Controllability matrix should be full rank, det = {det}"
);
}
#[test]
fn test_observability_matrix() {
let a = array![[0.0_f64, 1.0], [-2.0, -3.0]];
let c = array![[1.0_f64, 0.0]];
let obs = observability_matrix(&a.view(), &c.view());
assert_eq!(obs.nrows(), 2);
assert_eq!(obs.ncols(), 2);
let det = obs[[0, 0]] * obs[[1, 1]] - obs[[0, 1]] * obs[[1, 0]];
assert!(
det.abs() > 1e-10,
"Observability matrix should be full rank, det = {det}"
);
}
#[test]
fn test_controllability_gramian() {
let a = array![[-1.0_f64, 0.0], [0.0, -2.0]];
let b = array![[1.0_f64], [1.0]];
let wc = controllability_gramian(&a.view(), &b.view()).expect("gramian failed");
let bbt = b.dot(&b.t());
let res = a.dot(&wc) + wc.dot(&a.t()) + &bbt;
for &v in res.iter() {
assert!(v.abs() < 1e-8, "Gramian Lyapunov residual = {v}");
}
}
#[test]
fn test_balanced_truncation_output_shapes() {
let a = array![[-1.0_f64, 0.0], [0.0, -10.0]];
let b = array![[1.0_f64], [1.0]];
let c = array![[1.0_f64, 1.0]];
let result =
balanced_truncation(&a.view(), &b.view(), &c.view(), 1).expect("balanced_truncation");
assert_eq!(result.a_r.nrows(), 1);
assert_eq!(result.a_r.ncols(), 1);
assert_eq!(result.b_r.nrows(), 1);
assert_eq!(result.b_r.ncols(), 1);
assert_eq!(result.c_r.nrows(), 1);
assert_eq!(result.c_r.ncols(), 1);
assert_eq!(result.hankel_singular_values.len(), 2);
for &v in result.hankel_singular_values.iter() {
assert!(v >= 0.0, "HSV must be non-negative, got {v}");
}
}
}