cubecl_random/
bernoulli.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use cubecl::{CubeType, Runtime};
5use cubecl_std::tensor::View;
6
7use crate::RandomFamily;
8
9use super::{
10    PrngArgs, PrngRuntime, lcg_step, random, taus_step_0, taus_step_1, taus_step_2,
11    to_unit_interval_closed_open,
12};
13
14#[derive(CubeLaunch, CubeType)]
15pub(crate) struct Bernoulli {
16    probability: f32,
17}
18
19#[derive(Debug)]
20struct BernoulliFamily;
21
22impl RandomFamily for BernoulliFamily {
23    type Runtime = Bernoulli;
24}
25
26#[cube]
27impl PrngRuntime for Bernoulli {
28    fn inner_loop<E: Numeric>(
29        args: Bernoulli,
30        write_index_base: u32,
31        n_invocations: u32,
32        #[comptime] n_values_per_thread: u32,
33        #[comptime] line_size: u32,
34        state_0: &mut u32,
35        state_1: &mut u32,
36        state_2: &mut u32,
37        state_3: &mut u32,
38        output: &mut View<Line<E>, u32, ReadWrite>,
39    ) {
40        let prob = args.probability;
41
42        let mut output_line = Line::empty(line_size);
43
44        let num_iterations = n_values_per_thread / line_size;
45        #[unroll(num_iterations <=8)]
46        for line_index in 0..num_iterations {
47            // vectorization
48            #[unroll]
49            for i in 0..line_size {
50                *state_0 = taus_step_0(*state_0);
51                *state_1 = taus_step_1(*state_1);
52                *state_2 = taus_step_2(*state_2);
53                *state_3 = lcg_step(*state_3);
54
55                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56                let float_random = to_unit_interval_closed_open(int_random);
57                output_line[i] = E::cast_from(float_random < prob);
58            }
59            let write_index = line_index * n_invocations + write_index_base;
60
61            output[write_index] = output_line;
62        }
63    }
64}
65
66impl PrngArgs for Bernoulli {
67    type Args = Self;
68
69    fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, R> {
70        BernoulliLaunch::new(ScalarArg::new(self.probability))
71    }
72}
73
74/// Pseudo-random generator with bernoulli distribution
75pub fn random_bernoulli<R: Runtime>(
76    client: &ComputeClient<R::Server>,
77    probability: f32,
78    out: TensorHandleRef<R>,
79    dtype: StorageType,
80) {
81    assert_eq!(
82        out.elem_size,
83        dtype.size(),
84        "Tensor element type must be the same as type E"
85    );
86
87    random::<BernoulliFamily, R>(client, Bernoulli { probability }, out, dtype)
88}