float4/lib.rs
1//! Four-bit floating point types and block formats for Rust.
2//!
3//! This crate provides low-precision floating-point types following the OCP MX specification,
4//! designed for efficient storage and computation in machine learning applications where
5//! extreme quantization is beneficial.
6//!
7//! # Available Types
8//!
9//! - [`F4E2M1`]: 4-bit floating-point with 2 exponent bits and 1 mantissa bit
10//! - [`E8M0`]: 8-bit scale factor representing powers of two (2^-127 to 2^127)
11//! - [`MXFP4Block`]: Block format storing 32 F4E2M1 values with a shared E8M0 scale
12//!
13//! # F4E2M1 Format Details
14//!
15//! The [`F4E2M1`] type implements the E2M1 format with:
16//! - 1 sign bit
17//! - 2 exponent bits
18//! - 1 mantissa bit
19//! - Exponent bias of 1
20//! - Round-to-nearest-even (roundTiesToEven) rounding mode
21//!
22//! This format can represent 16 distinct values ranging from -6.0 to 6.0, including:
23//! - Normal numbers: ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0
24//! - Subnormal numbers: ±0.5
25//! - Zero: ±0.0
26//!
27//! # Examples
28//!
29//! Basic usage:
30//!
31//! ```
32//! use float4::F4E2M1;
33//!
34//! // Create from f64
35//! let a = F4E2M1::from_f64(1.5);
36//! assert_eq!(a.to_f64(), 1.5);
37//!
38//! // Create from raw bits
39//! let b = F4E2M1::from_bits(0x3); // 0b0011 = 1.5
40//! assert_eq!(b.to_f64(), 1.5);
41//!
42//! // Values outside representable range saturate
43//! let c = F4E2M1::from_f64(10.0);
44//! assert_eq!(c.to_f64(), 6.0); // Saturates to maximum
45//! ```
46//!
47//! # Rounding Behavior
48//!
49//! The type uses round-to-nearest-even as specified by IEEE 754:
50//!
51//! ```
52//! use float4::F4E2M1;
53//!
54//! // Rounding to nearest
55//! assert_eq!(F4E2M1::from_f64(1.75).to_f64(), 2.0);
56//! assert_eq!(F4E2M1::from_f64(2.25).to_f64(), 2.0);
57//!
58//! // Round-to-even when exactly halfway
59//! assert_eq!(F4E2M1::from_f64(1.25).to_f64(), 1.0); // Rounds to even
60//! assert_eq!(F4E2M1::from_f64(2.5).to_f64(), 2.0); // Rounds to even
61//! ```
62//!
63//! # Special Values
64//!
65//! Unlike standard floating point formats, F4E2M1 has no representation for infinity or NaN.
66//! These values saturate to the maximum representable value:
67//!
68//! ```
69//! use float4::F4E2M1;
70//!
71//! assert_eq!(F4E2M1::from_f64(f64::INFINITY).to_f64(), 6.0);
72//! assert_eq!(F4E2M1::from_f64(f64::NEG_INFINITY).to_f64(), -6.0);
73//! assert_eq!(F4E2M1::from_f64(f64::NAN).to_f64(), 6.0);
74//! ```
75//!
76//! # MXFP4 Block Format
77//!
78//! The [`MXFP4Block`] type provides efficient storage for multiple F4E2M1 values by sharing
79//! a common scale factor:
80//!
81//! ```
82//! use float4::{F4E2M1, E8M0, MXFP4Block};
83//!
84//! // Original f32 data
85//! let data = vec![1.5, -2.0, 0.5, 3.0];
86//!
87//! // Compute scale (rounds up to power of 2)
88//! let scale = E8M0::from_f32_slice(&data);
89//! assert_eq!(scale.to_f64(), 4.0); // 3.0 rounds up to 4.0
90//!
91//! // Quantize values
92//! let mut quantized = [F4E2M1::from_f64(0.0); 32];
93//! for i in 0..data.len() {
94//! quantized[i] = F4E2M1::from_f64(data[i] as f64 / scale.to_f64());
95//! }
96//!
97//! // Pack into block (17 bytes total for 32 values)
98//! let block = MXFP4Block::from_f32_slice(quantized, scale);
99//!
100//! // Convert back
101//! let restored = block.to_f32_array();
102//! // Note: Due to F4E2M1's limited precision, values may be quantized
103//! assert_eq!(restored[0], 2.0); // 1.5/4.0 = 0.375 -> rounds to 0.5 -> 0.5*4.0 = 2.0
104//! assert_eq!(restored[1], -2.0); // -2.0/4.0 = -0.5 is exactly representable
105//! ```
106//!
107//! This format achieves 4× compression compared to f32, making it ideal for:
108//! - Neural network weight storage
109//! - Activation caching in quantized models
110//! - Memory-bandwidth limited applications
111
112mod block;
113mod cvt;
114mod m8e0;
115
116pub use block::MXFP4Block;
117pub use m8e0::E8M0;
118
119/// A 4-bit floating point type with 2 exponent bits and 1 mantissa bit.
120///
121/// This type implements the E2M1 format from the OCP MX specification, providing
122/// a compact representation suitable for machine learning applications requiring
123/// extreme quantization.
124///
125/// # Format
126///
127/// The 4 bits are laid out as follows:
128/// - Bit 3: Sign bit (0 = positive, 1 = negative)
129/// - Bits 2-1: Exponent bits (biased by 1)
130/// - Bit 0: Mantissa bit
131///
132/// # Representable Values
133///
134/// F4E2M1 can exactly represent the following values:
135/// - **Normal numbers**: ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0
136/// - **Subnormal numbers**: ±0.5
137/// - **Zero**: ±0.0
138///
139/// # Examples
140///
141/// ```
142/// use float4::F4E2M1;
143///
144/// // Create from floating point value
145/// let x = F4E2M1::from_f64(2.5);
146/// assert_eq!(x.to_f64(), 2.0); // Rounded to nearest representable value
147///
148/// // Access raw bit representation
149/// let bits = x.to_bits();
150/// assert_eq!(bits, 0x4); // 0b0100 = +2.0
151/// ```
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
153#[repr(transparent)]
154pub struct F4E2M1(u8);
155
156const _: () = assert!(std::mem::size_of::<F4E2M1>() == 1);
157
158impl F4E2M1 {
159 /// Creates a new `F4E2M1` value from a 64-bit floating point number.
160 ///
161 /// This function converts the input to the nearest representable F4E2M1 value
162 /// using round-to-nearest-even. Values outside the
163 /// representable range will saturate to the maximum or minimum values.
164 ///
165 /// # Examples
166 ///
167 /// ```
168 /// use float4::F4E2M1;
169 ///
170 /// // Exact representable values
171 /// assert_eq!(F4E2M1::from_f64(2.0).to_f64(), 2.0);
172 /// assert_eq!(F4E2M1::from_f64(-3.0).to_f64(), -3.0);
173 ///
174 /// // Rounding
175 /// assert_eq!(F4E2M1::from_f64(2.7).to_f64(), 3.0);
176 /// assert_eq!(F4E2M1::from_f64(1.25).to_f64(), 1.0); // Round to even
177 ///
178 /// // Saturation
179 /// assert_eq!(F4E2M1::from_f64(10.0).to_f64(), 6.0);
180 /// assert_eq!(F4E2M1::from_f64(-10.0).to_f64(), -6.0);
181 /// ```
182 ///
183 /// # Special Values
184 ///
185 /// - `NaN` → 6.0 (maximum positive value)
186 /// - `+Infinity` → 6.0
187 /// - `-Infinity` → -6.0
188 #[inline(always)]
189 pub const fn from_f64(x: f64) -> Self {
190 Self(cvt::f64_to_fp4(x))
191 }
192
193 /// Converts this `F4E2M1` value to a 64-bit floating point number.
194 ///
195 /// This conversion is exact - the returned f64 will precisely represent
196 /// the value stored in the F4E2M1.
197 ///
198 /// # Examples
199 ///
200 /// ```
201 /// use float4::F4E2M1;
202 ///
203 /// let x = F4E2M1::from_f64(1.5);
204 /// assert_eq!(x.to_f64(), 1.5);
205 ///
206 /// // All 16 possible values can be converted
207 /// for i in 0..16 {
208 /// let fp4 = F4E2M1::from_bits(i);
209 /// let _ = fp4.to_f64(); // Always succeeds
210 /// }
211 /// ```
212 #[inline(always)]
213 pub fn to_f64(&self) -> f64 {
214 cvt::fp4_to_f64(self.0)
215 }
216
217 /// Creates a new `F4E2M1` value from its raw 4-bit representation.
218 ///
219 /// The bits are interpreted as:
220 /// - Bit 3: Sign (0 = positive, 1 = negative)
221 /// - Bits 2-1: Exponent (biased by 1)
222 /// - Bit 0: Mantissa
223 ///
224 /// Only the lower 4 bits of the input are used.
225 ///
226 /// # Examples
227 ///
228 /// ```
229 /// use float4::F4E2M1;
230 ///
231 /// // 0x0 = 0b0000 = +0.0
232 /// assert_eq!(F4E2M1::from_bits(0x0).to_f64(), 0.0);
233 ///
234 /// // 0x3 = 0b0011 = +1.5
235 /// assert_eq!(F4E2M1::from_bits(0x3).to_f64(), 1.5);
236 ///
237 /// // 0xF = 0b1111 = -6.0
238 /// assert_eq!(F4E2M1::from_bits(0xF).to_f64(), -6.0);
239 /// ```
240 ///
241 /// # Bit Patterns
242 ///
243 /// | Bits | Decimal | Value |
244 /// |------|---------|-------|
245 /// | 0000 | 0 | 0.0 |
246 /// | 0001 | 1 | 0.5 |
247 /// | 0010 | 2 | 1.0 |
248 /// | 0011 | 3 | 1.5 |
249 /// | 0100 | 4 | 2.0 |
250 /// | 0101 | 5 | 3.0 |
251 /// | 0110 | 6 | 4.0 |
252 /// | 0111 | 7 | 6.0 |
253 /// | 1000 | 8 | -0.0 |
254 /// | 1001 | 9 | -0.5 |
255 /// | 1010 | 10 | -1.0 |
256 /// | 1011 | 11 | -1.5 |
257 /// | 1100 | 12 | -2.0 |
258 /// | 1101 | 13 | -3.0 |
259 /// | 1110 | 14 | -4.0 |
260 /// | 1111 | 15 | -6.0 |
261 #[inline(always)]
262 pub const fn from_bits(bits: u8) -> Self {
263 Self(bits)
264 }
265
266 /// Returns the raw 4-bit representation of this `F4E2M1` value.
267 ///
268 /// The returned byte contains the 4-bit value in its lower nibble.
269 /// The upper 4 bits are always zero.
270 ///
271 /// # Examples
272 ///
273 /// ```
274 /// use float4::F4E2M1;
275 ///
276 /// let x = F4E2M1::from_f64(1.5);
277 /// assert_eq!(x.to_bits(), 0x3); // 0b0011
278 ///
279 /// let y = F4E2M1::from_f64(-2.0);
280 /// assert_eq!(y.to_bits(), 0xC); // 0b1100
281 /// ```
282 #[inline(always)]
283 pub const fn to_bits(&self) -> u8 {
284 self.0
285 }
286}
287
288impl F4E2M1 {
289 /// The smallest positive normal F4E2M1 value (1.0).
290 ///
291 /// # Examples
292 ///
293 /// ```
294 /// use float4::F4E2M1;
295 /// assert_eq!(F4E2M1::MIN_POSITIVE_NORMAL.to_f64(), 1.0);
296 /// ```
297 pub const MIN_POSITIVE_NORMAL: F4E2M1 = F4E2M1(0x2);
298
299 /// The smallest positive F4E2M1 value (0.5).
300 ///
301 /// # Examples
302 ///
303 /// ```
304 /// use float4::F4E2M1;
305 /// assert_eq!(F4E2M1::MIN_POSITIVE.to_f64(), 0.5);
306 /// ```
307 pub const MIN_POSITIVE: F4E2M1 = F4E2M1(0x1);
308
309 /// The largest F4E2M1 value (6.0).
310 ///
311 /// # Examples
312 ///
313 /// ```
314 /// use float4::F4E2M1;
315 /// assert_eq!(F4E2M1::MAX.to_f64(), 6.0);
316 /// ```
317 pub const MAX: F4E2M1 = F4E2M1(0x7);
318
319 /// The smallest (most negative) F4E2M1 value (-6.0).
320 ///
321 /// # Examples
322 ///
323 /// ```
324 /// use float4::F4E2M1;
325 /// assert_eq!(F4E2M1::MIN.to_f64(), -6.0);
326 /// ```
327 pub const MIN: F4E2M1 = F4E2M1(0xF);
328
329 /// Positive zero.
330 ///
331 /// # Examples
332 ///
333 /// ```
334 /// use float4::F4E2M1;
335 /// assert_eq!(F4E2M1::ZERO.to_f64(), 0.0);
336 /// ```
337 pub const ZERO: F4E2M1 = F4E2M1(0x0);
338
339 /// Negative zero.
340 ///
341 /// # Examples
342 ///
343 /// ```
344 /// use float4::F4E2M1;
345 /// assert_eq!(F4E2M1::NEG_ZERO.to_f64(), -0.0);
346 /// ```
347 pub const NEG_ZERO: F4E2M1 = F4E2M1(0x8);
348
349 /// One.
350 ///
351 /// # Examples
352 ///
353 /// ```
354 /// use float4::F4E2M1;
355 /// assert_eq!(F4E2M1::ONE.to_f64(), 1.0);
356 /// ```
357 pub const ONE: F4E2M1 = F4E2M1(0x2);
358
359 /// Negative one.
360 ///
361 /// # Examples
362 ///
363 /// ```
364 /// use float4::F4E2M1;
365 /// assert_eq!(F4E2M1::NEG_ONE.to_f64(), -1.0);
366 /// ```
367 pub const NEG_ONE: F4E2M1 = F4E2M1(0xA);
368
369 /// The machine epsilon for F4E2M1 (0.5).
370 ///
371 /// This is the difference between 1.0 and the next representable value.
372 ///
373 /// # Examples
374 ///
375 /// ```
376 /// use float4::F4E2M1;
377 /// assert_eq!(F4E2M1::EPSILON.to_f64(), 0.5);
378 /// ```
379 pub const EPSILON: F4E2M1 = F4E2M1(0x1);
380}
381
382impl Default for F4E2M1 {
383 /// Returns the default value of 0.0.
384 ///
385 /// # Examples
386 ///
387 /// ```
388 /// use float4::F4E2M1;
389 /// assert_eq!(F4E2M1::default().to_f64(), 0.0);
390 /// ```
391 #[inline]
392 fn default() -> Self {
393 F4E2M1::ZERO
394 }
395}
396
397impl From<f32> for F4E2M1 {
398 /// Converts a 32-bit float to F4E2M1.
399 ///
400 /// This is equivalent to converting via f64.
401 ///
402 /// # Examples
403 ///
404 /// ```
405 /// use float4::F4E2M1;
406 ///
407 /// let x: F4E2M1 = 2.5f32.into();
408 /// assert_eq!(x.to_f64(), 2.0); // Rounded to nearest
409 /// ```
410 #[inline]
411 fn from(value: f32) -> Self {
412 F4E2M1::from_f64(value as f64)
413 }
414}
415
416impl From<F4E2M1> for f32 {
417 /// Converts F4E2M1 to a 32-bit float.
418 ///
419 /// # Examples
420 ///
421 /// ```
422 /// use float4::F4E2M1;
423 ///
424 /// let x = F4E2M1::from_f64(1.5);
425 /// let y: f32 = x.into();
426 /// assert_eq!(y, 1.5);
427 /// ```
428 #[inline]
429 fn from(value: F4E2M1) -> Self {
430 value.to_f64() as f32
431 }
432}
433
434impl From<F4E2M1> for f64 {
435 /// Converts F4E2M1 to a 64-bit float.
436 ///
437 /// # Examples
438 ///
439 /// ```
440 /// use float4::F4E2M1;
441 ///
442 /// let x = F4E2M1::from_f64(3.0);
443 /// let y: f64 = x.into();
444 /// assert_eq!(y, 3.0);
445 /// ```
446 #[inline]
447 fn from(value: F4E2M1) -> Self {
448 value.to_f64()
449 }
450}
451
452impl std::fmt::Display for F4E2M1 {
453 /// Formats the F4E2M1 value for display.
454 ///
455 /// # Examples
456 ///
457 /// ```
458 /// use float4::F4E2M1;
459 ///
460 /// let x = F4E2M1::from_f64(1.5);
461 /// assert_eq!(format!("{}", x), "1.5");
462 /// ```
463 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
464 write!(f, "{}", self.to_f64())
465 }
466}
467
468impl std::fmt::LowerExp for F4E2M1 {
469 /// Formats the F4E2M1 value in scientific notation.
470 ///
471 /// # Examples
472 ///
473 /// ```
474 /// use float4::F4E2M1;
475 ///
476 /// let x = F4E2M1::from_f64(6.0);
477 /// assert_eq!(format!("{:e}", x), "6e0");
478 /// ```
479 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480 write!(f, "{:e}", self.to_f64())
481 }
482}
483
484impl std::fmt::UpperExp for F4E2M1 {
485 /// Formats the F4E2M1 value in scientific notation with uppercase E.
486 ///
487 /// # Examples
488 ///
489 /// ```
490 /// use float4::F4E2M1;
491 ///
492 /// let x = F4E2M1::from_f64(6.0);
493 /// assert_eq!(format!("{:E}", x), "6E0");
494 /// ```
495 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496 write!(f, "{:E}", self.to_f64())
497 }
498}
499
500#[cfg(test)]
501mod test {
502 use crate::F4E2M1;
503
504 #[test]
505 fn test_full_range() {
506 // Test all 16 possible FP4 values (0x0 to 0xF)
507 // Expected values for E2M1 format with bias=1:
508 // Positive values:
509 // 0x0 (0b0000): +0.0
510 // 0x1 (0b0001): +0.5 (denormal)
511 // 0x2 (0b0010): +1.0
512 // 0x3 (0b0011): +1.5
513 // 0x4 (0b0100): +2.0
514 // 0x5 (0b0101): +3.0
515 // 0x6 (0b0110): +4.0
516 // 0x7 (0b0111): +6.0
517 // Negative values (sign bit set):
518 // 0x8 (0b1000): -0.0
519 // 0x9 (0b1001): -0.5 (denormal)
520 // 0xA (0b1010): -1.0
521 // 0xB (0b1011): -1.5
522 // 0xC (0b1100): -2.0
523 // 0xD (0b1101): -3.0
524 // 0xE (0b1110): -4.0
525 // 0xF (0b1111): -6.0
526
527 let expected_values = [
528 0.0, // 0x0
529 0.5, // 0x1
530 1.0, // 0x2
531 1.5, // 0x3
532 2.0, // 0x4
533 3.0, // 0x5
534 4.0, // 0x6
535 6.0, // 0x7
536 -0.0, // 0x8
537 -0.5, // 0x9
538 -1.0, // 0xA
539 -1.5, // 0xB
540 -2.0, // 0xC
541 -3.0, // 0xD
542 -4.0, // 0xE
543 -6.0, // 0xF
544 ];
545
546 for (bits, expected) in (0u8..16).zip(expected_values.iter()) {
547 let converted = F4E2M1::from_bits(bits).to_f64();
548 assert_eq!(
549 converted, *expected,
550 "Failed for bits 0x{bits:X}: got {converted}, expected {expected}"
551 );
552
553 // Also test through the struct
554 let fp4 = F4E2M1(bits);
555 assert_eq!(
556 fp4.to_f64(),
557 *expected,
558 "Failed for F4E2M1(0x{:X}): got {}, expected {}",
559 bits,
560 fp4.to_f64(),
561 expected
562 );
563 }
564 }
565
566 #[test]
567 fn test_roundtrip() {
568 // Test that representable values round-trip correctly
569 let test_values = [
570 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
571 ];
572
573 for &x in &test_values {
574 let mxfp4 = F4E2M1::from_f64(x);
575 let roundtrip = mxfp4.to_f64();
576 assert_eq!(roundtrip, x, "Roundtrip failed for {x}: got {roundtrip}");
577 }
578 }
579
580 #[test]
581 fn test_rounding() {
582 // Test round-to-nearest-even behavior
583 // Values between representable FP4 values should round to nearest
584 // When exactly halfway, round to even (least significant bit = 0)
585
586 let test_cases = [
587 // Value -> Expected rounded value
588 // Based on actual behavior: 0.5 denormal (0x1) is the smallest positive value
589 (0.75, 1.0), // 0.75 -> 1.0 (nearest)
590 (1.25, 1.0), // 1.25 -> 1.0 (tie, round to even)
591 (1.75, 2.0), // 1.75 -> 2.0 (nearest)
592 (2.25, 2.0), // 2.25 -> 2.0 (nearest)
593 (2.5, 2.0), // 2.5 -> 2.0 (tie, round to even)
594 (2.75, 3.0), // 2.75 -> 3.0 (nearest)
595 (3.25, 3.0), // 3.25 -> 3.0 (nearest)
596 (3.5, 4.0), // 3.5 -> 4.0 (nearest)
597 (4.5, 4.0), // 4.5 -> 4.0 (nearest)
598 (5.0, 4.0), // 5.0 -> 4.0 (nearest)
599 (5.5, 6.0), // 5.5 -> 6.0 (nearest)
600 (7.0, 6.0), // 7.0 -> 6.0 (saturate to max)
601 (10.0, 6.0), // 10.0 -> 6.0 (saturate to max)
602 // Negative values
603 (-0.75, -1.0), // -0.75 -> -1.0
604 (-1.25, -1.0), // -1.25 -> -1.0
605 (-1.75, -2.0), // -1.75 -> -2.0
606 (-2.25, -2.0), // -2.25 -> -2.0
607 (-2.5, -2.0), // -2.5 -> -2.0
608 (-2.75, -3.0), // -2.75 -> -3.0
609 (-3.25, -3.0), // -3.25 -> -3.0
610 (-3.5, -4.0), // -3.5 -> -4.0
611 (-4.5, -4.0), // -4.5 -> -4.0
612 (-5.0, -4.0), // -5.0 -> -4.0
613 (-5.5, -6.0), // -5.5 -> -6.0
614 (-7.0, -6.0), // -7.0 -> -6.0 (saturate)
615 ];
616
617 for &(input, expected) in &test_cases {
618 let fp4 = F4E2M1::from_f64(input);
619 let result = fp4.to_f64();
620 assert_eq!(
621 result, expected,
622 "Rounding failed for {input}: got {result}, expected {expected}"
623 );
624 }
625 }
626
627 #[test]
628 fn test_special_values() {
629 // Test special values: infinities, NaN
630 use std::f64;
631
632 // Positive infinity should saturate to max positive value (6.0)
633 let fp4 = F4E2M1::from_f64(f64::INFINITY);
634 assert_eq!(fp4.to_f64(), 6.0);
635
636 // Negative infinity should saturate to max negative value (-6.0)
637 let fp4 = F4E2M1::from_f64(f64::NEG_INFINITY);
638 assert_eq!(fp4.to_f64(), -6.0);
639
640 // NaN should become positive max (6.0) according to the implementation
641 let fp4 = F4E2M1::from_f64(f64::NAN);
642 assert_eq!(fp4.to_f64(), 6.0);
643 }
644}