use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::{Debug, Display};
use std::iter::Sum;
pub trait LyapFloat:
Float + NumAssign + Sum + ScalarOperand + Debug + Display + Send + Sync + 'static
{
}
impl<F> LyapFloat for F where
F: Float + NumAssign + Sum + ScalarOperand + Debug + Display + Send + Sync + 'static
{
}
fn matmul<F: LyapFloat>(a: &Array2<F>, b: &Array2<F>) -> Array2<F> {
a.dot(b)
}
pub fn lyapunov_continuous<F: LyapFloat>(
a: &ArrayView2<F>,
q: &ArrayView2<F>,
) -> LinalgResult<Array2<F>> {
let n = check_square(a, "lyapunov_continuous: A")?;
let m = check_square(q, "lyapunov_continuous: Q")?;
if n != m {
return Err(LinalgError::DimensionError(format!(
"lyapunov_continuous: A is {n}×{n} but Q is {m}×{m}"
)));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
if n == 1 {
let a00 = a[[0, 0]];
let two = F::from(2.0).unwrap_or(F::one());
let denom = two * a00;
if denom.abs() < F::from(1e-14).unwrap_or(F::epsilon()) {
return Err(LinalgError::SingularMatrixError(
"lyapunov_continuous: A is not stable (a[0,0] must be < 0)".to_string(),
));
}
let mut res = Array2::<F>::zeros((1, 1));
res[[0, 0]] = -q[[0, 0]] / denom;
return Ok(res);
}
crate::matrix_functions::sylvester::solve_continuous_lyapunov(a, q)
}
pub fn lyapunov_discrete<F: LyapFloat>(
a: &ArrayView2<F>,
q: &ArrayView2<F>,
) -> LinalgResult<Array2<F>> {
let n = check_square(a, "lyapunov_discrete: A")?;
let m = check_square(q, "lyapunov_discrete: Q")?;
if n != m {
return Err(LinalgError::DimensionError(format!(
"lyapunov_discrete: A is {n}×{n} but Q is {m}×{m}"
)));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
if n == 1 {
let a00 = a[[0, 0]];
let denom = a00 * a00 - F::one();
if denom.abs() < F::from(1e-14).unwrap_or(F::epsilon()) {
return Err(LinalgError::SingularMatrixError(
"lyapunov_discrete: A is not Schur-stable (|a[0,0]| must be < 1)".to_string(),
));
}
let mut res = Array2::<F>::zeros((1, 1));
res[[0, 0]] = -q[[0, 0]] / denom;
return Ok(res);
}
crate::matrix_functions::sylvester::solve_discrete_lyapunov(a, q)
}
pub fn lyapunov_continuous_refine<F: LyapFloat>(
a: &ArrayView2<F>,
q: &ArrayView2<F>,
x0: &Array2<F>,
tol: F,
max_iter: usize,
) -> LinalgResult<Array2<F>> {
let a_owned = a.to_owned();
let q_owned = q.to_owned();
let mut x = x0.clone();
for _ in 0..max_iter {
let ax = matmul(&a_owned, &x);
let xat = matmul(&x, &a_owned.t().to_owned());
let residual = ax + xat + &q_owned;
let res_norm: F = residual.iter().map(|&v| v * v).sum::<F>().sqrt();
if res_norm <= tol {
return Ok(x);
}
let correction = lyapunov_continuous(a, &residual.view())?;
x = x - correction;
}
Ok(x)
}
fn check_square<F: LyapFloat>(a: &ArrayView2<F>, ctx: &str) -> LinalgResult<usize> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(format!("{ctx}: not square")));
}
Ok(n)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_continuous_lyapunov_diagonal() {
let a = array![[-1.0_f64, 0.0], [0.0, -2.0]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = lyapunov_continuous(&a.view(), &q.view()).expect("lyapunov_continuous failed");
let expected = array![[0.5_f64, 0.0], [0.0, 0.25]];
for i in 0..2 {
for j in 0..2 {
let diff = (x[[i, j]] - expected[[i, j]]).abs();
assert!(
diff < 1e-8,
"Mismatch at ({i},{j}): got {}, expected {}",
x[[i, j]],
expected[[i, j]]
);
}
}
}
#[test]
fn test_continuous_lyapunov_residual_2x2() {
let a = array![[-2.0_f64, 1.0], [0.0, -3.0]];
let q = array![[2.0_f64, 1.0], [1.0, 2.0]];
let x = lyapunov_continuous(&a.view(), &q.view()).expect("failed");
let res = a.dot(&x) + x.dot(&a.t()) + &q;
for &v in res.iter() {
assert!(v.abs() < 1e-7, "Residual {v} too large");
}
}
#[test]
fn test_discrete_lyapunov_residual_2x2() {
let a = array![[0.5_f64, 0.1], [0.0, 0.3]];
let q = array![[1.0_f64, 0.0], [0.0, 1.0]];
let x = lyapunov_discrete(&a.view(), &q.view()).expect("failed");
let res = a.dot(&x).dot(&a.t()) - &x + &q;
for &v in res.iter() {
assert!(v.abs() < 1e-6, "Discrete Lyapunov residual {v} too large");
}
}
#[test]
fn test_continuous_lyapunov_3x3_residual() {
let a = array![
[-3.0_f64, 1.0, 0.0],
[0.0, -2.0, 0.5],
[0.0, 0.0, -1.0]
];
let q = array![
[2.0_f64, 0.5, 0.0],
[0.5, 1.0, 0.0],
[0.0, 0.0, 3.0]
];
let x = lyapunov_continuous(&a.view(), &q.view()).expect("3x3 failed");
let res = a.dot(&x) + x.dot(&a.t()) + &q;
for &v in res.iter() {
assert!(v.abs() < 1e-6, "3x3 continuous residual {v}");
}
}
#[test]
fn test_discrete_lyapunov_3x3_residual() {
let a = array![
[0.4_f64, 0.1, 0.0],
[0.0, 0.5, 0.2],
[0.0, 0.0, 0.3]
];
let q = array![
[1.0_f64, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]
];
let x = lyapunov_discrete(&a.view(), &q.view()).expect("3x3 discrete failed");
let res = a.dot(&x).dot(&a.t()) - &x + &q;
for &v in res.iter() {
assert!(v.abs() < 1e-5, "3x3 discrete residual {v}");
}
}
}