burn_jit/kernel/prng/
normal.rs

1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use burn_tensor::Shape;
5
6use crate::{
7    kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
8    tensor::JitTensor,
9    JitElement, JitRuntime,
10};
11
12use super::{random, PrngArgs, PrngRuntime};
13
14#[derive(CubeLaunch)]
15pub(crate) struct Normal<E: Numeric> {
16    mean: E,
17    std: E,
18}
19
20#[cube]
21impl<E: JitElement> PrngRuntime<E> for Normal<E> {
22    fn inner_loop(
23        args: Normal<E>,
24        write_index_base: u32,
25        n_invocations: u32,
26        #[comptime] n_values_per_thread: u32,
27        state_0: &mut u32,
28        state_1: &mut u32,
29        state_2: &mut u32,
30        state_3: &mut u32,
31        output: &mut Tensor<E>,
32    ) {
33        let mean = f32::cast_from(args.mean);
34        let std = f32::cast_from(args.std);
35
36        let should_unroll = n_values_per_thread <= 16;
37
38        #[unroll(should_unroll)]
39        for i in 0..n_values_per_thread / 2 {
40            // First random uniform integer
41            *state_0 = taus_step_0(*state_0);
42            *state_1 = taus_step_1(*state_1);
43            *state_2 = taus_step_2(*state_2);
44            *state_3 = lcg_step(*state_3);
45
46            let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
47            let unit_0 = cast_uint_to_float(int_random);
48
49            // Second random uniform integer
50            *state_0 = taus_step_0(*state_0);
51            *state_1 = taus_step_1(*state_1);
52            *state_2 = taus_step_2(*state_2);
53            *state_3 = lcg_step(*state_3);
54
55            let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56            let unit_1 = cast_uint_to_float(int_random);
57
58            // Box-Muller transform
59            let coeff = Log::log(unit_0) * -2.0;
60            let coeff = Sqrt::sqrt(coeff) * std;
61            let trigo_arg = 2.0 * PI * unit_1;
62
63            let normal_0 = f32::cos(trigo_arg) * coeff + mean;
64            let normal_1 = f32::sin(trigo_arg) * coeff + mean;
65
66            let iteration_offset = 2 * i * n_invocations;
67            let write_index_0 = write_index_base + iteration_offset;
68            let write_index_1 = write_index_0 + n_invocations;
69
70            output[write_index_0] = E::cast_from(normal_0);
71            output[write_index_1] = E::cast_from(normal_1);
72        }
73    }
74}
75
76impl<E: JitElement> PrngArgs<E> for Normal<E> {
77    type Args = Self;
78
79    fn args<'a, R: Runtime>(self) -> NormalLaunch<'a, E, R> {
80        NormalLaunch::new(ScalarArg::new(self.mean), ScalarArg::new(self.std))
81    }
82}
83
84/// Pseudo-random generator with uniform distribution
85pub fn random_normal<R: JitRuntime, E: JitElement>(
86    shape: Shape,
87    device: &R::Device,
88    mean: E,
89    std: E,
90) -> JitTensor<R> {
91    random(shape, device, Normal { mean, std })
92}