cubek_random/
bernoulli.rs1use cubecl::prelude::*;
2use cubecl::std::tensor::View;
3use cubecl::{CubeType, Runtime};
4
5use crate::RandomFamily;
6
7use super::{
8 PrngArgs, PrngRuntime, lcg_step, random, taus_step_0, taus_step_1, taus_step_2,
9 to_unit_interval_closed_open,
10};
11
12#[derive(CubeLaunch, CubeType)]
13pub(crate) struct Bernoulli {
14 probability: f32,
15}
16
17#[derive(Debug)]
18struct BernoulliFamily;
19
20impl RandomFamily for BernoulliFamily {
21 type Runtime = Bernoulli;
22}
23
24#[cube]
25impl PrngRuntime for Bernoulli {
26 fn inner_loop<E: Numeric>(
27 args: Bernoulli,
28 write_index_base: usize,
29 n_invocations: u32,
30 #[comptime] n_values_per_thread: usize,
31 #[comptime] line_size: LineSize,
32 state_0: &mut u32,
33 state_1: &mut u32,
34 state_2: &mut u32,
35 state_3: &mut u32,
36 output: &mut View<Line<E>, usize, ReadWrite>,
37 ) {
38 let prob = args.probability;
39
40 let mut output_line = Line::empty(line_size);
41
42 let num_iterations = n_values_per_thread / line_size;
43 #[unroll(num_iterations <=8)]
44 for line_index in 0..num_iterations {
45 #[unroll]
47 for i in 0..line_size {
48 *state_0 = taus_step_0(*state_0);
49 *state_1 = taus_step_1(*state_1);
50 *state_2 = taus_step_2(*state_2);
51 *state_3 = lcg_step(*state_3);
52
53 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
54 let float_random = to_unit_interval_closed_open(int_random);
55 output_line[i] = E::cast_from(float_random < prob);
56 }
57 let write_index = line_index * n_invocations as usize + write_index_base;
58
59 output[write_index] = output_line;
60 }
61 }
62}
63
64impl PrngArgs for Bernoulli {
65 type Args = Self;
66
67 fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, R> {
68 BernoulliLaunch::new(ScalarArg::new(self.probability))
69 }
70}
71
72pub fn random_bernoulli<R: Runtime>(
74 client: &ComputeClient<R>,
75 probability: f32,
76 out: TensorHandleRef<R>,
77 dtype: StorageType,
78) -> Result<(), LaunchError> {
79 assert_eq!(
80 out.elem_size,
81 dtype.size(),
82 "Tensor element type must be the same as type E"
83 );
84
85 random::<BernoulliFamily, R>(client, Bernoulli { probability }, out, dtype)
86}