use std::cell::UnsafeCell;
use rand::Rng;
use rand_distr::Distribution;
use super::SimdFloatExt;
use super::chi_square::SimdChiSquared;
use super::normal::SimdNormal;
use crate::simd_rng::SimdRng;
const SMALL_STUDENT_T_THRESHOLD: usize = 16;
pub struct SimdStudentT<T: SimdFloatExt> {
nu: T,
normal: SimdNormal<T>,
chisq: SimdChiSquared<T>,
buffer: UnsafeCell<[T; 16]>,
index: UnsafeCell<usize>,
simd_rng: UnsafeCell<SimdRng>,
}
impl<T: SimdFloatExt> SimdStudentT<T> {
#[inline]
pub fn new(nu: T) -> Self {
Self::from_seed_source(nu, &crate::simd_rng::Unseeded)
}
#[inline]
pub fn with_seed(nu: T, seed: u64) -> Self {
Self::from_seed_source(nu, &crate::simd_rng::Deterministic::new(seed))
}
pub fn from_seed_source(nu: T, seed: &impl crate::simd_rng::SeedExt) -> Self {
Self {
nu,
normal: SimdNormal::from_seed_source(T::zero(), T::one(), seed),
chisq: SimdChiSquared::from_seed_source(nu, seed),
buffer: UnsafeCell::new([T::zero(); 16]),
index: UnsafeCell::new(16),
simd_rng: UnsafeCell::new(seed.rng()),
}
}
#[inline]
pub fn sample_fast(&self) -> T {
let index = unsafe { &mut *self.index.get() };
if *index >= 16 {
self.refill_buffer();
}
let buf = unsafe { &mut *self.buffer.get() };
let z = buf[*index];
*index += 1;
z
}
pub fn fill_slice<R: Rng + ?Sized>(&self, _rng: &mut R, out: &mut [T]) {
self.fill_slice_fast(out);
}
pub fn fill_slice_fast(&self, out: &mut [T]) {
let rng = unsafe { &mut *self.simd_rng.get() };
if out.len() < SMALL_STUDENT_T_THRESHOLD {
for x in out.iter_mut() {
let z = self.normal.sample(rng);
let v = self.chisq.sample(rng);
*x = z / (v / self.nu).sqrt();
}
return;
}
let inv_nu = T::splat(T::one() / self.nu);
let mut zbuf = [T::zero(); 8];
let mut vbuf = [T::zero(); 8];
let mut chunks = out.chunks_exact_mut(8);
for chunk in &mut chunks {
self.normal.fill_slice(rng, &mut zbuf);
self.chisq.fill_slice(rng, &mut vbuf);
let z = T::simd_from_array(zbuf);
let v = T::simd_from_array(vbuf);
let x = z / T::simd_sqrt(v * inv_nu);
chunk.copy_from_slice(&T::simd_to_array(x));
}
let rem = chunks.into_remainder();
if !rem.is_empty() {
self.normal.fill_slice(rng, &mut zbuf);
self.chisq.fill_slice(rng, &mut vbuf);
let z = T::simd_from_array(zbuf);
let v = T::simd_from_array(vbuf);
let x = T::simd_to_array(z / T::simd_sqrt(v * inv_nu));
rem.copy_from_slice(&x[..rem.len()]);
}
}
fn refill_buffer(&self) {
let buf = unsafe { &mut *self.buffer.get() };
self.fill_slice_fast(buf);
unsafe {
*self.index.get() = 0;
}
}
}
impl<T: SimdFloatExt> Clone for SimdStudentT<T> {
fn clone(&self) -> Self {
Self::new(self.nu)
}
}
impl<T: SimdFloatExt> Distribution<T> for SimdStudentT<T> {
fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> T {
let idx = unsafe { &mut *self.index.get() };
if *idx >= 16 {
self.refill_buffer();
}
let val = unsafe { (*self.buffer.get())[*idx] };
*idx += 1;
val
}
}
impl<T: SimdFloatExt> crate::traits::DistributionExt for SimdStudentT<T> {
fn pdf(&self, x: f64) -> f64 {
let nu = self.nu.to_f64().unwrap();
let log_norm = crate::special::ln_gamma(0.5 * (nu + 1.0))
- 0.5 * (nu * std::f64::consts::PI).ln()
- crate::special::ln_gamma(0.5 * nu);
let log_kernel = -0.5 * (nu + 1.0) * (1.0 + x * x / nu).ln();
(log_norm + log_kernel).exp()
}
fn cdf(&self, x: f64) -> f64 {
let nu = self.nu.to_f64().unwrap();
let t = nu / (nu + x * x);
let half = 0.5 * crate::special::beta_i(0.5 * nu, 0.5, t);
if x >= 0.0 { 1.0 - half } else { half }
}
fn inv_cdf(&self, p: f64) -> f64 {
if p <= 0.0 {
return f64::NEG_INFINITY;
}
if p >= 1.0 {
return f64::INFINITY;
}
let nu = self.nu.to_f64().unwrap();
let z = crate::special::ndtri(p);
let mut x = z * (1.0 + (z * z + 1.0) / (4.0 * nu));
for _ in 0..40 {
let cdf = {
let t = nu / (nu + x * x);
let half = 0.5 * crate::special::beta_i(0.5 * nu, 0.5, t);
if x >= 0.0 { 1.0 - half } else { half }
};
let f = cdf - p;
let log_norm = crate::special::ln_gamma(0.5 * (nu + 1.0))
- 0.5 * (nu * std::f64::consts::PI).ln()
- crate::special::ln_gamma(0.5 * nu);
let log_kernel = -0.5 * (nu + 1.0) * (1.0 + x * x / nu).ln();
let pdf = (log_norm + log_kernel).exp();
if pdf <= 0.0 {
break;
}
let dx = f / pdf;
let new_x = x - dx;
if (new_x - x).abs() < 1e-14 * (1.0 + x.abs()) {
return new_x;
}
x = new_x;
}
x
}
fn mean(&self) -> f64 {
if self.nu.to_f64().unwrap() > 1.0 {
0.0
} else {
f64::NAN
}
}
fn median(&self) -> f64 {
0.0
}
fn mode(&self) -> f64 {
0.0
}
fn variance(&self) -> f64 {
let nu = self.nu.to_f64().unwrap();
if nu > 2.0 {
nu / (nu - 2.0)
} else if nu > 1.0 {
f64::INFINITY
} else {
f64::NAN
}
}
fn skewness(&self) -> f64 {
if self.nu.to_f64().unwrap() > 3.0 {
0.0
} else {
f64::NAN
}
}
fn kurtosis(&self) -> f64 {
let nu = self.nu.to_f64().unwrap();
if nu > 4.0 {
6.0 / (nu - 4.0)
} else if nu > 2.0 {
f64::INFINITY
} else {
f64::NAN
}
}
fn entropy(&self) -> f64 {
let nu = self.nu.to_f64().unwrap();
let half_nu = 0.5 * nu;
let half_nu_p1 = 0.5 * (nu + 1.0);
half_nu_p1 * (crate::special::digamma(half_nu_p1) - crate::special::digamma(half_nu))
+ 0.5 * nu.ln()
+ crate::special::ln_gamma(half_nu)
- crate::special::ln_gamma(half_nu_p1)
+ 0.5 * std::f64::consts::PI.ln()
}
fn moment_generating_function(&self, _t: f64) -> f64 {
f64::NAN
}
}
py_distribution!(PyStudentT, SimdStudentT,
sig: (nu, seed=None, dtype=None),
params: (nu: f64)
);