Skip to main content

oxicuda_quant/scheme/
fp8.rs

1//! # FP8 Floating-Point Quantization
2//!
3//! FP8 is an 8-bit floating-point format used in Hopper/Blackwell tensor cores.
4//! Two variants are defined by the NVIDIA FP8 spec (OFP8):
5//!
6//! | Format | Sign | Exponent | Mantissa | Range          | Use case |
7//! |--------|------|----------|----------|----------------|----------|
8//! | E4M3   | 1    | 4        | 3        | ±448           | Weights / Fwd activations |
9//! | E5M2   | 1    | 5        | 2        | ±57344         | Gradients |
10//!
11//! ## E4M3 Special Values
12//!
13//! * Exponent all-1s + mantissa all-1s → NaN (no ±Inf)
14//! * Exponent all-0s → denormals (mantissa / 2^(-6))
15//!
16//! ## E5M2 Special Values
17//!
18//! * Exponent all-1s + mantissa = 00 → ±Inf
19//! * Exponent all-1s + mantissa ≠ 00 → NaN
20//!
21//! ## Encoding
22//!
23//! We store FP8 as `u8` bit patterns.  Conversion is done in pure Rust via
24//! IEEE 754 bit manipulation of f32.
25
26use crate::error::{QuantError, QuantResult};
27
28// ─── Format ───────────────────────────────────────────────────────────────────
29
30/// FP8 format variant.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum Fp8Format {
33    /// 1 sign + 4 exponent + 3 mantissa bits.  Max = 448.
34    E4M3,
35    /// 1 sign + 5 exponent + 2 mantissa bits.  Max = 57344.
36    E5M2,
37}
38
39impl Fp8Format {
40    /// Number of exponent bits.
41    #[must_use]
42    pub fn exp_bits(self) -> u32 {
43        match self {
44            Self::E4M3 => 4,
45            Self::E5M2 => 5,
46        }
47    }
48
49    /// Number of mantissa bits.
50    #[must_use]
51    pub fn man_bits(self) -> u32 {
52        match self {
53            Self::E4M3 => 3,
54            Self::E5M2 => 2,
55        }
56    }
57
58    /// Exponent bias.
59    #[must_use]
60    pub fn bias(self) -> i32 {
61        match self {
62            Self::E4M3 => 7,  // 2^(4-1) - 1
63            Self::E5M2 => 15, // 2^(5-1) - 1
64        }
65    }
66
67    /// Maximum finite representable value.
68    #[must_use]
69    pub fn max_val(self) -> f32 {
70        match self {
71            Self::E4M3 => 448.0,
72            Self::E5M2 => 57344.0,
73        }
74    }
75}
76
77// ─── Fp8Codec ─────────────────────────────────────────────────────────────────
78
79/// FP8 encoder/decoder.
80///
81/// Saturating conversion: values exceeding `max_val` are clamped; NaN/Inf
82/// inputs return an error.
83#[derive(Debug, Clone, Copy)]
84pub struct Fp8Codec {
85    /// Target FP8 format.
86    pub format: Fp8Format,
87    /// Whether to saturate (clamp) out-of-range values instead of erroring.
88    pub saturate: bool,
89}
90
91impl Fp8Codec {
92    /// E4M3 codec with saturation.
93    #[must_use]
94    pub fn e4m3() -> Self {
95        Self {
96            format: Fp8Format::E4M3,
97            saturate: true,
98        }
99    }
100
101    /// E5M2 codec with saturation.
102    #[must_use]
103    pub fn e5m2() -> Self {
104        Self {
105            format: Fp8Format::E5M2,
106            saturate: true,
107        }
108    }
109
110    /// Encode a single f32 to FP8 u8.
111    ///
112    /// # Errors
113    ///
114    /// * [`QuantError::NonFiniteFp8`] if `v` is NaN or Inf and `saturate = false`.
115    pub fn encode_f32(&self, v: f32) -> QuantResult<u8> {
116        if !v.is_finite() {
117            return Err(QuantError::NonFiniteFp8(v));
118        }
119        let max = self.format.max_val();
120        let v_sat = v.clamp(-max, max);
121        Ok(self.fp32_to_fp8(v_sat))
122    }
123
124    /// Decode a FP8 u8 to f32.
125    #[must_use]
126    pub fn decode_f32(&self, b: u8) -> f32 {
127        self.fp8_to_fp32(b)
128    }
129
130    /// Encode a slice of f32 to FP8.
131    ///
132    /// # Errors
133    ///
134    /// Propagates [`QuantError::NonFiniteFp8`] on first NaN/Inf when not saturating.
135    pub fn encode(&self, data: &[f32]) -> QuantResult<Vec<u8>> {
136        data.iter().map(|&v| self.encode_f32(v)).collect()
137    }
138
139    /// Decode a slice of FP8 to f32.
140    pub fn decode(&self, data: &[u8]) -> Vec<f32> {
141        data.iter().map(|&b| self.decode_f32(b)).collect()
142    }
143
144    /// Mean squared error of the FP8 round-trip.
145    ///
146    /// # Errors
147    ///
148    /// Propagates [`encode`](Self::encode) errors.
149    pub fn quantization_mse(&self, data: &[f32]) -> QuantResult<f32> {
150        let encoded = self.encode(data)?;
151        let decoded = self.decode(&encoded);
152        let mse = data
153            .iter()
154            .zip(decoded.iter())
155            .map(|(a, b)| (a - b).powi(2))
156            .sum::<f32>()
157            / data.len() as f32;
158        Ok(mse)
159    }
160
161    // ── Private: bit-level encode/decode ─────────────────────────────────────
162
163    fn fp32_to_fp8(&self, v: f32) -> u8 {
164        // Extract f32 components.
165        let bits = v.to_bits();
166        let sign = (bits >> 31) as u8;
167        let exp32 = ((bits >> 23) & 0xFF) as i32; // biased exponent (bias=127)
168        let man32 = bits & 0x007F_FFFF; // 23-bit mantissa
169
170        let exp_bits = self.format.exp_bits();
171        let man_bits = self.format.man_bits();
172        let bias8 = self.format.bias();
173
174        if v == 0.0 || v == -0.0 {
175            return sign << 7;
176        }
177
178        // Re-bias the exponent.
179        let exp_unbiased = exp32 - 127;
180        let exp8_raw = exp_unbiased + bias8;
181
182        let man_shift = 23 - man_bits; // how many mantissa bits to drop
183
184        if exp8_raw <= 0 {
185            // Denormal territory: right-shift mantissa into denormal position.
186            // Value = (1 + man/2^23) * 2^exp_unbiased
187            //       ≈ 2^exp8_raw * man_denorm / 2^man_bits
188            let full_man = (man32 | 0x0080_0000) >> 1; // include implicit 1
189            let shift = (1 - exp8_raw) as u32 + man_shift;
190            if shift >= 24 {
191                return sign << 7;
192            } // underflow → ±0
193            let man8 = (full_man >> shift) as u8 & ((1 << man_bits) - 1);
194            return (sign << 7) | man8;
195        }
196
197        let max_exp = (1 << exp_bits) - 1;
198        if exp8_raw >= max_exp {
199            // Saturate to max finite value (E4M3 has no Inf, E5M2 has Inf).
200            return match self.format {
201                Fp8Format::E4M3 => (sign << 7) | 0x7E, // 01111110 = max finite
202                Fp8Format::E5M2 => (sign << 7) | 0x7B, // 01111011 = max finite
203            };
204        }
205
206        let man8 = (man32 >> man_shift) as u8 & ((1 << man_bits) - 1);
207        (sign << 7) | ((exp8_raw as u8) << man_bits) | man8
208    }
209
210    fn fp8_to_fp32(&self, b: u8) -> f32 {
211        let sign = (b >> 7) as u32;
212        let exp_bits = self.format.exp_bits();
213        let man_bits = self.format.man_bits();
214        let bias8 = self.format.bias();
215
216        let exp8 = ((b >> man_bits) & ((1 << exp_bits) - 1)) as u32;
217        let man8 = (b & ((1 << man_bits) - 1)) as u32;
218
219        // Check for special values.
220        let all_exp = (1 << exp_bits) - 1;
221        match self.format {
222            Fp8Format::E4M3 => {
223                if exp8 == all_exp as u32 && man8 == (1 << man_bits) - 1 {
224                    return f32::NAN; // NaN in E4M3
225                }
226            }
227            Fp8Format::E5M2 => {
228                if exp8 == all_exp as u32 {
229                    if man8 == 0 {
230                        return if sign == 0 {
231                            f32::INFINITY
232                        } else {
233                            f32::NEG_INFINITY
234                        };
235                    }
236                    return f32::NAN;
237                }
238            }
239        }
240
241        // Zero / denormal.
242        if exp8 == 0 {
243            if man8 == 0 {
244                return f32::from_bits(sign << 31); // ±0
245            }
246            // Denormal: value = man8 / 2^man_bits * 2^(1-bias8)
247            let man_shift = 23 - man_bits;
248            let exp32 = (127 + 1 - bias8) as u32;
249            // Find leading bit position in man8.
250            let leading = man_bits - 1 - man8.leading_zeros().min(man_bits - 1);
251            let exp32_adj = exp32.wrapping_sub(leading);
252            let man32 = ((man8 << leading) & ((1 << man_bits) - 1)) << man_shift;
253            return f32::from_bits((sign << 31) | (exp32_adj << 23) | man32);
254        }
255
256        // Normal: re-bias.
257        let exp32 = (exp8 as i32 - bias8 + 127) as u32;
258        let man_shift = 23 - man_bits;
259        let man32 = man8 << man_shift;
260        f32::from_bits((sign << 31) | (exp32 << 23) | man32)
261    }
262}
263
264// ─── Tests ───────────────────────────────────────────────────────────────────
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use approx::assert_abs_diff_eq;
270
271    #[test]
272    fn e4m3_format_params() {
273        assert_eq!(Fp8Format::E4M3.exp_bits(), 4);
274        assert_eq!(Fp8Format::E4M3.man_bits(), 3);
275        assert_eq!(Fp8Format::E4M3.bias(), 7);
276        assert_abs_diff_eq!(Fp8Format::E4M3.max_val(), 448.0, epsilon = 1.0);
277    }
278
279    #[test]
280    fn e5m2_format_params() {
281        assert_eq!(Fp8Format::E5M2.exp_bits(), 5);
282        assert_eq!(Fp8Format::E5M2.man_bits(), 2);
283        assert_eq!(Fp8Format::E5M2.bias(), 15);
284        assert_abs_diff_eq!(Fp8Format::E5M2.max_val(), 57344.0, epsilon = 100.0);
285    }
286
287    #[test]
288    fn e4m3_zero_encodes_to_zero() {
289        let c = Fp8Codec::e4m3();
290        assert_eq!(c.encode_f32(0.0).unwrap(), 0x00);
291        assert_eq!(c.encode_f32(-0.0).unwrap(), 0x80);
292    }
293
294    #[test]
295    fn e4m3_round_trip_basic() {
296        let c = Fp8Codec::e4m3();
297        for &v in &[1.0_f32, -1.0, 2.0, 0.5, 0.25, -0.25] {
298            let enc = c.encode_f32(v).unwrap();
299            let dec = c.decode_f32(enc);
300            let rel_err = (v - dec).abs() / v.abs().max(1e-6);
301            assert!(rel_err < 0.15, "v={v}, dec={dec}, rel_err={rel_err}");
302        }
303    }
304
305    #[test]
306    fn e5m2_round_trip_basic() {
307        let c = Fp8Codec::e5m2();
308        for &v in &[1.0_f32, -1.0, 4.0, 16.0, -8.0] {
309            let enc = c.encode_f32(v).unwrap();
310            let dec = c.decode_f32(enc);
311            let rel_err = (v - dec).abs() / v.abs().max(1e-6);
312            assert!(rel_err < 0.25, "v={v}, dec={dec}, rel_err={rel_err}");
313        }
314    }
315
316    #[test]
317    fn e4m3_saturates_large_values() {
318        let c = Fp8Codec::e4m3();
319        let enc = c.encode_f32(1000.0).unwrap();
320        let dec = c.decode_f32(enc);
321        // Should saturate at max_val (448)
322        assert!(dec <= 448.0 + 1.0, "should saturate, got {dec}");
323        assert!(dec > 0.0, "positive saturation should be positive");
324    }
325
326    #[test]
327    fn nan_input_errors() {
328        let c = Fp8Codec {
329            format: Fp8Format::E4M3,
330            saturate: false,
331        };
332        assert!(matches!(
333            c.encode_f32(f32::NAN),
334            Err(QuantError::NonFiniteFp8(_))
335        ));
336        assert!(matches!(
337            c.encode_f32(f32::INFINITY),
338            Err(QuantError::NonFiniteFp8(_))
339        ));
340    }
341
342    #[test]
343    fn mse_within_tolerance() {
344        let c = Fp8Codec::e4m3();
345        let data: Vec<f32> = (0..256).map(|i| (i as f32 / 128.0) - 1.0).collect();
346        let mse = c.quantization_mse(&data).unwrap();
347        assert!(mse < 0.01, "E4M3 MSE unexpectedly large: {mse}");
348    }
349
350    #[test]
351    fn batch_encode_decode() {
352        let c = Fp8Codec::e4m3();
353        let data = vec![0.0_f32, 1.0, -1.0, 0.5, 2.0, -2.0];
354        let enc = c.encode(&data).unwrap();
355        assert_eq!(enc.len(), data.len());
356        let dec = c.decode(&enc);
357        assert_eq!(dec.len(), data.len());
358    }
359}