burn_jit/kernel/prng/
normal.rs1use cubecl::prelude::*;
2use std::f32::consts::PI;
3
4use burn_tensor::Shape;
5
6use crate::{
7 kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2},
8 tensor::JitTensor,
9 JitElement, JitRuntime,
10};
11
12use super::{random, PrngArgs, PrngRuntime};
13
14#[derive(CubeLaunch)]
15pub(crate) struct Normal<E: Numeric> {
16 mean: E,
17 std: E,
18}
19
20#[cube]
21impl<E: JitElement> PrngRuntime<E> for Normal<E> {
22 fn inner_loop(
23 args: Normal<E>,
24 write_index_base: u32,
25 n_invocations: u32,
26 #[comptime] n_values_per_thread: u32,
27 state_0: &mut u32,
28 state_1: &mut u32,
29 state_2: &mut u32,
30 state_3: &mut u32,
31 output: &mut Tensor<E>,
32 ) {
33 let mean = f32::cast_from(args.mean);
34 let std = f32::cast_from(args.std);
35
36 let should_unroll = n_values_per_thread <= 16;
37
38 #[unroll(should_unroll)]
39 for i in 0..n_values_per_thread / 2 {
40 *state_0 = taus_step_0(*state_0);
42 *state_1 = taus_step_1(*state_1);
43 *state_2 = taus_step_2(*state_2);
44 *state_3 = lcg_step(*state_3);
45
46 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
47 let unit_0 = cast_uint_to_float(int_random);
48
49 *state_0 = taus_step_0(*state_0);
51 *state_1 = taus_step_1(*state_1);
52 *state_2 = taus_step_2(*state_2);
53 *state_3 = lcg_step(*state_3);
54
55 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56 let unit_1 = cast_uint_to_float(int_random);
57
58 let coeff = Log::log(unit_0) * -2.0;
60 let coeff = Sqrt::sqrt(coeff) * std;
61 let trigo_arg = 2.0 * PI * unit_1;
62
63 let normal_0 = f32::cos(trigo_arg) * coeff + mean;
64 let normal_1 = f32::sin(trigo_arg) * coeff + mean;
65
66 let iteration_offset = 2 * i * n_invocations;
67 let write_index_0 = write_index_base + iteration_offset;
68 let write_index_1 = write_index_0 + n_invocations;
69
70 output[write_index_0] = E::cast_from(normal_0);
71 output[write_index_1] = E::cast_from(normal_1);
72 }
73 }
74}
75
76impl<E: JitElement> PrngArgs<E> for Normal<E> {
77 type Args = Self;
78
79 fn args<'a, R: Runtime>(self) -> NormalLaunch<'a, E, R> {
80 NormalLaunch::new(ScalarArg::new(self.mean), ScalarArg::new(self.std))
81 }
82}
83
84pub fn random_normal<R: JitRuntime, E: JitElement>(
86 shape: Shape,
87 device: &R::Device,
88 mean: E,
89 std: E,
90) -> JitTensor<R> {
91 random(shape, device, Normal { mean, std })
92}