cubek_random/
base.rs

1use cubecl::prelude::*;
2use cubecl::std::tensor::{
3    View,
4    layout::{
5        Coords1d,
6        linear::{LinearView, linear_view},
7    },
8};
9use cubecl_common::{rand::get_seeded_rng, stub::Mutex};
10use rand::{Rng, SeedableRng, rngs::StdRng};
11
12pub(crate) const N_VALUES_PER_THREAD: usize = 128;
13
14static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
15
16pub fn seed(seed: u64) {
17    let rng = StdRng::seed_from_u64(seed);
18    let mut seed = SEED.lock().unwrap();
19    *seed = Some(rng);
20}
21
22/// Pseudo-random generator
23pub(crate) fn random<F: RandomFamily, R: Runtime>(
24    client: &ComputeClient<R>,
25    prng: F::Runtime,
26    output: TensorHandleRef<'_, R>,
27    dtype: StorageType,
28) -> Result<(), LaunchError> {
29    let seeds = get_seeds();
30    let args = prng.args();
31
32    let cube_dim = CubeDim::new(client, output.size().div_ceil(N_VALUES_PER_THREAD));
33    let cube_count = prng_cube_count(output.size(), cube_dim, N_VALUES_PER_THREAD);
34
35    let output_line_size = 1;
36    // TODO: Higher vectorization can add some correlation locally.
37    //
38    // let output_line_size = tensor_line_size_parallel(
39    //     R::line_size_elem(&E::as_elem_native_unchecked()),
40    //     output.shape,
41    //     output.strides,
42    //     output.strides.len() - 1,
43    // );
44
45    let output = linear_view(client, &output, output_line_size);
46
47    prng_kernel::launch::<F, R>(
48        client,
49        cube_count,
50        cube_dim,
51        output,
52        ScalarArg::new(seeds[0]),
53        ScalarArg::new(seeds[1]),
54        ScalarArg::new(seeds[2]),
55        ScalarArg::new(seeds[3]),
56        args,
57        N_VALUES_PER_THREAD,
58        output_line_size,
59        dtype,
60    )
61}
62
63fn prng_cube_count(num_elems: usize, cube_dim: CubeDim, n_values_per_thread: usize) -> CubeCount {
64    let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
65    let num_invocations = f32::ceil(num_threads / cube_dim.num_elems() as f32);
66    let cubes_x = f32::ceil(f32::sqrt(num_invocations));
67    let cubes_y = f32::ceil(num_invocations / cubes_x);
68
69    CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
70}
71
72pub(crate) fn get_seeds() -> [u32; 4] {
73    let mut seed = SEED.lock().unwrap();
74    let mut rng: StdRng = match seed.as_ref() {
75        Some(rng_seeded) => rng_seeded.clone(),
76        None => get_seeded_rng(),
77    };
78    let mut seeds: Vec<u32> = Vec::with_capacity(4);
79    for _ in 0..4 {
80        seeds.push(rng.random());
81    }
82    *seed = Some(rng);
83
84    seeds.try_into().unwrap()
85}
86
87pub(crate) trait PrngArgs: Send + Sync + 'static {
88    type Args: LaunchArg;
89
90    fn args<'a, R: Runtime>(self) -> <Self::Args as LaunchArg>::RuntimeArg<'a, R>;
91}
92
93pub(crate) trait RandomFamily: Send + Sync + 'static + std::fmt::Debug {
94    type Runtime: PrngRuntime;
95}
96
97#[cube]
98pub(crate) trait PrngRuntime: Send + Sync + 'static + PrngArgs {
99    #[allow(clippy::too_many_arguments)]
100    fn inner_loop<E: Numeric>(
101        args: Self::Args,
102        write_index_base: usize,
103        n_invocations: u32,
104        #[comptime] n_values_per_thread: usize,
105        #[comptime] line_size: usize,
106        state_0: &mut u32,
107        state_1: &mut u32,
108        state_2: &mut u32,
109        state_3: &mut u32,
110        output: &mut View<Line<E>, Coords1d, ReadWrite>,
111    );
112}
113
114type Args<F> = <<F as RandomFamily>::Runtime as PrngArgs>::Args;
115
116#[cube(launch)]
117fn prng_kernel<F: RandomFamily, E: Numeric>(
118    output: &mut LinearView<Line<E>, ReadWrite>,
119    seed_0: u32,
120    seed_1: u32,
121    seed_2: u32,
122    seed_3: u32,
123    args: Args<F>,
124    #[comptime] n_values_per_thread: usize,
125    #[comptime] line_size: usize,
126    #[define(E)] _dtype: StorageType,
127) {
128    let cube_offset = CUBE_POS * CUBE_DIM as usize;
129
130    let write_index_base = cube_offset * n_values_per_thread / line_size + UNIT_POS as usize;
131
132    // Truncating position should be fine here, it's no issue if the seed repeats
133    #[allow(arithmetic_overflow)]
134    let thread_seed = 1000000007u32 * ABSOLUTE_POS as u32;
135
136    let mut state_0 = thread_seed + seed_0;
137    let mut state_1 = thread_seed + seed_1;
138    let mut state_2 = thread_seed + seed_2;
139    let mut state_3 = thread_seed + seed_3;
140
141    // Creation of n_values_per_thread values, specific to the distribution
142    F::Runtime::inner_loop(
143        args,
144        write_index_base,
145        CUBE_DIM,
146        n_values_per_thread,
147        line_size,
148        &mut state_0,
149        &mut state_1,
150        &mut state_2,
151        &mut state_3,
152        output,
153    );
154}
155
156#[cube]
157pub(crate) fn taus_step_0(z: u32) -> u32 {
158    taus_step(z, 13u32, 19u32, 12u32, 4294967294u32)
159}
160
161#[cube]
162pub(crate) fn taus_step_1(z: u32) -> u32 {
163    taus_step(z, 2u32, 25u32, 4u32, 4294967288u32)
164}
165
166#[cube]
167pub(crate) fn taus_step_2(z: u32) -> u32 {
168    taus_step(z, 3u32, 11u32, 17u32, 4294967280u32)
169}
170
171#[cube]
172fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
173    let b = z << s1;
174    let b = b ^ z;
175    let b = b >> s2;
176    let z = (z & m) << s3;
177    z ^ b
178}
179
180#[cube]
181pub(crate) fn lcg_step(z: u32) -> u32 {
182    let a = 1664525u32;
183    let b = 1013904223u32;
184
185    z * a + b
186}
187
188/// Converts a `u32` into a `f32` in the unit interval `[0.0, 1.0)`.
189/// Used for generating random floats.
190#[cube]
191pub fn to_unit_interval_closed_open(int_random: u32) -> f32 {
192    // Use upper 24 bits for f32 precision
193    // https://lemire.me/blog/2017/02/28/how-many-floating-point-numbers-are-in-the-interval-01/
194    let shifted = int_random >> 8;
195    f32::cast_from(shifted) / 16777216.0 // 2^24
196}
197
198/// Converts a `u32` into a `f32` in the unit interval `(0.0, 1.0)`.
199/// Used for generating random floats.
200#[cube]
201pub fn to_unit_interval_open(int_random: u32) -> f32 {
202    // Use upper 23 bits to leave room for the offset
203    let shifted = int_random >> 9;
204    (f32::cast_from(shifted) + 1.0) / 8388609.0 // 2^23 + 1
205}