rand_half/
lib.rs

1// src/lib.rs
2use half::{bf16, f16};
3use rand::distributions::{Distribution, Standard};
4use rand::prelude::*;
5use rand_distr::StandardNormal;
6use std::fmt;
7
8/// Wrapper type for bf16 to allow implementing foreign traits
9#[derive(Debug, Clone, Copy)]
10pub struct Bf16Wrapper(pub bf16);
11
12/// Wrapper type for f16 to allow implementing foreign traits
13#[derive(Debug, Clone, Copy)]
14pub struct F16Wrapper(pub f16);
15
16// Implement conversion methods
17impl From<bf16> for Bf16Wrapper {
18    fn from(value: bf16) -> Self {
19        Bf16Wrapper(value)
20    }
21}
22
23impl From<Bf16Wrapper> for bf16 {
24    fn from(wrapper: Bf16Wrapper) -> Self {
25        wrapper.0
26    }
27}
28
29impl From<f16> for F16Wrapper {
30    fn from(value: f16) -> Self {
31        F16Wrapper(value)
32    }
33}
34
35impl From<F16Wrapper> for f16 {
36    fn from(wrapper: F16Wrapper) -> Self {
37        wrapper.0
38    }
39}
40
41// Implement display for convenience
42impl fmt::Display for Bf16Wrapper {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        write!(f, "{}", f32::from(self.0))
45    }
46}
47
48impl fmt::Display for F16Wrapper {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        write!(f, "{}", f32::from(self.0))
51    }
52}
53
54// Now implement Distribution for our wrapper types
55impl Distribution<Bf16Wrapper> for Standard {
56    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Bf16Wrapper {
57        // Generate an f32 and convert to bf16
58        let val: f32 = rng.gen();
59        Bf16Wrapper(bf16::from_f32(val))
60    }
61}
62
63impl Distribution<F16Wrapper> for Standard {
64    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F16Wrapper {
65        // Generate an f32 and convert to f16
66        let val: f32 = rng.gen();
67        F16Wrapper(f16::from_f32(val))
68    }
69}
70
71impl Distribution<Bf16Wrapper> for StandardNormal {
72    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Bf16Wrapper {
73        let x: f32 = StandardNormal.sample(rng);
74        Bf16Wrapper(bf16::from_f32(x))
75    }
76}
77
78impl Distribution<F16Wrapper> for StandardNormal {
79    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F16Wrapper {
80        let x: f32 = StandardNormal.sample(rng);
81        F16Wrapper(f16::from_f32(x))
82    }
83}
84
85// Helper functions for generating uniform distributions
86pub fn rand_uniform_bf16<R: Rng + ?Sized>(rng: &mut R, min: f32, max: f32) -> bf16 {
87    let range = rand::distributions::Uniform::new(min, max);
88    let val: f32 = range.sample(rng);
89    bf16::from_f32(val)
90}
91
92pub fn rand_uniform_f16<R: Rng + ?Sized>(rng: &mut R, min: f32, max: f32) -> f16 {
93    let range = rand::distributions::Uniform::new(min, max);
94    let val: f32 = range.sample(rng);
95    f16::from_f32(val)
96}
97
98pub fn rand_normal_bf16<R: Rng + ?Sized>(rng: &mut R, mean: f32, std: f32) -> bf16 {
99    let normal = rand_distr::Normal::new(mean, std).unwrap();
100    let val: f32 = normal.sample(rng);
101    bf16::from_f32(val)
102}
103
104pub fn rand_normal_f16<R: Rng + ?Sized>(rng: &mut R, mean: f32, max: f32) -> f16 {
105    let normal = rand_distr::Normal::new(mean, max).unwrap();
106    let val: f32 = normal.sample(rng);
107    f16::from_f32(val)
108}
109
110// Extension trait for RngCore to add convenience methods
111pub trait HalfRngExt {
112    fn gen_bf16(&mut self) -> bf16;
113    fn gen_f16(&mut self) -> f16;
114    fn gen_range_bf16(&mut self, min: f32, max: f32) -> bf16;
115    fn gen_range_f16(&mut self, min: f32, max: f32) -> f16;
116}
117
118impl<R: RngCore + ?Sized> HalfRngExt for R {
119    fn gen_bf16(&mut self) -> bf16 {
120        let wrapper: Bf16Wrapper = self.gen();
121        wrapper.0
122    }
123    
124    fn gen_f16(&mut self) -> f16 {
125        let wrapper: F16Wrapper = self.gen();
126        wrapper.0
127    }
128    
129    fn gen_range_bf16(&mut self, min: f32, max: f32) -> bf16 {
130        rand_uniform_bf16(self, min, max)
131    }
132    
133    fn gen_range_f16(&mut self, min: f32, max: f32) -> f16 {
134        rand_uniform_f16(self, min, max)
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    
142    #[test]
143    fn test_bf16_standard_distribution() {
144        let mut rng = rand::thread_rng();
145        let val: Bf16Wrapper = rng.gen();
146        assert!(f32::from(val.0) >= 0.0 && f32::from(val.0) < 1.0);
147    }
148    
149    #[test]
150    fn test_f16_standard_distribution() {
151        let mut rng = rand::thread_rng();
152        let val: F16Wrapper = rng.gen();
153        assert!(f32::from(val.0) >= 0.0 && f32::from(val.0) < 1.0);
154    }
155    
156    #[test]
157    fn test_rand_uniform_bf16() {
158        let mut rng = rand::thread_rng();
159        for _ in 0..100 {
160            let val = rand_uniform_bf16(&mut rng, -1.0, 1.0);
161            assert!(f32::from(val) >= -1.0 && f32::from(val) <= 1.0);
162        }
163    }
164    
165    #[test]
166    fn test_rand_uniform_f16() {
167        let mut rng = rand::thread_rng();
168        for _ in 0..100 {
169            let val = rand_uniform_f16(&mut rng, -1.0, 1.0);
170            assert!(f32::from(val) >= -1.0 && f32::from(val) <= 1.0);
171        }
172    }
173    
174    #[test]
175    fn test_half_rng_ext() {
176        let mut rng = rand::thread_rng();
177        let bf16_val = rng.gen_bf16();
178        let f16_val = rng.gen_f16();
179        
180        assert!(f32::from(bf16_val).is_finite());
181        assert!(f32::from(f16_val).is_finite());
182        
183        let range_bf16 = rng.gen_range_bf16(-1.0, 1.0);
184        let range_f16 = rng.gen_range_f16(-1.0, 1.0);
185        
186        assert!(f32::from(range_bf16) >= -1.0 && f32::from(range_bf16) <= 1.0);
187        assert!(f32::from(range_f16) >= -1.0 && f32::from(range_f16) <= 1.0);
188    }
189}