cubecl_random/
bernoulli.rs

1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5
6use cubecl::{CubeType, Runtime};
7use cubecl_std::tensor::View;
8
9use crate::RandomFamily;
10
11use super::{
12    PrngArgs, PrngRuntime, lcg_step, random, taus_step_0, taus_step_1, taus_step_2,
13    to_unit_interval_closed_open,
14};
15
16#[derive(CubeLaunch, CubeType)]
17pub(crate) struct Bernoulli<E: Numeric> {
18    probability: f32,
19    #[cube(comptime)]
20    _phantom: PhantomData<E>,
21}
22
23#[derive(Debug)]
24struct BernoulliFamily;
25
26impl RandomFamily for BernoulliFamily {
27    type Runtime<E: Numeric> = Bernoulli<E>;
28}
29
30#[cube]
31impl<E: Numeric> PrngRuntime<E> for Bernoulli<E> {
32    fn inner_loop(
33        args: Bernoulli<E>,
34        write_index_base: u32,
35        n_invocations: u32,
36        #[comptime] n_values_per_thread: u32,
37        #[comptime] line_size: u32,
38        state_0: &mut u32,
39        state_1: &mut u32,
40        state_2: &mut u32,
41        state_3: &mut u32,
42        output: &mut View<Line<E>, u32, ReadWrite>,
43    ) {
44        let prob = args.probability;
45
46        let mut output_line = Line::empty(line_size);
47
48        let num_iterations = n_values_per_thread / line_size;
49        #[unroll(num_iterations <=8)]
50        for line_index in 0..num_iterations {
51            // vectorization
52            #[unroll]
53            for i in 0..line_size {
54                *state_0 = taus_step_0(*state_0);
55                *state_1 = taus_step_1(*state_1);
56                *state_2 = taus_step_2(*state_2);
57                *state_3 = lcg_step(*state_3);
58
59                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
60                let float_random = to_unit_interval_closed_open(int_random);
61                output_line[i] = E::cast_from(float_random < prob);
62            }
63            let write_index = line_index * n_invocations + write_index_base;
64
65            output[write_index] = output_line;
66        }
67    }
68}
69
70impl<E: Numeric> PrngArgs<E> for Bernoulli<E> {
71    type Args = Self;
72
73    fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, E, R> {
74        BernoulliLaunch::new(ScalarArg::new(self.probability), PhantomData::<E>)
75    }
76}
77
78/// Pseudo-random generator with bernoulli distribution
79pub fn random_bernoulli<R: Runtime, E: Numeric>(
80    client: &ComputeClient<R::Server>,
81    probability: f32,
82    out: TensorHandleRef<R>,
83) {
84    assert_eq!(
85        out.elem_size as u32,
86        E::elem_size(),
87        "Tensor element type must be the same as type E"
88    );
89
90    random::<BernoulliFamily, E, R>(
91        client,
92        Bernoulli::<E> {
93            probability,
94            _phantom: PhantomData,
95        },
96        out,
97    )
98}