cubecl_random/
base.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use cubecl_common::{rand::get_seeded_rng, stub::Mutex};
5use rand::{Rng, SeedableRng, rngs::StdRng};
6
7pub(crate) const N_VALUES_PER_THREAD: usize = 128;
8
9static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
10
11pub fn seed(seed: u64) {
12    let rng = StdRng::seed_from_u64(seed);
13    let mut seed = SEED.lock().unwrap();
14    *seed = Some(rng);
15}
16
17/// Pseudo-random generator
18pub(crate) fn random<F: RandomFamily, E: Numeric, R: Runtime>(
19    client: &ComputeClient<R::Server, R::Channel>,
20    prng: F::Runtime<E>,
21    output: TensorHandleRef<'_, R>,
22) {
23    let seeds = get_seeds();
24    let args = prng.args();
25
26    let cube_dim = CubeDim::default();
27    let cube_count = prng_cube_count(output.size(), cube_dim, N_VALUES_PER_THREAD);
28
29    let output_line_size = 1;
30    // TODO: Higher vectorization can add some correlation locally.
31    //
32    // let output_line_size = tensor_line_size_parallel(
33    //     R::line_size_elem(&E::as_elem_native_unchecked()),
34    //     output.shape,
35    //     output.strides,
36    //     output.strides.len() - 1,
37    // );
38
39    let output = output.as_tensor_arg(output_line_size);
40
41    prng_kernel::launch::<F, E, R>(
42        client,
43        cube_count,
44        cube_dim,
45        output,
46        ScalarArg::new(seeds[0]),
47        ScalarArg::new(seeds[1]),
48        ScalarArg::new(seeds[2]),
49        ScalarArg::new(seeds[3]),
50        args,
51        N_VALUES_PER_THREAD as u32,
52        output_line_size as u32,
53    );
54}
55
56fn prng_cube_count(num_elems: usize, cube_dim: CubeDim, n_values_per_thread: usize) -> CubeCount {
57    let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
58    let num_invocations = f32::ceil(num_threads / cube_dim.num_elems() as f32);
59    let cubes_x = f32::ceil(f32::sqrt(num_invocations));
60    let cubes_y = f32::ceil(num_invocations / cubes_x);
61
62    CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
63}
64
65pub(crate) fn get_seeds() -> [u32; 4] {
66    let mut seed = SEED.lock().unwrap();
67    let mut rng: StdRng = match seed.as_ref() {
68        Some(rng_seeded) => rng_seeded.clone(),
69        None => get_seeded_rng(),
70    };
71    let mut seeds: Vec<u32> = Vec::with_capacity(4);
72    for _ in 0..4 {
73        seeds.push(rng.random());
74    }
75    *seed = Some(rng);
76
77    seeds.try_into().unwrap()
78}
79
80pub(crate) trait PrngArgs<E: Numeric>: Send + Sync + 'static {
81    type Args: LaunchArg;
82
83    fn args<'a, R: Runtime>(self) -> <Self::Args as LaunchArg>::RuntimeArg<'a, R>;
84}
85
86pub(crate) trait RandomFamily: Send + Sync + 'static + std::fmt::Debug {
87    type Runtime<E: Numeric>: PrngRuntime<E>;
88}
89
90#[cube]
91pub(crate) trait PrngRuntime<E: Numeric>: Send + Sync + 'static + PrngArgs<E> {
92    #[allow(clippy::too_many_arguments)]
93    fn inner_loop(
94        args: Self::Args,
95        write_index_base: u32,
96        n_invocations: u32,
97        #[comptime] n_values_per_thread: u32,
98        #[comptime] line_size: u32,
99        state_0: &mut u32,
100        state_1: &mut u32,
101        state_2: &mut u32,
102        state_3: &mut u32,
103        output: &mut Tensor<Line<E>>,
104    );
105}
106
107type Args<F, E> = <<F as RandomFamily>::Runtime<E> as PrngArgs<E>>::Args;
108
109#[cube(launch)]
110fn prng_kernel<F: RandomFamily, E: Numeric>(
111    output: &mut Tensor<Line<E>>,
112    seed_0: u32,
113    seed_1: u32,
114    seed_2: u32,
115    seed_3: u32,
116    args: Args<F, E>,
117    #[comptime] n_values_per_thread: u32,
118    #[comptime] line_size: u32,
119) {
120    let cube_offset = CUBE_POS * CUBE_DIM;
121
122    let write_index_base = cube_offset * n_values_per_thread / line_size + UNIT_POS;
123
124    #[allow(arithmetic_overflow)]
125    let thread_seed = 1000000007u32 * ABSOLUTE_POS;
126
127    let mut state_0 = thread_seed + seed_0;
128    let mut state_1 = thread_seed + seed_1;
129    let mut state_2 = thread_seed + seed_2;
130    let mut state_3 = thread_seed + seed_3;
131
132    // Creation of n_values_per_thread values, specific to the distribution
133    F::Runtime::inner_loop(
134        args,
135        write_index_base,
136        CUBE_DIM,
137        n_values_per_thread,
138        line_size,
139        &mut state_0,
140        &mut state_1,
141        &mut state_2,
142        &mut state_3,
143        output,
144    );
145}
146
147#[cube]
148pub(crate) fn taus_step_0(z: u32) -> u32 {
149    taus_step(z, 13u32, 19u32, 12u32, 4294967294u32)
150}
151
152#[cube]
153pub(crate) fn taus_step_1(z: u32) -> u32 {
154    taus_step(z, 2u32, 25u32, 4u32, 4294967288u32)
155}
156
157#[cube]
158pub(crate) fn taus_step_2(z: u32) -> u32 {
159    taus_step(z, 3u32, 11u32, 17u32, 4294967280u32)
160}
161
162#[cube]
163fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
164    let b = z << s1;
165    let b = b ^ z;
166    let b = b >> s2;
167    let z = (z & m) << s3;
168    z ^ b
169}
170
171#[cube]
172pub(crate) fn lcg_step(z: u32) -> u32 {
173    let a = 1664525u32;
174    let b = 1013904223u32;
175
176    z * a + b
177}
178
179/// Converts a `u32` into a `f32` in the unit interval `[0.0, 1.0)`.
180/// Used for generating random floats.
181#[cube]
182pub fn to_unit_interval_closed_open(int_random: u32) -> f32 {
183    // Use upper 24 bits for f32 precision
184    // https://lemire.me/blog/2017/02/28/how-many-floating-point-numbers-are-in-the-interval-01/
185    let shifted = int_random >> 8;
186    f32::cast_from(shifted) / 16777216.0 // 2^24
187}
188
189/// Converts a `u32` into a `f32` in the unit interval `(0.0, 1.0)`.
190/// Used for generating random floats.
191#[cube]
192pub fn to_unit_interval_open(int_random: u32) -> f32 {
193    // Use upper 23 bits to leave room for the offset
194    let shifted = int_random >> 9;
195    (f32::cast_from(shifted) + 1.0) / 8388609.0 // 2^23 + 1
196}