use core::ops::Mul;
use ndarray::{prelude::Axis, Array2, Slice};
use ndarray_linalg::{Lapack, SVDInto};
use num::{traits::FloatConst, Float, Integer, NumCast, ToPrimitive, Unsigned};
moddef::moddef!(
flat(pub) mod {
chain,
complex_op,
len_eq,
maybe_len_eq,
not_range,
overlay,
result_or_ok,
truncate_im,
two_sided_range
}
);
pub(crate) fn pinv<T>(m: Array2<T>) -> Array2<T>
where
T: Lapack<Real: Into<T>> + Mul<T::Real, Output = T>
{
let mdim = m.dim();
let (u, s, v_h) = m.svd_into(true, true).unwrap();
let u = u.unwrap();
let v_h = v_h.unwrap();
let threshold = T::Real::epsilon()*NumCast::from(mdim.0.max(mdim.1)).unwrap();
let (num_keep, v_s_inv) = {
let mut v_h_t = v_h.reversed_axes();
let mut num_keep = 0;
for (&sing_val, mut v_h_t_col) in s.iter().zip(v_h_t.columns_mut()) {
if sing_val > threshold {
let sing_val_recip = sing_val.recip();
v_h_t_col.map_inplace(|v_h_t| *v_h_t = T::from_real(sing_val_recip) * v_h_t.conj());
num_keep += 1;
} else {
break;
}
}
v_h_t.slice_axis_inplace(Axis(1), Slice::from(..num_keep));
(num_keep, v_h_t)
};
let u_h = {
let mut u_t = u.reversed_axes();
u_t.slice_axis_inplace(Axis(0), Slice::from(..num_keep));
u_t.map_inplace(|x| *x = x.conj());
u_t
};
v_s_inv.dot(&u_h)
}
pub(crate) fn i0<T>(x: T) -> T
where
T: Float + FloatConst
{
let one = T::one();
let two = one + one;
let four = two + two;
let half = two.recip();
let lambda = half;
let p0 = one;
let q1 = (one - lambda*lambda)/four/(one - T::SQRT_2()*(lambda/T::PI()).sqrt());
let p1 = two*(lambda/T::PI()).sqrt()*q1*T::FRAC_1_SQRT_2();
(one + lambda*lambda*x*x).sqrt().sqrt().recip()*x.cosh()*(p0 + p1*x*x)/(one + q1*x*x)
}
pub(crate) fn gamma<T>(x: T) -> T
where
T: Float
{
NumCast::from(f64::gamma(NumCast::from(x).unwrap())).unwrap()
}
pub(crate) fn erf_inv<T>(x: T) -> T
where
T: Float
{
NumCast::from(statrs::function::erf::erf_inv(NumCast::from(x).unwrap())).unwrap()
}
pub(crate) fn factorial<T, U>(x: U) -> T
where
T: Float,
U: Unsigned + Integer + ToPrimitive + Copy
{
let n = NumCast::from(x).unwrap();
if let Some(y) = (1..=n).try_fold(n, u128::checked_mul)
{
T::from(y).unwrap()
}
else
{
gamma(T::from(x).unwrap())
}.max(T::one())
}
pub(crate) fn bincoeff<T, U>(n: U, k: U) -> T
where
T: Float,
U: Unsigned + Integer + ToPrimitive + Copy
{
let nn: u128 = NumCast::from(n).unwrap();
let kk: u128 = NumCast::from(k).unwrap();
let b = if let Some(b) = (nn + 1).checked_sub(kk)
.and_then(|nmkp1| {
(nmkp1..=nn).try_fold(nn, u128::checked_mul)
})
{
T::from(b).unwrap()
}
else
{
factorial::<T, u128>(nn)/gamma(-T::from(n - k).unwrap())
};
b/factorial(kk)
}