cubecl_random/
bernoulli.rs1use std::marker::PhantomData;
2
3use cubecl::prelude::*;
4use cubecl_core as cubecl;
5
6use cubecl::{CubeLaunch, CubeType, Runtime};
7
8use crate::RandomFamily;
9
10use super::{
11 PrngArgs, PrngRuntime, lcg_step, random, taus_step_0, taus_step_1, taus_step_2,
12 to_unit_interval_closed_open,
13};
14
15#[derive(CubeLaunch, CubeType)]
16pub(crate) struct Bernoulli<E: Numeric> {
17 probability: f32,
18 #[cube(comptime)]
19 _phantom: PhantomData<E>,
20}
21
22#[derive(Debug)]
23struct BernoulliFamily;
24
25impl RandomFamily for BernoulliFamily {
26 type Runtime<E: Numeric> = Bernoulli<E>;
27}
28
29#[cube]
30impl<E: Numeric> PrngRuntime<E> for Bernoulli<E> {
31 fn inner_loop(
32 args: Bernoulli<E>,
33 write_index_base: u32,
34 n_invocations: u32,
35 #[comptime] n_values_per_thread: u32,
36 #[comptime] line_size: u32,
37 state_0: &mut u32,
38 state_1: &mut u32,
39 state_2: &mut u32,
40 state_3: &mut u32,
41 output: &mut Tensor<Line<E>>,
42 ) {
43 let prob = args.probability;
44
45 let mut output_line = Line::empty(line_size);
46
47 let num_iterations = n_values_per_thread / line_size;
48 #[unroll(num_iterations <=8)]
49 for line_index in 0..num_iterations {
50 #[unroll]
52 for i in 0..line_size {
53 *state_0 = taus_step_0(*state_0);
54 *state_1 = taus_step_1(*state_1);
55 *state_2 = taus_step_2(*state_2);
56 *state_3 = lcg_step(*state_3);
57
58 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
59 let float_random = to_unit_interval_closed_open(int_random);
60 output_line[i] = E::cast_from(float_random < prob);
61 }
62 let write_index = line_index * n_invocations + write_index_base;
63
64 output[write_index] = output_line;
65 }
66 }
67}
68
69impl<E: Numeric> PrngArgs<E> for Bernoulli<E> {
70 type Args = Self;
71
72 fn args<'a, R: Runtime>(self) -> BernoulliLaunch<'a, E, R> {
73 BernoulliLaunch::new(ScalarArg::new(self.probability), &PhantomData::<E>)
74 }
75}
76
77pub fn random_bernoulli<R: Runtime, E: Numeric>(
79 client: &ComputeClient<R::Server, R::Channel>,
80 probability: f32,
81 out: TensorHandleRef<R>,
82) {
83 assert_eq!(
84 out.elem_size as u32,
85 E::elem_size(),
86 "Tensor element type must be the same as type E"
87 );
88
89 random::<BernoulliFamily, E, R>(
90 client,
91 Bernoulli::<E> {
92 probability,
93 _phantom: PhantomData,
94 },
95 out,
96 )
97}