cubecl_random/
base.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use cubecl_common::{rand::get_seeded_rng, stub::Mutex};
5use cubecl_std::tensor::{
6    View,
7    layout::{
8        Coords1d,
9        linear::{LinearView, linear_view},
10    },
11};
12use rand::{Rng, SeedableRng, rngs::StdRng};
13
14pub(crate) const N_VALUES_PER_THREAD: usize = 128;
15
16static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
17
18pub fn seed(seed: u64) {
19    let rng = StdRng::seed_from_u64(seed);
20    let mut seed = SEED.lock().unwrap();
21    *seed = Some(rng);
22}
23
24/// Pseudo-random generator
25pub(crate) fn random<F: RandomFamily, R: Runtime>(
26    client: &ComputeClient<R::Server>,
27    prng: F::Runtime,
28    output: TensorHandleRef<'_, R>,
29    dtype: StorageType,
30) {
31    let seeds = get_seeds();
32    let args = prng.args();
33
34    let cube_dim = CubeDim::default();
35    let cube_count = prng_cube_count(output.size(), cube_dim, N_VALUES_PER_THREAD);
36
37    let output_line_size = 1;
38    // TODO: Higher vectorization can add some correlation locally.
39    //
40    // let output_line_size = tensor_line_size_parallel(
41    //     R::line_size_elem(&E::as_elem_native_unchecked()),
42    //     output.shape,
43    //     output.strides,
44    //     output.strides.len() - 1,
45    // );
46
47    let output = linear_view(client, &output, output_line_size);
48
49    prng_kernel::launch::<F, R>(
50        client,
51        cube_count,
52        cube_dim,
53        output,
54        ScalarArg::new(seeds[0]),
55        ScalarArg::new(seeds[1]),
56        ScalarArg::new(seeds[2]),
57        ScalarArg::new(seeds[3]),
58        args,
59        N_VALUES_PER_THREAD as u32,
60        output_line_size as u32,
61        dtype,
62    );
63}
64
65fn prng_cube_count(num_elems: usize, cube_dim: CubeDim, n_values_per_thread: usize) -> CubeCount {
66    let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
67    let num_invocations = f32::ceil(num_threads / cube_dim.num_elems() as f32);
68    let cubes_x = f32::ceil(f32::sqrt(num_invocations));
69    let cubes_y = f32::ceil(num_invocations / cubes_x);
70
71    CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
72}
73
74pub(crate) fn get_seeds() -> [u32; 4] {
75    let mut seed = SEED.lock().unwrap();
76    let mut rng: StdRng = match seed.as_ref() {
77        Some(rng_seeded) => rng_seeded.clone(),
78        None => get_seeded_rng(),
79    };
80    let mut seeds: Vec<u32> = Vec::with_capacity(4);
81    for _ in 0..4 {
82        seeds.push(rng.random());
83    }
84    *seed = Some(rng);
85
86    seeds.try_into().unwrap()
87}
88
89pub(crate) trait PrngArgs: Send + Sync + 'static {
90    type Args: LaunchArg;
91
92    fn args<'a, R: Runtime>(self) -> <Self::Args as LaunchArg>::RuntimeArg<'a, R>;
93}
94
95pub(crate) trait RandomFamily: Send + Sync + 'static + std::fmt::Debug {
96    type Runtime: PrngRuntime;
97}
98
99#[cube]
100pub(crate) trait PrngRuntime: Send + Sync + 'static + PrngArgs {
101    #[allow(clippy::too_many_arguments)]
102    fn inner_loop<E: Numeric>(
103        args: Self::Args,
104        write_index_base: u32,
105        n_invocations: u32,
106        #[comptime] n_values_per_thread: u32,
107        #[comptime] line_size: u32,
108        state_0: &mut u32,
109        state_1: &mut u32,
110        state_2: &mut u32,
111        state_3: &mut u32,
112        output: &mut View<Line<E>, Coords1d, ReadWrite>,
113    );
114}
115
116type Args<F> = <<F as RandomFamily>::Runtime as PrngArgs>::Args;
117
118#[cube(launch)]
119fn prng_kernel<F: RandomFamily, E: Numeric>(
120    output: &mut LinearView<Line<E>, ReadWrite>,
121    seed_0: u32,
122    seed_1: u32,
123    seed_2: u32,
124    seed_3: u32,
125    args: Args<F>,
126    #[comptime] n_values_per_thread: u32,
127    #[comptime] line_size: u32,
128    #[define(E)] _dtype: StorageType,
129) {
130    let cube_offset = CUBE_POS * CUBE_DIM;
131
132    let write_index_base = cube_offset * n_values_per_thread / line_size + UNIT_POS;
133
134    #[allow(arithmetic_overflow)]
135    let thread_seed = 1000000007u32 * ABSOLUTE_POS;
136
137    let mut state_0 = thread_seed + seed_0;
138    let mut state_1 = thread_seed + seed_1;
139    let mut state_2 = thread_seed + seed_2;
140    let mut state_3 = thread_seed + seed_3;
141
142    // Creation of n_values_per_thread values, specific to the distribution
143    F::Runtime::inner_loop(
144        args,
145        write_index_base,
146        CUBE_DIM,
147        n_values_per_thread,
148        line_size,
149        &mut state_0,
150        &mut state_1,
151        &mut state_2,
152        &mut state_3,
153        output,
154    );
155}
156
157#[cube]
158pub(crate) fn taus_step_0(z: u32) -> u32 {
159    taus_step(z, 13u32, 19u32, 12u32, 4294967294u32)
160}
161
162#[cube]
163pub(crate) fn taus_step_1(z: u32) -> u32 {
164    taus_step(z, 2u32, 25u32, 4u32, 4294967288u32)
165}
166
167#[cube]
168pub(crate) fn taus_step_2(z: u32) -> u32 {
169    taus_step(z, 3u32, 11u32, 17u32, 4294967280u32)
170}
171
172#[cube]
173fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
174    let b = z << s1;
175    let b = b ^ z;
176    let b = b >> s2;
177    let z = (z & m) << s3;
178    z ^ b
179}
180
181#[cube]
182pub(crate) fn lcg_step(z: u32) -> u32 {
183    let a = 1664525u32;
184    let b = 1013904223u32;
185
186    z * a + b
187}
188
189/// Converts a `u32` into a `f32` in the unit interval `[0.0, 1.0)`.
190/// Used for generating random floats.
191#[cube]
192pub fn to_unit_interval_closed_open(int_random: u32) -> f32 {
193    // Use upper 24 bits for f32 precision
194    // https://lemire.me/blog/2017/02/28/how-many-floating-point-numbers-are-in-the-interval-01/
195    let shifted = int_random >> 8;
196    f32::cast_from(shifted) / 16777216.0 // 2^24
197}
198
199/// Converts a `u32` into a `f32` in the unit interval `(0.0, 1.0)`.
200/// Used for generating random floats.
201#[cube]
202pub fn to_unit_interval_open(int_random: u32) -> f32 {
203    // Use upper 23 bits to leave room for the offset
204    let shifted = int_random >> 9;
205    (f32::cast_from(shifted) + 1.0) / 8388609.0 // 2^23 + 1
206}