Skip to main content

burn_cubecl/kernel/prng/
uniform.rs

1use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
2use burn_backend::{DType, Shape, TensorMetadata};
3
4/// Pseudo-random generator with uniform distribution
5pub fn random_uniform<R: CubeRuntime>(
6    shape: Shape,
7    device: &R::Device,
8    lower_bound: f32,
9    upper_bound: f32,
10    dtype: DType,
11) -> CubeTensor<R> {
12    let client = R::client(device);
13    let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
14
15    cubek::random::random_uniform(
16        &client,
17        lower_bound,
18        upper_bound,
19        output.clone().binding(),
20        dtype.into(),
21    )
22    .expect("Kernel to never fail");
23
24    output
25}
26
27/// Pseudo-random generator for uniform distribution, based on
28/// another tensor.
29pub fn random_like_uniform<R: CubeRuntime>(
30    tensor: &CubeTensor<R>,
31    lower_bound: f32,
32    upper_bound: f32,
33    dtype: DType,
34) -> CubeTensor<R> {
35    random_uniform(
36        tensor.shape(),
37        &tensor.device,
38        lower_bound,
39        upper_bound,
40        dtype,
41    )
42}