use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
use burn_backend::{DType, Shape, TensorMetadata};
pub fn random_uniform<R: CubeRuntime>(
shape: Shape,
device: &R::Device,
lower_bound: f32,
upper_bound: f32,
dtype: DType,
) -> CubeTensor<R> {
let client = R::client(device);
let output = empty_device_dtype(client.clone(), device.clone(), shape, dtype);
cubek::random::random_uniform(
&client,
lower_bound,
upper_bound,
output.clone().binding(),
dtype.into(),
)
.expect("Kernel to never fail");
output
}
pub fn random_like_uniform<R: CubeRuntime>(
tensor: &CubeTensor<R>,
lower_bound: f32,
upper_bound: f32,
dtype: DType,
) -> CubeTensor<R> {
random_uniform(
tensor.shape(),
&tensor.device,
lower_bound,
upper_bound,
dtype,
)
}