burn-cubecl 0.21.0-pre.4

Generic backend that can be compiled just-in-time to any shader language target
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{DType, Shape};

/// Pseudo-random generator with bernoulli distribution
pub fn random_bernoulli<R: CubeRuntime>(
    shape: Shape,
    device: &R::Device,
    probability: f32,
    dtype: DType,
) -> CubeTensor<R> {
    let client = R::client(device);
    let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);

    cubek::random::random_bernoulli(&client, probability, output.clone().binding(), dtype.into())
        .expect("Kernel to never fail");

    output
}