cubek_random/
bernoulli.rs

1use cubecl::prelude::*;
2use cubecl::std::tensor::View;
3use cubecl::{CubeType, Runtime};
4
5use crate::RandomFamily;
6
7use super::{
8    PrngArgs, PrngRuntime, lcg_step, random, taus_step_0, taus_step_1, taus_step_2,
9    to_unit_interval_closed_open,
10};
11
12#[derive(CubeLaunch, CubeType)]
13pub(crate) struct Bernoulli {
14    probability: f32,
15}
16
17#[derive(Debug)]
18struct BernoulliFamily;
19
20impl RandomFamily for BernoulliFamily {
21    type Runtime = Bernoulli;
22}
23
24#[cube]
25impl PrngRuntime for Bernoulli {
26    fn inner_loop<E: Numeric>(
27        args: Bernoulli,
28        write_index_base: usize,
29        n_invocations: u32,
30        #[comptime] n_values_per_thread: usize,
31        #[comptime] line_size: LineSize,
32        state_0: &mut u32,
33        state_1: &mut u32,
34        state_2: &mut u32,
35        state_3: &mut u32,
36        output: &mut View<Line<E>, usize, ReadWrite>,
37    ) {
38        let prob = args.probability;
39
40        let mut output_line = Line::empty(line_size);
41
42        let num_iterations = n_values_per_thread / line_size;
43        #[unroll(num_iterations <=8)]
44        for line_index in 0..num_iterations {
45            // vectorization
46            #[unroll]
47            for i in 0..line_size {
48                *state_0 = taus_step_0(*state_0);
49                *state_1 = taus_step_1(*state_1);
50                *state_2 = taus_step_2(*state_2);
51                *state_3 = lcg_step(*state_3);
52
53                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
54                let float_random = to_unit_interval_closed_open(int_random);
55                output_line[i] = E::cast_from(float_random < prob);
56            }
57            let write_index = line_index * n_invocations as usize + write_index_base;
58
59            output[write_index] = output_line;
60        }
61    }
62}
63
64impl PrngArgs for Bernoulli {
65    type Args = Self;
66
67    fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, R> {
68        BernoulliLaunch::new(ScalarArg::new(self.probability))
69    }
70}
71
72/// Pseudo-random generator with bernoulli distribution
73pub fn random_bernoulli<R: Runtime>(
74    client: &ComputeClient<R>,
75    probability: f32,
76    out: TensorHandleRef<R>,
77    dtype: StorageType,
78) -> Result<(), LaunchError> {
79    assert_eq!(
80        out.elem_size,
81        dtype.size(),
82        "Tensor element type must be the same as type E"
83    );
84
85    random::<BernoulliFamily, R>(client, Bernoulli { probability }, out, dtype)
86}