cubecl_random/
bernoulli.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5
6use cubecl::{CubeLaunch, CubeType, Runtime};
7
8use crate::RandomFamily;
9
10use super::{
11    PrngArgs, PrngRuntime, lcg_step, random, taus_step_0, taus_step_1, taus_step_2,
12    to_unit_interval_closed_open,
13};
14
15#[derive(CubeLaunch, CubeType)]
16pub(crate) struct Bernoulli<E: Numeric> {
17    probability: f32,
18    #[cube(comptime)]
19    _phantom: PhantomData<E>,
20}
21
22#[derive(Debug)]
23struct BernoulliFamily;
24
25impl RandomFamily for BernoulliFamily {
26    type Runtime<E: Numeric> = Bernoulli<E>;
27}
28
29#[cube]
30impl<E: Numeric> PrngRuntime<E> for Bernoulli<E> {
31    fn inner_loop(
32        args: Bernoulli<E>,
33        write_index_base: u32,
34        n_invocations: u32,
35        #[comptime] n_values_per_thread: u32,
36        #[comptime] line_size: u32,
37        state_0: &mut u32,
38        state_1: &mut u32,
39        state_2: &mut u32,
40        state_3: &mut u32,
41        output: &mut Tensor<Line<E>>,
42    ) {
43        let prob = args.probability;
44
45        let mut output_line = Line::empty(line_size);
46
47        let num_iterations = n_values_per_thread / line_size;
48        #[unroll(num_iterations <=8)]
49        for line_index in 0..num_iterations {
50            // vectorization
51            #[unroll]
52            for i in 0..line_size {
53                *state_0 = taus_step_0(*state_0);
54                *state_1 = taus_step_1(*state_1);
55                *state_2 = taus_step_2(*state_2);
56                *state_3 = lcg_step(*state_3);
57
58                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
59                let float_random = to_unit_interval_closed_open(int_random);
60                output_line[i] = E::cast_from(float_random < prob);
61            }
62            let write_index = line_index * n_invocations + write_index_base;
63
64            output[write_index] = output_line;
65        }
66    }
67}
68
69impl<E: Numeric> PrngArgs<E> for Bernoulli<E> {
70    type Args = Self;
71
72    fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, E, R> {
73        BernoulliLaunch::new(ScalarArg::new(self.probability), &PhantomData::<E>)
74    }
75}
76
77/// Pseudo-random generator with bernoulli distribution
78pub fn random_bernoulli<R: Runtime, E: Numeric>(
79    client: &ComputeClient<R::Server, R::Channel>,
80    probability: f32,
81    out: TensorHandleRef<R>,
82) {
83    assert_eq!(
84        out.elem_size as u32,
85        E::elem_size(),
86        "Tensor element type must be the same as type E"
87    );
88
89    random::<BernoulliFamily, E, R>(
90        client,
91        Bernoulli::<E> {
92            probability,
93            _phantom: PhantomData,
94        },
95        out,
96    )
97}