use crate::DType;
use super::sylvester::sylvester_impl;
use numr::algorithm::linalg::LinearAlgebraAlgorithms;
use numr::error::{Error, Result};
use numr::ops::{MatmulOps, ScalarOps, ShapeOps, TensorOps, UnaryOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn continuous_lyapunov_impl<R, C>(client: &C, a: &Tensor<R>, q: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ MatmulOps<R>
+ UtilityOps<R>
+ LinearAlgebraAlgorithms<R>
+ RuntimeClient<R>,
{
let a_shape = a.shape();
if a_shape.len() != 2 || a_shape[0] != a_shape[1] {
return Err(Error::InvalidArgument {
arg: "a",
reason: "continuous Lyapunov: A must be square".into(),
});
}
let n = a_shape[0];
let q_shape = q.shape();
if q_shape != [n, n] {
return Err(Error::InvalidArgument {
arg: "q",
reason: format!("continuous Lyapunov: Q must be {n}×{n}, got {:?}", q_shape),
});
}
let at = a.transpose(0, 1)?.contiguous()?;
sylvester_impl(client, a, &at, q)
}
pub fn discrete_lyapunov_impl<R, C>(client: &C, a: &Tensor<R>, q: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ MatmulOps<R>
+ UtilityOps<R>
+ LinearAlgebraAlgorithms<R>
+ RuntimeClient<R>,
{
let a_shape = a.shape();
if a_shape.len() != 2 || a_shape[0] != a_shape[1] {
return Err(Error::InvalidArgument {
arg: "a",
reason: "discrete Lyapunov: A must be square".into(),
});
}
let n = a_shape[0];
let q_shape = q.shape();
if q_shape != [n, n] {
return Err(Error::InvalidArgument {
arg: "q",
reason: format!("discrete Lyapunov: Q must be {n}×{n}, got {:?}", q_shape),
});
}
let dtype = a.dtype();
let eye = client.eye(n, None, dtype)?;
let a_minus_i = client.sub(a, &eye)?;
let a_plus_i = client.add(a, &eye)?;
let a_minus_i_inv = LinearAlgebraAlgorithms::inverse(client, &a_minus_i)?;
let a_c = client.matmul(&a_minus_i_inv, &a_plus_i)?;
let a_minus_i_inv_t = a_minus_i_inv.transpose(0, 1)?.contiguous()?;
let q_c = client.matmul(&client.matmul(&a_minus_i_inv, q)?, &a_minus_i_inv_t)?;
let rhs = client.mul_scalar(&client.neg(&q_c)?, 2.0)?;
continuous_lyapunov_impl(client, &a_c, &rhs)
}