use crate::eigen::eig;
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{s, Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::{Debug, Display};
pub trait MatEqFloat:
Float
+ NumAssign
+ Debug
+ Display
+ scirs2_core::ndarray::ScalarOperand
+ std::iter::Sum
+ 'static
+ Send
+ Sync
{
}
impl<T> MatEqFloat for T where
T: Float
+ NumAssign
+ Debug
+ Display
+ scirs2_core::ndarray::ScalarOperand
+ std::iter::Sum
+ 'static
+ Send
+ Sync
{
}
pub fn solve_sylvester<A: MatEqFloat>(
a: &ArrayView2<A>,
b: &ArrayView2<A>,
c: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let m = a.shape()[0];
let n = b.shape()[0];
if a.shape()[1] != m {
return Err(LinalgError::ShapeError("Matrix A must be square".into()));
}
if b.shape()[1] != n {
return Err(LinalgError::ShapeError("Matrix B must be square".into()));
}
if c.shape() != [m, n] {
return Err(LinalgError::ShapeError(format!(
"Matrix C must have shape [{m}, {n}]"
)));
}
let mn = m * n;
let mut coeff = Array2::<A>::zeros((mn, mn));
for col_block in 0..n {
for i in 0..m {
for j in 0..m {
coeff[[col_block * m + i, col_block * m + j]] = a[[i, j]];
}
}
}
for rb in 0..n {
for cb in 0..n {
for d in 0..m {
coeff[[rb * m + d, cb * m + d]] += b[[cb, rb]];
}
}
}
let mut c_vec = vec![A::zero(); mn];
for col in 0..n {
for row in 0..m {
c_vec[col * m + row] = c[[row, col]];
}
}
let c_arr = Array2::from_shape_vec((mn, 1), c_vec)
.map_err(|e| LinalgError::ShapeError(e.to_string()))?;
let x_vec = crate::solve::solve(&coeff.view(), &c_arr.column(0), None)?;
let mut result = Array2::<A>::zeros((m, n));
for col in 0..n {
for row in 0..m {
result[[row, col]] = x_vec[col * m + row];
}
}
Ok(result)
}
pub fn solve_continuous_lyapunov<A: MatEqFloat>(
a: &ArrayView2<A>,
q: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let n = a.shape()[0];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError("Matrix A must be square".into()));
}
if q.shape() != [n, n] {
return Err(LinalgError::ShapeError(
"Matrix Q must have the same shape as A".into(),
));
}
let at = a.t().to_owned();
solve_sylvester(a, &at.view(), q)
}
pub fn solve_discrete_lyapunov<A: MatEqFloat>(
a: &ArrayView2<A>,
q: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let n = a.shape()[0];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError("Matrix A must be square".into()));
}
if q.shape() != [n, n] {
return Err(LinalgError::ShapeError(
"Matrix Q must have the same shape as A".into(),
));
}
if n <= 10 {
solve_discrete_lyapunov_direct(a, q)
} else {
solve_discrete_lyapunov_bilinear(a, q)
}
}
fn solve_discrete_lyapunov_direct<A: MatEqFloat>(
a: &ArrayView2<A>,
q: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let n = a.shape()[0];
let nn = n * n;
let mut coeff = Array2::<A>::zeros((nn, nn));
for i in 0..n {
for j in 0..n {
for k in 0..n {
for l in 0..n {
coeff[[i * n + j, k * n + l]] = a[[i, k]] * a[[j, l]];
if i == k && j == l {
coeff[[i * n + j, k * n + l]] -= A::one();
}
}
}
}
}
let q_vec: Vec<A> = q.t().iter().map(|&x| -x).collect();
let q_arr = Array2::from_shape_vec((nn, 1), q_vec)
.map_err(|e| LinalgError::ShapeError(e.to_string()))?;
let x_vec = crate::solve::solve(&coeff.view(), &q_arr.column(0), None)?;
let x_data: Vec<A> = x_vec.iter().cloned().collect();
Ok(Array2::from_shape_vec((n, n), x_data)
.map_err(|e| LinalgError::ShapeError(e.to_string()))?
.t()
.to_owned())
}
fn solve_discrete_lyapunov_bilinear<A: MatEqFloat>(
a: &ArrayView2<A>,
q: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let n = a.shape()[0];
let eye = Array2::<A>::eye(n);
let a_plus_i = a + &eye;
let a_minus_i = a.to_owned() - &eye;
let a_plus_i_inv = crate::inv(&a_plus_i.view(), None)?;
let a_c = a_minus_i.dot(&a_plus_i_inv);
let two = A::one() + A::one();
let q_c = a_plus_i_inv.t().dot(q).dot(&a_plus_i_inv) * two;
solve_continuous_lyapunov(&a_c.view(), &q_c.view())
}
pub fn solve_stein<A: MatEqFloat>(
a: &ArrayView2<A>,
b: &ArrayView2<A>,
c: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let m = a.shape()[0];
let n = b.shape()[0];
if a.shape()[1] != m {
return Err(LinalgError::ShapeError("Matrix A must be square".into()));
}
if b.shape()[1] != n {
return Err(LinalgError::ShapeError("Matrix B must be square".into()));
}
if c.shape() != [m, n] {
return Err(LinalgError::ShapeError(format!(
"Matrix C must have shape [{m}, {n}]"
)));
}
let mn = m * n;
let mut coeff = Array2::<A>::zeros((mn, mn));
for i in 0..n {
for j in 0..n {
for k in 0..m {
for l in 0..m {
let row = i * m + k;
let col = j * m + l;
coeff[[row, col]] = -a[[k, l]] * b[[j, i]];
if row == col {
coeff[[row, col]] += A::one();
}
}
}
}
}
let mut c_vec = vec![A::zero(); mn];
for col in 0..n {
for row in 0..m {
c_vec[col * m + row] = c[[row, col]];
}
}
let c_arr = Array2::from_shape_vec((mn, 1), c_vec)
.map_err(|e| LinalgError::ShapeError(e.to_string()))?;
let x_vec = crate::solve::solve(&coeff.view(), &c_arr.column(0), None)?;
let mut result = Array2::<A>::zeros((m, n));
for col in 0..n {
for row in 0..m {
result[[row, col]] = x_vec[col * m + row];
}
}
Ok(result)
}
pub fn solve_continuous_riccati<A: MatEqFloat>(
a: &ArrayView2<A>,
b: &ArrayView2<A>,
q: &ArrayView2<A>,
r: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let n = a.shape()[0];
let m_dim = b.shape()[1];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError("Matrix A must be square".into()));
}
if b.shape()[0] != n {
return Err(LinalgError::ShapeError(format!(
"Matrix B must have {n} rows"
)));
}
if q.shape() != [n, n] {
return Err(LinalgError::ShapeError("Matrix Q must be n x n".into()));
}
if r.shape() != [m_dim, m_dim] {
return Err(LinalgError::ShapeError("Matrix R must be m x m".into()));
}
let mut h = Array2::<A>::zeros((2 * n, 2 * n));
let r_inv = crate::inv(r, None)?;
let br_inv_bt = b.dot(&r_inv).dot(&b.t());
h.slice_mut(s![..n, ..n]).assign(a);
h.slice_mut(s![..n, n..]).assign(&br_inv_bt.mapv(|x| -x));
h.slice_mut(s![n.., ..n]).assign(&q.mapv(|x| -x));
h.slice_mut(s![n.., n..]).assign(&a.t().mapv(|x| -x));
let (eigvals, eigvecs) = eig(&h.view(), None)?;
let mut stable_indices = Vec::new();
for (i, &lambda) in eigvals.iter().enumerate() {
if lambda.re < A::zero() {
stable_indices.push(i);
}
}
if stable_indices.len() < n {
return Err(LinalgError::ConvergenceError(
"Could not find n stable eigenvalues for CARE".into(),
));
}
stable_indices.truncate(n);
let mut u1 = Array2::<A>::zeros((n, n));
let mut u2 = Array2::<A>::zeros((n, n));
for (j, &i) in stable_indices.iter().enumerate() {
for k in 0..n {
u1[[k, j]] = eigvecs[[k, i]].re;
u2[[k, j]] = eigvecs[[n + k, i]].re;
}
}
let u1_inv = crate::inv(&u1.view(), None)?;
let x = u2.dot(&u1_inv);
let half = A::from(0.5)
.ok_or_else(|| LinalgError::ComputationError("Cannot convert 0.5 to target type".into()))?;
Ok((&x + &x.t()) * half)
}
pub fn solve_discrete_riccati<A: MatEqFloat>(
a: &ArrayView2<A>,
b: &ArrayView2<A>,
q: &ArrayView2<A>,
r: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let n = a.shape()[0];
let m_dim = b.shape()[1];
if a.shape()[1] != n {
return Err(LinalgError::ShapeError("Matrix A must be square".into()));
}
if b.shape()[0] != n {
return Err(LinalgError::ShapeError(format!(
"Matrix B must have {n} rows"
)));
}
if q.shape() != [n, n] {
return Err(LinalgError::ShapeError("Matrix Q must be n x n".into()));
}
if r.shape() != [m_dim, m_dim] {
return Err(LinalgError::ShapeError("Matrix R must be m x m".into()));
}
let tol = A::from(1e-10).ok_or_else(|| {
LinalgError::ComputationError("Cannot convert tolerance to target type".into())
})?;
let half = A::from(0.5)
.ok_or_else(|| LinalgError::ComputationError("Cannot convert 0.5 to target type".into()))?;
let max_iter = 200;
let r_inv = crate::inv(r, None)?;
let mut a_k = a.to_owned();
let mut g_k = b.dot(&r_inv).dot(&b.t());
let mut q_k = q.to_owned();
for _ in 0..max_iter {
let eye_plus_gq = Array2::<A>::eye(n) + g_k.dot(&q_k);
let inv_eye_gq = crate::inv(&eye_plus_gq.view(), None)?;
let a_next = a_k.dot(&inv_eye_gq).dot(&a_k);
let g_next = &g_k + a_k.dot(&inv_eye_gq).dot(&g_k).dot(&a_k.t());
let q_next = &q_k + a_k.t().dot(&q_k).dot(&inv_eye_gq).dot(&a_k);
let diff = &q_next - &q_k;
let err = diff
.iter()
.map(|&v| v.abs())
.fold(A::zero(), |acc, v| acc.max(v));
a_k = a_next;
g_k = g_next;
q_k = q_next;
if err < tol {
return Ok((&q_k + &q_k.t()) * half);
}
}
let mut x = q.to_owned();
for _ in 0..max_iter {
let x_old = x.clone();
let r_tilde = r + &b.t().dot(&x).dot(b);
let r_tilde_inv = crate::inv(&r_tilde.view(), None)?;
let term1 = a.t().dot(&x).dot(a);
let term2 = a
.t()
.dot(&x)
.dot(b)
.dot(&r_tilde_inv)
.dot(&b.t())
.dot(&x)
.dot(a);
x = &term1 - &term2 + q;
let diff = &x - &x_old;
let err = diff
.iter()
.map(|&v| v.abs())
.fold(A::zero(), |acc, v| acc.max(v));
if err < tol {
return Ok((&x + &x.t()) * half);
}
}
Err(LinalgError::ConvergenceError(
"Discrete Riccati equation solver did not converge".into(),
))
}
pub fn solve_generalized_sylvester<A: MatEqFloat>(
a: &ArrayView2<A>,
b: &ArrayView2<A>,
c: &ArrayView2<A>,
d: &ArrayView2<A>,
e: &ArrayView2<A>,
) -> LinalgResult<Array2<A>> {
let m = a.shape()[0];
let n = b.shape()[0];
if a.shape()[1] != m || c.shape() != a.shape() {
return Err(LinalgError::ShapeError(
"Matrices A and C must be square and have the same shape".into(),
));
}
if b.shape()[1] != n || d.shape() != b.shape() {
return Err(LinalgError::ShapeError(
"Matrices B and D must be square and have the same shape".into(),
));
}
if e.shape() != [m, n] {
return Err(LinalgError::ShapeError(format!(
"Matrix E must have shape [{m}, {n}]"
)));
}
if c.iter().all(|&x| x.abs() < A::epsilon()) && d.iter().all(|&x| x.abs() < A::epsilon()) {
return solve_sylvester(a, b, e);
}
let mn = m * n;
let mut coeff = Array2::<A>::zeros((mn, mn));
for i in 0..n {
for j in 0..n {
for k in 0..m {
for l in 0..m {
coeff[[i * m + k, j * m + l]] = a[[k, l]] * b[[j, i]];
}
}
}
}
for i in 0..n {
for j in 0..n {
for k in 0..m {
for l in 0..m {
coeff[[i * m + k, j * m + l]] += c[[k, l]] * d[[j, i]];
}
}
}
}
let mut e_vec = vec![A::zero(); mn];
for col in 0..n {
for row in 0..m {
e_vec[col * m + row] = e[[row, col]];
}
}
let e_arr = Array2::from_shape_vec((mn, 1), e_vec)
.map_err(|e| LinalgError::ShapeError(e.to_string()))?;
let x_vec = crate::solve::solve(&coeff.view(), &e_arr.column(0), None)?;
let mut result = Array2::<A>::zeros((m, n));
for col in 0..n {
for row in 0..m {
result[[row, col]] = x_vec[col * m + row];
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_sylvester_diagonal() {
let a = array![[1.0, 0.0], [0.0, 2.0]];
let b = array![[-3.0, 0.0], [0.0, -4.0]];
let c = array![[1.0, 2.0], [3.0, 4.0]];
let x = solve_sylvester(&a.view(), &b.view(), &c.view()).expect("solve failed");
let residual = a.dot(&x) + x.dot(&b);
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(residual[[i, j]], c[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_sylvester_upper_triangular() {
let a = array![[1.0, 2.0], [0.0, 3.0]];
let b = array![[-4.0, 0.0], [0.0, -5.0]];
let c = array![[1.0, 2.0], [3.0, 4.0]];
let x = solve_sylvester(&a.view(), &b.view(), &c.view()).expect("solve failed");
let residual = a.dot(&x) + x.dot(&b);
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(residual[[i, j]], c[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_sylvester_3x3() {
let a = array![[1.0, 0.5, 0.0], [0.0, 2.0, 0.5], [0.0, 0.0, 3.0]];
let b = array![[-4.0, 1.0], [0.0, -5.0]];
let c = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let x = solve_sylvester(&a.view(), &b.view(), &c.view()).expect("solve failed");
let residual = a.dot(&x) + x.dot(&b);
for i in 0..3 {
for j in 0..2 {
assert_abs_diff_eq!(residual[[i, j]], c[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_continuous_lyapunov() {
let a = array![[-1.0, 0.5], [0.0, -2.0]];
let q = array![[1.0, 0.0], [0.0, 1.0]];
let x = solve_continuous_lyapunov(&a.view(), &q.view()).expect("solve failed");
let residual = a.dot(&x) + x.dot(&a.t());
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(residual[[i, j]], q[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_continuous_lyapunov_3x3() {
let a = array![[-2.0, 1.0, 0.0], [0.0, -3.0, 1.0], [0.0, 0.0, -4.0]];
let q = array![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let x = solve_continuous_lyapunov(&a.view(), &q.view()).expect("solve failed");
let residual = a.dot(&x) + x.dot(&a.t());
for i in 0..3 {
for j in 0..3 {
assert_abs_diff_eq!(residual[[i, j]], q[[i, j]], epsilon = 1e-6);
}
}
}
#[test]
fn test_discrete_lyapunov() {
let a = array![[0.5, 0.1], [0.0, 0.6]];
let q = array![[1.0, 0.0], [0.0, 1.0]];
let x = solve_discrete_lyapunov(&a.view(), &q.view()).expect("solve failed");
let residual = a.dot(&x).dot(&a.t()) - &x + &q;
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(residual[[i, j]], 0.0, epsilon = 1e-8);
}
}
}
#[test]
fn test_stein_equation() {
let a = array![[0.3, 0.1], [0.0, 0.4]];
let b = array![[0.2, 0.0], [0.1, 0.3]];
let c = array![[1.0, 0.5], [0.5, 1.0]];
let x = solve_stein(&a.view(), &b.view(), &c.view()).expect("solve failed");
let residual = &x - &a.dot(&x).dot(&b) - &c;
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(residual[[i, j]], 0.0, epsilon = 1e-8);
}
}
}
#[test]
fn test_stein_identity() {
let a = array![[1.0, 0.0], [0.0, 1.0]];
let b = array![[0.0, 0.0], [0.0, 0.0]];
let c = array![[3.0, 1.0], [2.0, 5.0]];
let x = solve_stein(&a.view(), &b.view(), &c.view()).expect("solve failed");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(x[[i, j]], c[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_generalized_sylvester() {
let a = array![[1.0, 0.0], [0.0, 2.0]];
let b = array![[3.0, 0.0], [0.0, 4.0]];
let c = array![[0.5, 0.0], [0.0, 0.5]];
let d = array![[0.25, 0.0], [0.0, 0.25]];
let e = array![[1.0, 2.0], [3.0, 4.0]];
let x = solve_generalized_sylvester(&a.view(), &b.view(), &c.view(), &d.view(), &e.view())
.expect("solve failed");
let residual = a.dot(&x).dot(&b) + c.dot(&x).dot(&d) - &e;
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(residual[[i, j]], 0.0, epsilon = 1e-8);
}
}
}
#[test]
fn test_discrete_riccati() {
let a = array![[1.0, 0.1], [0.0, 1.0]];
let b = array![[0.0], [0.1]];
let q = array![[1.0, 0.0], [0.0, 0.0]];
let r = array![[1.0]];
let x = solve_discrete_riccati(&a.view(), &b.view(), &q.view(), &r.view())
.expect("solve failed");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(x[[i, j]], x[[j, i]], epsilon = 1e-8);
}
}
let r_tilde = &r + &b.t().dot(&x).dot(&b);
let r_tilde_inv = crate::inv(&r_tilde.view(), None).expect("inv failed");
let rhs = a.t().dot(&x).dot(&a)
- a.t()
.dot(&x)
.dot(&b)
.dot(&r_tilde_inv)
.dot(&b.t())
.dot(&x)
.dot(&a)
+ &q;
let residual = &x - &rhs;
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(residual[[i, j]], 0.0, epsilon = 1e-6);
}
}
}
#[test]
fn test_continuous_riccati() {
let a = array![[0.0, 1.0], [0.0, 0.0]];
let b = array![[0.0], [1.0]];
let q = array![[1.0, 0.0], [0.0, 1.0]];
let r = array![[1.0]];
let result = solve_continuous_riccati(&a.view(), &b.view(), &q.view(), &r.view());
if let Ok(x) = result {
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(x[[i, j]], x[[j, i]], epsilon = 1e-6);
}
}
}
}
#[test]
fn test_sylvester_dimension_check() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[1.0]];
let c = array![[1.0, 2.0], [3.0, 4.0]]; let result = solve_sylvester(&a.view(), &b.view(), &c.view());
assert!(result.is_err());
}
}