use std::cmp::Ordering;
use std::fmt;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct F16(u16);
impl F16 {
pub const ZERO: F16 = F16(0x0000);
pub const NEG_ZERO: F16 = F16(0x8000);
pub const ONE: F16 = F16(0x3C00);
pub const NEG_ONE: F16 = F16(0xBC00);
pub const MAX: F16 = F16(0x7BFF);
pub const MIN_POSITIVE: F16 = F16(0x0400);
pub const EPSILON: F16 = F16(0x1400);
pub const NAN: F16 = F16(0x7E00);
pub const INFINITY: F16 = F16(0x7C00);
pub const NEG_INFINITY: F16 = F16(0xFC00);
#[inline]
pub const fn from_bits(bits: u16) -> Self {
F16(bits)
}
#[inline]
pub const fn bits(self) -> u16 {
self.0
}
#[inline]
pub fn from_f32(value: f32) -> Self {
let bits = value.to_bits();
let sign = (bits >> 31) & 1;
let exp = ((bits >> 23) & 0xFF) as i32;
let mantissa = bits & 0x007F_FFFF;
let h_sign = (sign as u16) << 15;
if exp == 0 && mantissa == 0 {
return F16(h_sign);
}
if exp == 0xFF && mantissa != 0 {
let h_mantissa = (mantissa >> 13) as u16;
let h_mantissa = if h_mantissa == 0 { 1 } else { h_mantissa };
return F16(h_sign | 0x7C00 | (h_mantissa & 0x03FF));
}
if exp == 0xFF {
return F16(h_sign | 0x7C00);
}
let new_exp = exp - 127 + 15;
if exp == 0 {
return F16(h_sign);
}
if new_exp < -10 {
return F16(h_sign);
}
if new_exp < 1 {
let shift = (1 - new_exp) as u32;
let full_mantissa = mantissa | 0x0080_0000;
let total_shift = 13 + shift;
if total_shift >= 24 {
return F16(h_sign);
}
let h_mantissa_raw = full_mantissa >> total_shift;
let round_bit = if total_shift > 0 {
(full_mantissa >> (total_shift - 1)) & 1
} else {
0
};
let sticky_bits = if total_shift > 1 {
full_mantissa & ((1 << (total_shift - 1)) - 1)
} else {
0
};
let h_mantissa = if round_bit != 0 && (sticky_bits != 0 || (h_mantissa_raw & 1) != 0) {
(h_mantissa_raw + 1) as u16
} else {
h_mantissa_raw as u16
};
return F16(h_sign | h_mantissa);
}
if new_exp > 30 {
return F16(h_sign | 0x7C00);
}
let h_exp = (new_exp as u16) << 10;
let h_mantissa_raw = (mantissa >> 13) as u16;
let round_bit = (mantissa >> 12) & 1;
let sticky_bits = mantissa & 0x0FFF;
let h_mantissa = if round_bit != 0 && (sticky_bits != 0 || (h_mantissa_raw & 1) != 0) {
h_mantissa_raw + 1
} else {
h_mantissa_raw
};
if h_mantissa > 0x03FF {
let h_exp_inc = h_exp + (1 << 10);
if h_exp_inc >= 0x7C00 {
return F16(h_sign | 0x7C00);
}
return F16(h_sign | h_exp_inc | (h_mantissa & 0x03FF));
}
F16(h_sign | h_exp | h_mantissa)
}
#[inline]
pub fn to_f32(self) -> f32 {
let sign = ((self.0 >> 15) & 1) as u32;
let exp = ((self.0 >> 10) & 0x1F) as u32;
let mantissa = (self.0 & 0x03FF) as u32;
if exp == 0 {
if mantissa == 0 {
return f32::from_bits(sign << 31);
}
let mut m = mantissa;
let mut shift: i32 = 0;
while (m & 0x0400) == 0 {
m <<= 1;
shift += 1;
}
m &= 0x03FF;
let f32_exp = (113 - shift) as u32;
let f32_mantissa = m << 13;
return f32::from_bits((sign << 31) | (f32_exp << 23) | f32_mantissa);
}
if exp == 0x1F {
if mantissa == 0 {
return f32::from_bits((sign << 31) | 0x7F80_0000);
}
return f32::from_bits((sign << 31) | 0x7F80_0000 | (mantissa << 13));
}
let f32_exp = (exp + 112) << 23;
let f32_mantissa = mantissa << 13;
f32::from_bits((sign << 31) | f32_exp | f32_mantissa)
}
#[inline]
pub fn from_f64(value: f64) -> Self {
F16::from_f32(value as f32)
}
#[inline]
pub fn to_f64(self) -> f64 {
self.to_f32() as f64
}
#[inline]
pub fn is_nan(self) -> bool {
let exp = (self.0 >> 10) & 0x1F;
let mantissa = self.0 & 0x03FF;
exp == 0x1F && mantissa != 0
}
#[inline]
pub fn is_infinite(self) -> bool {
let exp = (self.0 >> 10) & 0x1F;
let mantissa = self.0 & 0x03FF;
exp == 0x1F && mantissa == 0
}
#[inline]
pub fn is_subnormal(self) -> bool {
let exp = (self.0 >> 10) & 0x1F;
let mantissa = self.0 & 0x03FF;
exp == 0 && mantissa != 0
}
#[inline]
pub fn is_zero(self) -> bool {
self.0 & 0x7FFF == 0
}
#[inline]
pub fn is_finite(self) -> bool {
(self.0 >> 10) & 0x1F != 0x1F
}
#[inline]
pub fn abs(self) -> Self {
F16(self.0 & 0x7FFF)
}
#[inline]
#[allow(clippy::should_implement_trait)]
pub fn neg(self) -> Self {
F16(self.0 ^ 0x8000)
}
}
impl fmt::Display for F16 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_f32())
}
}
impl From<f32> for F16 {
#[inline]
fn from(value: f32) -> Self {
F16::from_f32(value)
}
}
impl From<f64> for F16 {
#[inline]
fn from(value: f64) -> Self {
F16::from_f64(value)
}
}
impl From<F16> for f32 {
#[inline]
fn from(value: F16) -> Self {
value.to_f32()
}
}
impl From<F16> for f64 {
#[inline]
fn from(value: F16) -> Self {
value.to_f64()
}
}
impl PartialOrd for F16 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
impl Default for F16 {
fn default() -> Self {
F16::ZERO
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct BF16(u16);
impl BF16 {
pub const ZERO: BF16 = BF16(0x0000);
pub const NEG_ZERO: BF16 = BF16(0x8000);
pub const ONE: BF16 = BF16(0x3F80);
pub const NEG_ONE: BF16 = BF16(0xBF80);
pub const MAX: BF16 = BF16(0x7F7F);
pub const MIN_POSITIVE: BF16 = BF16(0x0080);
pub const EPSILON: BF16 = BF16(0x3C00);
pub const NAN: BF16 = BF16(0x7FC0);
pub const INFINITY: BF16 = BF16(0x7F80);
pub const NEG_INFINITY: BF16 = BF16(0xFF80);
#[inline]
pub const fn from_bits(bits: u16) -> Self {
BF16(bits)
}
#[inline]
pub const fn bits(self) -> u16 {
self.0
}
#[inline]
pub fn from_f32(value: f32) -> Self {
let bits = value.to_bits();
if value.is_nan() {
return BF16::NAN;
}
let rounded = bits.wrapping_add(0x7FFF + ((bits >> 16) & 1));
BF16((rounded >> 16) as u16)
}
#[inline]
pub fn to_f32(self) -> f32 {
f32::from_bits((self.0 as u32) << 16)
}
#[inline]
pub fn from_f64(value: f64) -> Self {
BF16::from_f32(value as f32)
}
#[inline]
pub fn to_f64(self) -> f64 {
self.to_f32() as f64
}
#[inline]
pub fn is_nan(self) -> bool {
let exp = (self.0 >> 7) & 0xFF;
let mantissa = self.0 & 0x7F;
exp == 0xFF && mantissa != 0
}
#[inline]
pub fn is_infinite(self) -> bool {
let exp = (self.0 >> 7) & 0xFF;
let mantissa = self.0 & 0x7F;
exp == 0xFF && mantissa == 0
}
#[inline]
pub fn is_subnormal(self) -> bool {
let exp = (self.0 >> 7) & 0xFF;
let mantissa = self.0 & 0x7F;
exp == 0 && mantissa != 0
}
#[inline]
pub fn is_zero(self) -> bool {
self.0 & 0x7FFF == 0
}
#[inline]
pub fn is_finite(self) -> bool {
(self.0 >> 7) & 0xFF != 0xFF
}
#[inline]
pub fn abs(self) -> Self {
BF16(self.0 & 0x7FFF)
}
#[inline]
#[allow(clippy::should_implement_trait)]
pub fn neg(self) -> Self {
BF16(self.0 ^ 0x8000)
}
}
impl fmt::Display for BF16 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_f32())
}
}
impl From<f32> for BF16 {
#[inline]
fn from(value: f32) -> Self {
BF16::from_f32(value)
}
}
impl From<f64> for BF16 {
#[inline]
fn from(value: f64) -> Self {
BF16::from_f64(value)
}
}
impl From<BF16> for f32 {
#[inline]
fn from(value: BF16) -> Self {
value.to_f32()
}
}
impl From<BF16> for f64 {
#[inline]
fn from(value: BF16) -> Self {
value.to_f64()
}
}
impl PartialOrd for BF16 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
impl Default for BF16 {
fn default() -> Self {
BF16::ZERO
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_f16_zero() {
let z = F16::from_f32(0.0);
assert_eq!(z.bits(), 0x0000);
assert_eq!(z.to_f32(), 0.0);
assert!(z.is_zero());
assert!(!z.is_nan());
assert!(!z.is_infinite());
}
#[test]
fn test_f16_neg_zero() {
let nz = F16::from_f32(-0.0);
assert_eq!(nz.bits(), 0x8000);
assert!(nz.is_zero());
let val = nz.to_f32();
assert!(val.is_sign_negative());
assert_eq!(val, 0.0);
}
#[test]
fn test_f16_one() {
let one = F16::from_f32(1.0);
assert_eq!(one.bits(), F16::ONE.bits());
assert!((one.to_f32() - 1.0).abs() < 1e-10);
}
#[test]
fn test_f16_roundtrip_representable() {
let values = [0.0f32, 1.0, -1.0, 0.5, -0.5, 2.0, 0.25, 65504.0, -65504.0];
for &v in &values {
let h = F16::from_f32(v);
let back = h.to_f32();
assert_eq!(back, v, "f16 round-trip failed for {v}: got {back}");
}
}
#[test]
fn test_f16_nan() {
let nan = F16::from_f32(f32::NAN);
assert!(nan.is_nan());
assert!(!nan.is_infinite());
assert!(!nan.is_finite());
assert!(nan.to_f32().is_nan());
}
#[test]
fn test_f16_infinity() {
let pos_inf = F16::from_f32(f32::INFINITY);
assert!(pos_inf.is_infinite());
assert!(!pos_inf.is_nan());
assert_eq!(pos_inf.bits(), F16::INFINITY.bits());
assert_eq!(pos_inf.to_f32(), f32::INFINITY);
let neg_inf = F16::from_f32(f32::NEG_INFINITY);
assert!(neg_inf.is_infinite());
assert_eq!(neg_inf.bits(), F16::NEG_INFINITY.bits());
assert_eq!(neg_inf.to_f32(), f32::NEG_INFINITY);
}
#[test]
fn test_f16_overflow() {
let big = F16::from_f32(100000.0);
assert!(big.is_infinite());
}
#[test]
fn test_f16_subnormal() {
let tiny = F16::from_bits(0x0001);
assert!(tiny.is_subnormal());
assert!(!tiny.is_zero());
let val = tiny.to_f32();
let expected = 2.0f32.powi(-24);
assert!(
(val - expected).abs() < 1e-10,
"smallest subnormal: expected {expected}, got {val}"
);
}
#[test]
fn test_f16_display() {
let v = F16::from_f32(3.25);
let s = format!("{v}");
assert!(!s.is_empty());
}
#[test]
fn test_f16_partial_ord() {
let a = F16::from_f32(1.0);
let b = F16::from_f32(2.0);
assert!(a < b);
assert!(b > a);
let c = F16::from_f32(1.0);
assert_eq!(a.partial_cmp(&c), Some(Ordering::Equal));
}
#[test]
fn test_f16_from_f64() {
let v = F16::from_f64(1.5);
assert!((v.to_f64() - 1.5).abs() < 1e-3);
}
#[test]
fn test_f16_abs_neg() {
let pos = F16::from_f32(3.0);
let neg = pos.neg();
assert!((neg.to_f32() + 3.0).abs() < 1e-3);
assert_eq!(neg.abs().bits(), pos.bits());
}
#[test]
fn test_bf16_zero() {
let z = BF16::from_f32(0.0);
assert_eq!(z.bits(), 0x0000);
assert_eq!(z.to_f32(), 0.0);
assert!(z.is_zero());
}
#[test]
fn test_bf16_one() {
let one = BF16::from_f32(1.0);
assert_eq!(one.bits(), BF16::ONE.bits());
assert_eq!(one.to_f32(), 1.0);
}
#[test]
fn test_bf16_roundtrip() {
let values = [0.0f32, 1.0, -1.0, 2.0, -2.0, 0.5, 128.0];
for &v in &values {
let b = BF16::from_f32(v);
let back = b.to_f32();
assert_eq!(back, v, "bf16 round-trip failed for {v}: got {back}");
}
}
#[test]
fn test_bf16_precision_loss() {
let v = 1.0f32 + (1.0 / 256.0); let b = BF16::from_f32(v);
let back = b.to_f32();
assert!((back - v).abs() < 0.01);
}
#[test]
fn test_bf16_nan() {
let nan = BF16::from_f32(f32::NAN);
assert!(nan.is_nan());
assert!(nan.to_f32().is_nan());
}
#[test]
fn test_bf16_infinity() {
let inf = BF16::from_f32(f32::INFINITY);
assert!(inf.is_infinite());
assert_eq!(inf.to_f32(), f32::INFINITY);
let ninf = BF16::from_f32(f32::NEG_INFINITY);
assert!(ninf.is_infinite());
assert_eq!(ninf.to_f32(), f32::NEG_INFINITY);
}
#[test]
fn test_bf16_large_range() {
let large = BF16::from_f32(1.0e38);
assert!(large.is_finite());
let back = large.to_f32();
assert!(back > 9.0e37);
assert!(back.is_finite());
let max_val = BF16::MAX.to_f32();
assert!(BF16::MAX.is_finite());
assert!(max_val > 3.0e38, "BF16 max should be > 3e38, got {max_val}");
}
#[test]
fn test_bf16_to_f32_fast_path() {
let one = BF16::ONE;
let f = one.to_f32();
assert_eq!(f, 1.0f32);
assert_eq!(f.to_bits(), (one.bits() as u32) << 16);
}
#[test]
fn test_bf16_display() {
let v = BF16::from_f32(42.0);
let s = format!("{v}");
assert!(s.contains("42"));
}
#[test]
fn test_bf16_partial_ord() {
let a = BF16::from_f32(1.0);
let b = BF16::from_f32(2.0);
assert!(a < b);
}
#[test]
fn test_bf16_neg_zero() {
let nz = BF16::from_f32(-0.0);
assert!(nz.is_zero());
assert!(nz.to_f32().is_sign_negative());
}
}