cubecl_random/
uniform.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::{
5    RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_closed_open,
6};
7
8use super::{PrngArgs, PrngRuntime, random};
9
10#[derive(CubeLaunch, CubeType)]
11pub(crate) struct Uniform<E: Numeric> {
12    lower_bound: E,
13    upper_bound: E,
14}
15
16#[derive(Debug)]
17struct UniformFamily;
18
19impl RandomFamily for UniformFamily {
20    type Runtime<E: Numeric> = Uniform<E>;
21}
22
23#[cube]
24impl<E: Numeric> PrngRuntime<E> for Uniform<E> {
25    fn inner_loop(
26        args: Uniform<E>,
27        write_index_base: u32,
28        n_invocations: u32,
29        #[comptime] n_values_per_thread: u32,
30        #[comptime] line_size: u32,
31        state_0: &mut u32,
32        state_1: &mut u32,
33        state_2: &mut u32,
34        state_3: &mut u32,
35        output: &mut Tensor<Line<E>>,
36    ) {
37        let lower_bound = args.lower_bound;
38        let upper_bound = args.upper_bound;
39
40        let scale = upper_bound - lower_bound;
41
42        let mut output_line = Line::empty(line_size);
43
44        let num_iterations = n_values_per_thread / line_size;
45        #[unroll(num_iterations <= 8)]
46        for line_index in 0..num_iterations {
47            // vectorization
48            #[unroll]
49            for i in 0..line_size {
50                *state_0 = taus_step_0(*state_0);
51                *state_1 = taus_step_1(*state_1);
52                *state_2 = taus_step_2(*state_2);
53                *state_3 = lcg_step(*state_3);
54
55                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56                let f32_random = to_unit_interval_closed_open(int_random);
57
58                let f32_uniform = f32_random * f32::cast_from(scale) + f32::cast_from(lower_bound);
59
60                let uniform = E::cast_from(f32_uniform);
61
62                output_line[i] = uniform;
63            }
64
65            let write_index = line_index * n_invocations + write_index_base;
66
67            output[write_index] = output_line;
68        }
69    }
70}
71
72impl<E: Numeric> PrngArgs<E> for Uniform<E> {
73    type Args = Self;
74
75    fn args<'a, R: Runtime>(self) -> UniformLaunch<'a, E, R> {
76        UniformLaunch::new(
77            ScalarArg::new(self.lower_bound),
78            ScalarArg::new(self.upper_bound),
79        )
80    }
81}
82
83/// Pseudo-random generator with uniform distribution
84pub fn random_uniform<R: Runtime, E: Numeric>(
85    client: &ComputeClient<R::Server, R::Channel>,
86    lower_bound: E,
87    upper_bound: E,
88    out: TensorHandleRef<R>,
89) {
90    assert_eq!(
91        out.elem_size as u32,
92        E::elem_size(),
93        "Tensor element type must be the same as type E"
94    );
95
96    random::<UniformFamily, E, R>(
97        client,
98        Uniform {
99            lower_bound,
100            upper_bound,
101        },
102        out,
103    )
104}