half_2/
rand_distr.rs

1use crate::{bf16, f16};
2
3use rand::{distributions::Distribution, Rng};
4use rand_distr::uniform::UniformFloat;
5
6macro_rules! impl_distribution_via_f32 {
7    ($Ty:ty, $Distr:ty) => {
8        impl Distribution<$Ty> for $Distr {
9            fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $Ty {
10                <$Ty>::from_f32(<Self as Distribution<f32>>::sample(self, rng))
11            }
12        }
13    };
14}
15
16impl_distribution_via_f32!(f16, rand_distr::Standard);
17impl_distribution_via_f32!(f16, rand_distr::StandardNormal);
18impl_distribution_via_f32!(f16, rand_distr::Exp1);
19impl_distribution_via_f32!(f16, rand_distr::Open01);
20impl_distribution_via_f32!(f16, rand_distr::OpenClosed01);
21
22impl_distribution_via_f32!(bf16, rand_distr::Standard);
23impl_distribution_via_f32!(bf16, rand_distr::StandardNormal);
24impl_distribution_via_f32!(bf16, rand_distr::Exp1);
25impl_distribution_via_f32!(bf16, rand_distr::Open01);
26impl_distribution_via_f32!(bf16, rand_distr::OpenClosed01);
27
28#[derive(Debug, Clone, Copy)]
29pub struct Float16Sampler(UniformFloat<f32>);
30
31impl rand_distr::uniform::SampleUniform for f16 {
32    type Sampler = Float16Sampler;
33}
34
35impl rand_distr::uniform::UniformSampler for Float16Sampler {
36    type X = f16;
37    fn new<B1, B2>(low: B1, high: B2) -> Self
38    where
39        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
40        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
41    {
42        Self(UniformFloat::new(
43            low.borrow().to_f32(),
44            high.borrow().to_f32(),
45        ))
46    }
47    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
48    where
49        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
50        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
51    {
52        Self(UniformFloat::new_inclusive(
53            low.borrow().to_f32(),
54            high.borrow().to_f32(),
55        ))
56    }
57    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
58        f16::from_f32(self.0.sample(rng))
59    }
60}
61
62#[derive(Debug, Clone, Copy)]
63pub struct BFloat16Sampler(UniformFloat<f32>);
64
65impl rand_distr::uniform::SampleUniform for bf16 {
66    type Sampler = BFloat16Sampler;
67}
68
69impl rand_distr::uniform::UniformSampler for BFloat16Sampler {
70    type X = bf16;
71    fn new<B1, B2>(low: B1, high: B2) -> Self
72    where
73        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
74        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
75    {
76        Self(UniformFloat::new(
77            low.borrow().to_f32(),
78            high.borrow().to_f32(),
79        ))
80    }
81    fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
82    where
83        B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
84        B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
85    {
86        Self(UniformFloat::new_inclusive(
87            low.borrow().to_f32(),
88            high.borrow().to_f32(),
89        ))
90    }
91    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
92        bf16::from_f32(self.0.sample(rng))
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99
100    #[allow(unused_imports)]
101    use rand::{thread_rng, Rng};
102    use rand_distr::{Standard, StandardNormal, Uniform};
103
104    #[test]
105    fn test_sample_f16() {
106        let mut rng = thread_rng();
107        let _: f16 = rng.sample(Standard);
108        let _: f16 = rng.sample(StandardNormal);
109        let _: f16 = rng.sample(Uniform::new(f16::from_f32(0.0), f16::from_f32(1.0)));
110        #[cfg(feature = "num-traits")]
111        let _: f16 =
112            rng.sample(rand_distr::Normal::new(f16::from_f32(0.0), f16::from_f32(1.0)).unwrap());
113    }
114
115    #[test]
116    fn test_sample_bf16() {
117        let mut rng = thread_rng();
118        let _: bf16 = rng.sample(Standard);
119        let _: bf16 = rng.sample(StandardNormal);
120        let _: bf16 = rng.sample(Uniform::new(bf16::from_f32(0.0), bf16::from_f32(1.0)));
121        #[cfg(feature = "num-traits")]
122        let _: bf16 =
123            rng.sample(rand_distr::Normal::new(bf16::from_f32(0.0), bf16::from_f32(1.0)).unwrap());
124    }
125}