1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use cubecl_common::{rand::get_seeded_rng, stub::Mutex};
5use cubecl_std::tensor::{
6 View,
7 layout::{
8 Coords1d,
9 linear::{LinearView, linear_view},
10 },
11};
12use rand::{Rng, SeedableRng, rngs::StdRng};
13
14pub(crate) const N_VALUES_PER_THREAD: usize = 128;
15
16static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
17
18pub fn seed(seed: u64) {
19 let rng = StdRng::seed_from_u64(seed);
20 let mut seed = SEED.lock().unwrap();
21 *seed = Some(rng);
22}
23
24pub(crate) fn random<F: RandomFamily, R: Runtime>(
26 client: &ComputeClient<R::Server>,
27 prng: F::Runtime,
28 output: TensorHandleRef<'_, R>,
29 dtype: StorageType,
30) {
31 let seeds = get_seeds();
32 let args = prng.args();
33
34 let cube_dim = CubeDim::default();
35 let cube_count = prng_cube_count(output.size(), cube_dim, N_VALUES_PER_THREAD);
36
37 let output_line_size = 1;
38 let output = linear_view(client, &output, output_line_size);
48
49 prng_kernel::launch::<F, R>(
50 client,
51 cube_count,
52 cube_dim,
53 output,
54 ScalarArg::new(seeds[0]),
55 ScalarArg::new(seeds[1]),
56 ScalarArg::new(seeds[2]),
57 ScalarArg::new(seeds[3]),
58 args,
59 N_VALUES_PER_THREAD as u32,
60 output_line_size as u32,
61 dtype,
62 );
63}
64
65fn prng_cube_count(num_elems: usize, cube_dim: CubeDim, n_values_per_thread: usize) -> CubeCount {
66 let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
67 let num_invocations = f32::ceil(num_threads / cube_dim.num_elems() as f32);
68 let cubes_x = f32::ceil(f32::sqrt(num_invocations));
69 let cubes_y = f32::ceil(num_invocations / cubes_x);
70
71 CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
72}
73
74pub(crate) fn get_seeds() -> [u32; 4] {
75 let mut seed = SEED.lock().unwrap();
76 let mut rng: StdRng = match seed.as_ref() {
77 Some(rng_seeded) => rng_seeded.clone(),
78 None => get_seeded_rng(),
79 };
80 let mut seeds: Vec<u32> = Vec::with_capacity(4);
81 for _ in 0..4 {
82 seeds.push(rng.random());
83 }
84 *seed = Some(rng);
85
86 seeds.try_into().unwrap()
87}
88
89pub(crate) trait PrngArgs: Send + Sync + 'static {
90 type Args: LaunchArg;
91
92 fn args<'a, R: Runtime>(self) -> <Self::Args as LaunchArg>::RuntimeArg<'a, R>;
93}
94
95pub(crate) trait RandomFamily: Send + Sync + 'static + std::fmt::Debug {
96 type Runtime: PrngRuntime;
97}
98
99#[cube]
100pub(crate) trait PrngRuntime: Send + Sync + 'static + PrngArgs {
101 #[allow(clippy::too_many_arguments)]
102 fn inner_loop<E: Numeric>(
103 args: Self::Args,
104 write_index_base: u32,
105 n_invocations: u32,
106 #[comptime] n_values_per_thread: u32,
107 #[comptime] line_size: u32,
108 state_0: &mut u32,
109 state_1: &mut u32,
110 state_2: &mut u32,
111 state_3: &mut u32,
112 output: &mut View<Line<E>, Coords1d, ReadWrite>,
113 );
114}
115
116type Args<F> = <<F as RandomFamily>::Runtime as PrngArgs>::Args;
117
118#[cube(launch)]
119fn prng_kernel<F: RandomFamily, E: Numeric>(
120 output: &mut LinearView<Line<E>, ReadWrite>,
121 seed_0: u32,
122 seed_1: u32,
123 seed_2: u32,
124 seed_3: u32,
125 args: Args<F>,
126 #[comptime] n_values_per_thread: u32,
127 #[comptime] line_size: u32,
128 #[define(E)] _dtype: StorageType,
129) {
130 let cube_offset = CUBE_POS * CUBE_DIM;
131
132 let write_index_base = cube_offset * n_values_per_thread / line_size + UNIT_POS;
133
134 #[allow(arithmetic_overflow)]
135 let thread_seed = 1000000007u32 * ABSOLUTE_POS;
136
137 let mut state_0 = thread_seed + seed_0;
138 let mut state_1 = thread_seed + seed_1;
139 let mut state_2 = thread_seed + seed_2;
140 let mut state_3 = thread_seed + seed_3;
141
142 F::Runtime::inner_loop(
144 args,
145 write_index_base,
146 CUBE_DIM,
147 n_values_per_thread,
148 line_size,
149 &mut state_0,
150 &mut state_1,
151 &mut state_2,
152 &mut state_3,
153 output,
154 );
155}
156
157#[cube]
158pub(crate) fn taus_step_0(z: u32) -> u32 {
159 taus_step(z, 13u32, 19u32, 12u32, 4294967294u32)
160}
161
162#[cube]
163pub(crate) fn taus_step_1(z: u32) -> u32 {
164 taus_step(z, 2u32, 25u32, 4u32, 4294967288u32)
165}
166
167#[cube]
168pub(crate) fn taus_step_2(z: u32) -> u32 {
169 taus_step(z, 3u32, 11u32, 17u32, 4294967280u32)
170}
171
172#[cube]
173fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
174 let b = z << s1;
175 let b = b ^ z;
176 let b = b >> s2;
177 let z = (z & m) << s3;
178 z ^ b
179}
180
181#[cube]
182pub(crate) fn lcg_step(z: u32) -> u32 {
183 let a = 1664525u32;
184 let b = 1013904223u32;
185
186 z * a + b
187}
188
189#[cube]
192pub fn to_unit_interval_closed_open(int_random: u32) -> f32 {
193 let shifted = int_random >> 8;
196 f32::cast_from(shifted) / 16777216.0 }
198
199#[cube]
202pub fn to_unit_interval_open(int_random: u32) -> f32 {
203 let shifted = int_random >> 9;
205 (f32::cast_from(shifted) + 1.0) / 8388609.0 }