use crate::error::{QuantError, QuantResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Fp8Format {
E4M3,
E5M2,
}
impl Fp8Format {
#[must_use]
pub fn exp_bits(self) -> u32 {
match self {
Self::E4M3 => 4,
Self::E5M2 => 5,
}
}
#[must_use]
pub fn man_bits(self) -> u32 {
match self {
Self::E4M3 => 3,
Self::E5M2 => 2,
}
}
#[must_use]
pub fn bias(self) -> i32 {
match self {
Self::E4M3 => 7, Self::E5M2 => 15, }
}
#[must_use]
pub fn max_val(self) -> f32 {
match self {
Self::E4M3 => 448.0,
Self::E5M2 => 57344.0,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Fp8Codec {
pub format: Fp8Format,
pub saturate: bool,
}
impl Fp8Codec {
#[must_use]
pub fn e4m3() -> Self {
Self {
format: Fp8Format::E4M3,
saturate: true,
}
}
#[must_use]
pub fn e5m2() -> Self {
Self {
format: Fp8Format::E5M2,
saturate: true,
}
}
pub fn encode_f32(&self, v: f32) -> QuantResult<u8> {
if !v.is_finite() {
return Err(QuantError::NonFiniteFp8(v));
}
let max = self.format.max_val();
let v_sat = v.clamp(-max, max);
Ok(self.fp32_to_fp8(v_sat))
}
#[must_use]
pub fn decode_f32(&self, b: u8) -> f32 {
self.fp8_to_fp32(b)
}
pub fn encode(&self, data: &[f32]) -> QuantResult<Vec<u8>> {
data.iter().map(|&v| self.encode_f32(v)).collect()
}
pub fn decode(&self, data: &[u8]) -> Vec<f32> {
data.iter().map(|&b| self.decode_f32(b)).collect()
}
pub fn quantization_mse(&self, data: &[f32]) -> QuantResult<f32> {
let encoded = self.encode(data)?;
let decoded = self.decode(&encoded);
let mse = data
.iter()
.zip(decoded.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
/ data.len() as f32;
Ok(mse)
}
fn fp32_to_fp8(&self, v: f32) -> u8 {
let bits = v.to_bits();
let sign = (bits >> 31) as u8;
let exp32 = ((bits >> 23) & 0xFF) as i32; let man32 = bits & 0x007F_FFFF;
let exp_bits = self.format.exp_bits();
let man_bits = self.format.man_bits();
let bias8 = self.format.bias();
if v == 0.0 || v == -0.0 {
return sign << 7;
}
let exp_unbiased = exp32 - 127;
let exp8_raw = exp_unbiased + bias8;
let man_shift = 23 - man_bits;
if exp8_raw <= 0 {
let full_man = (man32 | 0x0080_0000) >> 1; let shift = (1 - exp8_raw) as u32 + man_shift;
if shift >= 24 {
return sign << 7;
} let man8 = (full_man >> shift) as u8 & ((1 << man_bits) - 1);
return (sign << 7) | man8;
}
let max_exp = (1 << exp_bits) - 1;
if exp8_raw >= max_exp {
return match self.format {
Fp8Format::E4M3 => (sign << 7) | 0x7E, Fp8Format::E5M2 => (sign << 7) | 0x7B, };
}
let man8 = (man32 >> man_shift) as u8 & ((1 << man_bits) - 1);
(sign << 7) | ((exp8_raw as u8) << man_bits) | man8
}
fn fp8_to_fp32(&self, b: u8) -> f32 {
let sign = (b >> 7) as u32;
let exp_bits = self.format.exp_bits();
let man_bits = self.format.man_bits();
let bias8 = self.format.bias();
let exp8 = ((b >> man_bits) & ((1 << exp_bits) - 1)) as u32;
let man8 = (b & ((1 << man_bits) - 1)) as u32;
let all_exp = (1 << exp_bits) - 1;
match self.format {
Fp8Format::E4M3 => {
if exp8 == all_exp as u32 && man8 == (1 << man_bits) - 1 {
return f32::NAN; }
}
Fp8Format::E5M2 => {
if exp8 == all_exp as u32 {
if man8 == 0 {
return if sign == 0 {
f32::INFINITY
} else {
f32::NEG_INFINITY
};
}
return f32::NAN;
}
}
}
if exp8 == 0 {
if man8 == 0 {
return f32::from_bits(sign << 31); }
let man_shift = 23 - man_bits;
let exp32 = (127 + 1 - bias8) as u32;
let leading = man_bits - 1 - man8.leading_zeros().min(man_bits - 1);
let exp32_adj = exp32.wrapping_sub(leading);
let man32 = ((man8 << leading) & ((1 << man_bits) - 1)) << man_shift;
return f32::from_bits((sign << 31) | (exp32_adj << 23) | man32);
}
let exp32 = (exp8 as i32 - bias8 + 127) as u32;
let man_shift = 23 - man_bits;
let man32 = man8 << man_shift;
f32::from_bits((sign << 31) | (exp32 << 23) | man32)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn e4m3_format_params() {
assert_eq!(Fp8Format::E4M3.exp_bits(), 4);
assert_eq!(Fp8Format::E4M3.man_bits(), 3);
assert_eq!(Fp8Format::E4M3.bias(), 7);
assert_abs_diff_eq!(Fp8Format::E4M3.max_val(), 448.0, epsilon = 1.0);
}
#[test]
fn e5m2_format_params() {
assert_eq!(Fp8Format::E5M2.exp_bits(), 5);
assert_eq!(Fp8Format::E5M2.man_bits(), 2);
assert_eq!(Fp8Format::E5M2.bias(), 15);
assert_abs_diff_eq!(Fp8Format::E5M2.max_val(), 57344.0, epsilon = 100.0);
}
#[test]
fn e4m3_zero_encodes_to_zero() {
let c = Fp8Codec::e4m3();
assert_eq!(c.encode_f32(0.0).unwrap(), 0x00);
assert_eq!(c.encode_f32(-0.0).unwrap(), 0x80);
}
#[test]
fn e4m3_round_trip_basic() {
let c = Fp8Codec::e4m3();
for &v in &[1.0_f32, -1.0, 2.0, 0.5, 0.25, -0.25] {
let enc = c.encode_f32(v).unwrap();
let dec = c.decode_f32(enc);
let rel_err = (v - dec).abs() / v.abs().max(1e-6);
assert!(rel_err < 0.15, "v={v}, dec={dec}, rel_err={rel_err}");
}
}
#[test]
fn e5m2_round_trip_basic() {
let c = Fp8Codec::e5m2();
for &v in &[1.0_f32, -1.0, 4.0, 16.0, -8.0] {
let enc = c.encode_f32(v).unwrap();
let dec = c.decode_f32(enc);
let rel_err = (v - dec).abs() / v.abs().max(1e-6);
assert!(rel_err < 0.25, "v={v}, dec={dec}, rel_err={rel_err}");
}
}
#[test]
fn e4m3_saturates_large_values() {
let c = Fp8Codec::e4m3();
let enc = c.encode_f32(1000.0).unwrap();
let dec = c.decode_f32(enc);
assert!(dec <= 448.0 + 1.0, "should saturate, got {dec}");
assert!(dec > 0.0, "positive saturation should be positive");
}
#[test]
fn nan_input_errors() {
let c = Fp8Codec {
format: Fp8Format::E4M3,
saturate: false,
};
assert!(matches!(
c.encode_f32(f32::NAN),
Err(QuantError::NonFiniteFp8(_))
));
assert!(matches!(
c.encode_f32(f32::INFINITY),
Err(QuantError::NonFiniteFp8(_))
));
}
#[test]
fn mse_within_tolerance() {
let c = Fp8Codec::e4m3();
let data: Vec<f32> = (0..256).map(|i| (i as f32 / 128.0) - 1.0).collect();
let mse = c.quantization_mse(&data).unwrap();
assert!(mse < 0.01, "E4M3 MSE unexpectedly large: {mse}");
}
#[test]
fn batch_encode_decode() {
let c = Fp8Codec::e4m3();
let data = vec![0.0_f32, 1.0, -1.0, 0.5, 2.0, -2.0];
let enc = c.encode(&data).unwrap();
assert_eq!(enc.len(), data.len());
let dec = c.decode(&enc);
assert_eq!(dec.len(), data.len());
}
}