#![warn(missing_docs)]
use crate::rand::distr::{Distribution, Uniform};
use crate::rand::rngs::SmallRng;
use crate::rand::seq::index;
use crate::rand::{rng, Rng, SeedableRng};
use ndarray::{Array, ArrayRef, Axis, RemoveAxis, ShapeBuilder};
use ndarray::{ArrayBase, Data, DataOwned, Dimension, RawData};
#[cfg(feature = "quickcheck")]
use quickcheck::{Arbitrary, Gen};
pub mod rand
{
pub use rand::*;
}
pub mod rand_distr
{
pub use rand_distr::*;
}
pub trait RandomExt<S, A, D>
where
S: RawData<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;
fn random_using<Sh, IdS, R>(shape: Sh, distribution: IdS, rng: &mut R) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
fn sample_axis_using<R>(
&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
}
pub trait RandomRefExt<A, D>
where D: Dimension
{
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
D: RemoveAxis;
fn sample_axis_using<R>(
&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
D: RemoveAxis;
}
impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
where
S: RawData<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>,
{
Self::random_using(shape, dist, &mut get_rng())
}
fn random_using<Sh, IdS, R>(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_shape_simple_fn(shape, move || dist.sample(rng))
}
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
(**self).sample_axis(axis, n_samples, strategy)
}
fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
(**self).sample_axis_using(axis, n_samples, strategy, rng)
}
}
impl<A, D> RandomRefExt<A, D> for ArrayRef<A, D>
where D: Dimension
{
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
D: RemoveAxis,
{
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
}
fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
D: RemoveAxis,
{
let indices: Vec<_> = match strategy {
SamplingStrategy::WithReplacement => {
let distribution = Uniform::new(0, self.len_of(axis)).unwrap();
(0..n_samples).map(|_| distribution.sample(rng)).collect()
}
SamplingStrategy::WithoutReplacement => index::sample(rng, self.len_of(axis), n_samples).into_vec(),
};
self.select(axis, &indices)
}
}
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub enum SamplingStrategy
{
WithReplacement,
WithoutReplacement,
}
#[cfg(feature = "quickcheck")]
impl Arbitrary for SamplingStrategy
{
fn arbitrary(g: &mut Gen) -> Self
{
if bool::arbitrary(g) {
SamplingStrategy::WithReplacement
} else {
SamplingStrategy::WithoutReplacement
}
}
}
fn get_rng() -> SmallRng
{
SmallRng::from_rng(&mut rng())
}