a_half/
lib.rs

1//! A crate that implements a half-precision floating-point type `f16` and its associated methods.
2
3use std::cmp::Ordering;
4use std::ops::{Add, Sub, Mul, Div};
5use std::fmt;
6
7extern "C" {
8    fn hs_floatToHalf(f: f32) -> u16; 
9    fn hs_halfToFloat(c: u16) -> f32; 
10}
11
12/// Type `f16`, which is essentially a wrapper around u16.
13#[derive(Debug, Copy, Clone)]
14#[allow(non_camel_case_types)]
15pub struct f16 {
16    bit_repr: u16,
17}
18
19const EXP_BW: u16 = 5;
20const SIG_BW: u16 = 10;
21const SIG_MASK: u16 = 1023;
22const EXP_MASK: u16 = 31;
23const POS_INF_BR: u16 = 31744;
24const NEG_INF_BR: u16 = 64512;
25const POS_ZERO_BR: u16 = 0;
26const NEG_ZERO_BR: u16 = 32768;
27
28impl f16 {
29    // bit conversions
30    /// Build a `f16` from its bit-representation.
31    #[inline]
32    pub fn from_bits(x: u16) -> Self {
33        Self {
34            bit_repr: x,
35        }
36    }
37
38    /// Obtain a `f16`'s bit-representation.
39    #[inline]
40    pub fn to_bits(self) -> u16 {
41        self.bit_repr
42    }
43
44    fn get_exp_bits(self) -> u16 {
45        (self.to_bits() >> SIG_BW) & EXP_MASK
46    }
47
48    fn get_sig_bits(self) -> u16 {
49        self.to_bits() & SIG_MASK
50    }
51
52    fn get_sign_bit(self) -> u16 {
53        self.to_bits() >> (SIG_BW + EXP_BW)
54    }
55
56    // predicates
57    /// The `is_finite` method for `f16`
58    #[inline]
59    pub fn is_finite(self) -> bool {
60        self.get_exp_bits() < EXP_MASK
61    }
62
63    /// The `is_infinite` method for `f16`
64    #[inline]
65    pub fn is_infinite(self) -> bool {
66        let b = self.to_bits();
67        b == POS_INF_BR || b == NEG_INF_BR
68    }
69
70    /// The `is_nan` method for `f16`
71    #[inline]
72    pub fn is_nan(self) -> bool {
73        self != self
74    }
75
76    /// The `is_normal` method for `f16`
77    #[inline]
78    pub fn is_normal(self) -> bool {
79        let exp = self.get_exp_bits();
80        exp > 0 && exp < EXP_MASK
81    }
82
83    /// The `is_sign_positive` method for `f16`
84    #[inline]
85    pub fn is_sign_positive(self) -> bool {
86        self.get_sign_bit() == 0
87    }
88
89    /// The `is_sign_negative` method for `f16`
90    #[inline]
91    pub fn is_sign_negative(self) -> bool {
92        !self.is_sign_positive()
93    }
94
95    fn is_subnormal(self) -> bool {
96        self.get_exp_bits() == 0 && !self.is_zero()
97    }
98
99    fn is_zero(self) -> bool {
100        let b = self.to_bits();
101        b == POS_ZERO_BR || b == NEG_ZERO_BR
102    }
103}
104
105// conversions from/to f32
106impl From<f16> for f32 {
107    fn from(x: f16) -> Self {
108        unsafe {
109            hs_halfToFloat(f16::to_bits(x))
110        }
111    }
112}
113
114impl From<f32> for f16 {
115    fn from(x: f32) -> Self {
116        unsafe {
117            f16::from_bits(hs_floatToHalf(x))
118        }
119    }
120}
121
122
123macro_rules! bin_op {
124    ($op_name:ident, $ret_type:ty) => {
125        fn $op_name(&self, other: &Self) -> $ret_type {
126            let lhs = f32::from(*self);
127            let rhs = &f32::from(*other);
128            lhs.$op_name(rhs)
129        }
130    }
131}
132
133macro_rules! bin_arith {
134    ($op_trait:ident, $op_name:ident) => {
135        impl $op_trait for f16 {
136            type Output = f16;
137
138            fn $op_name(self, other: Self) -> Self {
139                let lhs = f32::from(self);
140                let rhs = f32::from(other);
141                Self::from(lhs.$op_name(rhs))
142            }
143        }
144    }
145}
146
147// partial equality
148impl PartialEq for f16 {
149    bin_op!(eq, bool);
150    bin_op!(ne, bool);
151}
152
153// partial order
154impl PartialOrd for f16 {
155    bin_op!(partial_cmp, Option<Ordering>);
156}
157
158// arithmetic
159bin_arith!(Add, add);
160bin_arith!(Sub, sub);
161bin_arith!(Mul, mul);
162bin_arith!(Div, div);
163
164// printer
165impl fmt::Display for f16 {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        f32::from(*self).fmt(f)
168    }
169}
170
171
172#[cfg(test)]
173mod tests {
174    use crate::f16;
175    #[test]
176    fn size_check() {
177        use std::mem;
178        assert_eq!(mem::size_of::<f16>(), 2);
179        assert_eq!(2 + 2, 4);
180    }
181
182    #[test]
183    fn identity_check() {
184        let x = 1;
185        let xs = f16 {
186            bit_repr: x,
187        };
188        assert_eq!(x, xs.bit_repr);
189    }
190
191    #[test]
192    fn bits_and_half() {
193        let x = 1;
194        assert_eq!(x, f16::from_bits(x).to_bits());
195    }
196
197    #[test]
198    fn float_to_half() {
199        let x0: f32 = 0.0;
200        let x1: f32 = 1.0;
201        let x2: f32 = 2.0;
202        assert_eq!(0, f16::to_bits(f16::from(x0)));
203        assert_eq!(15360, f16::from(x1).to_bits());
204        assert_eq!(16384, f16::from(x2).to_bits());
205    }
206
207    #[test]
208    fn half_to_float() {
209        let x0: f16 = f16::from_bits(0);
210        let x1: f16 = f16::from_bits(15360);
211        let xmin: f16 = f16::from_bits(1);
212        assert_eq!(0.0, f32::from(x0));
213        assert_eq!(1.0, f32::from(x1));
214        assert_eq!(5.9604645e-8, f32::from(xmin));
215    }
216
217    #[test]
218    fn partial_eq_check() {
219        let x0: f16 = f16::from_bits(0);
220        let x1: f16 = f16::from_bits(1);
221        assert!(x0 != x1);
222    }
223
224    #[test]
225    fn partial_ord_check() {
226        let x0: f16 = f16::from_bits(0);
227        let x1: f16 = f16::from_bits(1);
228        assert!(x0 < x1);
229    }
230
231    #[test]
232    fn arith_ops() {
233        let x0: f16 = f16::from_bits(0);
234        let x1: f16 = f16::from_bits(1);
235        assert_eq!(x1, x0+x1);
236        assert_eq!(x0, x0*x1);
237    }
238
239    #[test]
240    fn predicates() {
241        let nan: f16 = f16::from(0.0/0.0);
242        let pos_inf: f16 = f16::from(1.0/0.0);
243        let neg_inf: f16 = f16::from(-1.0/0.0);
244        let xmin: f16 = f16::from_bits(1);
245        assert!(nan.is_nan());
246        assert!(pos_inf.is_infinite());
247        assert!(neg_inf.is_infinite());
248        assert!(pos_inf.is_sign_positive());
249        assert!(neg_inf.is_sign_negative());
250        assert!(!xmin.is_normal());
251        assert!(xmin.is_subnormal());
252    }
253
254}