use super::FixedPoint;
use super::FixedMatrix;
use super::decompose::lu_decompose;
use super::linalg::{convergence_threshold, upscale_to_compute, ComputeStorage};
use super::compute_matrix::{ComputeMatrix, compute_lu_decompose};
use crate::fixed_point::core_types::errors::OverflowDetected;
const PADE_B: [&str; 7] = [
"1",
"0.5",
"0.113636363636363636363", "0.015151515151515151515", "0.001262626262626262626", "0.000063131313131313131", "0.000001503126503126503", ];
fn pade_coeff_compute(k: usize) -> ComputeStorage {
upscale_to_compute(FixedPoint::from_str(PADE_B[k]).raw())
}
pub(crate) fn matrix_exp_compute(a: &ComputeMatrix) -> Result<ComputeMatrix, OverflowDetected> {
let n = a.rows();
if a.frobenius_norm_compute().is_zero() {
return Ok(ComputeMatrix::identity(n));
}
let a_norm = a.norm_1_compute();
let mut s = 0u32;
let mut scale = a_norm;
let one = FixedPoint::one();
let half = one / FixedPoint::from_int(2);
while scale >= half {
scale = scale / (one + one);
s += 1;
}
let mut b = a.copy();
for _ in 0..s {
b = b.halve();
}
let b2 = b.mat_mul(&b);
let b4 = b2.mat_mul(&b2);
let b6 = b2.mat_mul(&b4);
let id = ComputeMatrix::identity(n);
let c0 = pade_coeff_compute(0);
let c1 = pade_coeff_compute(1);
let c2 = pade_coeff_compute(2);
let c3 = pade_coeff_compute(3);
let c4 = pade_coeff_compute(4);
let c5 = pade_coeff_compute(5);
let c6 = pade_coeff_compute(6);
let v = id.scalar_mul(c0)
.add(&b2.scalar_mul(c2))
.add(&b4.scalar_mul(c4))
.add(&b6.scalar_mul(c6));
let p_odd = ComputeMatrix::identity(n).scalar_mul(c1)
.add(&b2.scalar_mul(c3))
.add(&b4.scalar_mul(c5));
let u = b.mat_mul(&p_odd);
let n_mat = v.add(&u);
let d_mat = v.sub(&u);
let lu_d = compute_lu_decompose(&d_mat)?;
let mut exp_b = ComputeMatrix::new(n, n);
for j in 0..n {
let n_col = n_mat.col_vec(j);
let r_col = lu_d.solve(&n_col)?;
for i in 0..n {
exp_b.set(i, j, r_col[i]);
}
}
let mut result = exp_b;
for _ in 0..s {
result = result.mat_mul(&result);
}
Ok(result)
}
pub fn matrix_exp(a: &FixedMatrix) -> Result<FixedMatrix, OverflowDetected> {
assert!(a.is_square(), "matrix_exp: matrix must be square");
let a_c = ComputeMatrix::from_fixed_matrix(a);
Ok(matrix_exp_compute(&a_c)?.to_fixed_matrix())
}
pub(crate) fn matrix_sqrt_compute(a: &ComputeMatrix) -> Result<ComputeMatrix, OverflowDetected> {
let n = a.rows();
let max_iter = 50;
let threshold = convergence_threshold(a.frobenius_norm_compute());
let mut y = a.copy();
let mut z = ComputeMatrix::identity(n);
for _ in 0..max_iter {
let y_prev = y.copy();
let z_inv = compute_lu_decompose(&z)?.inverse()?;
let y_inv = compute_lu_decompose(&y)?.inverse()?;
y = y.add(&z_inv).halve();
z = z.add(&y_inv).halve();
let diff = y.sub(&y_prev);
let diff_norm = diff.frobenius_norm_compute();
if diff_norm < threshold {
return Ok(y);
}
}
Ok(y)
}
pub fn matrix_sqrt(a: &FixedMatrix) -> Result<FixedMatrix, OverflowDetected> {
assert!(a.is_square(), "matrix_sqrt: matrix must be square");
let a_c = ComputeMatrix::from_fixed_matrix(a);
Ok(matrix_sqrt_compute(&a_c)?.to_fixed_matrix())
}
pub(crate) fn matrix_log_compute(a: &ComputeMatrix) -> Result<ComputeMatrix, OverflowDetected> {
let n = a.rows();
let id = ComputeMatrix::identity(n);
let quarter = FixedPoint::one() / FixedPoint::from_int(4);
let mut a_s = a.copy();
let mut s = 0u32;
for _ in 0..30 {
let diff = a_s.sub(&id);
let diff_norm = diff.frobenius_norm_compute();
if diff_norm < quarter {
break;
}
a_s = matrix_sqrt_compute(&a_s)?;
s += 1;
}
let x = a_s.sub(&id);
let num_terms = 22;
let mut horner = ComputeMatrix::identity(n);
for k in (1..num_terms).rev() {
let coeff = FixedPoint::from_int(k as i32) / FixedPoint::from_int((k + 1) as i32);
let coeff_compute = upscale_to_compute(coeff.raw());
let x_scaled = x.scalar_mul(coeff_compute);
horner = id.sub(&x_scaled.mat_mul(&horner));
}
let mut log_approx = x.mat_mul(&horner);
for _ in 0..s {
log_approx = log_approx.add(&log_approx);
}
Ok(log_approx)
}
pub fn matrix_log(a: &FixedMatrix) -> Result<FixedMatrix, OverflowDetected> {
assert!(a.is_square(), "matrix_log: matrix must be square");
let a_c = ComputeMatrix::from_fixed_matrix(a);
Ok(matrix_log_compute(&a_c)?.to_fixed_matrix())
}
pub fn matrix_pow(a: &FixedMatrix, p: FixedPoint) -> Result<FixedMatrix, OverflowDetected> {
assert!(a.is_square(), "matrix_pow: matrix must be square");
let n = a.rows();
if p.is_zero() {
return Ok(FixedMatrix::identity(n));
}
if p == FixedPoint::one() {
return Ok(a.clone());
}
if p == -FixedPoint::one() {
return lu_decompose(a)?.inverse();
}
let a_c = ComputeMatrix::from_fixed_matrix(a);
let log_a = matrix_log_compute(&a_c)?;
let p_c = upscale_to_compute(p.raw());
let p_log_a = log_a.scalar_mul(p_c);
let result = matrix_exp_compute(&p_log_a)?;
Ok(result.to_fixed_matrix())
}