use crate::DType;
use numr::algorithm::linalg::{LinearAlgebraAlgorithms, MatrixNormOrder};
use numr::error::{Error, Result};
use numr::ops::{MatmulOps, ScalarOps, ShapeOps, TensorOps, UnaryOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
const TOL: f64 = 1e-12;
const MAX_ITER: usize = 100;
pub fn solve_care_iterative_impl<R, C>(
client: &C,
a: &Tensor<R>,
b: &Tensor<R>,
q: &Tensor<R>,
r: &Tensor<R>,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ MatmulOps<R>
+ UtilityOps<R>
+ LinearAlgebraAlgorithms<R>
+ RuntimeClient<R>,
{
let a_shape = a.shape();
if a_shape.len() != 2 || a_shape[0] != a_shape[1] {
return Err(Error::InvalidArgument {
arg: "a",
reason: "CARE iterative: A must be square".into(),
});
}
let n = a_shape[0];
let b_shape = b.shape();
if b_shape.len() != 2 || b_shape[0] != n {
return Err(Error::InvalidArgument {
arg: "b",
reason: format!("CARE iterative: B must be {n}×m"),
});
}
let m = b_shape[1];
if q.shape() != [n, n] {
return Err(Error::InvalidArgument {
arg: "q",
reason: format!("CARE iterative: Q must be {n}×{n}"),
});
}
if r.shape() != [m, m] {
return Err(Error::InvalidArgument {
arg: "r",
reason: format!("CARE iterative: R must be {m}×{m}"),
});
}
let r_inv = LinearAlgebraAlgorithms::inverse(client, r)?;
let bt = b.transpose(0, 1)?.contiguous()?;
let s = client.matmul(&client.matmul(b, &r_inv)?, &bt)?;
let at = a.transpose(0, 1)?.contiguous()?;
let neg_s = client.neg(&s)?;
let neg_q = client.neg(q)?;
let neg_at = client.neg(&at)?;
let top = client.cat(&[a, &neg_s], 1)?;
let bottom = client.cat(&[&neg_q, &neg_at], 1)?;
let mut s_k = client.cat(&[&top, &bottom], 0)?;
for _ in 0..MAX_ITER {
let s_inv = LinearAlgebraAlgorithms::inverse(client, &s_k)?;
let s_next = client.mul_scalar(&client.add(&s_k, &s_inv)?, 0.5)?;
let diff = client.sub(&s_next, &s_k)?;
let diff_norm =
LinearAlgebraAlgorithms::matrix_norm(client, &diff, MatrixNormOrder::Frobenius)?;
let s_norm =
LinearAlgebraAlgorithms::matrix_norm(client, &s_k, MatrixNormOrder::Frobenius)?;
let diff_val: f64 = diff_norm.to_vec()[0];
let s_val: f64 = s_norm.to_vec()[0];
s_k = s_next;
if s_val > 0.0 && diff_val / s_val < TOL {
break;
}
}
let eye_2n = client.eye(2 * n, None, a.dtype())?;
let w = client.mul_scalar(&client.sub(&eye_2n, &s_k)?, 0.5)?;
let w11 = w.narrow(0, 0, n)?.narrow(1, 0, n)?.contiguous()?;
let w21 = w.narrow(0, n, n)?.narrow(1, 0, n)?.contiguous()?;
let w11_inv = LinearAlgebraAlgorithms::inverse(client, &w11)?;
let x = client.matmul(&w21, &w11_inv)?;
let xt = x.transpose(0, 1)?.contiguous()?;
let x_sym = client.mul_scalar(&client.add(&x, &xt)?, 0.5)?;
Ok(x_sym)
}
pub fn solve_dare_iterative_impl<R, C>(
client: &C,
a: &Tensor<R>,
b: &Tensor<R>,
q: &Tensor<R>,
r: &Tensor<R>,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ MatmulOps<R>
+ UtilityOps<R>
+ LinearAlgebraAlgorithms<R>
+ RuntimeClient<R>,
{
let a_shape = a.shape();
if a_shape.len() != 2 || a_shape[0] != a_shape[1] {
return Err(Error::InvalidArgument {
arg: "a",
reason: "DARE iterative: A must be square".into(),
});
}
let n = a_shape[0];
let b_shape = b.shape();
if b_shape.len() != 2 || b_shape[0] != n {
return Err(Error::InvalidArgument {
arg: "b",
reason: format!("DARE iterative: B must be {n}×m"),
});
}
let m = b_shape[1];
if q.shape() != [n, n] {
return Err(Error::InvalidArgument {
arg: "q",
reason: format!("DARE iterative: Q must be {n}×{n}"),
});
}
if r.shape() != [m, m] {
return Err(Error::InvalidArgument {
arg: "r",
reason: format!("DARE iterative: R must be {m}×{m}"),
});
}
let dtype = a.dtype();
let r_inv = LinearAlgebraAlgorithms::inverse(client, r)?;
let bt = b.transpose(0, 1)?.contiguous()?;
let s = client.matmul(&client.matmul(b, &r_inv)?, &bt)?;
let eye = client.eye(n, None, dtype)?;
let zeros = client.mul_scalar(&eye, 0.0)?;
let at = a.transpose(0, 1)?.contiguous()?;
let neg_q = client.neg(q)?;
let l_top = client.cat(&[a, &zeros], 1)?;
let l_bottom = client.cat(&[&neg_q, &eye], 1)?;
let l = client.cat(&[&l_top, &l_bottom], 0)?;
let m_top = client.cat(&[&eye, &s], 1)?;
let m_bottom = client.cat(&[&zeros, &at], 1)?;
let m_mat = client.cat(&[&m_top, &m_bottom], 0)?;
let m_inv = LinearAlgebraAlgorithms::inverse(client, &m_mat)?;
let g_k = client.matmul(&m_inv, &l)?;
let eye_2n = client.eye(2 * n, None, dtype)?;
let g_minus_i = client.sub(&g_k, &eye_2n)?;
let g_plus_i = client.add(&g_k, &eye_2n)?;
let g_minus_i_inv = LinearAlgebraAlgorithms::inverse(client, &g_minus_i)?;
let mut s_k = client.matmul(&g_minus_i_inv, &g_plus_i)?;
for _ in 0..MAX_ITER {
let s_inv = LinearAlgebraAlgorithms::inverse(client, &s_k)?;
let s_next = client.mul_scalar(&client.add(&s_k, &s_inv)?, 0.5)?;
let diff = client.sub(&s_next, &s_k)?;
let diff_norm =
LinearAlgebraAlgorithms::matrix_norm(client, &diff, MatrixNormOrder::Frobenius)?;
let s_norm =
LinearAlgebraAlgorithms::matrix_norm(client, &s_k, MatrixNormOrder::Frobenius)?;
let diff_val: f64 = diff_norm.to_vec()[0];
let s_val: f64 = s_norm.to_vec()[0];
s_k = s_next;
if s_val > 0.0 && diff_val / s_val < TOL {
break;
}
}
let w = client.mul_scalar(&client.sub(&eye_2n, &s_k)?, 0.5)?;
let w11 = w.narrow(0, 0, n)?.narrow(1, 0, n)?.contiguous()?;
let w21 = w.narrow(0, n, n)?.narrow(1, 0, n)?.contiguous()?;
let w11_inv = LinearAlgebraAlgorithms::inverse(client, &w11)?;
let x = client.matmul(&w21, &w11_inv)?;
let xt = x.transpose(0, 1)?.contiguous()?;
let x_sym = client.mul_scalar(&client.add(&x, &xt)?, 0.5)?;
Ok(x_sym)
}
pub fn solve_discrete_lyapunov_iterative_impl<R, C>(
client: &C,
a: &Tensor<R>,
q: &Tensor<R>,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ UnaryOps<R>
+ MatmulOps<R>
+ UtilityOps<R>
+ LinearAlgebraAlgorithms<R>
+ RuntimeClient<R>,
{
let a_shape = a.shape();
if a_shape.len() != 2 || a_shape[0] != a_shape[1] {
return Err(Error::InvalidArgument {
arg: "a",
reason: "discrete Lyapunov iterative: A must be square".into(),
});
}
let n = a_shape[0];
if q.shape() != [n, n] {
return Err(Error::InvalidArgument {
arg: "q",
reason: format!("discrete Lyapunov iterative: Q must be {n}×{n}"),
});
}
let mut x_k = q.clone();
let mut a_k = a.clone();
for _ in 0..MAX_ITER {
let a_k_t = a_k.transpose(0, 1)?.contiguous()?;
let ax = client.matmul(&a_k, &x_k)?;
let axat = client.matmul(&ax, &a_k_t)?;
x_k = client.add(&axat, &x_k)?;
a_k = client.matmul(&a_k, &a_k)?;
let a_norm =
LinearAlgebraAlgorithms::matrix_norm(client, &a_k, MatrixNormOrder::Frobenius)?;
let a_val: f64 = a_norm.to_vec()[0];
if a_val < TOL {
break;
}
}
let xt = x_k.transpose(0, 1)?.contiguous()?;
let x_sym = client.mul_scalar(&client.add(&x_k, &xt)?, 0.5)?;
Ok(x_sym)
}