1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use cubecl_common::{rand::get_seeded_rng, stub::Mutex};
5use cubecl_std::tensor::{
6 View,
7 layout::{
8 Coords1d,
9 linear::{LinearView, linear_view},
10 },
11};
12use rand::{Rng, SeedableRng, rngs::StdRng};
13
14pub(crate) const N_VALUES_PER_THREAD: usize = 128;
15
16static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
17
18pub fn seed(seed: u64) {
19 let rng = StdRng::seed_from_u64(seed);
20 let mut seed = SEED.lock().unwrap();
21 *seed = Some(rng);
22}
23
24pub(crate) fn random<F: RandomFamily, E: Numeric, R: Runtime>(
26 client: &ComputeClient<R::Server>,
27 prng: F::Runtime<E>,
28 output: TensorHandleRef<'_, R>,
29) {
30 let seeds = get_seeds();
31 let args = prng.args();
32
33 let cube_dim = CubeDim::default();
34 let cube_count = prng_cube_count(output.size(), cube_dim, N_VALUES_PER_THREAD);
35
36 let output_line_size = 1;
37 let output = linear_view(client, &output, output_line_size);
47
48 prng_kernel::launch::<F, E, R>(
49 client,
50 cube_count,
51 cube_dim,
52 output,
53 ScalarArg::new(seeds[0]),
54 ScalarArg::new(seeds[1]),
55 ScalarArg::new(seeds[2]),
56 ScalarArg::new(seeds[3]),
57 args,
58 N_VALUES_PER_THREAD as u32,
59 output_line_size as u32,
60 );
61}
62
63fn prng_cube_count(num_elems: usize, cube_dim: CubeDim, n_values_per_thread: usize) -> CubeCount {
64 let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
65 let num_invocations = f32::ceil(num_threads / cube_dim.num_elems() as f32);
66 let cubes_x = f32::ceil(f32::sqrt(num_invocations));
67 let cubes_y = f32::ceil(num_invocations / cubes_x);
68
69 CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
70}
71
72pub(crate) fn get_seeds() -> [u32; 4] {
73 let mut seed = SEED.lock().unwrap();
74 let mut rng: StdRng = match seed.as_ref() {
75 Some(rng_seeded) => rng_seeded.clone(),
76 None => get_seeded_rng(),
77 };
78 let mut seeds: Vec<u32> = Vec::with_capacity(4);
79 for _ in 0..4 {
80 seeds.push(rng.random());
81 }
82 *seed = Some(rng);
83
84 seeds.try_into().unwrap()
85}
86
87pub(crate) trait PrngArgs<E: Numeric>: Send + Sync + 'static {
88 type Args: LaunchArg;
89
90 fn args<'a, R: Runtime>(self) -> <Self::Args as LaunchArg>::RuntimeArg<'a, R>;
91}
92
93pub(crate) trait RandomFamily: Send + Sync + 'static + std::fmt::Debug {
94 type Runtime<E: Numeric>: PrngRuntime<E>;
95}
96
97#[cube]
98pub(crate) trait PrngRuntime<E: Numeric>: Send + Sync + 'static + PrngArgs<E> {
99 #[allow(clippy::too_many_arguments)]
100 fn inner_loop(
101 args: Self::Args,
102 write_index_base: u32,
103 n_invocations: u32,
104 #[comptime] n_values_per_thread: u32,
105 #[comptime] line_size: u32,
106 state_0: &mut u32,
107 state_1: &mut u32,
108 state_2: &mut u32,
109 state_3: &mut u32,
110 output: &mut View<Line<E>, Coords1d, ReadWrite>,
111 );
112}
113
114type Args<F, E> = <<F as RandomFamily>::Runtime<E> as PrngArgs<E>>::Args;
115
116#[cube(launch)]
117fn prng_kernel<F: RandomFamily, E: Numeric>(
118 output: &mut LinearView<Line<E>, ReadWrite>,
119 seed_0: u32,
120 seed_1: u32,
121 seed_2: u32,
122 seed_3: u32,
123 args: Args<F, E>,
124 #[comptime] n_values_per_thread: u32,
125 #[comptime] line_size: u32,
126) {
127 let cube_offset = CUBE_POS * CUBE_DIM;
128
129 let write_index_base = cube_offset * n_values_per_thread / line_size + UNIT_POS;
130
131 #[allow(arithmetic_overflow)]
132 let thread_seed = 1000000007u32 * ABSOLUTE_POS;
133
134 let mut state_0 = thread_seed + seed_0;
135 let mut state_1 = thread_seed + seed_1;
136 let mut state_2 = thread_seed + seed_2;
137 let mut state_3 = thread_seed + seed_3;
138
139 F::Runtime::inner_loop(
141 args,
142 write_index_base,
143 CUBE_DIM,
144 n_values_per_thread,
145 line_size,
146 &mut state_0,
147 &mut state_1,
148 &mut state_2,
149 &mut state_3,
150 output,
151 );
152}
153
154#[cube]
155pub(crate) fn taus_step_0(z: u32) -> u32 {
156 taus_step(z, 13u32, 19u32, 12u32, 4294967294u32)
157}
158
159#[cube]
160pub(crate) fn taus_step_1(z: u32) -> u32 {
161 taus_step(z, 2u32, 25u32, 4u32, 4294967288u32)
162}
163
164#[cube]
165pub(crate) fn taus_step_2(z: u32) -> u32 {
166 taus_step(z, 3u32, 11u32, 17u32, 4294967280u32)
167}
168
169#[cube]
170fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
171 let b = z << s1;
172 let b = b ^ z;
173 let b = b >> s2;
174 let z = (z & m) << s3;
175 z ^ b
176}
177
178#[cube]
179pub(crate) fn lcg_step(z: u32) -> u32 {
180 let a = 1664525u32;
181 let b = 1013904223u32;
182
183 z * a + b
184}
185
186#[cube]
189pub fn to_unit_interval_closed_open(int_random: u32) -> f32 {
190 let shifted = int_random >> 8;
193 f32::cast_from(shifted) / 16777216.0 }
195
196#[cube]
199pub fn to_unit_interval_open(int_random: u32) -> f32 {
200 let shifted = int_random >> 9;
202 (f32::cast_from(shifted) + 1.0) / 8388609.0 }