cubecl_random/
normal.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use std::f32::consts::PI;
4
5use super::{PrngArgs, PrngRuntime, random};
6
7use crate::{RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_open};
8
9#[derive(CubeLaunch, CubeType)]
10pub(crate) struct Normal<E: Numeric> {
11    mean: E,
12    std: E,
13}
14
15#[derive(Debug)]
16struct NormalFamily;
17
18impl RandomFamily for NormalFamily {
19    type Runtime<E: Numeric> = Normal<E>;
20}
21
22#[cube]
23impl<E: Numeric> PrngRuntime<E> for Normal<E> {
24    fn inner_loop(
25        args: Normal<E>,
26        write_index_base: u32,
27        n_invocations: u32,
28        #[comptime] n_values_per_thread: u32,
29        #[comptime] line_size: u32,
30        state_0: &mut u32,
31        state_1: &mut u32,
32        state_2: &mut u32,
33        state_3: &mut u32,
34        output: &mut Tensor<Line<E>>,
35    ) {
36        let mean = f32::cast_from(args.mean);
37        let std = f32::cast_from(args.std);
38
39        let mut output_line_0 = Line::empty(line_size);
40        let mut output_line_1 = Line::empty(line_size);
41
42        let num_iterations = n_values_per_thread / line_size / 2;
43        #[unroll(num_iterations <= 8)]
44        for line_index in 0..num_iterations {
45            // vectorization
46            #[unroll]
47            for i in 0..line_size {
48                // First random uniform integer
49                *state_0 = taus_step_0(*state_0);
50                *state_1 = taus_step_1(*state_1);
51                *state_2 = taus_step_2(*state_2);
52                *state_3 = lcg_step(*state_3);
53
54                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
55                let unit_0 = to_unit_interval_open(int_random);
56
57                // Second random uniform integer
58                *state_0 = taus_step_0(*state_0);
59                *state_1 = taus_step_1(*state_1);
60                *state_2 = taus_step_2(*state_2);
61                *state_3 = lcg_step(*state_3);
62
63                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
64                let unit_1 = to_unit_interval_open(int_random);
65
66                // Box-Muller transform
67                let coeff = Log::log(unit_0) * -2.0;
68                let coeff = Sqrt::sqrt(coeff) * std;
69                let trigo_arg = 2.0 * PI * unit_1;
70
71                let normal_0 = f32::cos(trigo_arg) * coeff + mean;
72                let normal_1 = f32::sin(trigo_arg) * coeff + mean;
73
74                output_line_0[i] = E::cast_from(normal_0);
75                output_line_1[i] = E::cast_from(normal_1);
76            }
77
78            let iteration_offset = line_index * n_invocations * 2;
79            let write_index_0 = write_index_base + iteration_offset;
80            let write_index_1 = write_index_0 + n_invocations;
81
82            output[write_index_0] = output_line_0;
83            output[write_index_1] = output_line_1;
84        }
85    }
86}
87
88impl<E: Numeric> PrngArgs<E> for Normal<E> {
89    type Args = Self;
90
91    fn args<'a, R: Runtime>(self) -> NormalLaunch<'a, E, R> {
92        NormalLaunch::new(ScalarArg::new(self.mean), ScalarArg::new(self.std))
93    }
94}
95
96/// Pseudo-random generator with uniform distribution
97pub fn random_normal<R: Runtime, E: Numeric>(
98    client: &ComputeClient<R::Server, R::Channel>,
99    mean: E,
100    std: E,
101    out: TensorHandleRef<R>,
102) {
103    assert_eq!(
104        out.elem_size as u32,
105        E::elem_size(),
106        "Tensor element type must be the same as type E"
107    );
108
109    random::<NormalFamily, E, R>(client, Normal { mean, std }, out)
110}