1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::View;
4
5use crate::{
6 RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_closed_open,
7};
8
9use super::{PrngArgs, PrngRuntime, random};
10
11#[derive(CubeLaunch, CubeType)]
12pub(crate) struct Uniform {
13 lower_bound: f32,
14 upper_bound: f32,
15}
16
17#[derive(Debug)]
18struct UniformFamily;
19
20impl RandomFamily for UniformFamily {
21 type Runtime = Uniform;
22}
23
24#[cube]
25impl PrngRuntime for Uniform {
26 fn inner_loop<E: Numeric>(
27 args: Uniform,
28 write_index_base: u32,
29 n_invocations: u32,
30 #[comptime] n_values_per_thread: u32,
31 #[comptime] line_size: u32,
32 state_0: &mut u32,
33 state_1: &mut u32,
34 state_2: &mut u32,
35 state_3: &mut u32,
36 output: &mut View<Line<E>, u32, ReadWrite>,
37 ) {
38 let lower_bound = args.lower_bound;
39 let upper_bound = args.upper_bound;
40
41 let scale = upper_bound - lower_bound;
42
43 let mut output_line = Line::empty(line_size);
44
45 let num_iterations = n_values_per_thread / line_size;
46 #[unroll(num_iterations <= 8)]
47 for line_index in 0..num_iterations {
48 #[unroll]
50 for i in 0..line_size {
51 *state_0 = taus_step_0(*state_0);
52 *state_1 = taus_step_1(*state_1);
53 *state_2 = taus_step_2(*state_2);
54 *state_3 = lcg_step(*state_3);
55
56 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
57 let f32_random = to_unit_interval_closed_open(int_random);
58
59 let f32_uniform = f32_random * f32::cast_from(scale) + f32::cast_from(lower_bound);
60
61 let uniform = E::cast_from(f32_uniform);
62
63 output_line[i] = uniform;
64 }
65
66 let write_index = line_index * n_invocations + write_index_base;
67
68 output[write_index] = output_line;
69 }
70 }
71}
72
73impl PrngArgs for Uniform {
74 type Args = Self;
75
76 fn args<'a, R: Runtime>(self) -> UniformLaunch<'a, R> {
77 UniformLaunch::new(
78 ScalarArg::new(self.lower_bound),
79 ScalarArg::new(self.upper_bound),
80 )
81 }
82}
83
84pub fn random_uniform<R: Runtime>(
86 client: &ComputeClient<R::Server>,
87 lower_bound: f32,
88 upper_bound: f32,
89 out: TensorHandleRef<R>,
90 dtype: StorageType,
91) {
92 assert_eq!(
93 out.elem_size,
94 dtype.size(),
95 "Tensor element type must be the same as type E"
96 );
97
98 random::<UniformFamily, R>(
99 client,
100 Uniform {
101 lower_bound,
102 upper_bound,
103 },
104 out,
105 dtype,
106 )
107}