1use std::cmp::Ordering;
4use std::ops::{Add, Sub, Mul, Div};
5use std::fmt;
6
7extern "C" {
8 fn hs_floatToHalf(f: f32) -> u16;
9 fn hs_halfToFloat(c: u16) -> f32;
10}
11
12#[derive(Debug, Copy, Clone)]
14#[allow(non_camel_case_types)]
15pub struct f16 {
16 bit_repr: u16,
17}
18
19const EXP_BW: u16 = 5;
20const SIG_BW: u16 = 10;
21const SIG_MASK: u16 = 1023;
22const EXP_MASK: u16 = 31;
23const POS_INF_BR: u16 = 31744;
24const NEG_INF_BR: u16 = 64512;
25const POS_ZERO_BR: u16 = 0;
26const NEG_ZERO_BR: u16 = 32768;
27
28impl f16 {
29 #[inline]
32 pub fn from_bits(x: u16) -> Self {
33 Self {
34 bit_repr: x,
35 }
36 }
37
38 #[inline]
40 pub fn to_bits(self) -> u16 {
41 self.bit_repr
42 }
43
44 fn get_exp_bits(self) -> u16 {
45 (self.to_bits() >> SIG_BW) & EXP_MASK
46 }
47
48 fn get_sig_bits(self) -> u16 {
49 self.to_bits() & SIG_MASK
50 }
51
52 fn get_sign_bit(self) -> u16 {
53 self.to_bits() >> (SIG_BW + EXP_BW)
54 }
55
56 #[inline]
59 pub fn is_finite(self) -> bool {
60 self.get_exp_bits() < EXP_MASK
61 }
62
63 #[inline]
65 pub fn is_infinite(self) -> bool {
66 let b = self.to_bits();
67 b == POS_INF_BR || b == NEG_INF_BR
68 }
69
70 #[inline]
72 pub fn is_nan(self) -> bool {
73 self != self
74 }
75
76 #[inline]
78 pub fn is_normal(self) -> bool {
79 let exp = self.get_exp_bits();
80 exp > 0 && exp < EXP_MASK
81 }
82
83 #[inline]
85 pub fn is_sign_positive(self) -> bool {
86 self.get_sign_bit() == 0
87 }
88
89 #[inline]
91 pub fn is_sign_negative(self) -> bool {
92 !self.is_sign_positive()
93 }
94
95 fn is_subnormal(self) -> bool {
96 self.get_exp_bits() == 0 && !self.is_zero()
97 }
98
99 fn is_zero(self) -> bool {
100 let b = self.to_bits();
101 b == POS_ZERO_BR || b == NEG_ZERO_BR
102 }
103}
104
105impl From<f16> for f32 {
107 fn from(x: f16) -> Self {
108 unsafe {
109 hs_halfToFloat(f16::to_bits(x))
110 }
111 }
112}
113
114impl From<f32> for f16 {
115 fn from(x: f32) -> Self {
116 unsafe {
117 f16::from_bits(hs_floatToHalf(x))
118 }
119 }
120}
121
122
123macro_rules! bin_op {
124 ($op_name:ident, $ret_type:ty) => {
125 fn $op_name(&self, other: &Self) -> $ret_type {
126 let lhs = f32::from(*self);
127 let rhs = &f32::from(*other);
128 lhs.$op_name(rhs)
129 }
130 }
131}
132
133macro_rules! bin_arith {
134 ($op_trait:ident, $op_name:ident) => {
135 impl $op_trait for f16 {
136 type Output = f16;
137
138 fn $op_name(self, other: Self) -> Self {
139 let lhs = f32::from(self);
140 let rhs = f32::from(other);
141 Self::from(lhs.$op_name(rhs))
142 }
143 }
144 }
145}
146
147impl PartialEq for f16 {
149 bin_op!(eq, bool);
150 bin_op!(ne, bool);
151}
152
153impl PartialOrd for f16 {
155 bin_op!(partial_cmp, Option<Ordering>);
156}
157
158bin_arith!(Add, add);
160bin_arith!(Sub, sub);
161bin_arith!(Mul, mul);
162bin_arith!(Div, div);
163
164impl fmt::Display for f16 {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 f32::from(*self).fmt(f)
168 }
169}
170
171
172#[cfg(test)]
173mod tests {
174 use crate::f16;
175 #[test]
176 fn size_check() {
177 use std::mem;
178 assert_eq!(mem::size_of::<f16>(), 2);
179 assert_eq!(2 + 2, 4);
180 }
181
182 #[test]
183 fn identity_check() {
184 let x = 1;
185 let xs = f16 {
186 bit_repr: x,
187 };
188 assert_eq!(x, xs.bit_repr);
189 }
190
191 #[test]
192 fn bits_and_half() {
193 let x = 1;
194 assert_eq!(x, f16::from_bits(x).to_bits());
195 }
196
197 #[test]
198 fn float_to_half() {
199 let x0: f32 = 0.0;
200 let x1: f32 = 1.0;
201 let x2: f32 = 2.0;
202 assert_eq!(0, f16::to_bits(f16::from(x0)));
203 assert_eq!(15360, f16::from(x1).to_bits());
204 assert_eq!(16384, f16::from(x2).to_bits());
205 }
206
207 #[test]
208 fn half_to_float() {
209 let x0: f16 = f16::from_bits(0);
210 let x1: f16 = f16::from_bits(15360);
211 let xmin: f16 = f16::from_bits(1);
212 assert_eq!(0.0, f32::from(x0));
213 assert_eq!(1.0, f32::from(x1));
214 assert_eq!(5.9604645e-8, f32::from(xmin));
215 }
216
217 #[test]
218 fn partial_eq_check() {
219 let x0: f16 = f16::from_bits(0);
220 let x1: f16 = f16::from_bits(1);
221 assert!(x0 != x1);
222 }
223
224 #[test]
225 fn partial_ord_check() {
226 let x0: f16 = f16::from_bits(0);
227 let x1: f16 = f16::from_bits(1);
228 assert!(x0 < x1);
229 }
230
231 #[test]
232 fn arith_ops() {
233 let x0: f16 = f16::from_bits(0);
234 let x1: f16 = f16::from_bits(1);
235 assert_eq!(x1, x0+x1);
236 assert_eq!(x0, x0*x1);
237 }
238
239 #[test]
240 fn predicates() {
241 let nan: f16 = f16::from(0.0/0.0);
242 let pos_inf: f16 = f16::from(1.0/0.0);
243 let neg_inf: f16 = f16::from(-1.0/0.0);
244 let xmin: f16 = f16::from_bits(1);
245 assert!(nan.is_nan());
246 assert!(pos_inf.is_infinite());
247 assert!(neg_inf.is_infinite());
248 assert!(pos_inf.is_sign_positive());
249 assert!(neg_inf.is_sign_negative());
250 assert!(!xmin.is_normal());
251 assert!(xmin.is_subnormal());
252 }
253
254}