use ndarray::Array1;
use num_complex::Complex64;
use rand::Rng;
pub trait DistributionExt {
fn characteristic_function(&self, _t: f64) -> Complex64 {
unimplemented!(
"DistributionExt::characteristic_function is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn pdf(&self, _x: f64) -> f64 {
unimplemented!(
"DistributionExt::pdf is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn cdf(&self, _x: f64) -> f64 {
unimplemented!(
"DistributionExt::cdf is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn inv_cdf(&self, _p: f64) -> f64 {
unimplemented!(
"DistributionExt::inv_cdf is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn mean(&self) -> f64 {
unimplemented!(
"DistributionExt::mean is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn median(&self) -> f64 {
unimplemented!(
"DistributionExt::median is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn mode(&self) -> f64 {
unimplemented!(
"DistributionExt::mode is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn variance(&self) -> f64 {
unimplemented!(
"DistributionExt::variance is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn skewness(&self) -> f64 {
unimplemented!(
"DistributionExt::skewness is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn kurtosis(&self) -> f64 {
unimplemented!(
"DistributionExt::kurtosis is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn entropy(&self) -> f64 {
unimplemented!(
"DistributionExt::entropy is not implemented for {}",
std::any::type_name::<Self>()
)
}
fn moment_generating_function(&self, _t: f64) -> f64 {
unimplemented!(
"DistributionExt::moment_generating_function is not implemented for {}",
std::any::type_name::<Self>()
)
}
}
pub trait DistributionSampler<T> {
fn fill_slice<R: Rng + ?Sized>(&self, rng: &mut R, out: &mut [T]);
#[inline]
fn sample_n(&self, n: usize) -> Array1<T> {
let mut out = Array1::<T>::uninit(n);
let flat_uninit = out
.as_slice_mut()
.expect("distribution sample_n output must be contiguous");
let flat = unsafe {
std::slice::from_raw_parts_mut(flat_uninit.as_mut_ptr().cast::<T>(), flat_uninit.len())
};
let mut rng = crate::simd_rng::SimdRng::new();
self.fill_slice(&mut rng, flat);
unsafe {
out.assume_init()
}
}
#[inline]
fn sample_matrix(&self, m: usize, n: usize) -> ndarray::Array2<T>
where
Self: Clone + Send,
T: Send,
{
let mut out = ndarray::Array2::<T>::uninit((m, n));
if m == 0 || n == 0 {
return unsafe {
out.assume_init()
};
}
let flat_uninit = out
.as_slice_mut()
.expect("distribution sample_matrix output must be contiguous");
let flat = unsafe {
std::slice::from_raw_parts_mut(flat_uninit.as_mut_ptr().cast::<T>(), flat_uninit.len())
};
const MIN_PAR_CHUNK: usize = 16 * 1024;
let total = flat.len();
let max_workers_for_size = total.div_ceil(MIN_PAR_CHUNK).max(1);
let workers = rayon::current_num_threads()
.max(1)
.min(max_workers_for_size);
if workers == 1 {
let mut rng = crate::simd_rng::SimdRng::new();
self.fill_slice(&mut rng, flat);
return unsafe {
out.assume_init()
};
}
let chunk_len = total.div_ceil(workers);
let base = self.clone();
rayon::scope(move |scope| {
for chunk in flat.chunks_mut(chunk_len) {
let sampler = base.clone();
scope.spawn(move |_| {
let mut rng = crate::simd_rng::SimdRng::new();
sampler.fill_slice(&mut rng, chunk);
});
}
});
unsafe {
out.assume_init()
}
}
}