cubecl_random/
bernoulli.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use cubecl::{CubeType, Runtime};
5use cubecl_std::tensor::View;
6
7use crate::RandomFamily;
8
9use super::{
10 PrngArgs, PrngRuntime, lcg_step, random, taus_step_0, taus_step_1, taus_step_2,
11 to_unit_interval_closed_open,
12};
13
14#[derive(CubeLaunch, CubeType)]
15pub(crate) struct Bernoulli {
16 probability: f32,
17}
18
19#[derive(Debug)]
20struct BernoulliFamily;
21
22impl RandomFamily for BernoulliFamily {
23 type Runtime = Bernoulli;
24}
25
26#[cube]
27impl PrngRuntime for Bernoulli {
28 fn inner_loop<E: Numeric>(
29 args: Bernoulli,
30 write_index_base: u32,
31 n_invocations: u32,
32 #[comptime] n_values_per_thread: u32,
33 #[comptime] line_size: u32,
34 state_0: &mut u32,
35 state_1: &mut u32,
36 state_2: &mut u32,
37 state_3: &mut u32,
38 output: &mut View<Line<E>, u32, ReadWrite>,
39 ) {
40 let prob = args.probability;
41
42 let mut output_line = Line::empty(line_size);
43
44 let num_iterations = n_values_per_thread / line_size;
45 #[unroll(num_iterations <=8)]
46 for line_index in 0..num_iterations {
47 #[unroll]
49 for i in 0..line_size {
50 *state_0 = taus_step_0(*state_0);
51 *state_1 = taus_step_1(*state_1);
52 *state_2 = taus_step_2(*state_2);
53 *state_3 = lcg_step(*state_3);
54
55 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56 let float_random = to_unit_interval_closed_open(int_random);
57 output_line[i] = E::cast_from(float_random < prob);
58 }
59 let write_index = line_index * n_invocations + write_index_base;
60
61 output[write_index] = output_line;
62 }
63 }
64}
65
66impl PrngArgs for Bernoulli {
67 type Args = Self;
68
69 fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, R> {
70 BernoulliLaunch::new(ScalarArg::new(self.probability))
71 }
72}
73
74pub fn random_bernoulli<R: Runtime>(
76 client: &ComputeClient<R::Server>,
77 probability: f32,
78 out: TensorHandleRef<R>,
79 dtype: StorageType,
80) {
81 assert_eq!(
82 out.elem_size,
83 dtype.size(),
84 "Tensor element type must be the same as type E"
85 );
86
87 random::<BernoulliFamily, R>(client, Bernoulli { probability }, out, dtype)
88}