burn_cubecl/kernel/prng/
normal.rs1use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
2use burn_backend::{DType, Shape};
3
4pub fn random_normal<R: CubeRuntime>(
6 shape: Shape,
7 device: &R::Device,
8 mean: f32,
9 std: f32,
10 dtype: DType,
11) -> CubeTensor<R> {
12 let client = R::client(device);
13 let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
14 let output_handle = output.as_handle_ref();
15
16 cubek::random::random_normal(&client, mean, std, output_handle, dtype.into())
17 .expect("Kernel to never fail");
18
19 output
20}