use crate::error::{LinalgError, LinalgResult};
use crate::matrix_equations::{solve_continuous_lyapunov, solve_discrete_lyapunov};
use scirs2_core::ndarray::{s, Array2, ArrayView2};
pub fn lyapunov(a: &ArrayView2<f64>, q: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
let n = a.shape()[0];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError("Matrix A must be square".into()));
}
if q.shape() != [n, n] {
return Err(LinalgError::ShapeError(
"Matrix Q must have the same shape as A".into(),
));
}
let neg_q = q.mapv(|v| -v);
solve_continuous_lyapunov(&a, &neg_q.view())
}
pub fn discrete_lyapunov(a: &ArrayView2<f64>, q: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
solve_discrete_lyapunov(a, q)
}
pub fn care(
a: &ArrayView2<f64>,
b: &ArrayView2<f64>,
q: &ArrayView2<f64>,
r: &ArrayView2<f64>,
) -> LinalgResult<Array2<f64>> {
let n = a.shape()[0];
let m = b.shape()[1];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError("Matrix A must be square".into()));
}
if b.shape()[0] != n {
return Err(LinalgError::ShapeError(format!(
"Matrix B must have {n} rows"
)));
}
if q.shape() != [n, n] {
return Err(LinalgError::ShapeError("Matrix Q must be n×n".into()));
}
if r.shape() != [m, m] {
return Err(LinalgError::ShapeError("Matrix R must be m×m".into()));
}
let r_inv = crate::inv(r, None)?;
let b_r_inv_bt = b.dot(&r_inv).dot(&b.t());
let mut h = Array2::<f64>::zeros((2 * n, 2 * n));
h.slice_mut(s![..n, ..n]).assign(a);
h.slice_mut(s![..n, n..]).assign(&b_r_inv_bt.mapv(|v| -v));
h.slice_mut(s![n.., ..n]).assign(&q.mapv(|v| -v));
h.slice_mut(s![n.., n..]).assign(&a.t().mapv(|v| -v));
let (eigvals, eigvecs) = crate::eigen::eig(&h.view(), None)?;
let mut stable_idx: Vec<usize> = (0..2 * n).filter(|&i| eigvals[i].re < 0.0).collect();
if stable_idx.len() < n {
return Err(LinalgError::ConvergenceError(
"CARE: could not find n stable eigenvalues in Hamiltonian spectrum".into(),
));
}
stable_idx.truncate(n);
let mut u1 = Array2::<f64>::zeros((n, n));
let mut u2 = Array2::<f64>::zeros((n, n));
for (col, &idx) in stable_idx.iter().enumerate() {
for row in 0..n {
u1[[row, col]] = eigvecs[[row, idx]].re;
u2[[row, col]] = eigvecs[[n + row, idx]].re;
}
}
let u1_inv = crate::inv(&u1.view(), None)?;
let x = u2.dot(&u1_inv);
Ok((&x + &x.t()) * 0.5)
}
pub fn dare(
a: &ArrayView2<f64>,
b: &ArrayView2<f64>,
q: &ArrayView2<f64>,
r: &ArrayView2<f64>,
) -> LinalgResult<Array2<f64>> {
crate::matrix_equations::solve_discrete_riccati(a, b, q, r)
}
pub fn lqr_gain(
a: &ArrayView2<f64>,
b: &ArrayView2<f64>,
q: &ArrayView2<f64>,
r: &ArrayView2<f64>,
) -> LinalgResult<Array2<f64>> {
let x = care(a, b, q, r)?;
let r_inv = crate::inv(r, None)?;
Ok(r_inv.dot(&b.t()).dot(&x))
}
pub fn controllability_matrix(a: &ArrayView2<f64>, b: &ArrayView2<f64>) -> Array2<f64> {
let n = a.shape()[0];
let m = b.shape()[1];
let total_cols = n * m;
let mut result = Array2::<f64>::zeros((n, total_cols));
let mut a_pow_b = b.to_owned();
for k in 0..n {
let col_start = k * m;
let col_end = col_start + m;
result
.slice_mut(s![.., col_start..col_end])
.assign(&a_pow_b);
if k + 1 < n {
a_pow_b = a.dot(&a_pow_b);
}
}
result
}
pub fn is_controllable(a: &ArrayView2<f64>, b: &ArrayView2<f64>) -> bool {
let n = a.shape()[0];
let ctrl = controllability_matrix(a, b);
numerical_rank_matrix(&ctrl.view(), None) >= n
}
pub fn observability_matrix(a: &ArrayView2<f64>, c: &ArrayView2<f64>) -> Array2<f64> {
let n = a.shape()[0];
let p = c.shape()[0];
let total_rows = n * p;
let mut result = Array2::<f64>::zeros((total_rows, n));
let mut c_a_pow = c.to_owned();
for k in 0..n {
let row_start = k * p;
let row_end = row_start + p;
result
.slice_mut(s![row_start..row_end, ..])
.assign(&c_a_pow);
if k + 1 < n {
c_a_pow = c_a_pow.dot(a);
}
}
result
}
pub fn is_observable(a: &ArrayView2<f64>, c: &ArrayView2<f64>) -> bool {
let n = a.shape()[0];
let obs = observability_matrix(a, c);
numerical_rank_matrix(&obs.view(), None) >= n
}
fn numerical_rank_matrix(a: &ArrayView2<f64>, tol: Option<f64>) -> usize {
match crate::decomposition::svd(a, false, None) {
Ok((_, s, _)) => {
let max_sv = s.iter().cloned().fold(0.0_f64, f64::max);
let threshold = tol.unwrap_or_else(|| {
let (m, n) = (a.nrows(), a.ncols());
let eps = f64::EPSILON;
eps * (m.max(n) as f64) * max_sv
});
s.iter().filter(|&&sv| sv > threshold).count()
}
Err(_) => 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
const TOL: f64 = 1e-7;
#[test]
fn test_lyapunov_diagonal() {
let a = array![[-1.0_f64, 0.0], [0.0, -2.0]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = lyapunov(&a.view(), &q.view()).expect("lyapunov failed");
let residual = a.dot(&x) + x.dot(&a.t()) + &q;
for &v in residual.iter() {
assert!(v.abs() < TOL, "lyapunov residual {v}");
}
}
#[test]
fn test_lyapunov_full() {
let a = array![[-2.0_f64, 1.0], [-1.0, -3.0]];
let q = array![[2.0_f64, 0.5], [0.5, 3.0]];
let x = lyapunov(&a.view(), &q.view()).expect("lyapunov failed");
let residual = a.dot(&x) + x.dot(&a.t()) + &q;
for &v in residual.iter() {
assert!(v.abs() < TOL, "lyapunov full residual {v}");
}
}
#[test]
fn test_discrete_lyapunov_basic() {
let a = array![[0.5_f64, 0.1], [0.0, 0.6]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = discrete_lyapunov(&a.view(), &q.view()).expect("discrete_lyapunov failed");
let residual = a.dot(&x).dot(&a.t()) - &x + &q;
for &v in residual.iter() {
assert!(v.abs() < TOL, "discrete_lyapunov residual {v}");
}
}
#[test]
fn test_care_double_integrator() {
let a = array![[0.0_f64, 1.0], [0.0, 0.0]];
let b = array![[0.0_f64], [1.0]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let r = array![[1.0_f64]];
let x = care(&a.view(), &b.view(), &q.view(), &r.view()).expect("CARE failed");
assert!(x[[0, 0]] > 0.0);
assert!(x[[1, 1]] > 0.0);
let r_inv = crate::inv(&r.view(), None).expect("inv failed");
let residual = a.t().dot(&x) + x.dot(&a) - x.dot(&b).dot(&r_inv).dot(&b.t()).dot(&x) + &q;
for &v in residual.iter() {
assert!(v.abs() < 1e-6, "CARE residual {v}");
}
}
#[test]
fn test_dare_basic() {
let a = array![[1.0_f64, 0.1], [0.0, 1.0]];
let b = array![[0.0_f64], [0.1]];
let q = array![[1.0_f64, 0.0], [0.0, 0.0]];
let r = array![[1.0_f64]];
let x = dare(&a.view(), &b.view(), &q.view(), &r.view()).expect("DARE failed");
assert!(x[[0, 0]] > 0.0);
}
#[test]
fn test_lqr_gain_shape() {
let a = array![[0.0_f64, 1.0], [0.0, 0.0]];
let b = array![[0.0_f64], [1.0]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let r = array![[1.0_f64]];
let k = lqr_gain(&a.view(), &b.view(), &q.view(), &r.view()).expect("LQR failed");
assert_eq!(k.shape(), &[1, 2]);
}
#[test]
fn test_controllability_matrix_shape() {
let a = Array2::<f64>::eye(3);
let b = Array2::<f64>::zeros((3, 2));
let c = controllability_matrix(&a.view(), &b.view());
assert_eq!(c.shape(), &[3, 6]);
}
#[test]
fn test_is_controllable_true() {
let a = array![[0.0_f64, 1.0], [0.0, 0.0]];
let b = array![[0.0_f64], [1.0]];
assert!(is_controllable(&a.view(), &b.view()));
}
#[test]
fn test_is_controllable_false() {
let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
let b = Array2::<f64>::zeros((2, 1));
assert!(!is_controllable(&a.view(), &b.view()));
}
#[test]
fn test_observability_matrix_shape() {
let a = Array2::<f64>::eye(3);
let c = Array2::<f64>::zeros((1, 3));
let o = observability_matrix(&a.view(), &c.view());
assert_eq!(o.shape(), &[3, 3]);
}
#[test]
fn test_is_observable_true() {
let a = array![[0.0_f64, 1.0], [0.0, 0.0]];
let c = array![[1.0_f64, 0.0]];
assert!(is_observable(&a.view(), &c.view()));
}
#[test]
fn test_is_observable_false() {
let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
let c = Array2::<f64>::zeros((1, 2));
assert!(!is_observable(&a.view(), &c.view()));
}
}