oxicuda_quant/scheme/
fp8.rs1use crate::error::{QuantError, QuantResult};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum Fp8Format {
33 E4M3,
35 E5M2,
37}
38
39impl Fp8Format {
40 #[must_use]
42 pub fn exp_bits(self) -> u32 {
43 match self {
44 Self::E4M3 => 4,
45 Self::E5M2 => 5,
46 }
47 }
48
49 #[must_use]
51 pub fn man_bits(self) -> u32 {
52 match self {
53 Self::E4M3 => 3,
54 Self::E5M2 => 2,
55 }
56 }
57
58 #[must_use]
60 pub fn bias(self) -> i32 {
61 match self {
62 Self::E4M3 => 7, Self::E5M2 => 15, }
65 }
66
67 #[must_use]
69 pub fn max_val(self) -> f32 {
70 match self {
71 Self::E4M3 => 448.0,
72 Self::E5M2 => 57344.0,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Copy)]
84pub struct Fp8Codec {
85 pub format: Fp8Format,
87 pub saturate: bool,
89}
90
91impl Fp8Codec {
92 #[must_use]
94 pub fn e4m3() -> Self {
95 Self {
96 format: Fp8Format::E4M3,
97 saturate: true,
98 }
99 }
100
101 #[must_use]
103 pub fn e5m2() -> Self {
104 Self {
105 format: Fp8Format::E5M2,
106 saturate: true,
107 }
108 }
109
110 pub fn encode_f32(&self, v: f32) -> QuantResult<u8> {
116 if !v.is_finite() {
117 return Err(QuantError::NonFiniteFp8(v));
118 }
119 let max = self.format.max_val();
120 let v_sat = v.clamp(-max, max);
121 Ok(self.fp32_to_fp8(v_sat))
122 }
123
124 #[must_use]
126 pub fn decode_f32(&self, b: u8) -> f32 {
127 self.fp8_to_fp32(b)
128 }
129
130 pub fn encode(&self, data: &[f32]) -> QuantResult<Vec<u8>> {
136 data.iter().map(|&v| self.encode_f32(v)).collect()
137 }
138
139 pub fn decode(&self, data: &[u8]) -> Vec<f32> {
141 data.iter().map(|&b| self.decode_f32(b)).collect()
142 }
143
144 pub fn quantization_mse(&self, data: &[f32]) -> QuantResult<f32> {
150 let encoded = self.encode(data)?;
151 let decoded = self.decode(&encoded);
152 let mse = data
153 .iter()
154 .zip(decoded.iter())
155 .map(|(a, b)| (a - b).powi(2))
156 .sum::<f32>()
157 / data.len() as f32;
158 Ok(mse)
159 }
160
161 fn fp32_to_fp8(&self, v: f32) -> u8 {
164 let bits = v.to_bits();
166 let sign = (bits >> 31) as u8;
167 let exp32 = ((bits >> 23) & 0xFF) as i32; let man32 = bits & 0x007F_FFFF; let exp_bits = self.format.exp_bits();
171 let man_bits = self.format.man_bits();
172 let bias8 = self.format.bias();
173
174 if v == 0.0 || v == -0.0 {
175 return sign << 7;
176 }
177
178 let exp_unbiased = exp32 - 127;
180 let exp8_raw = exp_unbiased + bias8;
181
182 let man_shift = 23 - man_bits; if exp8_raw <= 0 {
185 let full_man = (man32 | 0x0080_0000) >> 1; let shift = (1 - exp8_raw) as u32 + man_shift;
190 if shift >= 24 {
191 return sign << 7;
192 } let man8 = (full_man >> shift) as u8 & ((1 << man_bits) - 1);
194 return (sign << 7) | man8;
195 }
196
197 let max_exp = (1 << exp_bits) - 1;
198 if exp8_raw >= max_exp {
199 return match self.format {
201 Fp8Format::E4M3 => (sign << 7) | 0x7E, Fp8Format::E5M2 => (sign << 7) | 0x7B, };
204 }
205
206 let man8 = (man32 >> man_shift) as u8 & ((1 << man_bits) - 1);
207 (sign << 7) | ((exp8_raw as u8) << man_bits) | man8
208 }
209
210 fn fp8_to_fp32(&self, b: u8) -> f32 {
211 let sign = (b >> 7) as u32;
212 let exp_bits = self.format.exp_bits();
213 let man_bits = self.format.man_bits();
214 let bias8 = self.format.bias();
215
216 let exp8 = ((b >> man_bits) & ((1 << exp_bits) - 1)) as u32;
217 let man8 = (b & ((1 << man_bits) - 1)) as u32;
218
219 let all_exp = (1 << exp_bits) - 1;
221 match self.format {
222 Fp8Format::E4M3 => {
223 if exp8 == all_exp as u32 && man8 == (1 << man_bits) - 1 {
224 return f32::NAN; }
226 }
227 Fp8Format::E5M2 => {
228 if exp8 == all_exp as u32 {
229 if man8 == 0 {
230 return if sign == 0 {
231 f32::INFINITY
232 } else {
233 f32::NEG_INFINITY
234 };
235 }
236 return f32::NAN;
237 }
238 }
239 }
240
241 if exp8 == 0 {
243 if man8 == 0 {
244 return f32::from_bits(sign << 31); }
246 let man_shift = 23 - man_bits;
248 let exp32 = (127 + 1 - bias8) as u32;
249 let leading = man_bits - 1 - man8.leading_zeros().min(man_bits - 1);
251 let exp32_adj = exp32.wrapping_sub(leading);
252 let man32 = ((man8 << leading) & ((1 << man_bits) - 1)) << man_shift;
253 return f32::from_bits((sign << 31) | (exp32_adj << 23) | man32);
254 }
255
256 let exp32 = (exp8 as i32 - bias8 + 127) as u32;
258 let man_shift = 23 - man_bits;
259 let man32 = man8 << man_shift;
260 f32::from_bits((sign << 31) | (exp32 << 23) | man32)
261 }
262}
263
264#[cfg(test)]
267mod tests {
268 use super::*;
269 use approx::assert_abs_diff_eq;
270
271 #[test]
272 fn e4m3_format_params() {
273 assert_eq!(Fp8Format::E4M3.exp_bits(), 4);
274 assert_eq!(Fp8Format::E4M3.man_bits(), 3);
275 assert_eq!(Fp8Format::E4M3.bias(), 7);
276 assert_abs_diff_eq!(Fp8Format::E4M3.max_val(), 448.0, epsilon = 1.0);
277 }
278
279 #[test]
280 fn e5m2_format_params() {
281 assert_eq!(Fp8Format::E5M2.exp_bits(), 5);
282 assert_eq!(Fp8Format::E5M2.man_bits(), 2);
283 assert_eq!(Fp8Format::E5M2.bias(), 15);
284 assert_abs_diff_eq!(Fp8Format::E5M2.max_val(), 57344.0, epsilon = 100.0);
285 }
286
287 #[test]
288 fn e4m3_zero_encodes_to_zero() {
289 let c = Fp8Codec::e4m3();
290 assert_eq!(c.encode_f32(0.0).unwrap(), 0x00);
291 assert_eq!(c.encode_f32(-0.0).unwrap(), 0x80);
292 }
293
294 #[test]
295 fn e4m3_round_trip_basic() {
296 let c = Fp8Codec::e4m3();
297 for &v in &[1.0_f32, -1.0, 2.0, 0.5, 0.25, -0.25] {
298 let enc = c.encode_f32(v).unwrap();
299 let dec = c.decode_f32(enc);
300 let rel_err = (v - dec).abs() / v.abs().max(1e-6);
301 assert!(rel_err < 0.15, "v={v}, dec={dec}, rel_err={rel_err}");
302 }
303 }
304
305 #[test]
306 fn e5m2_round_trip_basic() {
307 let c = Fp8Codec::e5m2();
308 for &v in &[1.0_f32, -1.0, 4.0, 16.0, -8.0] {
309 let enc = c.encode_f32(v).unwrap();
310 let dec = c.decode_f32(enc);
311 let rel_err = (v - dec).abs() / v.abs().max(1e-6);
312 assert!(rel_err < 0.25, "v={v}, dec={dec}, rel_err={rel_err}");
313 }
314 }
315
316 #[test]
317 fn e4m3_saturates_large_values() {
318 let c = Fp8Codec::e4m3();
319 let enc = c.encode_f32(1000.0).unwrap();
320 let dec = c.decode_f32(enc);
321 assert!(dec <= 448.0 + 1.0, "should saturate, got {dec}");
323 assert!(dec > 0.0, "positive saturation should be positive");
324 }
325
326 #[test]
327 fn nan_input_errors() {
328 let c = Fp8Codec {
329 format: Fp8Format::E4M3,
330 saturate: false,
331 };
332 assert!(matches!(
333 c.encode_f32(f32::NAN),
334 Err(QuantError::NonFiniteFp8(_))
335 ));
336 assert!(matches!(
337 c.encode_f32(f32::INFINITY),
338 Err(QuantError::NonFiniteFp8(_))
339 ));
340 }
341
342 #[test]
343 fn mse_within_tolerance() {
344 let c = Fp8Codec::e4m3();
345 let data: Vec<f32> = (0..256).map(|i| (i as f32 / 128.0) - 1.0).collect();
346 let mse = c.quantization_mse(&data).unwrap();
347 assert!(mse < 0.01, "E4M3 MSE unexpectedly large: {mse}");
348 }
349
350 #[test]
351 fn batch_encode_decode() {
352 let c = Fp8Codec::e4m3();
353 let data = vec![0.0_f32, 1.0, -1.0, 0.5, 2.0, -2.0];
354 let enc = c.encode(&data).unwrap();
355 assert_eq!(enc.len(), data.len());
356 let dec = c.decode(&enc);
357 assert_eq!(dec.len(), data.len());
358 }
359}