1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
// Copyright 2016 bluss and ndarray developers. // // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your // option. This file may not be copied, modified, or distributed // except according to those terms. // // Vendored from ndarray-rand to support newer versions of the rand // crate. This can be removed once ndarray-rand and rand are at // version 1. //! Constructors for randomized arrays. `rand` integration for `ndarray`. //! //! See [**`RandomExt`**](trait.RandomExt.html) for usage examples. //! //! **Note:** `ndarray-rand` depends on `rand` 0.7. If you use any other items //! from `rand`, you need to specify a compatible version of `rand` in your //! `Cargo.toml`. If you want to use a RNG or distribution from another crate //! with `ndarray-rand`, you need to make sure that crate also depends on the //! correct version of `rand`. Otherwise, the compiler will return errors //! saying that the items are not compatible (e.g. that a type doesn't //! implement a necessary trait). use rand::distributions::Distribution; use rand::rngs::SmallRng; use rand::{thread_rng, Rng, SeedableRng}; use ndarray::ShapeBuilder; use ndarray::{ArrayBase, DataOwned, Dimension}; /// Constructors for n-dimensional arrays with random elements. /// /// This trait extends ndarray’s `ArrayBase` and can not be implemented /// for other types. /// /// The default RNG is a fast automatically seeded rng (currently /// [`rand::rngs::SmallRng`](https://docs.rs/rand/0.5/rand/rngs/struct.SmallRng.html) /// seeded from [`rand::thread_rng`](https://docs.rs/rand/0.5/rand/fn.thread_rng.html)). /// /// Note that `SmallRng` is cheap to initialize and fast, but it may generate /// low-quality random numbers, and reproducibility is not guaranteed. See its /// documentation for information. You can select a different RNG with /// [`.random_using()`](#tymethod.random_using). pub trait RandomExt<S, D> where S: DataOwned, D: Dimension, { /// Create an array with shape `dim` with elements drawn from /// `distribution` using the default RNG. /// /// ***Panics*** if creation of the RNG fails or if the number of elements /// overflows usize. fn random<Sh, IdS>(shape: Sh, distribution: IdS) -> ArrayBase<S, D> where IdS: Distribution<S::Elem>, Sh: ShapeBuilder<Dim = D>; /// Create an array with shape `dim` with elements drawn from /// `distribution`, using a specific Rng `rng`. /// /// ***Panics*** if the number of elements overflows usize. fn random_using<Sh, IdS, R>(shape: Sh, distribution: IdS, rng: &mut R) -> ArrayBase<S, D> where IdS: Distribution<S::Elem>, R: Rng + ?Sized, Sh: ShapeBuilder<Dim = D>; } impl<S, D> RandomExt<S, D> for ArrayBase<S, D> where S: DataOwned, D: Dimension, { fn random<Sh, IdS>(shape: Sh, dist: IdS) -> ArrayBase<S, D> where IdS: Distribution<S::Elem>, Sh: ShapeBuilder<Dim = D>, { let mut rng = SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed"); Self::random_using(shape, dist, &mut rng) } fn random_using<Sh, IdS, R>(shape: Sh, dist: IdS, rng: &mut R) -> ArrayBase<S, D> where IdS: Distribution<S::Elem>, R: Rng + ?Sized, Sh: ShapeBuilder<Dim = D>, { Self::from_shape_fn(shape, |_| dist.sample(rng)) } } /// A wrapper type that allows casting f64 distributions to f32 #[derive(Copy, Clone, Debug)] pub struct F32<S>(pub S); impl<S> Distribution<f32> for F32<S> where S: Distribution<f64>, { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f32 { self.0.sample(rng) as f32 } } #[cfg(test)] mod tests { use ndarray::Array; use rand::distributions::Uniform; use crate::ndarray_rand::RandomExt; #[test] fn test_dim() { let (mm, nn) = (5, 5); for m in 0..mm { for n in 0..nn { let a = Array::random((m, n), Uniform::new(0., 2.)); assert_eq!(a.shape(), &[m, n]); assert!(a.iter().all(|x| *x < 2.)); assert!(a.iter().all(|x| *x >= 0.)); } } } }