1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use std::f32::consts::PI;
4
5use super::{PrngArgs, PrngRuntime, random};
6
7use crate::{RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_open};
8
9#[derive(CubeLaunch, CubeType)]
10pub(crate) struct Normal<E: Numeric> {
11 mean: E,
12 std: E,
13}
14
15#[derive(Debug)]
16struct NormalFamily;
17
18impl RandomFamily for NormalFamily {
19 type Runtime<E: Numeric> = Normal<E>;
20}
21
22#[cube]
23impl<E: Numeric> PrngRuntime<E> for Normal<E> {
24 fn inner_loop(
25 args: Normal<E>,
26 write_index_base: u32,
27 n_invocations: u32,
28 #[comptime] n_values_per_thread: u32,
29 #[comptime] line_size: u32,
30 state_0: &mut u32,
31 state_1: &mut u32,
32 state_2: &mut u32,
33 state_3: &mut u32,
34 output: &mut Tensor<Line<E>>,
35 ) {
36 let mean = f32::cast_from(args.mean);
37 let std = f32::cast_from(args.std);
38
39 let mut output_line_0 = Line::empty(line_size);
40 let mut output_line_1 = Line::empty(line_size);
41
42 let num_iterations = n_values_per_thread / line_size / 2;
43 #[unroll(num_iterations <= 8)]
44 for line_index in 0..num_iterations {
45 #[unroll]
47 for i in 0..line_size {
48 *state_0 = taus_step_0(*state_0);
50 *state_1 = taus_step_1(*state_1);
51 *state_2 = taus_step_2(*state_2);
52 *state_3 = lcg_step(*state_3);
53
54 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
55 let unit_0 = to_unit_interval_open(int_random);
56
57 *state_0 = taus_step_0(*state_0);
59 *state_1 = taus_step_1(*state_1);
60 *state_2 = taus_step_2(*state_2);
61 *state_3 = lcg_step(*state_3);
62
63 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
64 let unit_1 = to_unit_interval_open(int_random);
65
66 let coeff = Log::log(unit_0) * -2.0;
68 let coeff = Sqrt::sqrt(coeff) * std;
69 let trigo_arg = 2.0 * PI * unit_1;
70
71 let normal_0 = f32::cos(trigo_arg) * coeff + mean;
72 let normal_1 = f32::sin(trigo_arg) * coeff + mean;
73
74 output_line_0[i] = E::cast_from(normal_0);
75 output_line_1[i] = E::cast_from(normal_1);
76 }
77
78 let iteration_offset = line_index * n_invocations * 2;
79 let write_index_0 = write_index_base + iteration_offset;
80 let write_index_1 = write_index_0 + n_invocations;
81
82 output[write_index_0] = output_line_0;
83 output[write_index_1] = output_line_1;
84 }
85 }
86}
87
88impl<E: Numeric> PrngArgs<E> for Normal<E> {
89 type Args = Self;
90
91 fn args<'a, R: Runtime>(self) -> NormalLaunch<'a, E, R> {
92 NormalLaunch::new(ScalarArg::new(self.mean), ScalarArg::new(self.std))
93 }
94}
95
96pub fn random_normal<R: Runtime, E: Numeric>(
98 client: &ComputeClient<R::Server, R::Channel>,
99 mean: E,
100 std: E,
101 out: TensorHandleRef<R>,
102) {
103 assert_eq!(
104 out.elem_size as u32,
105 E::elem_size(),
106 "Tensor element type must be the same as type E"
107 );
108
109 random::<NormalFamily, E, R>(client, Normal { mean, std }, out)
110}