use ndarray::{Array, Dimension, IxDyn};
use rand::prelude::*;
use rand::SeedableRng;
use rand_distr::{Distribution, Uniform};
use std::cell::RefCell;
pub struct Random<R: Rng + ?Sized = rand::rngs::ThreadRng> {
rng: R,
}
impl Default for Random {
fn default() -> Self {
Self { rng: rand::rng() }
}
}
impl<R: Rng> Random<R> {
pub fn sample<D, T>(&mut self, distribution: D) -> T
where
D: Distribution<T>,
{
distribution.sample(&mut self.rng)
}
pub fn random_range<T: rand_distr::uniform::SampleUniform + PartialOrd + Copy>(
&mut self,
min: T,
max: T,
) -> T {
self.sample(rand_distr::Uniform::new(min, max).unwrap())
}
pub fn random_bool(&mut self) -> bool {
let dist = rand_distr::Bernoulli::new(0.5).unwrap();
dist.sample(&mut self.rng)
}
pub fn random_bool_with_chance(&mut self, prob: f64) -> bool {
let dist = rand_distr::Bernoulli::new(prob).unwrap();
dist.sample(&mut self.rng)
}
pub fn shuffle<T>(&mut self, slice: &mut [T]) {
slice.shuffle(&mut self.rng);
}
pub fn sample_vec<D, T>(&mut self, distribution: D, size: usize) -> Vec<T>
where
D: Distribution<T> + Copy,
{
(0..size)
.map(|_| distribution.sample(&mut self.rng))
.collect()
}
pub fn sample_array<D, T, Sh>(&mut self, distribution: D, shape: Sh) -> Array<T, IxDyn>
where
D: Distribution<T> + Copy,
Sh: Into<IxDyn>,
{
let shape = shape.into();
let size = shape.size();
let values = self.sample_vec(distribution, size);
Array::from_shape_vec(shape, values).unwrap()
}
}
impl Random {
pub fn with_seed(seed: u64) -> Random<StdRng> {
Random {
rng: StdRng::seed_from_u64(seed),
}
}
}
thread_local! {
static THREAD_RNG: RefCell<Random> = RefCell::new(Random::default());
}
pub fn get_thread_rng<F, R>(f: F) -> R
where
F: FnOnce(&mut Random) -> R,
{
THREAD_RNG.with(|rng| f(&mut rng.borrow_mut()))
}
pub trait DistributionExt<T>: Distribution<T> + Sized {
fn random_array<U, Sh>(&self, rng: &mut Random<U>, shape: Sh) -> Array<T, IxDyn>
where
U: Rng,
Sh: Into<IxDyn>,
Self: Copy,
{
rng.sample_array(*self, shape)
}
fn random_vec<U>(&self, rng: &mut Random<U>, size: usize) -> Vec<T>
where
U: Rng,
Self: Copy,
{
rng.sample_vec(*self, size)
}
}
impl<D, T> DistributionExt<T> for D where D: Distribution<T> {}
pub mod sampling {
use super::*;
use rand_distr as rdistr;
pub fn random_uniform01<R: Rng>(rng: &mut Random<R>) -> f64 {
Uniform::new(0.0_f64, 1.0_f64).unwrap().sample(&mut rng.rng)
}
pub fn random_standard_normal<R: Rng>(rng: &mut Random<R>) -> f64 {
rdistr::Normal::new(0.0_f64, 1.0_f64)
.unwrap()
.sample(&mut rng.rng)
}
pub fn random_normal<R: Rng>(rng: &mut Random<R>, mean: f64, std_dev: f64) -> f64 {
rdistr::Normal::new(mean, std_dev)
.unwrap()
.sample(&mut rng.rng)
}
pub fn random_lognormal<R: Rng>(rng: &mut Random<R>, mean: f64, std_dev: f64) -> f64 {
rdistr::LogNormal::new(mean, std_dev)
.unwrap()
.sample(&mut rng.rng)
}
pub fn random_exponential<R: Rng>(rng: &mut Random<R>, lambda: f64) -> f64 {
rdistr::Exp::new(lambda).unwrap().sample(&mut rng.rng)
}
pub fn random_integers<R: Rng, Sh>(
rng: &mut Random<R>,
min: i64,
max: i64,
shape: Sh,
) -> Array<i64, IxDyn>
where
Sh: Into<IxDyn>,
{
rng.sample_array(Uniform::new_inclusive(min, max).unwrap(), shape)
}
pub fn random_floats<R: Rng, Sh>(
rng: &mut Random<R>,
min: f64,
max: f64,
shape: Sh,
) -> Array<f64, IxDyn>
where
Sh: Into<IxDyn>,
{
rng.sample_array(Uniform::new(min, max).unwrap(), shape)
}
pub fn bootstrap_indices<R: Rng>(
rng: &mut Random<R>,
data_size: usize,
sample_size: usize,
) -> Vec<usize> {
let dist = Uniform::new(0, data_size).unwrap();
rng.sample_vec(dist, sample_size)
}
pub fn sample_without_replacement<R: Rng>(
rng: &mut Random<R>,
data_size: usize,
sample_size: usize,
) -> Vec<usize> {
let mut indices: Vec<usize> = (0..data_size).collect();
indices.shuffle(&mut rng.rng);
indices.truncate(sample_size);
indices
}
}
pub struct DeterministicSequence {
seed: u64,
counter: u64,
}
impl DeterministicSequence {
pub fn new(seed: u64) -> Self {
Self { seed, counter: 0 }
}
pub fn next_f64(&mut self) -> f64 {
let mut x = self.counter.wrapping_add(self.seed);
x = ((x >> 16) ^ x).wrapping_mul(0x45d9f3b);
x = ((x >> 16) ^ x).wrapping_mul(0x45d9f3b);
x = (x >> 16) ^ x;
self.counter = self.counter.wrapping_add(1);
(x as f64) / (u64::MAX as f64)
}
pub fn reset(&mut self) {
self.counter = 0;
}
pub fn get_vec(&mut self, size: usize) -> Vec<f64> {
(0..size).map(|_| self.next_f64()).collect()
}
pub fn get_array<Sh>(&mut self, shape: Sh) -> Array<f64, IxDyn>
where
Sh: Into<IxDyn>,
{
let shape = shape.into();
let size = shape.size();
let values = self.get_vec(size);
Array::from_shape_vec(shape, values).unwrap()
}
}