Skip to main content

burn_cubecl/kernel/prng/
normal.rs

1use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
2use burn_backend::{DType, Shape};
3
4/// Pseudo-random generator with uniform distribution
5pub fn random_normal<R: CubeRuntime>(
6    shape: Shape,
7    device: &R::Device,
8    mean: f32,
9    std: f32,
10    dtype: DType,
11) -> CubeTensor<R> {
12    let client = R::client(device);
13    let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
14    let output_handle = output.as_handle_ref();
15
16    cubek::random::random_normal(&client, mean, std, output_handle, dtype.into())
17        .expect("Kernel to never fail");
18
19    output
20}