1use cubecl::prelude::*;
2use cubecl::std::tensor::View;
3
4use crate::{
5 RandomFamily, lcg_step, taus_step_0, taus_step_1, taus_step_2, to_unit_interval_closed_open,
6};
7
8use super::{PrngArgs, PrngRuntime, random};
9
10#[derive(CubeLaunch, CubeType)]
11pub(crate) struct Uniform {
12 lower_bound: f32,
13 upper_bound: f32,
14}
15
16#[derive(Debug)]
17struct UniformFamily;
18
19impl RandomFamily for UniformFamily {
20 type Runtime = Uniform;
21}
22
23#[cube]
24impl PrngRuntime for Uniform {
25 fn inner_loop<E: Numeric>(
26 args: Uniform,
27 write_index_base: usize,
28 n_invocations: u32,
29 #[comptime] n_values_per_thread: usize,
30 #[comptime] line_size: LineSize,
31 state_0: &mut u32,
32 state_1: &mut u32,
33 state_2: &mut u32,
34 state_3: &mut u32,
35 output: &mut View<Line<E>, usize, ReadWrite>,
36 ) {
37 let lower_bound = args.lower_bound;
38 let upper_bound = args.upper_bound;
39
40 let scale = upper_bound - lower_bound;
41
42 let mut output_line = Line::empty(line_size);
43
44 let num_iterations = n_values_per_thread / line_size;
45 #[unroll(num_iterations <= 8)]
46 for line_index in 0..num_iterations {
47 #[unroll]
49 for i in 0..line_size {
50 *state_0 = taus_step_0(*state_0);
51 *state_1 = taus_step_1(*state_1);
52 *state_2 = taus_step_2(*state_2);
53 *state_3 = lcg_step(*state_3);
54
55 let int_random = *state_0 ^ *state_1 ^ *state_2 ^ *state_3;
56 let f32_random = to_unit_interval_closed_open(int_random);
57
58 let f32_uniform = f32_random * f32::cast_from(scale) + f32::cast_from(lower_bound);
59
60 let uniform = E::cast_from(f32_uniform);
61
62 output_line[i] = uniform;
63 }
64
65 let write_index = line_index * n_invocations as usize + write_index_base;
66
67 output[write_index] = output_line;
68 }
69 }
70}
71
72impl PrngArgs for Uniform {
73 type Args = Self;
74
75 fn args<'a, R: Runtime>(self) -> UniformLaunch<'a, R> {
76 UniformLaunch::new(
77 ScalarArg::new(self.lower_bound),
78 ScalarArg::new(self.upper_bound),
79 )
80 }
81}
82
83pub fn random_uniform<R: Runtime>(
85 client: &ComputeClient<R>,
86 lower_bound: f32,
87 upper_bound: f32,
88 out: TensorHandleRef<R>,
89 dtype: StorageType,
90) -> Result<(), LaunchError> {
91 assert_eq!(
92 out.elem_size,
93 dtype.size(),
94 "Tensor element type must be the same as type E"
95 );
96
97 random::<UniformFamily, R>(
98 client,
99 Uniform {
100 lower_bound,
101 upper_bound,
102 },
103 out,
104 dtype,
105 )
106}