use ndarray::{Array, Array2, ArrayView1, Axis};
#[cfg(feature = "quickcheck")]
use ndarray_rand::rand::{distributions::Distribution, thread_rng};
use ndarray::ShapeBuilder;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::{RandomExt, SamplingStrategy};
use quickcheck::quickcheck;
#[test]
fn test_dim() {
let (mm, nn) = (5, 5);
for m in 0..mm {
for n in 0..nn {
let a = Array::random((m, n), Uniform::new(0., 2.));
assert_eq!(a.shape(), &[m, n]);
assert!(a.iter().all(|x| *x < 2.));
assert!(a.iter().all(|x| *x >= 0.));
assert!(a.is_standard_layout());
}
}
}
#[test]
fn test_dim_f() {
let (mm, nn) = (5, 5);
for m in 0..mm {
for n in 0..nn {
let a = Array::random((m, n).f(), Uniform::new(0., 2.));
assert_eq!(a.shape(), &[m, n]);
assert!(a.iter().all(|x| *x < 2.));
assert!(a.iter().all(|x| *x >= 0.));
assert!(a.t().is_standard_layout());
}
}
}
#[test]
fn sample_axis_on_view() {
let m = 5;
let a = Array::random((m, 4), Uniform::new(0., 2.));
let _samples = a.view().sample_axis(Axis(0), m, SamplingStrategy::WithoutReplacement);
}
#[test]
#[should_panic]
fn oversampling_without_replacement_should_panic() {
let m = 5;
let a = Array::random((m, 4), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), m + 1, SamplingStrategy::WithoutReplacement);
}
quickcheck! {
fn oversampling_with_replacement_is_fine(m: usize, n: usize) -> bool {
let a = Array::random((m, n), Uniform::new(0., 2.));
let n_samples = m + n + 1;
if m != 0 {
if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(0), n_samples) {
return false;
}
}
if n != 0 {
if !sampling_works(&a, SamplingStrategy::WithReplacement, Axis(1), n_samples) {
return false;
}
}
true
}
}
#[cfg(feature = "quickcheck")]
quickcheck! {
fn sampling_behaves_as_expected(m: usize, n: usize, strategy: SamplingStrategy) -> bool {
let a = Array::random((m, n), Uniform::new(0., 2.));
let mut rng = &mut thread_rng();
if m != 0 {
let n_row_samples = Uniform::from(1..m+1).sample(&mut rng);
if !sampling_works(&a, strategy.clone(), Axis(0), n_row_samples) {
return false;
}
}
if n != 0 {
let n_col_samples = Uniform::from(1..n+1).sample(&mut rng);
if !sampling_works(&a, strategy, Axis(1), n_col_samples) {
return false;
}
}
true
}
}
fn sampling_works(
a: &Array2<f64>,
strategy: SamplingStrategy,
axis: Axis,
n_samples: usize,
) -> bool {
let samples = a.sample_axis(axis, n_samples, strategy);
samples
.axis_iter(axis)
.all(|lane| is_subset(&a, &lane, axis))
}
fn is_subset(a: &Array2<f64>, b: &ArrayView1<f64>, axis: Axis) -> bool {
a.axis_iter(axis).any(|lane| &lane == b)
}
#[test]
#[should_panic]
fn sampling_without_replacement_from_a_zero_length_axis_should_panic() {
let n = 5;
let a = Array::random((0, n), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithoutReplacement);
}
#[test]
#[should_panic]
fn sampling_with_replacement_from_a_zero_length_axis_should_panic() {
let n = 5;
let a = Array::random((0, n), Uniform::new(0., 2.));
let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement);
}