1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::View;
4use std::f32::consts::PI;
5
6use super::{PrngArgs, PrngRuntime, random};
7
8use crate::{RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_open};
9
10#[derive(CubeLaunch, CubeType)]
11pub(crate) struct Normal {
12 mean: f32,
13 std: f32,
14}
15
16#[derive(Debug)]
17struct NormalFamily;
18
19impl RandomFamily for NormalFamily {
20 type Runtime = Normal;
21}
22
23#[cube]
24impl PrngRuntime for Normal {
25 fn inner_loop<E: Numeric>(
26 args: Normal,
27 write_index_base: u32,
28 n_invocations: u32,
29 #[comptime] n_values_per_thread: u32,
30 #[comptime] line_size: u32,
31 state_0: &mut u32,
32 state_1: &mut u32,
33 state_2: &mut u32,
34 state_3: &mut u32,
35 output: &mut View<Line<E>, u32, ReadWrite>,
36 ) {
37 let mean = f32::cast_from(args.mean);
38 let std = f32::cast_from(args.std);
39
40 let mut output_line_0 = Line::empty(line_size);
41 let mut output_line_1 = Line::empty(line_size);
42
43 let num_iterations = n_values_per_thread / line_size / 2;
44 #[unroll(num_iterations <= 8)]
45 for line_index in 0..num_iterations {
46 #[unroll]
48 for i in 0..line_size {
49 *state_0 = taus_step_0(*state_0);
51 *state_1 = taus_step_1(*state_1);
52 *state_2 = taus_step_2(*state_2);
53 *state_3 = lcg_step(*state_3);
54
55 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56 let unit_0 = to_unit_interval_open(int_random);
57
58 *state_0 = taus_step_0(*state_0);
60 *state_1 = taus_step_1(*state_1);
61 *state_2 = taus_step_2(*state_2);
62 *state_3 = lcg_step(*state_3);
63
64 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
65 let unit_1 = to_unit_interval_open(int_random);
66
67 let coeff = Log::log(unit_0) * -2.0;
69 let coeff = Sqrt::sqrt(coeff) * std;
70 let trigo_arg = 2.0 * PI * unit_1;
71
72 let normal_0 = f32::cos(trigo_arg) * coeff + mean;
73 let normal_1 = f32::sin(trigo_arg) * coeff + mean;
74
75 output_line_0[i] = E::cast_from(normal_0);
76 output_line_1[i] = E::cast_from(normal_1);
77 }
78
79 let iteration_offset = line_index * n_invocations * 2;
80 let write_index_0 = write_index_base + iteration_offset;
81 let write_index_1 = write_index_0 + n_invocations;
82
83 output[write_index_0] = output_line_0;
84 output[write_index_1] = output_line_1;
85 }
86 }
87}
88
89impl PrngArgs for Normal {
90 type Args = Self;
91
92 fn args<'a, R: Runtime>(self) -> NormalLaunch<'a, R> {
93 NormalLaunch::new(ScalarArg::new(self.mean), ScalarArg::new(self.std))
94 }
95}
96
97pub fn random_normal<R: Runtime>(
99 client: &ComputeClient<R>,
100 mean: f32,
101 std: f32,
102 out: TensorHandleRef<R>,
103 dtype: StorageType,
104) -> Result<(), LaunchError> {
105 assert_eq!(
106 out.elem_size,
107 dtype.size(),
108 "Tensor element type must be the same as type E"
109 );
110
111 random::<NormalFamily, R>(client, Normal { mean, std }, out, dtype)
112}