cubecl_random/
uniform.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::View;
4
5use crate::{
6    RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_closed_open,
7};
8
9use super::{PrngArgs, PrngRuntime, random};
10
11#[derive(CubeLaunch, CubeType)]
12pub(crate) struct Uniform {
13    lower_bound: f32,
14    upper_bound: f32,
15}
16
17#[derive(Debug)]
18struct UniformFamily;
19
20impl RandomFamily for UniformFamily {
21    type Runtime = Uniform;
22}
23
24#[cube]
25impl PrngRuntime for Uniform {
26    fn inner_loop<E: Numeric>(
27        args: Uniform,
28        write_index_base: u32,
29        n_invocations: u32,
30        #[comptime] n_values_per_thread: u32,
31        #[comptime] line_size: u32,
32        state_0: &mut u32,
33        state_1: &mut u32,
34        state_2: &mut u32,
35        state_3: &mut u32,
36        output: &mut View<Line<E>, u32, ReadWrite>,
37    ) {
38        let lower_bound = args.lower_bound;
39        let upper_bound = args.upper_bound;
40
41        let scale = upper_bound - lower_bound;
42
43        let mut output_line = Line::empty(line_size);
44
45        let num_iterations = n_values_per_thread / line_size;
46        #[unroll(num_iterations <= 8)]
47        for line_index in 0..num_iterations {
48            // vectorization
49            #[unroll]
50            for i in 0..line_size {
51                *state_0 = taus_step_0(*state_0);
52                *state_1 = taus_step_1(*state_1);
53                *state_2 = taus_step_2(*state_2);
54                *state_3 = lcg_step(*state_3);
55
56                let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
57                let f32_random = to_unit_interval_closed_open(int_random);
58
59                let f32_uniform = f32_random * f32::cast_from(scale) + f32::cast_from(lower_bound);
60
61                let uniform = E::cast_from(f32_uniform);
62
63                output_line[i] = uniform;
64            }
65
66            let write_index = line_index * n_invocations + write_index_base;
67
68            output[write_index] = output_line;
69        }
70    }
71}
72
73impl PrngArgs for Uniform {
74    type Args = Self;
75
76    fn args<'a, R: Runtime>(self) -> UniformLaunch<'a, R> {
77        UniformLaunch::new(
78            ScalarArg::new(self.lower_bound),
79            ScalarArg::new(self.upper_bound),
80        )
81    }
82}
83
84/// Pseudo-random generator with uniform distribution
85pub fn random_uniform<R: Runtime>(
86    client: &ComputeClient<R::Server>,
87    lower_bound: f32,
88    upper_bound: f32,
89    out: TensorHandleRef<R>,
90    dtype: StorageType,
91) {
92    assert_eq!(
93        out.elem_size,
94        dtype.size(),
95        "Tensor element type must be the same as type E"
96    );
97
98    random::<UniformFamily, R>(
99        client,
100        Uniform {
101            lower_bound,
102            upper_bound,
103        },
104        out,
105        dtype,
106    )
107}