cubecl_random/
bernoulli.rs1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5
6use cubecl::{CubeType, Runtime};
7use cubecl_std::tensor::View;
8
9use crate::RandomFamily;
10
11use super::{
12 PrngArgs, PrngRuntime, lcg_step, random, taus_step_0, taus_step_1, taus_step_2,
13 to_unit_interval_closed_open,
14};
15
16#[derive(CubeLaunch, CubeType)]
17pub(crate) struct Bernoulli<E: Numeric> {
18 probability: f32,
19 #[cube(comptime)]
20 _phantom: PhantomData<E>,
21}
22
23#[derive(Debug)]
24struct BernoulliFamily;
25
26impl RandomFamily for BernoulliFamily {
27 type Runtime<E: Numeric> = Bernoulli<E>;
28}
29
30#[cube]
31impl<E: Numeric> PrngRuntime<E> for Bernoulli<E> {
32 fn inner_loop(
33 args: Bernoulli<E>,
34 write_index_base: u32,
35 n_invocations: u32,
36 #[comptime] n_values_per_thread: u32,
37 #[comptime] line_size: u32,
38 state_0: &mut u32,
39 state_1: &mut u32,
40 state_2: &mut u32,
41 state_3: &mut u32,
42 output: &mut View<Line<E>, u32, ReadWrite>,
43 ) {
44 let prob = args.probability;
45
46 let mut output_line = Line::empty(line_size);
47
48 let num_iterations = n_values_per_thread / line_size;
49 #[unroll(num_iterations <=8)]
50 for line_index in 0..num_iterations {
51 #[unroll]
53 for i in 0..line_size {
54 *state_0 = taus_step_0(*state_0);
55 *state_1 = taus_step_1(*state_1);
56 *state_2 = taus_step_2(*state_2);
57 *state_3 = lcg_step(*state_3);
58
59 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
60 let float_random = to_unit_interval_closed_open(int_random);
61 output_line[i] = E::cast_from(float_random < prob);
62 }
63 let write_index = line_index * n_invocations + write_index_base;
64
65 output[write_index] = output_line;
66 }
67 }
68}
69
70impl<E: Numeric> PrngArgs<E> for Bernoulli<E> {
71 type Args = Self;
72
73 fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, E, R> {
74 BernoulliLaunch::new(ScalarArg::new(self.probability), PhantomData::<E>)
75 }
76}
77
78pub fn random_bernoulli<R: Runtime, E: Numeric>(
80 client: &ComputeClient<R::Server>,
81 probability: f32,
82 out: TensorHandleRef<R>,
83) {
84 assert_eq!(
85 out.elem_size as u32,
86 E::elem_size(),
87 "Tensor element type must be the same as type E"
88 );
89
90 random::<BernoulliFamily, E, R>(
91 client,
92 Bernoulli::<E> {
93 probability,
94 _phantom: PhantomData,
95 },
96 out,
97 )
98}