use bytemuck::{Pod, Zeroable};
use std::cmp::Ordering;
use std::fmt;
use std::ops::{Add, Div, Mul, Sub};
#[cfg(feature = "cuda")]
use cudarc::driver::DeviceRepr;
#[cfg(feature = "cuda")]
use cudarc::types::CudaTypeName;
#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct FP8E4M3(pub u8);
unsafe impl Pod for FP8E4M3 {}
unsafe impl Zeroable for FP8E4M3 {}
impl FP8E4M3 {
pub const ZERO: Self = Self(0x00);
pub const ONE: Self = Self(0x38);
pub const NEG_ONE: Self = Self(0xB8);
pub const MAX: Self = Self(0x7E); pub const MIN_POSITIVE: Self = Self(0x08); pub const INFINITY: Self = Self::MAX;
pub const NAN: Self = Self(0x7F);
#[allow(dead_code)]
const BIAS: i32 = 7;
#[allow(dead_code)]
const MANTISSA_BITS: u32 = 3;
#[allow(dead_code)]
const EXPONENT_BITS: u32 = 4;
#[inline]
pub const fn from_bits(bits: u8) -> Self {
Self(bits)
}
#[inline]
pub const fn to_bits(self) -> u8 {
self.0
}
#[inline]
pub fn from_f32(x: f32) -> Self {
Self(f32_to_fp8_e4m3(x))
}
#[inline]
pub fn to_f32(self) -> f32 {
fp8_e4m3_to_f32(self.0)
}
#[inline]
pub fn from_f64(x: f64) -> Self {
Self::from_f32(x as f32)
}
#[inline]
pub fn to_f64(self) -> f64 {
self.to_f32() as f64
}
#[inline]
pub fn is_nan(self) -> bool {
(self.0 & 0x7F) == 0x7F
}
#[inline]
pub fn is_zero(self) -> bool {
(self.0 & 0x7F) == 0
}
#[inline]
pub fn is_negative(self) -> bool {
(self.0 & 0x80) != 0 && !self.is_zero()
}
pub const fn max_value() -> f32 {
448.0
}
pub const fn min_value() -> f32 {
-448.0
}
}
impl fmt::Debug for FP8E4M3 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FP8E4M3({})", self.to_f32())
}
}
impl fmt::Display for FP8E4M3 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_f32())
}
}
impl PartialOrd for FP8E4M3 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
impl Add for FP8E4M3 {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() + rhs.to_f32())
}
}
impl Sub for FP8E4M3 {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() - rhs.to_f32())
}
}
impl Mul for FP8E4M3 {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() * rhs.to_f32())
}
}
impl Div for FP8E4M3 {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() / rhs.to_f32())
}
}
impl std::ops::AddAssign for FP8E4M3 {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = Self::from_f32(self.to_f32() + rhs.to_f32());
}
}
impl std::ops::SubAssign for FP8E4M3 {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = Self::from_f32(self.to_f32() - rhs.to_f32());
}
}
impl std::ops::MulAssign for FP8E4M3 {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = Self::from_f32(self.to_f32() * rhs.to_f32());
}
}
impl std::ops::DivAssign for FP8E4M3 {
#[inline]
fn div_assign(&mut self, rhs: Self) {
*self = Self::from_f32(self.to_f32() / rhs.to_f32());
}
}
#[derive(Copy, Clone, Default, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct FP8E5M2(pub u8);
unsafe impl Pod for FP8E5M2 {}
unsafe impl Zeroable for FP8E5M2 {}
impl FP8E5M2 {
pub const ZERO: Self = Self(0x00);
pub const ONE: Self = Self(0x3C);
pub const NEG_ONE: Self = Self(0xBC);
pub const MAX: Self = Self(0x7B); pub const MIN_POSITIVE: Self = Self(0x04); pub const INFINITY: Self = Self(0x7C); pub const NEG_INFINITY: Self = Self(0xFC);
pub const NAN: Self = Self(0x7F);
#[allow(dead_code)]
const BIAS: i32 = 15;
#[allow(dead_code)]
const MANTISSA_BITS: u32 = 2;
#[allow(dead_code)]
const EXPONENT_BITS: u32 = 5;
#[inline]
pub const fn from_bits(bits: u8) -> Self {
Self(bits)
}
#[inline]
pub const fn to_bits(self) -> u8 {
self.0
}
#[inline]
pub fn from_f32(x: f32) -> Self {
Self(f32_to_fp8_e5m2(x))
}
#[inline]
pub fn to_f32(self) -> f32 {
fp8_e5m2_to_f32(self.0)
}
#[inline]
pub fn from_f64(x: f64) -> Self {
Self::from_f32(x as f32)
}
#[inline]
pub fn to_f64(self) -> f64 {
self.to_f32() as f64
}
#[inline]
pub fn is_nan(self) -> bool {
let exp = (self.0 >> 2) & 0x1F;
let mant = self.0 & 0x03;
exp == 0x1F && mant != 0
}
#[inline]
pub fn is_infinite(self) -> bool {
(self.0 & 0x7F) == 0x7C
}
#[inline]
pub fn is_zero(self) -> bool {
(self.0 & 0x7F) == 0
}
#[inline]
pub fn is_negative(self) -> bool {
(self.0 & 0x80) != 0 && !self.is_zero()
}
pub const fn max_value() -> f32 {
57344.0
}
pub const fn min_value() -> f32 {
-57344.0
}
}
impl fmt::Debug for FP8E5M2 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "FP8E5M2({})", self.to_f32())
}
}
impl fmt::Display for FP8E5M2 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_f32())
}
}
impl PartialOrd for FP8E5M2 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.to_f32().partial_cmp(&other.to_f32())
}
}
impl Add for FP8E5M2 {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() + rhs.to_f32())
}
}
impl Sub for FP8E5M2 {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() - rhs.to_f32())
}
}
impl Mul for FP8E5M2 {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() * rhs.to_f32())
}
}
impl Div for FP8E5M2 {
type Output = Self;
#[inline]
fn div(self, rhs: Self) -> Self {
Self::from_f32(self.to_f32() / rhs.to_f32())
}
}
impl std::ops::AddAssign for FP8E5M2 {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = Self::from_f32(self.to_f32() + rhs.to_f32());
}
}
impl std::ops::SubAssign for FP8E5M2 {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = Self::from_f32(self.to_f32() - rhs.to_f32());
}
}
impl std::ops::MulAssign for FP8E5M2 {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = Self::from_f32(self.to_f32() * rhs.to_f32());
}
}
impl std::ops::DivAssign for FP8E5M2 {
#[inline]
fn div_assign(&mut self, rhs: Self) {
*self = Self::from_f32(self.to_f32() / rhs.to_f32());
}
}
#[cfg(feature = "cuda")]
impl CudaTypeName for FP8E4M3 {
const NAME: &'static str = "numr_fp8_e4m3";
}
#[cfg(feature = "cuda")]
impl CudaTypeName for FP8E5M2 {
const NAME: &'static str = "numr_fp8_e5m2";
}
#[cfg(feature = "cuda")]
unsafe impl DeviceRepr for FP8E4M3 {}
#[cfg(feature = "cuda")]
unsafe impl DeviceRepr for FP8E5M2 {}
#[inline]
pub fn f32_to_fp8_e4m3(x: f32) -> u8 {
let bits = x.to_bits();
let sign = (bits >> 31) as u8;
let exp = ((bits >> 23) & 0xFF) as i32;
let mant = bits & 0x7FFFFF;
if exp == 255 {
if mant != 0 {
return 0x7F | (sign << 7);
}
return 0x7E | (sign << 7);
}
if exp == 0 && mant == 0 {
return sign << 7;
}
let unbiased_exp = if exp == 0 {
-126
} else {
exp - 127
};
if unbiased_exp > 8 {
return 0x7E | (sign << 7);
}
if unbiased_exp < -9 {
return sign << 7;
}
if unbiased_exp < -6 {
let shift = (-6 - unbiased_exp) as u32;
let fp8_mant = if exp == 0 {
mant >> (23 - 3 + shift)
} else {
(0x800000 | mant) >> (23 - 3 + shift + 1)
};
return (sign << 7) | (fp8_mant as u8 & 0x07);
}
let fp8_exp = (unbiased_exp + 7) as u8;
let fp8_mant = if exp == 0 {
(mant >> (23 - 3)) as u8
} else {
let mant_shifted = mant >> (23 - 3 - 1); let round_bit = (mant_shifted & 1) != 0;
let mant_3bit = (mant_shifted >> 1) as u8;
if round_bit {
let remainder = mant & ((1 << (23 - 4)) - 1);
if remainder != 0 || (mant_3bit & 1) != 0 {
if mant_3bit == 0x07 {
if fp8_exp >= 15 {
return 0x7E | (sign << 7);
}
return (sign << 7) | ((fp8_exp + 1) << 3);
}
mant_3bit + 1
} else {
mant_3bit
}
} else {
mant_3bit
}
};
if fp8_exp == 15 && fp8_mant == 7 {
return 0x7E | (sign << 7);
}
(sign << 7) | (fp8_exp << 3) | (fp8_mant & 0x07)
}
#[inline]
pub fn fp8_e4m3_to_f32(x: u8) -> f32 {
let sign = (x >> 7) & 1;
let exp = (x >> 3) & 0x0F;
let mant = x & 0x07;
if exp == 15 && mant == 7 {
return f32::NAN;
}
if exp == 0 && mant == 0 {
return if sign != 0 { -0.0 } else { 0.0 };
}
let f32_sign = (sign as u32) << 31;
if exp == 0 {
let mant_val = mant as f32 / 8.0; let value = mant_val * 2.0f32.powi(-6);
return if sign != 0 { -value } else { value };
}
let f32_exp = ((exp as u32) + 120) << 23;
let f32_mant = (mant as u32) << 20;
f32::from_bits(f32_sign | f32_exp | f32_mant)
}
#[inline]
pub fn f32_to_fp8_e5m2(x: f32) -> u8 {
let bits = x.to_bits();
let sign = (bits >> 31) as u8;
let exp = ((bits >> 23) & 0xFF) as i32;
let mant = bits & 0x7FFFFF;
if exp == 255 {
if mant != 0 {
return 0x7F | (sign << 7);
}
return 0x7C | (sign << 7);
}
if exp == 0 && mant == 0 {
return sign << 7;
}
let unbiased_exp = if exp == 0 {
-126
} else {
exp - 127
};
if unbiased_exp > 15 {
return 0x7C | (sign << 7);
}
if unbiased_exp < -16 {
return sign << 7;
}
if unbiased_exp < -14 {
let shift = (-14 - unbiased_exp) as u32;
let fp8_mant = if exp == 0 {
mant >> (23 - 2 + shift)
} else {
(0x800000 | mant) >> (23 - 2 + shift + 1)
};
return (sign << 7) | (fp8_mant as u8 & 0x03);
}
let fp8_exp = (unbiased_exp + 15) as u8;
let fp8_mant = if exp == 0 {
(mant >> (23 - 2)) as u8
} else {
let mant_shifted = mant >> (23 - 2 - 1);
let round_bit = (mant_shifted & 1) != 0;
let mant_2bit = (mant_shifted >> 1) as u8;
if round_bit {
let remainder = mant & ((1 << (23 - 3)) - 1);
if remainder != 0 || (mant_2bit & 1) != 0 {
if mant_2bit == 0x03 {
if fp8_exp >= 30 {
return 0x7C | (sign << 7);
}
return (sign << 7) | ((fp8_exp + 1) << 2);
}
mant_2bit + 1
} else {
mant_2bit
}
} else {
mant_2bit
}
};
(sign << 7) | (fp8_exp << 2) | (fp8_mant & 0x03)
}
#[inline]
pub fn fp8_e5m2_to_f32(x: u8) -> f32 {
let sign = (x >> 7) & 1;
let exp = (x >> 2) & 0x1F;
let mant = x & 0x03;
if exp == 31 {
if mant != 0 {
return f32::NAN;
}
return if sign != 0 {
f32::NEG_INFINITY
} else {
f32::INFINITY
};
}
if exp == 0 && mant == 0 {
return if sign != 0 { -0.0 } else { 0.0 };
}
let f32_sign = (sign as u32) << 31;
if exp == 0 {
let mant_val = mant as f32 / 4.0;
let value = mant_val * 2.0f32.powi(-14);
return if sign != 0 { -value } else { value };
}
let f32_exp = ((exp as u32) + 112) << 23;
let f32_mant = (mant as u32) << 21;
f32::from_bits(f32_sign | f32_exp | f32_mant)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fp8_e4m3_zero() {
assert_eq!(FP8E4M3::ZERO.to_f32(), 0.0);
assert_eq!(FP8E4M3::from_f32(0.0).to_bits(), 0x00);
assert_eq!(FP8E4M3::from_f32(-0.0).to_bits(), 0x80);
}
#[test]
fn test_fp8_e4m3_one() {
let one = FP8E4M3::ONE;
assert!((one.to_f32() - 1.0).abs() < 0.01);
let from_one = FP8E4M3::from_f32(1.0);
assert!((from_one.to_f32() - 1.0).abs() < 0.01);
}
#[test]
fn test_fp8_e4m3_roundtrip() {
let test_values = [0.5, 1.0, 1.5, 2.0, 4.0, 8.0, 16.0, 100.0, 448.0];
for &val in &test_values {
let fp8 = FP8E4M3::from_f32(val);
let back = fp8.to_f32();
let rel_error = (back - val).abs() / val.abs().max(1e-6);
assert!(
rel_error < 0.2,
"FP8E4M3 roundtrip failed for {}: got {}, rel_error={}",
val,
back,
rel_error
);
}
}
#[test]
fn test_fp8_e4m3_negative() {
let neg_one = FP8E4M3::from_f32(-1.0);
assert!(neg_one.to_f32() < 0.0);
assert!((neg_one.to_f32() + 1.0).abs() < 0.01);
}
#[test]
fn test_fp8_e4m3_overflow() {
let big = FP8E4M3::from_f32(1000.0);
assert!((big.to_f32() - 448.0).abs() < 1.0); }
#[test]
fn test_fp8_e4m3_underflow() {
let tiny = FP8E4M3::from_f32(1e-10);
assert_eq!(tiny.to_f32(), 0.0);
}
#[test]
fn test_fp8_e4m3_nan() {
let nan = FP8E4M3::from_f32(f32::NAN);
assert!(nan.is_nan());
assert!(nan.to_f32().is_nan());
}
#[test]
fn test_fp8_e4m3_inf() {
let inf = FP8E4M3::from_f32(f32::INFINITY);
assert!((inf.to_f32() - 448.0).abs() < 1.0);
}
#[test]
fn test_fp8_e4m3_arithmetic() {
let a = FP8E4M3::from_f32(2.0);
let b = FP8E4M3::from_f32(3.0);
let sum = a + b;
assert!((sum.to_f32() - 5.0).abs() < 0.5);
let diff = b - a;
assert!((diff.to_f32() - 1.0).abs() < 0.5);
let prod = a * b;
assert!((prod.to_f32() - 6.0).abs() < 0.5);
let quot = b / a;
assert!((quot.to_f32() - 1.5).abs() < 0.3);
}
#[test]
fn test_fp8_e5m2_zero() {
assert_eq!(FP8E5M2::ZERO.to_f32(), 0.0);
assert_eq!(FP8E5M2::from_f32(0.0).to_bits(), 0x00);
}
#[test]
fn test_fp8_e5m2_one() {
let one = FP8E5M2::ONE;
assert!((one.to_f32() - 1.0).abs() < 0.01);
}
#[test]
fn test_fp8_e5m2_roundtrip() {
let test_values = [0.5, 1.0, 2.0, 4.0, 8.0, 100.0, 1000.0, 10000.0];
for &val in &test_values {
let fp8 = FP8E5M2::from_f32(val);
let back = fp8.to_f32();
let rel_error = (back - val).abs() / val.abs().max(1e-6);
assert!(
rel_error < 0.35,
"FP8E5M2 roundtrip failed for {}: got {}, rel_error={}",
val,
back,
rel_error
);
}
}
#[test]
fn test_fp8_e5m2_large_range() {
let large = FP8E5M2::from_f32(50000.0);
assert!(large.to_f32() > 30000.0);
}
#[test]
fn test_fp8_e5m2_inf() {
let inf = FP8E5M2::from_f32(f32::INFINITY);
assert!(inf.is_infinite());
assert!(inf.to_f32().is_infinite());
}
#[test]
fn test_fp8_e5m2_nan() {
let nan = FP8E5M2::from_f32(f32::NAN);
assert!(nan.is_nan());
assert!(nan.to_f32().is_nan());
}
#[test]
fn test_fp8_e5m2_arithmetic() {
let a = FP8E5M2::from_f32(100.0);
let b = FP8E5M2::from_f32(200.0);
let sum = a + b;
assert!((sum.to_f32() - 300.0).abs() < 50.0);
let diff = b - a;
assert!((diff.to_f32() - 100.0).abs() < 30.0);
}
#[test]
fn test_fp8_e4m3_bytemuck() {
let arr = [FP8E4M3::ZERO, FP8E4M3::ONE];
let bytes: &[u8] = bytemuck::cast_slice(&arr);
assert_eq!(bytes.len(), 2);
let back: &[FP8E4M3] = bytemuck::cast_slice(bytes);
assert_eq!(back[0], FP8E4M3::ZERO);
assert_eq!(back[1], FP8E4M3::ONE);
}
#[test]
fn test_fp8_e5m2_bytemuck() {
let arr = [FP8E5M2::ZERO, FP8E5M2::ONE];
let bytes: &[u8] = bytemuck::cast_slice(&arr);
assert_eq!(bytes.len(), 2);
let back: &[FP8E5M2] = bytemuck::cast_slice(bytes);
assert_eq!(back[0], FP8E5M2::ZERO);
assert_eq!(back[1], FP8E5M2::ONE);
}
}