burn_jit/kernel/prng/
base.rs1use cubecl::prelude::*;
2
3use crate::{ops::numeric::empty_device, tensor::JitTensor, JitElement, JitRuntime, SEED};
4use burn_common::rand::get_seeded_rng;
5use burn_tensor::Shape;
6use rand::Rng;
7
8pub(crate) const N_VALUES_PER_THREAD: usize = 128;
9
10pub(crate) fn random<P: PrngRuntime<E>, R: JitRuntime, E: JitElement>(
12 shape: Shape,
13 device: &R::Device,
14 prng: P,
15) -> JitTensor<R> {
16 let client = R::client(device);
17 let output = empty_device::<R, E>(client.clone(), device.clone(), shape);
18 let seeds = get_seeds();
19 let args = prng.args();
20
21 let cube_dim = CubeDim::default();
22 let cube_count = prng_cube_count(output.shape.num_elements(), cube_dim, N_VALUES_PER_THREAD);
23
24 prng_kernel::launch::<P, E, R>(
25 &client,
26 cube_count,
27 cube_dim,
28 output.as_tensor_arg::<E>(1),
29 ScalarArg::new(seeds[0]),
30 ScalarArg::new(seeds[1]),
31 ScalarArg::new(seeds[2]),
32 ScalarArg::new(seeds[3]),
33 args,
34 N_VALUES_PER_THREAD as u32,
35 );
36
37 output
38}
39
40fn prng_cube_count(num_elems: usize, cube_dim: CubeDim, n_values_per_thread: usize) -> CubeCount {
41 let num_threads = f32::ceil(num_elems as f32 / n_values_per_thread as f32);
42 let num_invocations = f32::ceil(num_threads / cube_dim.num_elems() as f32);
43 let cubes_x = f32::ceil(f32::sqrt(num_invocations));
44 let cubes_y = f32::ceil(num_invocations / cubes_x);
45
46 CubeCount::Static(cubes_x as u32, cubes_y as u32, 1)
47}
48
49pub(crate) fn get_seeds() -> [u32; 4] {
50 let mut seed = SEED.lock().unwrap();
51 let mut rng = match seed.as_ref() {
52 Some(rng_seeded) => rng_seeded.clone(),
53 None => get_seeded_rng(),
54 };
55 let mut seeds: Vec<u32> = Vec::with_capacity(4);
56 for _ in 0..4 {
57 seeds.push(rng.gen());
58 }
59 *seed = Some(rng);
60
61 seeds.try_into().unwrap()
62}
63
64pub(crate) trait PrngArgs<E: JitElement>: Send + Sync + 'static {
65 type Args: LaunchArg;
66
67 fn args<'a, R: Runtime>(self) -> <Self::Args as LaunchArg>::RuntimeArg<'a, R>;
68}
69
70#[cube]
71pub(crate) trait PrngRuntime<E: JitElement>: Send + Sync + 'static + PrngArgs<E> {
72 #[allow(clippy::too_many_arguments)]
73 fn inner_loop(
74 args: Self::Args,
75 write_index_base: u32,
76 n_invocations: u32,
77 #[comptime] n_values_per_thread: u32,
78 state_0: &mut u32,
79 state_1: &mut u32,
80 state_2: &mut u32,
81 state_3: &mut u32,
82 output: &mut Tensor<E>,
83 );
84}
85
86#[cube(launch)]
87fn prng_kernel<P: PrngRuntime<E>, E: JitElement>(
88 output: &mut Tensor<E>,
89 seed_0: u32,
90 seed_1: u32,
91 seed_2: u32,
92 seed_3: u32,
93 args: P::Args,
94 #[comptime] n_values_per_thread: u32,
95) {
96 let cube_offset = CUBE_POS * CUBE_DIM;
97
98 let write_index_base = cube_offset * n_values_per_thread + UNIT_POS;
99
100 #[allow(arithmetic_overflow)]
101 let thread_seed = 1000000007u32 * ABSOLUTE_POS;
102
103 let mut state_0 = thread_seed + seed_0;
104 let mut state_1 = thread_seed + seed_1;
105 let mut state_2 = thread_seed + seed_2;
106 let mut state_3 = thread_seed + seed_3;
107
108 P::inner_loop(
110 args,
111 write_index_base,
112 CUBE_DIM,
113 n_values_per_thread,
114 &mut state_0,
115 &mut state_1,
116 &mut state_2,
117 &mut state_3,
118 output,
119 );
120}
121
122#[cube]
123pub(crate) fn taus_step_0(z: u32) -> u32 {
124 taus_step(z, 13u32, 19u32, 12u32, 4294967294u32)
125}
126
127#[cube]
128pub(crate) fn taus_step_1(z: u32) -> u32 {
129 taus_step(z, 2u32, 25u32, 4u32, 4294967288u32)
130}
131
132#[cube]
133pub(crate) fn taus_step_2(z: u32) -> u32 {
134 taus_step(z, 3u32, 11u32, 17u32, 4294967280u32)
135}
136
137#[cube]
138fn taus_step(z: u32, s1: u32, s2: u32, s3: u32, m: u32) -> u32 {
139 let b = z << s1;
140 let b = b ^ z;
141 let b = b >> s2;
142 let z = (z & m) << s3;
143 z ^ b
144}
145
146#[cube]
147pub(crate) fn lcg_step(z: u32) -> u32 {
148 let a = 1664525u32;
149 let b = 1013904223u32;
150
151 z * a + b
152}
153
154#[cube]
155pub(crate) fn cast_uint_to_float(int_random: u32) -> f32 {
156 let tmp = 2.328_306_4e-10f32;
157 f32::cast_from(int_random) * tmp
158}
159
160#[allow(missing_docs)]
161pub mod tests_utils {
162 use burn_tensor::Element;
163
164 #[derive(Default, Copy, Clone)]
165 pub struct BinStats {
166 pub count: usize,
167 pub n_runs: usize, }
169
170 #[allow(unused)]
171 pub fn calculate_bin_stats<E: Element>(
172 numbers: &[E],
173 number_of_bins: usize,
174 low: f32,
175 high: f32,
176 ) -> Vec<BinStats> {
177 let range = (high - low) / number_of_bins as f32;
178 let mut output: Vec<BinStats> = (0..number_of_bins).map(|_| Default::default()).collect();
179 let mut initialized = false;
180 let mut current_runs = number_of_bins; for number in numbers {
182 let num = number.elem::<f32>();
183 if num < low || num > high {
184 continue;
185 }
186 let index = f32::floor((num - low) / range) as usize;
187 output[index].count += 1;
188 if initialized && index != current_runs {
189 output[current_runs].n_runs += 1;
190 }
191 initialized = true;
192 current_runs = index;
193 }
194 output[current_runs].n_runs += 1;
195 output
196 }
197}