use super::*;
use crate::math::tensor::{AsMutTensor, Tensor};
pub struct UniformWithZeros {
pub prob_zero: f32,
}
#[allow(unused_macros)]
macro_rules! implement_uniform_with_zeros {
($T:ty, $bits:literal) => {
impl RandomGenerable<UniformWithZeros> for $T {
#[allow(unused)]
fn sample(UniformWithZeros { prob_zero }: UniformWithZeros) -> Self {
let float_sample = random_uniform::<u32>() as f32 / u32::MAX as f32;
if float_sample < prob_zero {
<$T>::ZERO
} else {
random_uniform::<$T>()
}
}
}
};
}
implement_uniform_with_zeros!(u8, 1);
implement_uniform_with_zeros!(u16, 2);
implement_uniform_with_zeros!(u32, 4);
implement_uniform_with_zeros!(u64, 8);
implement_uniform_with_zeros!(u128, 16);
implement_uniform_with_zeros!(i8, 1);
implement_uniform_with_zeros!(i16, 2);
implement_uniform_with_zeros!(i32, 4);
implement_uniform_with_zeros!(i64, 8);
implement_uniform_with_zeros!(i128, 16);
pub fn random_uniform_with_zeros<T: RandomGenerable<UniformWithZeros>>(prob_zero: f32) -> T {
T::sample(UniformWithZeros { prob_zero })
}
pub fn fill_with_random_uniform_with_zeros<Scalar, Tensorable>(
output: &mut Tensorable,
prob_zero: f32,
) where
Scalar: RandomGenerable<UniformWithZeros>,
Tensorable: AsMutTensor<Element = Scalar>,
{
output.as_mut_tensor().iter_mut().for_each(|s| {
*s = random_uniform_with_zeros::<Scalar>(prob_zero);
});
}
pub fn random_uniform_with_zeros_tensor<T: RandomGenerable<UniformWithZeros>>(
size: usize,
prob_zero: f32,
) -> Tensor<Vec<T>> {
(0..size)
.map(|_| random_uniform_with_zeros::<T>(prob_zero))
.collect()
}