1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use burn_tensor::Shape;

use crate::{
    compute::{compute_client, StaticKernel},
    element::WgpuElement,
    kernel::{
        prng::base::{make_args_buffer, make_info_buffer},
        prng_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT,
    },
    ops::numeric::empty_device,
    tensor::WgpuTensor,
    GraphicsApi, WgpuDevice,
};

use super::base::Prng;

struct NormalPrng;

impl StaticKernelSource for NormalPrng {
    fn source() -> SourceTemplate {
        Prng::source()
            .register("num_args", "2")
            .register(
                "prng_loop",
                include_str!("../../template/prng/normal_inner_loop.wgsl"),
            )
            .add_template(include_str!(
                "../../template/prng/box_muller_transform.wgsl"
            ))
    }
}

/// Pseudo-random generator for normal distribution
pub fn random_normal<G: GraphicsApi, E: WgpuElement, const D: usize>(
    shape: Shape<D>,
    device: &WgpuDevice,
    mean: E,
    std: E,
) -> WgpuTensor<E, D> {
    const N_VALUES_PER_THREAD: usize = 128; // must be even

    let client = compute_client::<G>(device);
    let output = empty_device(client.clone(), device.clone(), shape.clone());
    let info_handle = make_info_buffer(client.clone(), N_VALUES_PER_THREAD);
    let args_handle = make_args_buffer(client.clone(), &[mean, std]);
    let workgroup = prng_workgroup(shape.num_elements(), WORKGROUP_DEFAULT, N_VALUES_PER_THREAD);
    let kernel = StaticKernel::<
        KernelSettings<NormalPrng, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
    >::new(workgroup);

    client.execute(
        Box::new(kernel),
        &[&output.handle, &info_handle, &args_handle],
    );

    output
}

#[cfg(test)]
mod tests {

    use burn_tensor::{backend::Backend, Data, Distribution, Shape, Tensor};
    use serial_test::serial;

    use crate::{kernel::prng::base::tests::calculate_bin_stats, tests::TestBackend, WgpuDevice};

    #[test]
    #[serial]
    fn subsequent_calls_give_different_tensors() {
        TestBackend::seed(0);
        let shape = [4, 5];
        let device = WgpuDevice::default();

        let tensor_1 =
            Tensor::<TestBackend, 2>::random_device(shape, Distribution::Normal(0., 1.), &device);
        let tensor_2 =
            Tensor::<TestBackend, 2>::random_device(shape, Distribution::Normal(0., 1.), &device);
        for i in 0..20 {
            assert!(tensor_1.to_data().value[i] != tensor_2.to_data().value[i]);
        }
    }

    #[test]
    #[serial]
    fn empirical_mean_close_to_expectation() {
        TestBackend::seed(0);
        let shape = [128, 128];
        let device = WgpuDevice::default();
        let mean = 10.;
        let tensor =
            Tensor::<TestBackend, 2>::random_device(shape, Distribution::Normal(mean, 2.), &device);
        let empirical_mean = tensor.mean().into_data();
        empirical_mean.assert_approx_eq(&Data::from([mean as f32]), 1);
    }

    #[test]
    #[serial]
    fn normal_respects_68_95_99_rule() {
        // https://en.wikipedia.org/wiki/68%E2%80%9395%E2%80%9399.7_rule
        let shape: Shape<2> = [1000, 1000].into();
        let device = WgpuDevice::default();
        let mu = 0.;
        let s = 1.;
        let tensor = Tensor::<TestBackend, 2>::random_device(
            shape.clone(),
            Distribution::Normal(mu, s),
            &device,
        );
        let stats = calculate_bin_stats(
            tensor.into_data().value,
            6,
            (mu - 3. * s) as f32,
            (mu + 3. * s) as f32,
        );
        let assert_approx_eq = |count, percent| {
            let expected = percent * shape.num_elements() as f32 / 100.;
            assert!(f32::abs(count as f32 - expected) < 2000.);
        };
        assert_approx_eq(stats[0].count, 2.1);
        assert_approx_eq(stats[1].count, 13.6);
        assert_approx_eq(stats[2].count, 34.1);
        assert_approx_eq(stats[3].count, 34.1);
        assert_approx_eq(stats[4].count, 13.6);
        assert_approx_eq(stats[5].count, 2.1);
    }
}