burn_jit/kernel/prng/
base.rs

1use cubecl::prelude::*;
2
3use crate::{ops::numeric::empty_device, tensor::JitTensor, JitElement, JitRuntime, SEED};
4use burn_common::rand::get_seeded_rng;
5use burn_tensor::Shape;
6use rand::Rng;
7
8pub(crate) const N_VALUES_PER_THREAD: usize = 128;
9
10/// Pseudo-random generator
11pub(crate) fn random<P: PrngRuntime<E>, R: JitRuntime, E: JitElement>(
12    shape: Shape,
13    device: &R::Device,
14    prng: P,
15) -> JitTensor<R> {
16    let client = R::client(device);
17    let output = empty_device::<R, E>(client.clone(), device.clone(), shape);
18    let seeds = get_seeds();
19    let args = prng.args();
20
21    let cube_dim = CubeDim::default();
22    let cube_count = prng_cube_count(output.shape.num_elements(), cube_dim, N_VALUES_PER_THREAD);
23
24    prng_kernel::launch::<P, E, R>(
25        &client,
26        cube_count,
27        cube_dim,
28        output.as_tensor_arg::<E>(1),
29        ScalarArg::new(seeds[0]),
30        ScalarArg::new(seeds[1]),
31        ScalarArg::new(seeds[2]),
32        ScalarArg::new(seeds[3]),
33        args,
34        N_VALUES_PER_THREAD as u32,
35    );
36
37    output
38}
39
40fn prng_cube_count(num_elems: usize, cube_dim: CubeDim, n_values_per_thread: usize) -> CubeCount {
41    let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
42    let num_invocations = f32::ceil(num_threads / cube_dim.num_elems() as f32);
43    let cubes_x = f32::ceil(f32::sqrt(num_invocations));
44    let cubes_y = f32::ceil(num_invocations / cubes_x);
45
46    CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
47}
48
49pub(crate) fn get_seeds() -> [u32; 4] {
50    let mut seed = SEED.lock().unwrap();
51    let mut rng = match seed.as_ref() {
52        Some(rng_seeded) => rng_seeded.clone(),
53        None => get_seeded_rng(),
54    };
55    let mut seeds: Vec<u32> = Vec::with_capacity(4);
56    for _ in 0..4 {
57        seeds.push(rng.gen());
58    }
59    *seed = Some(rng);
60
61    seeds.try_into().unwrap()
62}
63
64pub(crate) trait PrngArgs<E: JitElement>: Send + Sync + 'static {
65    type Args: LaunchArg;
66
67    fn args<'a, R: Runtime>(self) -> <Self::Args as LaunchArg>::RuntimeArg<'a, R>;
68}
69
70#[cube]
71pub(crate) trait PrngRuntime<E: JitElement>: Send + Sync + 'static + PrngArgs<E> {
72    #[allow(clippy::too_many_arguments)]
73    fn inner_loop(
74        args: Self::Args,
75        write_index_base: u32,
76        n_invocations: u32,
77        #[comptime] n_values_per_thread: u32,
78        state_0: &mut u32,
79        state_1: &mut u32,
80        state_2: &mut u32,
81        state_3: &mut u32,
82        output: &mut Tensor<E>,
83    );
84}
85
86#[cube(launch)]
87fn prng_kernel<P: PrngRuntime<E>, E: JitElement>(
88    output: &mut Tensor<E>,
89    seed_0: u32,
90    seed_1: u32,
91    seed_2: u32,
92    seed_3: u32,
93    args: P::Args,
94    #[comptime] n_values_per_thread: u32,
95) {
96    let cube_offset = CUBE_POS * CUBE_DIM;
97
98    let write_index_base = cube_offset * n_values_per_thread + UNIT_POS;
99
100    #[allow(arithmetic_overflow)]
101    let thread_seed = 1000000007u32 * ABSOLUTE_POS;
102
103    let mut state_0 = thread_seed + seed_0;
104    let mut state_1 = thread_seed + seed_1;
105    let mut state_2 = thread_seed + seed_2;
106    let mut state_3 = thread_seed + seed_3;
107
108    // Creation of n_values_per_thread values, specific to the distribution
109    P::inner_loop(
110        args,
111        write_index_base,
112        CUBE_DIM,
113        n_values_per_thread,
114        &mut state_0,
115        &mut state_1,
116        &mut state_2,
117        &mut state_3,
118        output,
119    );
120}
121
122#[cube]
123pub(crate) fn taus_step_0(z: u32) -> u32 {
124    taus_step(z, 13u32, 19u32, 12u32, 4294967294u32)
125}
126
127#[cube]
128pub(crate) fn taus_step_1(z: u32) -> u32 {
129    taus_step(z, 2u32, 25u32, 4u32, 4294967288u32)
130}
131
132#[cube]
133pub(crate) fn taus_step_2(z: u32) -> u32 {
134    taus_step(z, 3u32, 11u32, 17u32, 4294967280u32)
135}
136
137#[cube]
138fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
139    let b = z << s1;
140    let b = b ^ z;
141    let b = b >> s2;
142    let z = (z & m) << s3;
143    z ^ b
144}
145
146#[cube]
147pub(crate) fn lcg_step(z: u32) -> u32 {
148    let a = 1664525u32;
149    let b = 1013904223u32;
150
151    z * a + b
152}
153
154#[cube]
155pub(crate) fn cast_uint_to_float(int_random: u32) -> f32 {
156    let tmp = 2.328_306_4e-10f32;
157    f32::cast_from(int_random) * tmp
158}
159
160#[allow(missing_docs)]
161pub mod tests_utils {
162    use burn_tensor::Element;
163
164    #[derive(Default, Copy, Clone)]
165    pub struct BinStats {
166        pub count: usize,
167        pub n_runs: usize, // Number of sequences of same bin
168    }
169
170    #[allow(unused)]
171    pub fn calculate_bin_stats<E: Element>(
172        numbers: &[E],
173        number_of_bins: usize,
174        low: f32,
175        high: f32,
176    ) -> Vec<BinStats> {
177        let range = (high - low) / number_of_bins as f32;
178        let mut output: Vec<BinStats> = (0..number_of_bins).map(|_| Default::default()).collect();
179        let mut initialized = false;
180        let mut current_runs = number_of_bins; // impossible value for starting point
181        for number in numbers {
182            let num = number.elem::<f32>();
183            if num < low || num > high {
184                continue;
185            }
186            let index = f32::floor((num - low) / range) as usize;
187            output[index].count += 1;
188            if initialized && index != current_runs {
189                output[current_runs].n_runs += 1;
190            }
191            initialized = true;
192            current_runs = index;
193        }
194        output[current_runs].n_runs += 1;
195        output
196    }
197}