#[allow(non_camel_case_types)]
#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct f16(u16);
impl f16 {
pub const ZERO: Self = Self(0);
#[inline]
pub const fn from_bits(bits: u16) -> Self {
Self(bits)
}
#[inline]
pub const fn to_bits(self) -> u16 {
self.0
}
#[inline]
pub fn to_f32(self) -> f32 {
let bits = self.0;
let sign = ((bits >> 15) & 1) as u32;
let exp = ((bits >> 10) & 0x1F) as u32;
let mant = (bits & 0x3FF) as u32;
let f32_bits = if exp == 0 {
if mant == 0 {
sign << 31
} else {
let mut m = mant;
let mut e = 0u32;
while (m & 0x400) == 0 {
m <<= 1;
e += 1;
}
m &= 0x3FF; let new_exp = 127 - 14 - e;
(sign << 31) | (new_exp << 23) | (m << 13)
}
} else if exp == 31 {
if mant == 0 {
(sign << 31) | (0xFF << 23)
} else {
(sign << 31) | (0xFF << 23) | (mant << 13) | 0x0040_0000
}
} else {
let new_exp = exp + 112;
(sign << 31) | (new_exp << 23) | (mant << 13)
};
f32::from_bits(f32_bits)
}
#[inline]
pub fn from_f32(f: f32) -> Self {
let bits = f.to_bits();
let sign = ((bits >> 31) & 1) as u16;
let exp = ((bits >> 23) & 0xFF) as i32;
let mant = bits & 0x007F_FFFF;
let h_bits = if exp == 0 {
sign << 15
} else if exp == 255 {
if mant == 0 {
(sign << 15) | (0x1F << 10) } else {
(sign << 15) | (0x1F << 10) | 0x0200 }
} else {
let unbiased = exp - 127;
if unbiased < -24 {
sign << 15
} else if unbiased < -14 {
let shift = (-14 - unbiased) as u32;
let m = ((mant | 0x0080_0000) >> (shift + 14)) as u16;
(sign << 15) | m
} else if unbiased > 15 {
(sign << 15) | (0x1F << 10)
} else {
let h_exp = (unbiased + 15) as u16;
let h_mant = (mant >> 13) as u16;
let round_bit = (mant >> 12) & 1;
let sticky = mant & 0x0FFF;
let h_mant = if round_bit == 1 && (sticky != 0 || (h_mant & 1) == 1) {
h_mant + 1
} else {
h_mant
};
if h_mant > 0x3FF {
if h_exp >= 30 {
(sign << 15) | (0x1F << 10)
} else {
(sign << 15) | ((h_exp + 1) << 10)
}
} else {
(sign << 15) | (h_exp << 10) | h_mant
}
}
};
Self(h_bits)
}
#[inline]
pub fn from_f64(f: f64) -> Self {
Self::from_f32(f as f32)
}
#[inline]
pub fn to_f64(self) -> f64 {
self.to_f32() as f64
}
#[inline]
pub fn is_finite(self) -> bool {
((self.0 >> 10) & 0x1F) != 31
}
#[inline]
pub const fn to_le_bytes(self) -> [u8; 2] {
self.0.to_le_bytes()
}
#[inline]
pub const fn to_be_bytes(self) -> [u8; 2] {
self.0.to_be_bytes()
}
}
impl From<f16> for f32 {
#[inline]
fn from(f: f16) -> f32 {
f.to_f32()
}
}
impl From<f16> for f64 {
#[inline]
fn from(f: f16) -> f64 {
f.to_f64()
}
}
impl core::fmt::Debug for f16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.to_f32())
}
}
impl core::fmt::Display for f16 {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.to_f32())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zero() {
let z = f16::ZERO;
assert_eq!(z.to_bits(), 0);
assert_eq!(z.to_f32(), 0.0);
assert!(z.is_finite());
}
#[test]
fn test_one() {
let one = f16::from_bits(0x3C00);
assert!((one.to_f32() - 1.0).abs() < 1e-6);
assert!(one.is_finite());
}
#[test]
fn test_negative_one() {
let neg_one = f16::from_bits(0xBC00);
assert!((neg_one.to_f32() - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_infinity() {
let inf = f16::from_bits(0x7C00);
assert!(inf.to_f32().is_infinite());
assert!(!inf.is_finite());
let neg_inf = f16::from_bits(0xFC00);
assert!(neg_inf.to_f32().is_infinite());
assert!(!neg_inf.is_finite());
}
#[test]
fn test_nan() {
let nan = f16::from_bits(0x7C01);
assert!(nan.to_f32().is_nan());
assert!(!nan.is_finite());
}
#[test]
fn test_denormal() {
let tiny = f16::from_bits(0x0001);
let val = tiny.to_f32();
assert!(val > 0.0);
assert!(val < 1e-6);
assert!(tiny.is_finite());
}
#[test]
fn test_roundtrip_normal() {
let test_values: [f32; 8] = [0.5, 1.0, 2.0, 100.0, 0.001, -0.5, -1.0, -100.0];
for &v in &test_values {
let h = f16::from_f32(v);
let back = h.to_f32();
let rel_err = ((v - back) / v).abs();
assert!(
rel_err < 0.002,
"Roundtrip failed for {}: got {}, rel_err {}",
v,
back,
rel_err
);
}
}
#[test]
fn test_roundtrip_special() {
assert_eq!(f16::from_f32(0.0).to_f32(), 0.0);
assert!(f16::from_f32(f32::INFINITY).to_f32().is_infinite());
assert!(f16::from_f32(f32::NEG_INFINITY).to_f32().is_infinite());
assert!(f16::from_f32(f32::NAN).to_f32().is_nan());
}
#[test]
fn test_overflow_to_infinity() {
let big = f16::from_f32(100000.0);
assert!(big.to_f32().is_infinite());
}
#[test]
fn test_underflow_to_zero() {
let tiny = f16::from_f32(1e-10);
assert_eq!(tiny.to_f32(), 0.0);
}
#[test]
fn test_bytes() {
let h = f16::from_bits(0x1234);
assert_eq!(h.to_le_bytes(), [0x34, 0x12]);
assert_eq!(h.to_be_bytes(), [0x12, 0x34]);
}
}