1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// Copyright 2016 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//
// Vendored from ndarray-rand to support newer versions of the rand
// crate. This can be removed once ndarray-rand and rand are at
// version 1.

//! Constructors for randomized arrays. `rand` integration for `ndarray`.
//!
//! See [**`RandomExt`**](trait.RandomExt.html) for usage examples.
//!
//! **Note:** `ndarray-rand` depends on `rand` 0.7. If you use any other items
//! from `rand`, you need to specify a compatible version of `rand` in your
//! `Cargo.toml`. If you want to use a RNG or distribution from another crate
//! with `ndarray-rand`, you need to make sure that crate also depends on the
//! correct version of `rand`. Otherwise, the compiler will return errors
//! saying that the items are not compatible (e.g. that a type doesn't
//! implement a necessary trait).

use rand::distributions::Distribution;
use rand::rngs::SmallRng;
use rand::{thread_rng, Rng, SeedableRng};

use ndarray::ShapeBuilder;
use ndarray::{ArrayBase, DataOwned, Dimension};

/// Constructors for n-dimensional arrays with random elements.
///
/// This trait extends ndarray’s `ArrayBase` and can not be implemented
/// for other types.
///
/// The default RNG is a fast automatically seeded rng (currently
/// [`rand::rngs::SmallRng`](https://docs.rs/rand/0.5/rand/rngs/struct.SmallRng.html)
/// seeded from [`rand::thread_rng`](https://docs.rs/rand/0.5/rand/fn.thread_rng.html)).
///
/// Note that `SmallRng` is cheap to initialize and fast, but it may generate
/// low-quality random numbers, and reproducibility is not guaranteed. See its
/// documentation for information. You can select a different RNG with
/// [`.random_using()`](#tymethod.random_using).
pub trait RandomExt<S, D>
where
    S: DataOwned,
    D: Dimension,
{
    /// Create an array with shape `dim` with elements drawn from
    /// `distribution` using the default RNG.
    ///
    /// ***Panics*** if creation of the RNG fails or if the number of elements
    /// overflows usize.
    fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D>
    where
        IdS: Distribution<S::Elem>,
        Sh: ShapeBuilder<Dim = D>;

    /// Create an array with shape `dim` with elements drawn from
    /// `distribution`, using a specific Rng `rng`.
    ///
    /// ***Panics*** if the number of elements overflows usize.
    fn random_using<Sh, IdS, R>(shape: Sh, distribution: IdS, rng: &mut R) -> ArrayBase<S, D>
    where
        IdS: Distribution<S::Elem>,
        R: Rng + ?Sized,
        Sh: ShapeBuilder<Dim = D>;
}

impl<S, D> RandomExt<S, D> for ArrayBase<S, D>
where
    S: DataOwned,
    D: Dimension,
{
    fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D>
    where
        IdS: Distribution<S::Elem>,
        Sh: ShapeBuilder<Dim = D>,
    {
        let mut rng =
            SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed");
        Self::random_using(shape, dist, &mut rng)
    }

    fn random_using<Sh, IdS, R>(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase<S, D>
    where
        IdS: Distribution<S::Elem>,
        R: Rng + ?Sized,
        Sh: ShapeBuilder<Dim = D>,
    {
        Self::from_shape_fn(shape, |_| dist.sample(rng))
    }
}

/// A wrapper type that allows casting f64 distributions to f32
#[derive(Copy, Clone, Debug)]
pub struct F32<S>(pub S);

impl<S> Distribution<f32> for F32<S>
where
    S: Distribution<f64>,
{
    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 {
        self.0.sample(rng) as f32
    }
}

#[cfg(test)]
mod tests {
    use ndarray::Array;
    use rand::distributions::Uniform;

    use crate::ndarray_rand::RandomExt;

    #[test]
    fn test_dim() {
        let (mm, nn) = (5, 5);
        for m in 0..mm {
            for n in 0..nn {
                let a = Array::random((m, n), Uniform::new(0., 2.));
                assert_eq!(a.shape(), &[m, n]);
                assert!(a.iter().all(|x| *x < 2.));
                assert!(a.iter().all(|x| *x >= 0.));
            }
        }
    }
}