use crate::rand::distributions::{Distribution, Uniform};
use crate::rand::rngs::SmallRng;
use crate::rand::seq::index;
use crate::rand::{thread_rng, Rng, SeedableRng};
use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder};
use ndarray::{ArrayBase, DataOwned, RawData, Data, Dimension};
#[cfg(feature = "quickcheck")]
use quickcheck::{Arbitrary, Gen};
pub mod rand {
pub use rand::*;
}
pub mod rand_distr {
pub use rand_distr::*;
}
pub trait RandomExt<S, A, D>
where
S: RawData<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;
fn random_using<Sh, IdS, R>(shape: Sh, distribution: IdS, rng: &mut R) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
fn sample_axis_using<R>(
&self,
axis: Axis,
n_samples: usize,
strategy: SamplingStrategy,
rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
}
impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
where
S: RawData<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>,
{
Self::random_using(shape, dist, &mut get_rng())
}
fn random_using<Sh, IdS, R>(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_shape_simple_fn(shape, move || dist.sample(rng))
}
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
}
fn sample_axis_using<R>(
&self,
axis: Axis,
n_samples: usize,
strategy: SamplingStrategy,
rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
let indices: Vec<_> = match strategy {
SamplingStrategy::WithReplacement => {
let distribution = Uniform::from(0..self.len_of(axis));
(0..n_samples).map(|_| distribution.sample(rng)).collect()
}
SamplingStrategy::WithoutReplacement => {
index::sample(rng, self.len_of(axis), n_samples).into_vec()
}
};
self.select(axis, &indices)
}
}
#[derive(Debug, Clone)]
pub enum SamplingStrategy {
WithReplacement,
WithoutReplacement,
}
#[cfg(feature = "quickcheck")]
impl Arbitrary for SamplingStrategy {
fn arbitrary<G: Gen>(g: &mut G) -> Self {
if bool::arbitrary(g) {
SamplingStrategy::WithReplacement
} else {
SamplingStrategy::WithoutReplacement
}
}
}
fn get_rng() -> SmallRng {
SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed")
}
#[derive(Copy, Clone, Debug)]
pub struct F32<S>(pub S);
impl<S> Distribution<f32> for F32<S>
where
S: Distribution<f64>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
self.0.sample(rng) as f32
}
}