pub const E4M3_MAX_FINITE: f32 = 448.0;
pub const E4M3_MIN_POS_SUBNORMAL: f32 = 1.0 / 512.0;
pub const FP8_MAX_FROBENIUS_REL_ERR: f64 = 0.01;
pub const FP8_MAX_FROBENIUS_REL_ERR_SYNTHETIC: f64 = 0.05;
pub const FP8_MIN_HOPPER_SM: u32 = 90;
pub const FP8_INCAPABLE_EXIT_CODE: u8 = 2;
pub fn e4m3_positive_finites() -> Vec<f32> {
let mut v: Vec<f32> = Vec::with_capacity(126);
for m in 1..=7u32 {
v.push((m as f32) * E4M3_MIN_POS_SUBNORMAL);
}
for e in 1..=14i32 {
let scale = 2f32.powi(e - 7);
for m in 0..=7u32 {
v.push((1.0 + (m as f32) / 8.0) * scale);
}
}
let top_scale = 2f32.powi(8);
for m in 0..=6u32 {
v.push((1.0 + (m as f32) / 8.0) * top_scale);
}
v.sort_by(|a, b| a.partial_cmp(b).expect("finite"));
v
}
pub fn e4m3_round_to_nearest(x: f32) -> f32 {
if x.is_nan() {
return f32::NAN;
}
let mag = x.abs();
let clamped = mag.min(E4M3_MAX_FINITE);
let table = e4m3_positive_finites();
let mut best: f32 = 0.0;
let mut best_dist: f32 = (clamped - 0.0).abs();
for &c in &table {
let d = (clamped - c).abs();
if d < best_dist {
best_dist = d;
best = c;
}
}
if x.is_sign_negative() {
-best
} else {
best
}
}
pub fn fp8_absmax_scale(w: &[f32]) -> f32 {
let m = w.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
if m == 0.0 {
1.0
} else {
m / E4M3_MAX_FINITE
}
}
pub fn fp8_quantize_dequantize(w: &[f32]) -> (Vec<f32>, f32) {
let scale = fp8_absmax_scale(w);
let out: Vec<f32> = w
.iter()
.map(|x| {
let q = e4m3_round_to_nearest(x / scale);
q * scale
})
.collect();
(out, scale)
}
pub fn frobenius_norm(w: &[f32]) -> f64 {
let sq: f64 = w.iter().map(|x| (*x as f64) * (*x as f64)).sum();
sq.sqrt()
}
pub fn relative_frobenius_error(original: &[f32], reconstructed: &[f32]) -> Option<f64> {
if original.len() != reconstructed.len() {
return None;
}
let orig_norm = frobenius_norm(original);
if orig_norm == 0.0 {
return None;
}
let diff_sq: f64 = original
.iter()
.zip(reconstructed.iter())
.map(|(a, b)| {
let d = (*a as f64) - (*b as f64);
d * d
})
.sum();
Some(diff_sq.sqrt() / orig_norm)
}
#[derive(Debug, Clone, PartialEq)]
pub enum FrobeniusOutcome {
Ok { rel_err: f64 },
Degraded { rel_err: f64, threshold: f64 },
InvalidInput,
}
pub fn classify_frobenius_error(
original: &[f32],
reconstructed: &[f32],
threshold: f64,
) -> FrobeniusOutcome {
match relative_frobenius_error(original, reconstructed) {
None => FrobeniusOutcome::InvalidInput,
Some(e) if e <= threshold => FrobeniusOutcome::Ok { rel_err: e },
Some(e) => FrobeniusOutcome::Degraded {
rel_err: e,
threshold,
},
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum CapabilityOutcome {
Capable {
sm: u32,
},
Incapable {
sm: u32,
required: u32,
message: String,
exit_code: u8,
},
}
pub fn classify_sm_capability(sm: u32) -> CapabilityOutcome {
if sm >= FP8_MIN_HOPPER_SM {
CapabilityOutcome::Capable { sm }
} else {
CapabilityOutcome::Incapable {
sm,
required: FP8_MIN_HOPPER_SM,
message: format!(
"FP8 quantization requires Hopper/Blackwell (sm_{}+), detected sm_{}",
FP8_MIN_HOPPER_SM, sm
),
exit_code: FP8_INCAPABLE_EXIT_CODE,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn e4m3_max_finite_is_exactly_448() {
assert_eq!(E4M3_MAX_FINITE, 448.0);
}
#[test]
fn e4m3_positive_finites_has_exactly_126_entries() {
assert_eq!(e4m3_positive_finites().len(), 126);
}
#[test]
fn e4m3_positive_finites_are_sorted_ascending_and_unique() {
let v = e4m3_positive_finites();
for pair in v.windows(2) {
assert!(pair[0] < pair[1], "not strictly ascending: {:?}", pair);
}
}
#[test]
fn e4m3_positive_finites_endpoints_match_spec() {
let v = e4m3_positive_finites();
assert_eq!(*v.first().unwrap(), E4M3_MIN_POS_SUBNORMAL);
assert_eq!(*v.last().unwrap(), E4M3_MAX_FINITE);
}
#[test]
fn e4m3_round_on_exactly_representable_is_identity() {
for v in [0.0f32, 0.5, 1.0, 2.0, 4.0, 448.0, -448.0, -1.0] {
assert_eq!(e4m3_round_to_nearest(v), v, "{}", v);
}
}
#[test]
fn e4m3_round_clamps_above_max() {
assert_eq!(e4m3_round_to_nearest(500.0), 448.0);
assert_eq!(e4m3_round_to_nearest(1.0e6), 448.0);
assert_eq!(e4m3_round_to_nearest(-1.0e6), -448.0);
}
#[test]
fn e4m3_round_preserves_sign_through_zero() {
assert_eq!(e4m3_round_to_nearest(0.0), 0.0);
assert_eq!(e4m3_round_to_nearest(-0.0), 0.0);
}
#[test]
fn e4m3_round_nan_is_nan() {
assert!(e4m3_round_to_nearest(f32::NAN).is_nan());
}
#[test]
fn e4m3_round_tiny_value_goes_to_zero_or_smallest_subnormal() {
let r = e4m3_round_to_nearest(E4M3_MIN_POS_SUBNORMAL / 4.0);
assert!(r == 0.0 || r == E4M3_MIN_POS_SUBNORMAL, "got {}", r);
}
#[test]
fn absmax_scale_zero_vector_is_one() {
let w = vec![0.0f32; 8];
assert_eq!(fp8_absmax_scale(&w), 1.0);
}
#[test]
fn absmax_scale_puts_max_at_448() {
let w = vec![1.0f32, -2.0, 3.5, -0.5];
let s = fp8_absmax_scale(&w);
let scaled_max = w.iter().map(|x| (x / s).abs()).fold(0.0f32, f32::max);
assert!((scaled_max - 448.0).abs() < 1e-3);
}
#[test]
fn roundtrip_zero_vector_is_zero() {
let w = vec![0.0f32; 16];
let (out, _) = fp8_quantize_dequantize(&w);
assert!(out.iter().all(|&x| x == 0.0));
}
#[test]
fn roundtrip_preserves_sign() {
let w: Vec<f32> = (0..16).map(|i| ((i as f32) - 7.5) * 0.1).collect();
let (out, _) = fp8_quantize_dequantize(&w);
for (a, b) in w.iter().zip(out.iter()) {
if *a != 0.0 {
assert_eq!(a.signum(), b.signum(), "sign flip: {} → {}", a, b);
}
}
}
#[test]
fn frobenius_norm_zero_vector_is_zero() {
assert_eq!(frobenius_norm(&[0.0, 0.0, 0.0]), 0.0);
}
#[test]
fn frobenius_norm_unit_basis_is_one() {
let n = frobenius_norm(&[1.0, 0.0, 0.0]);
assert!((n - 1.0).abs() < 1e-12);
}
#[test]
fn relative_frobenius_error_identical_is_zero() {
let w = vec![1.0f32, 2.0, 3.0];
assert_eq!(relative_frobenius_error(&w, &w), Some(0.0));
}
#[test]
fn relative_frobenius_error_length_mismatch_is_none() {
let a = vec![1.0f32, 2.0];
let b = vec![1.0f32, 2.0, 3.0];
assert_eq!(relative_frobenius_error(&a, &b), None);
}
#[test]
fn relative_frobenius_error_zero_original_is_none() {
let a = vec![0.0f32; 4];
let b = vec![1e-6f32; 4];
assert_eq!(relative_frobenius_error(&a, &b), None);
}
#[test]
fn classify_frobenius_under_threshold_is_ok() {
let a = vec![1.0f32, 2.0, 3.0, 4.0];
let b = vec![1.001f32, 2.001, 3.001, 4.001];
match classify_frobenius_error(&a, &b, FP8_MAX_FROBENIUS_REL_ERR) {
FrobeniusOutcome::Ok { rel_err } => assert!(rel_err <= FP8_MAX_FROBENIUS_REL_ERR),
other => panic!("expected Ok, got {:?}", other),
}
}
#[test]
fn classify_frobenius_above_threshold_is_degraded() {
let a = vec![1.0f32, 2.0, 3.0, 4.0];
let b = vec![2.0f32, 3.0, 4.0, 5.0];
match classify_frobenius_error(&a, &b, FP8_MAX_FROBENIUS_REL_ERR) {
FrobeniusOutcome::Degraded { rel_err, threshold } => {
assert!(rel_err > threshold);
assert_eq!(threshold, FP8_MAX_FROBENIUS_REL_ERR);
}
other => panic!("expected Degraded, got {:?}", other),
}
}
#[test]
fn classify_frobenius_length_mismatch_is_invalid() {
let a = vec![1.0f32, 2.0];
let b = vec![1.0f32, 2.0, 3.0];
assert_eq!(
classify_frobenius_error(&a, &b, FP8_MAX_FROBENIUS_REL_ERR),
FrobeniusOutcome::InvalidInput
);
}
#[test]
fn fp8_mean_roundtrip_error_under_synthetic_uniform_bound() {
let mut state: u64 = 0xA5A5_D00D_1234_5678;
let mut w = Vec::with_capacity(4096);
for _ in 0..4096 {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
let bits = (state >> 32) as u32;
let u = (bits as f64) / (u32::MAX as f64); w.push(((u * 2.0) - 1.0) as f32);
}
let (dequant, _scale) = fp8_quantize_dequantize(&w);
let outcome = classify_frobenius_error(&w, &dequant, FP8_MAX_FROBENIUS_REL_ERR_SYNTHETIC);
match outcome {
FrobeniusOutcome::Ok { rel_err } => {
assert!(
rel_err < FP8_MAX_FROBENIUS_REL_ERR_SYNTHETIC,
"FP8 round-trip rel Frobenius {} ≥ synthetic guard {}",
rel_err,
FP8_MAX_FROBENIUS_REL_ERR_SYNTHETIC
);
assert!(
rel_err > 1e-4,
"suspiciously small error ({}) — test may be broken",
rel_err
);
}
other => panic!("expected Ok, got {:?}", other),
}
}
#[test]
fn sm_90_is_capable() {
match classify_sm_capability(90) {
CapabilityOutcome::Capable { sm } => assert_eq!(sm, 90),
other => panic!("expected Capable, got {:?}", other),
}
}
#[test]
fn sm_100_blackwell_is_capable() {
match classify_sm_capability(100) {
CapabilityOutcome::Capable { .. } => {}
other => panic!("expected Capable, got {:?}", other),
}
}
#[test]
fn sm_120_next_gen_is_capable() {
match classify_sm_capability(120) {
CapabilityOutcome::Capable { .. } => {}
other => panic!("expected Capable, got {:?}", other),
}
}
#[test]
fn sm_80_ampere_is_incapable_exit_2() {
match classify_sm_capability(80) {
CapabilityOutcome::Incapable {
sm,
required,
exit_code,
..
} => {
assert_eq!(sm, 80);
assert_eq!(required, 90);
assert_eq!(exit_code, 2);
}
other => panic!("expected Incapable, got {:?}", other),
}
}
#[test]
fn sm_86_ampere_incapable_message_mentions_hopper() {
match classify_sm_capability(86) {
CapabilityOutcome::Incapable { message, .. } => {
let lower = message.to_lowercase();
assert!(
lower.contains("hopper")
|| lower.contains("sm_90")
|| lower.contains("capability"),
"message must mention hopper/sm_90/capability; got {:?}",
message
);
}
other => panic!("expected Incapable, got {:?}", other),
}
}
#[test]
fn sm_0_unknown_gpu_is_incapable() {
match classify_sm_capability(0) {
CapabilityOutcome::Incapable { .. } => {}
other => panic!("expected Incapable, got {:?}", other),
}
}
#[test]
fn sm_capability_is_deterministic() {
let a = classify_sm_capability(75);
let b = classify_sm_capability(75);
assert_eq!(a, b);
}
#[test]
fn capability_boundary_is_sm_90_inclusive() {
assert!(matches!(
classify_sm_capability(89),
CapabilityOutcome::Incapable { .. }
));
assert!(matches!(
classify_sm_capability(90),
CapabilityOutcome::Capable { .. }
));
}
}