1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3
4use crate::{
5 RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_closed_open,
6};
7
8use super::{PrngArgs, PrngRuntime, random};
9
10#[derive(CubeLaunch, CubeType)]
11pub(crate) struct Uniform<E: Numeric> {
12 lower_bound: E,
13 upper_bound: E,
14}
15
16#[derive(Debug)]
17struct UniformFamily;
18
19impl RandomFamily for UniformFamily {
20 type Runtime<E: Numeric> = Uniform<E>;
21}
22
23#[cube]
24impl<E: Numeric> PrngRuntime<E> for Uniform<E> {
25 fn inner_loop(
26 args: Uniform<E>,
27 write_index_base: u32,
28 n_invocations: u32,
29 #[comptime] n_values_per_thread: u32,
30 #[comptime] line_size: u32,
31 state_0: &mut u32,
32 state_1: &mut u32,
33 state_2: &mut u32,
34 state_3: &mut u32,
35 output: &mut Tensor<Line<E>>,
36 ) {
37 let lower_bound = args.lower_bound;
38 let upper_bound = args.upper_bound;
39
40 let scale = upper_bound - lower_bound;
41
42 let mut output_line = Line::empty(line_size);
43
44 let num_iterations = n_values_per_thread / line_size;
45 #[unroll(num_iterations <= 8)]
46 for line_index in 0..num_iterations {
47 #[unroll]
49 for i in 0..line_size {
50 *state_0 = taus_step_0(*state_0);
51 *state_1 = taus_step_1(*state_1);
52 *state_2 = taus_step_2(*state_2);
53 *state_3 = lcg_step(*state_3);
54
55 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56 let f32_random = to_unit_interval_closed_open(int_random);
57
58 let f32_uniform = f32_random * f32::cast_from(scale) + f32::cast_from(lower_bound);
59
60 let uniform = E::cast_from(f32_uniform);
61
62 output_line[i] = uniform;
63 }
64
65 let write_index = line_index * n_invocations + write_index_base;
66
67 output[write_index] = output_line;
68 }
69 }
70}
71
72impl<E: Numeric> PrngArgs<E> for Uniform<E> {
73 type Args = Self;
74
75 fn args<'a, R: Runtime>(self) -> UniformLaunch<'a, E, R> {
76 UniformLaunch::new(
77 ScalarArg::new(self.lower_bound),
78 ScalarArg::new(self.upper_bound),
79 )
80 }
81}
82
83pub fn random_uniform<R: Runtime, E: Numeric>(
85 client: &ComputeClient<R::Server, R::Channel>,
86 lower_bound: E,
87 upper_bound: E,
88 out: TensorHandleRef<R>,
89) {
90 assert_eq!(
91 out.elem_size as u32,
92 E::elem_size(),
93 "Tensor element type must be the same as type E"
94 );
95
96 random::<UniformFamily, E, R>(
97 client,
98 Uniform {
99 lower_bound,
100 upper_bound,
101 },
102 out,
103 )
104}