cubek_random/
uniform.rs

1use cubecl::prelude::*;
2use cubecl::std::tensor::View;
3
4use crate::{
5    RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_closed_open,
6};
7
8use super::{PrngArgs, PrngRuntime, random};
9
10#[derive(CubeLaunch, CubeType)]
11pub(crate) struct Uniform {
12    lower_bound: f32,
13    upper_bound: f32,
14}
15
16#[derive(Debug)]
17struct UniformFamily;
18
19impl RandomFamily for UniformFamily {
20    type Runtime = Uniform;
21}
22
23#[cube]
24impl PrngRuntime for Uniform {
25    fn inner_loop<E: Numeric>(
26        args: Uniform,
27        write_index_base: usize,
28        n_invocations: u32,
29        #[comptime] n_values_per_thread: usize,
30        #[comptime] line_size: LineSize,
31        state_0: &mut u32,
32        state_1: &mut u32,
33        state_2: &mut u32,
34        state_3: &mut u32,
35        output: &mut View<Line<E>, usize, ReadWrite>,
36    ) {
37        let lower_bound = args.lower_bound;
38        let upper_bound = args.upper_bound;
39
40        let scale = upper_bound - lower_bound;
41
42        let mut output_line = Line::empty(line_size);
43
44        let num_iterations = n_values_per_thread / line_size;
45        #[unroll(num_iterations <= 8)]
46        for line_index in 0..num_iterations {
47            // vectorization
48            #[unroll]
49            for i in 0..line_size {
50                *state_0 = taus_step_0(*state_0);
51                *state_1 = taus_step_1(*state_1);
52                *state_2 = taus_step_2(*state_2);
53                *state_3 = lcg_step(*state_3);
54
55                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56                let f32_random = to_unit_interval_closed_open(int_random);
57
58                let f32_uniform = f32_random * f32::cast_from(scale) + f32::cast_from(lower_bound);
59
60                let uniform = E::cast_from(f32_uniform);
61
62                output_line[i] = uniform;
63            }
64
65            let write_index = line_index * n_invocations as usize + write_index_base;
66
67            output[write_index] = output_line;
68        }
69    }
70}
71
72impl PrngArgs for Uniform {
73    type Args = Self;
74
75    fn args<'a, R: Runtime>(self) -> UniformLaunch<'a, R> {
76        UniformLaunch::new(
77            ScalarArg::new(self.lower_bound),
78            ScalarArg::new(self.upper_bound),
79        )
80    }
81}
82
83/// Pseudo-random generator with uniform distribution
84pub fn random_uniform<R: Runtime>(
85    client: &ComputeClient<R>,
86    lower_bound: f32,
87    upper_bound: f32,
88    out: TensorHandleRef<R>,
89    dtype: StorageType,
90) -> Result<(), LaunchError> {
91    assert_eq!(
92        out.elem_size,
93        dtype.size(),
94        "Tensor element type must be the same as type E"
95    );
96
97    random::<UniformFamily, R>(
98        client,
99        Uniform {
100            lower_bound,
101            upper_bound,
102        },
103        out,
104        dtype,
105    )
106}