cubecl_random/
normal.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::View;
4use std::f32::consts::PI;
5
6use super::{PrngArgs, PrngRuntime, random};
7
8use crate::{RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_open};
9
10#[derive(CubeLaunch, CubeType)]
11pub(crate) struct Normal {
12    mean: f32,
13    std: f32,
14}
15
16#[derive(Debug)]
17struct NormalFamily;
18
19impl RandomFamily for NormalFamily {
20    type Runtime = Normal;
21}
22
23#[cube]
24impl PrngRuntime for Normal {
25    fn inner_loop<E: Numeric>(
26        args: Normal,
27        write_index_base: u32,
28        n_invocations: u32,
29        #[comptime] n_values_per_thread: u32,
30        #[comptime] line_size: u32,
31        state_0: &mut u32,
32        state_1: &mut u32,
33        state_2: &mut u32,
34        state_3: &mut u32,
35        output: &mut View<Line<E>, u32, ReadWrite>,
36    ) {
37        let mean = f32::cast_from(args.mean);
38        let std = f32::cast_from(args.std);
39
40        let mut output_line_0 = Line::empty(line_size);
41        let mut output_line_1 = Line::empty(line_size);
42
43        let num_iterations = n_values_per_thread / line_size / 2;
44        #[unroll(num_iterations <= 8)]
45        for line_index in 0..num_iterations {
46            // vectorization
47            #[unroll]
48            for i in 0..line_size {
49                // First random uniform integer
50                *state_0 = taus_step_0(*state_0);
51                *state_1 = taus_step_1(*state_1);
52                *state_2 = taus_step_2(*state_2);
53                *state_3 = lcg_step(*state_3);
54
55                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56                let unit_0 = to_unit_interval_open(int_random);
57
58                // Second random uniform integer
59                *state_0 = taus_step_0(*state_0);
60                *state_1 = taus_step_1(*state_1);
61                *state_2 = taus_step_2(*state_2);
62                *state_3 = lcg_step(*state_3);
63
64                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
65                let unit_1 = to_unit_interval_open(int_random);
66
67                // Box-Muller transform
68                let coeff = Log::log(unit_0) * -2.0;
69                let coeff = Sqrt::sqrt(coeff) * std;
70                let trigo_arg = 2.0 * PI * unit_1;
71
72                let normal_0 = f32::cos(trigo_arg) * coeff + mean;
73                let normal_1 = f32::sin(trigo_arg) * coeff + mean;
74
75                output_line_0[i] = E::cast_from(normal_0);
76                output_line_1[i] = E::cast_from(normal_1);
77            }
78
79            let iteration_offset = line_index * n_invocations * 2;
80            let write_index_0 = write_index_base + iteration_offset;
81            let write_index_1 = write_index_0 + n_invocations;
82
83            output[write_index_0] = output_line_0;
84            output[write_index_1] = output_line_1;
85        }
86    }
87}
88
89impl PrngArgs for Normal {
90    type Args = Self;
91
92    fn args<'a, R: Runtime>(self) -> NormalLaunch<'a, R> {
93        NormalLaunch::new(ScalarArg::new(self.mean), ScalarArg::new(self.std))
94    }
95}
96
97/// Pseudo-random generator with uniform distribution
98pub fn random_normal<R: Runtime>(
99    client: &ComputeClient<R::Server>,
100    mean: f32,
101    std: f32,
102    out: TensorHandleRef<R>,
103    dtype: StorageType,
104) {
105    assert_eq!(
106        out.elem_size,
107        dtype.size(),
108        "Tensor element type must be the same as type E"
109    );
110
111    random::<NormalFamily, R>(client, Normal { mean, std }, out, dtype)
112}