1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
use crate::prelude::*;
use rand::{distributions::Distribution, Rng};

/// Something that has parameters that can be randomized from a generic distribution.
pub trait Randomize<T> {
    fn randomize<R: Rng, D: Distribution<T>>(&mut self, rng: &mut R, dist: &D);
}

macro_rules! tensor_impl {
    ($typename:ident, [$($Vs:tt),*]) => {
impl<$(const $Vs: usize, )* H> Randomize<f32> for $typename<$($Vs, )* H> {
    /// Fills `self.mut_data()` with data from the distribution `D`
    fn randomize<R: Rng, D: Distribution<f32>>(&mut self, rng: &mut R, dist: &D) {
        <Self as HasDevice>::Device::fill(self.mut_data(), &mut |v| *v = dist.sample(rng));
    }
}
    };
}

tensor_impl!(Tensor0D, []);
tensor_impl!(Tensor1D, [M]);
tensor_impl!(Tensor2D, [M, N]);
tensor_impl!(Tensor3D, [M, N, O]);
tensor_impl!(Tensor4D, [M, N, O, P]);

#[cfg(test)]
mod tests {
    use super::*;
    use rand::thread_rng;
    use rand_distr::Standard;

    #[test]
    fn test_randomize() {
        let mut t = Tensor1D::<100>::zeros();
        assert_eq!(t.data(), &[0.0; 100]);

        t.randomize(&mut thread_rng(), &Standard);
        assert!(t.data() != &[0.0; 100]);
    }
}