1use half::{bf16, f16};
3use rand::distributions::{Distribution, Standard};
4use rand::prelude::*;
5use rand_distr::StandardNormal;
6use std::fmt;
7
8#[derive(Debug, Clone, Copy)]
10pub struct Bf16Wrapper(pub bf16);
11
12#[derive(Debug, Clone, Copy)]
14pub struct F16Wrapper(pub f16);
15
16impl From<bf16> for Bf16Wrapper {
18 fn from(value: bf16) -> Self {
19 Bf16Wrapper(value)
20 }
21}
22
23impl From<Bf16Wrapper> for bf16 {
24 fn from(wrapper: Bf16Wrapper) -> Self {
25 wrapper.0
26 }
27}
28
29impl From<f16> for F16Wrapper {
30 fn from(value: f16) -> Self {
31 F16Wrapper(value)
32 }
33}
34
35impl From<F16Wrapper> for f16 {
36 fn from(wrapper: F16Wrapper) -> Self {
37 wrapper.0
38 }
39}
40
41impl fmt::Display for Bf16Wrapper {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 write!(f, "{}", f32::from(self.0))
45 }
46}
47
48impl fmt::Display for F16Wrapper {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 write!(f, "{}", f32::from(self.0))
51 }
52}
53
54impl Distribution<Bf16Wrapper> for Standard {
56 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Bf16Wrapper {
57 let val: f32 = rng.gen();
59 Bf16Wrapper(bf16::from_f32(val))
60 }
61}
62
63impl Distribution<F16Wrapper> for Standard {
64 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F16Wrapper {
65 let val: f32 = rng.gen();
67 F16Wrapper(f16::from_f32(val))
68 }
69}
70
71impl Distribution<Bf16Wrapper> for StandardNormal {
72 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Bf16Wrapper {
73 let x: f32 = StandardNormal.sample(rng);
74 Bf16Wrapper(bf16::from_f32(x))
75 }
76}
77
78impl Distribution<F16Wrapper> for StandardNormal {
79 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F16Wrapper {
80 let x: f32 = StandardNormal.sample(rng);
81 F16Wrapper(f16::from_f32(x))
82 }
83}
84
85pub fn rand_uniform_bf16<R: Rng + ?Sized>(rng: &mut R, min: f32, max: f32) -> bf16 {
87 let range = rand::distributions::Uniform::new(min, max);
88 let val: f32 = range.sample(rng);
89 bf16::from_f32(val)
90}
91
92pub fn rand_uniform_f16<R: Rng + ?Sized>(rng: &mut R, min: f32, max: f32) -> f16 {
93 let range = rand::distributions::Uniform::new(min, max);
94 let val: f32 = range.sample(rng);
95 f16::from_f32(val)
96}
97
98pub fn rand_normal_bf16<R: Rng + ?Sized>(rng: &mut R, mean: f32, std: f32) -> bf16 {
99 let normal = rand_distr::Normal::new(mean, std).unwrap();
100 let val: f32 = normal.sample(rng);
101 bf16::from_f32(val)
102}
103
104pub fn rand_normal_f16<R: Rng + ?Sized>(rng: &mut R, mean: f32, max: f32) -> f16 {
105 let normal = rand_distr::Normal::new(mean, max).unwrap();
106 let val: f32 = normal.sample(rng);
107 f16::from_f32(val)
108}
109
110pub trait HalfRngExt {
112 fn gen_bf16(&mut self) -> bf16;
113 fn gen_f16(&mut self) -> f16;
114 fn gen_range_bf16(&mut self, min: f32, max: f32) -> bf16;
115 fn gen_range_f16(&mut self, min: f32, max: f32) -> f16;
116}
117
118impl<R: RngCore + ?Sized> HalfRngExt for R {
119 fn gen_bf16(&mut self) -> bf16 {
120 let wrapper: Bf16Wrapper = self.gen();
121 wrapper.0
122 }
123
124 fn gen_f16(&mut self) -> f16 {
125 let wrapper: F16Wrapper = self.gen();
126 wrapper.0
127 }
128
129 fn gen_range_bf16(&mut self, min: f32, max: f32) -> bf16 {
130 rand_uniform_bf16(self, min, max)
131 }
132
133 fn gen_range_f16(&mut self, min: f32, max: f32) -> f16 {
134 rand_uniform_f16(self, min, max)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_bf16_standard_distribution() {
144 let mut rng = rand::thread_rng();
145 let val: Bf16Wrapper = rng.gen();
146 assert!(f32::from(val.0) >= 0.0 && f32::from(val.0) < 1.0);
147 }
148
149 #[test]
150 fn test_f16_standard_distribution() {
151 let mut rng = rand::thread_rng();
152 let val: F16Wrapper = rng.gen();
153 assert!(f32::from(val.0) >= 0.0 && f32::from(val.0) < 1.0);
154 }
155
156 #[test]
157 fn test_rand_uniform_bf16() {
158 let mut rng = rand::thread_rng();
159 for _ in 0..100 {
160 let val = rand_uniform_bf16(&mut rng, -1.0, 1.0);
161 assert!(f32::from(val) >= -1.0 && f32::from(val) <= 1.0);
162 }
163 }
164
165 #[test]
166 fn test_rand_uniform_f16() {
167 let mut rng = rand::thread_rng();
168 for _ in 0..100 {
169 let val = rand_uniform_f16(&mut rng, -1.0, 1.0);
170 assert!(f32::from(val) >= -1.0 && f32::from(val) <= 1.0);
171 }
172 }
173
174 #[test]
175 fn test_half_rng_ext() {
176 let mut rng = rand::thread_rng();
177 let bf16_val = rng.gen_bf16();
178 let f16_val = rng.gen_f16();
179
180 assert!(f32::from(bf16_val).is_finite());
181 assert!(f32::from(f16_val).is_finite());
182
183 let range_bf16 = rng.gen_range_bf16(-1.0, 1.0);
184 let range_f16 = rng.gen_range_f16(-1.0, 1.0);
185
186 assert!(f32::from(range_bf16) >= -1.0 && f32::from(range_bf16) <= 1.0);
187 assert!(f32::from(range_f16) >= -1.0 && f32::from(range_f16) <= 1.0);
188 }
189}