burn_cubecl/kernel/prng/
uniform.rs1use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
2use burn_backend::{DType, Shape, TensorMetadata};
3
4pub fn random_uniform<R: CubeRuntime>(
6 shape: Shape,
7 device: &R::Device,
8 lower_bound: f32,
9 upper_bound: f32,
10 dtype: DType,
11) -> CubeTensor<R> {
12 let client = R::client(device);
13 let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
14
15 cubek::random::random_uniform(
16 &client,
17 lower_bound,
18 upper_bound,
19 output.clone().binding(),
20 dtype.into(),
21 )
22 .expect("Kernel to never fail");
23
24 output
25}
26
27pub fn random_like_uniform<R: CubeRuntime>(
30 tensor: &CubeTensor<R>,
31 lower_bound: f32,
32 upper_bound: f32,
33 dtype: DType,
34) -> CubeTensor<R> {
35 random_uniform(
36 tensor.shape(),
37 &tensor.device,
38 lower_bound,
39 upper_bound,
40 dtype,
41 )
42}