1use crate::{bf16, f16};
2
3use rand::{distributions::Distribution, Rng};
4use rand_distr::uniform::UniformFloat;
5
6macro_rules! impl_distribution_via_f32 {
7 ($Ty:ty, $Distr:ty) => {
8 impl Distribution<$Ty> for $Distr {
9 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> $Ty {
10 <$Ty>::from_f32(<Self as Distribution<f32>>::sample(self, rng))
11 }
12 }
13 };
14}
15
16impl_distribution_via_f32!(f16, rand_distr::Standard);
17impl_distribution_via_f32!(f16, rand_distr::StandardNormal);
18impl_distribution_via_f32!(f16, rand_distr::Exp1);
19impl_distribution_via_f32!(f16, rand_distr::Open01);
20impl_distribution_via_f32!(f16, rand_distr::OpenClosed01);
21
22impl_distribution_via_f32!(bf16, rand_distr::Standard);
23impl_distribution_via_f32!(bf16, rand_distr::StandardNormal);
24impl_distribution_via_f32!(bf16, rand_distr::Exp1);
25impl_distribution_via_f32!(bf16, rand_distr::Open01);
26impl_distribution_via_f32!(bf16, rand_distr::OpenClosed01);
27
28#[derive(Debug, Clone, Copy)]
29pub struct Float16Sampler(UniformFloat<f32>);
30
31impl rand_distr::uniform::SampleUniform for f16 {
32 type Sampler = Float16Sampler;
33}
34
35impl rand_distr::uniform::UniformSampler for Float16Sampler {
36 type X = f16;
37 fn new<B1, B2>(low: B1, high: B2) -> Self
38 where
39 B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
40 B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
41 {
42 Self(UniformFloat::new(
43 low.borrow().to_f32(),
44 high.borrow().to_f32(),
45 ))
46 }
47 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
48 where
49 B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
50 B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
51 {
52 Self(UniformFloat::new_inclusive(
53 low.borrow().to_f32(),
54 high.borrow().to_f32(),
55 ))
56 }
57 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
58 f16::from_f32(self.0.sample(rng))
59 }
60}
61
62#[derive(Debug, Clone, Copy)]
63pub struct BFloat16Sampler(UniformFloat<f32>);
64
65impl rand_distr::uniform::SampleUniform for bf16 {
66 type Sampler = BFloat16Sampler;
67}
68
69impl rand_distr::uniform::UniformSampler for BFloat16Sampler {
70 type X = bf16;
71 fn new<B1, B2>(low: B1, high: B2) -> Self
72 where
73 B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
74 B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
75 {
76 Self(UniformFloat::new(
77 low.borrow().to_f32(),
78 high.borrow().to_f32(),
79 ))
80 }
81 fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
82 where
83 B1: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
84 B2: rand_distr::uniform::SampleBorrow<Self::X> + Sized,
85 {
86 Self(UniformFloat::new_inclusive(
87 low.borrow().to_f32(),
88 high.borrow().to_f32(),
89 ))
90 }
91 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
92 bf16::from_f32(self.0.sample(rng))
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[allow(unused_imports)]
101 use rand::{thread_rng, Rng};
102 use rand_distr::{Standard, StandardNormal, Uniform};
103
104 #[test]
105 fn test_sample_f16() {
106 let mut rng = thread_rng();
107 let _: f16 = rng.sample(Standard);
108 let _: f16 = rng.sample(StandardNormal);
109 let _: f16 = rng.sample(Uniform::new(f16::from_f32(0.0), f16::from_f32(1.0)));
110 #[cfg(feature = "num-traits")]
111 let _: f16 =
112 rng.sample(rand_distr::Normal::new(f16::from_f32(0.0), f16::from_f32(1.0)).unwrap());
113 }
114
115 #[test]
116 fn test_sample_bf16() {
117 let mut rng = thread_rng();
118 let _: bf16 = rng.sample(Standard);
119 let _: bf16 = rng.sample(StandardNormal);
120 let _: bf16 = rng.sample(Uniform::new(bf16::from_f32(0.0), bf16::from_f32(1.0)));
121 #[cfg(feature = "num-traits")]
122 let _: bf16 =
123 rng.sample(rand_distr::Normal::new(bf16::from_f32(0.0), bf16::from_f32(1.0)).unwrap());
124 }
125}