use crate::rand::distr::{Distribution, Uniform};
use crate::rand::rngs::SmallRng;
use crate::rand::seq::index;
use crate::rand::{rng, Rng, SeedableRng};
use kn0sys_ndarray::{
Array, Axis, RemoveAxis, ShapeBuilder,
ArrayBase, Data, DataOwned, Dimension, RawData
};
use std::convert::TryFrom;
pub mod rand
{
pub use rand::*;
}
pub mod rand_distr
{
pub use rand_distr::*;
}
pub trait RandomExt<S, A, D>
where
S: RawData<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;
fn random_using<Sh, IdS, R>(shape: Sh, distribution: IdS, rng: &mut R) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>;
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
fn sample_axis_using<R>(
&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R,
) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis;
}
impl<S, A, D> RandomExt<S, A, D> for ArrayBase<S, D>
where
S: RawData<Elem = A>,
D: Dimension,
{
fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>,
{
Self::random_using(shape, dist, &mut get_rng())
}
fn random_using<Sh, IdS, R>(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase<S, D>
where
IdS: Distribution<S::Elem>,
R: Rng + ?Sized,
S: DataOwned<Elem = A>,
Sh: ShapeBuilder<Dim = D>,
{
Self::from_shape_simple_fn(shape, move || dist.sample(rng))
}
fn sample_axis(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy) -> Array<A, D>
where
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
self.sample_axis_using(axis, n_samples, strategy, &mut get_rng())
}
fn sample_axis_using<R>(&self, axis: Axis, n_samples: usize, strategy: SamplingStrategy, rng: &mut R) -> Array<A, D>
where
R: Rng + ?Sized,
A: Copy,
S: Data<Elem = A>,
D: RemoveAxis,
{
let indices: Vec<_> = match strategy {
SamplingStrategy::WithReplacement => {
let distribution = Uniform::try_from(0..self.len_of(axis)).expect("Uniform");
(0..n_samples).map(|_| distribution.sample(rng)).collect()
}
SamplingStrategy::WithoutReplacement => index::sample(rng, self.len_of(axis), n_samples).into_vec(),
};
self.select(axis, &indices)
}
}
#[derive(Debug, Clone)]
pub enum SamplingStrategy
{
WithReplacement,
WithoutReplacement,
}
fn get_rng() -> SmallRng
{
SmallRng::from_rng(&mut rng())
}