use cartan_core::Real;
use nalgebra::SMatrix;
use rand::Rng;
use rand_distr::{Distribution, StandardNormal};
pub fn wishart_step<const N: usize, R: Rng + ?Sized>(
x: &SMatrix<Real, N, N>,
shape: Real,
dt: Real,
rng: &mut R,
) -> SMatrix<Real, N, N>
where
StandardNormal: Distribution<Real>,
{
let sqrt_dt = dt.sqrt();
let mut db = SMatrix::<Real, N, N>::zeros();
for i in 0..N {
for j in 0..N {
let z: Real = StandardNormal.sample(rng);
db[(i, j)] = z * sqrt_dt;
}
}
let sqrt_x = {
let dm = nalgebra::DMatrix::from_column_slice(N, N, x.as_slice());
let eig = dm.symmetric_eigen();
let floor: Real = 0.0;
let mut vals = eig.eigenvalues.clone();
for lam in vals.iter_mut() {
if *lam < floor {
*lam = floor;
}
*lam = lam.sqrt();
}
let diag = nalgebra::DMatrix::from_diagonal(&vals);
let product = &eig.eigenvectors * diag * eig.eigenvectors.transpose();
let mut out = SMatrix::<Real, N, N>::zeros();
for i in 0..N {
for j in 0..N {
out[(i, j)] = product[(i, j)];
}
}
out
};
let diffusion = sqrt_x * db + db.transpose() * sqrt_x;
let drift = SMatrix::<Real, N, N>::identity() * (shape * dt);
let mut next = x + diffusion + drift;
next = (next + next.transpose()) * 0.5;
next
}