float4/
lib.rs

1//! Four-bit floating point types and block formats for Rust.
2//!
3//! This crate provides low-precision floating-point types following the OCP MX specification,
4//! designed for efficient storage and computation in machine learning applications where
5//! extreme quantization is beneficial.
6//!
7//! # Available Types
8//!
9//! - [`F4E2M1`]: 4-bit floating-point with 2 exponent bits and 1 mantissa bit
10//! - [`E8M0`]: 8-bit scale factor representing powers of two (2^-127 to 2^127)
11//! - [`MXFP4Block`]: Block format storing 32 F4E2M1 values with a shared E8M0 scale
12//!
13//! # F4E2M1 Format Details
14//!
15//! The [`F4E2M1`] type implements the E2M1 format with:
16//! - 1 sign bit
17//! - 2 exponent bits  
18//! - 1 mantissa bit
19//! - Exponent bias of 1
20//! - Round-to-nearest-even (roundTiesToEven) rounding mode
21//!
22//! This format can represent 16 distinct values ranging from -6.0 to 6.0, including:
23//! - Normal numbers: ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0
24//! - Subnormal numbers: ±0.5
25//! - Zero: ±0.0
26//!
27//! # Examples
28//!
29//! Basic usage:
30//!
31//! ```
32//! use float4::F4E2M1;
33//!
34//! // Create from f64
35//! let a = F4E2M1::from_f64(1.5);
36//! assert_eq!(a.to_f64(), 1.5);
37//!
38//! // Create from raw bits
39//! let b = F4E2M1::from_bits(0x3); // 0b0011 = 1.5
40//! assert_eq!(b.to_f64(), 1.5);
41//!
42//! // Values outside representable range saturate
43//! let c = F4E2M1::from_f64(10.0);
44//! assert_eq!(c.to_f64(), 6.0); // Saturates to maximum
45//! ```
46//!
47//! # Rounding Behavior
48//!
49//! The type uses round-to-nearest-even as specified by IEEE 754:
50//!
51//! ```
52//! use float4::F4E2M1;
53//!
54//! // Rounding to nearest
55//! assert_eq!(F4E2M1::from_f64(1.75).to_f64(), 2.0);
56//! assert_eq!(F4E2M1::from_f64(2.25).to_f64(), 2.0);
57//!
58//! // Round-to-even when exactly halfway
59//! assert_eq!(F4E2M1::from_f64(1.25).to_f64(), 1.0); // Rounds to even
60//! assert_eq!(F4E2M1::from_f64(2.5).to_f64(), 2.0);  // Rounds to even
61//! ```
62//!
63//! # Special Values
64//!
65//! Unlike standard floating point formats, F4E2M1 has no representation for infinity or NaN.
66//! These values saturate to the maximum representable value:
67//!
68//! ```
69//! use float4::F4E2M1;
70//!
71//! assert_eq!(F4E2M1::from_f64(f64::INFINITY).to_f64(), 6.0);
72//! assert_eq!(F4E2M1::from_f64(f64::NEG_INFINITY).to_f64(), -6.0);
73//! assert_eq!(F4E2M1::from_f64(f64::NAN).to_f64(), 6.0);
74//! ```
75//!
76//! # MXFP4 Block Format
77//!
78//! The [`MXFP4Block`] type provides efficient storage for multiple F4E2M1 values by sharing
79//! a common scale factor:
80//!
81//! ```
82//! use float4::{F4E2M1, E8M0, MXFP4Block};
83//!
84//! // Original f32 data
85//! let data = vec![1.5, -2.0, 0.5, 3.0];
86//!
87//! // Compute scale (rounds up to power of 2)
88//! let scale = E8M0::from_f32_slice(&data);
89//! assert_eq!(scale.to_f64(), 4.0); // 3.0 rounds up to 4.0
90//!
91//! // Quantize values
92//! let mut quantized = [F4E2M1::from_f64(0.0); 32];
93//! for i in 0..data.len() {
94//!     quantized[i] = F4E2M1::from_f64(data[i] as f64 / scale.to_f64());
95//! }
96//!
97//! // Pack into block (17 bytes total for 32 values)
98//! let block = MXFP4Block::from_f32_slice(quantized, scale);
99//!
100//! // Convert back
101//! let restored = block.to_f32_array();
102//! // Note: Due to F4E2M1's limited precision, values may be quantized
103//! assert_eq!(restored[0], 2.0);  // 1.5/4.0 = 0.375 -> rounds to 0.5 -> 0.5*4.0 = 2.0
104//! assert_eq!(restored[1], -2.0); // -2.0/4.0 = -0.5 is exactly representable
105//! ```
106//!
107//! This format achieves 4× compression compared to f32, making it ideal for:
108//! - Neural network weight storage
109//! - Activation caching in quantized models
110//! - Memory-bandwidth limited applications
111
112mod block;
113mod cvt;
114mod m8e0;
115
116pub use block::MXFP4Block;
117pub use m8e0::E8M0;
118
119/// A 4-bit floating point type with 2 exponent bits and 1 mantissa bit.
120///
121/// This type implements the E2M1 format from the OCP MX specification, providing
122/// a compact representation suitable for machine learning applications requiring
123/// extreme quantization.
124///
125/// # Format
126///
127/// The 4 bits are laid out as follows:
128/// - Bit 3: Sign bit (0 = positive, 1 = negative)
129/// - Bits 2-1: Exponent bits (biased by 1)
130/// - Bit 0: Mantissa bit
131///
132/// # Representable Values
133///
134/// F4E2M1 can exactly represent the following values:
135/// - **Normal numbers**: ±1.0, ±1.5, ±2.0, ±3.0, ±4.0, ±6.0
136/// - **Subnormal numbers**: ±0.5  
137/// - **Zero**: ±0.0
138///
139/// # Examples
140///
141/// ```
142/// use float4::F4E2M1;
143///
144/// // Create from floating point value
145/// let x = F4E2M1::from_f64(2.5);
146/// assert_eq!(x.to_f64(), 2.0); // Rounded to nearest representable value
147///
148/// // Access raw bit representation  
149/// let bits = x.to_bits();
150/// assert_eq!(bits, 0x4); // 0b0100 = +2.0
151/// ```
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
153#[repr(transparent)]
154pub struct F4E2M1(u8);
155
156const _: () = assert!(std::mem::size_of::<F4E2M1>() == 1);
157
158impl F4E2M1 {
159    /// Creates a new `F4E2M1` value from a 64-bit floating point number.
160    ///
161    /// This function converts the input to the nearest representable F4E2M1 value
162    /// using round-to-nearest-even. Values outside the
163    /// representable range will saturate to the maximum or minimum values.
164    ///
165    /// # Examples
166    ///
167    /// ```
168    /// use float4::F4E2M1;
169    ///
170    /// // Exact representable values
171    /// assert_eq!(F4E2M1::from_f64(2.0).to_f64(), 2.0);
172    /// assert_eq!(F4E2M1::from_f64(-3.0).to_f64(), -3.0);
173    ///
174    /// // Rounding
175    /// assert_eq!(F4E2M1::from_f64(2.7).to_f64(), 3.0);
176    /// assert_eq!(F4E2M1::from_f64(1.25).to_f64(), 1.0); // Round to even
177    ///
178    /// // Saturation
179    /// assert_eq!(F4E2M1::from_f64(10.0).to_f64(), 6.0);
180    /// assert_eq!(F4E2M1::from_f64(-10.0).to_f64(), -6.0);
181    /// ```
182    ///
183    /// # Special Values
184    ///
185    /// - `NaN` → 6.0 (maximum positive value)
186    /// - `+Infinity` → 6.0  
187    /// - `-Infinity` → -6.0
188    #[inline(always)]
189    pub const fn from_f64(x: f64) -> Self {
190        Self(cvt::f64_to_fp4(x))
191    }
192
193    /// Converts this `F4E2M1` value to a 64-bit floating point number.
194    ///
195    /// This conversion is exact - the returned f64 will precisely represent
196    /// the value stored in the F4E2M1.
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// use float4::F4E2M1;
202    ///
203    /// let x = F4E2M1::from_f64(1.5);
204    /// assert_eq!(x.to_f64(), 1.5);
205    ///
206    /// // All 16 possible values can be converted
207    /// for i in 0..16 {
208    ///     let fp4 = F4E2M1::from_bits(i);
209    ///     let _ = fp4.to_f64(); // Always succeeds
210    /// }
211    /// ```
212    #[inline(always)]
213    pub fn to_f64(&self) -> f64 {
214        cvt::fp4_to_f64(self.0)
215    }
216
217    /// Creates a new `F4E2M1` value from its raw 4-bit representation.
218    ///
219    /// The bits are interpreted as:
220    /// - Bit 3: Sign (0 = positive, 1 = negative)
221    /// - Bits 2-1: Exponent (biased by 1)
222    /// - Bit 0: Mantissa
223    ///
224    /// Only the lower 4 bits of the input are used.
225    ///
226    /// # Examples
227    ///
228    /// ```
229    /// use float4::F4E2M1;
230    ///
231    /// // 0x0 = 0b0000 = +0.0
232    /// assert_eq!(F4E2M1::from_bits(0x0).to_f64(), 0.0);
233    ///
234    /// // 0x3 = 0b0011 = +1.5
235    /// assert_eq!(F4E2M1::from_bits(0x3).to_f64(), 1.5);
236    ///
237    /// // 0xF = 0b1111 = -6.0
238    /// assert_eq!(F4E2M1::from_bits(0xF).to_f64(), -6.0);
239    /// ```
240    ///
241    /// # Bit Patterns
242    ///
243    /// | Bits | Decimal | Value |
244    /// |------|---------|-------|
245    /// | 0000 |    0    |  0.0  |
246    /// | 0001 |    1    |  0.5  |
247    /// | 0010 |    2    |  1.0  |
248    /// | 0011 |    3    |  1.5  |
249    /// | 0100 |    4    |  2.0  |
250    /// | 0101 |    5    |  3.0  |
251    /// | 0110 |    6    |  4.0  |
252    /// | 0111 |    7    |  6.0  |
253    /// | 1000 |    8    | -0.0  |
254    /// | 1001 |    9    | -0.5  |
255    /// | 1010 |   10    | -1.0  |
256    /// | 1011 |   11    | -1.5  |
257    /// | 1100 |   12    | -2.0  |
258    /// | 1101 |   13    | -3.0  |
259    /// | 1110 |   14    | -4.0  |
260    /// | 1111 |   15    | -6.0  |
261    #[inline(always)]
262    pub const fn from_bits(bits: u8) -> Self {
263        Self(bits)
264    }
265
266    /// Returns the raw 4-bit representation of this `F4E2M1` value.
267    ///
268    /// The returned byte contains the 4-bit value in its lower nibble.
269    /// The upper 4 bits are always zero.
270    ///
271    /// # Examples
272    ///
273    /// ```
274    /// use float4::F4E2M1;
275    ///
276    /// let x = F4E2M1::from_f64(1.5);
277    /// assert_eq!(x.to_bits(), 0x3); // 0b0011
278    ///
279    /// let y = F4E2M1::from_f64(-2.0);
280    /// assert_eq!(y.to_bits(), 0xC); // 0b1100
281    /// ```
282    #[inline(always)]
283    pub const fn to_bits(&self) -> u8 {
284        self.0
285    }
286}
287
288impl F4E2M1 {
289    /// The smallest positive normal F4E2M1 value (1.0).
290    ///
291    /// # Examples
292    ///
293    /// ```
294    /// use float4::F4E2M1;
295    /// assert_eq!(F4E2M1::MIN_POSITIVE_NORMAL.to_f64(), 1.0);
296    /// ```
297    pub const MIN_POSITIVE_NORMAL: F4E2M1 = F4E2M1(0x2);
298
299    /// The smallest positive F4E2M1 value (0.5).
300    ///
301    /// # Examples
302    ///
303    /// ```
304    /// use float4::F4E2M1;
305    /// assert_eq!(F4E2M1::MIN_POSITIVE.to_f64(), 0.5);
306    /// ```
307    pub const MIN_POSITIVE: F4E2M1 = F4E2M1(0x1);
308
309    /// The largest F4E2M1 value (6.0).
310    ///
311    /// # Examples
312    ///
313    /// ```
314    /// use float4::F4E2M1;
315    /// assert_eq!(F4E2M1::MAX.to_f64(), 6.0);
316    /// ```
317    pub const MAX: F4E2M1 = F4E2M1(0x7);
318
319    /// The smallest (most negative) F4E2M1 value (-6.0).
320    ///
321    /// # Examples
322    ///
323    /// ```
324    /// use float4::F4E2M1;
325    /// assert_eq!(F4E2M1::MIN.to_f64(), -6.0);
326    /// ```
327    pub const MIN: F4E2M1 = F4E2M1(0xF);
328
329    /// Positive zero.
330    ///
331    /// # Examples
332    ///
333    /// ```
334    /// use float4::F4E2M1;
335    /// assert_eq!(F4E2M1::ZERO.to_f64(), 0.0);
336    /// ```
337    pub const ZERO: F4E2M1 = F4E2M1(0x0);
338
339    /// Negative zero.
340    ///
341    /// # Examples
342    ///
343    /// ```
344    /// use float4::F4E2M1;
345    /// assert_eq!(F4E2M1::NEG_ZERO.to_f64(), -0.0);
346    /// ```
347    pub const NEG_ZERO: F4E2M1 = F4E2M1(0x8);
348
349    /// One.
350    ///
351    /// # Examples
352    ///
353    /// ```
354    /// use float4::F4E2M1;
355    /// assert_eq!(F4E2M1::ONE.to_f64(), 1.0);
356    /// ```
357    pub const ONE: F4E2M1 = F4E2M1(0x2);
358
359    /// Negative one.
360    ///
361    /// # Examples
362    ///
363    /// ```
364    /// use float4::F4E2M1;
365    /// assert_eq!(F4E2M1::NEG_ONE.to_f64(), -1.0);
366    /// ```
367    pub const NEG_ONE: F4E2M1 = F4E2M1(0xA);
368
369    /// The machine epsilon for F4E2M1 (0.5).
370    ///
371    /// This is the difference between 1.0 and the next representable value.
372    ///
373    /// # Examples
374    ///
375    /// ```
376    /// use float4::F4E2M1;
377    /// assert_eq!(F4E2M1::EPSILON.to_f64(), 0.5);
378    /// ```
379    pub const EPSILON: F4E2M1 = F4E2M1(0x1);
380}
381
382impl Default for F4E2M1 {
383    /// Returns the default value of 0.0.
384    ///
385    /// # Examples
386    ///
387    /// ```
388    /// use float4::F4E2M1;
389    /// assert_eq!(F4E2M1::default().to_f64(), 0.0);
390    /// ```
391    #[inline]
392    fn default() -> Self {
393        F4E2M1::ZERO
394    }
395}
396
397impl From<f32> for F4E2M1 {
398    /// Converts a 32-bit float to F4E2M1.
399    ///
400    /// This is equivalent to converting via f64.
401    ///
402    /// # Examples
403    ///
404    /// ```
405    /// use float4::F4E2M1;
406    ///
407    /// let x: F4E2M1 = 2.5f32.into();
408    /// assert_eq!(x.to_f64(), 2.0); // Rounded to nearest
409    /// ```
410    #[inline]
411    fn from(value: f32) -> Self {
412        F4E2M1::from_f64(value as f64)
413    }
414}
415
416impl From<F4E2M1> for f32 {
417    /// Converts F4E2M1 to a 32-bit float.
418    ///
419    /// # Examples
420    ///
421    /// ```
422    /// use float4::F4E2M1;
423    ///
424    /// let x = F4E2M1::from_f64(1.5);
425    /// let y: f32 = x.into();
426    /// assert_eq!(y, 1.5);
427    /// ```
428    #[inline]
429    fn from(value: F4E2M1) -> Self {
430        value.to_f64() as f32
431    }
432}
433
434impl From<F4E2M1> for f64 {
435    /// Converts F4E2M1 to a 64-bit float.
436    ///
437    /// # Examples
438    ///
439    /// ```
440    /// use float4::F4E2M1;
441    ///
442    /// let x = F4E2M1::from_f64(3.0);
443    /// let y: f64 = x.into();
444    /// assert_eq!(y, 3.0);
445    /// ```
446    #[inline]
447    fn from(value: F4E2M1) -> Self {
448        value.to_f64()
449    }
450}
451
452impl std::fmt::Display for F4E2M1 {
453    /// Formats the F4E2M1 value for display.
454    ///
455    /// # Examples
456    ///
457    /// ```
458    /// use float4::F4E2M1;
459    ///
460    /// let x = F4E2M1::from_f64(1.5);
461    /// assert_eq!(format!("{}", x), "1.5");
462    /// ```
463    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
464        write!(f, "{}", self.to_f64())
465    }
466}
467
468impl std::fmt::LowerExp for F4E2M1 {
469    /// Formats the F4E2M1 value in scientific notation.
470    ///
471    /// # Examples
472    ///
473    /// ```
474    /// use float4::F4E2M1;
475    ///
476    /// let x = F4E2M1::from_f64(6.0);
477    /// assert_eq!(format!("{:e}", x), "6e0");
478    /// ```
479    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480        write!(f, "{:e}", self.to_f64())
481    }
482}
483
484impl std::fmt::UpperExp for F4E2M1 {
485    /// Formats the F4E2M1 value in scientific notation with uppercase E.
486    ///
487    /// # Examples
488    ///
489    /// ```
490    /// use float4::F4E2M1;
491    ///
492    /// let x = F4E2M1::from_f64(6.0);
493    /// assert_eq!(format!("{:E}", x), "6E0");
494    /// ```
495    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496        write!(f, "{:E}", self.to_f64())
497    }
498}
499
500#[cfg(test)]
501mod test {
502    use crate::F4E2M1;
503
504    #[test]
505    fn test_full_range() {
506        // Test all 16 possible FP4 values (0x0 to 0xF)
507        // Expected values for E2M1 format with bias=1:
508        // Positive values:
509        // 0x0 (0b0000): +0.0
510        // 0x1 (0b0001): +0.5 (denormal)
511        // 0x2 (0b0010): +1.0
512        // 0x3 (0b0011): +1.5
513        // 0x4 (0b0100): +2.0
514        // 0x5 (0b0101): +3.0
515        // 0x6 (0b0110): +4.0
516        // 0x7 (0b0111): +6.0
517        // Negative values (sign bit set):
518        // 0x8 (0b1000): -0.0
519        // 0x9 (0b1001): -0.5 (denormal)
520        // 0xA (0b1010): -1.0
521        // 0xB (0b1011): -1.5
522        // 0xC (0b1100): -2.0
523        // 0xD (0b1101): -3.0
524        // 0xE (0b1110): -4.0
525        // 0xF (0b1111): -6.0
526
527        let expected_values = [
528            0.0,  // 0x0
529            0.5,  // 0x1
530            1.0,  // 0x2
531            1.5,  // 0x3
532            2.0,  // 0x4
533            3.0,  // 0x5
534            4.0,  // 0x6
535            6.0,  // 0x7
536            -0.0, // 0x8
537            -0.5, // 0x9
538            -1.0, // 0xA
539            -1.5, // 0xB
540            -2.0, // 0xC
541            -3.0, // 0xD
542            -4.0, // 0xE
543            -6.0, // 0xF
544        ];
545
546        for (bits, expected) in (0u8..16).zip(expected_values.iter()) {
547            let converted = F4E2M1::from_bits(bits).to_f64();
548            assert_eq!(
549                converted, *expected,
550                "Failed for bits 0x{bits:X}: got {converted}, expected {expected}"
551            );
552
553            // Also test through the struct
554            let fp4 = F4E2M1(bits);
555            assert_eq!(
556                fp4.to_f64(),
557                *expected,
558                "Failed for F4E2M1(0x{:X}): got {}, expected {}",
559                bits,
560                fp4.to_f64(),
561                expected
562            );
563        }
564    }
565
566    #[test]
567    fn test_roundtrip() {
568        // Test that representable values round-trip correctly
569        let test_values = [
570            0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
571        ];
572
573        for &x in &test_values {
574            let mxfp4 = F4E2M1::from_f64(x);
575            let roundtrip = mxfp4.to_f64();
576            assert_eq!(roundtrip, x, "Roundtrip failed for {x}: got {roundtrip}");
577        }
578    }
579
580    #[test]
581    fn test_rounding() {
582        // Test round-to-nearest-even behavior
583        // Values between representable FP4 values should round to nearest
584        // When exactly halfway, round to even (least significant bit = 0)
585
586        let test_cases = [
587            // Value -> Expected rounded value
588            // Based on actual behavior: 0.5 denormal (0x1) is the smallest positive value
589            (0.75, 1.0), // 0.75 -> 1.0 (nearest)
590            (1.25, 1.0), // 1.25 -> 1.0 (tie, round to even)
591            (1.75, 2.0), // 1.75 -> 2.0 (nearest)
592            (2.25, 2.0), // 2.25 -> 2.0 (nearest)
593            (2.5, 2.0),  // 2.5 -> 2.0 (tie, round to even)
594            (2.75, 3.0), // 2.75 -> 3.0 (nearest)
595            (3.25, 3.0), // 3.25 -> 3.0 (nearest)
596            (3.5, 4.0),  // 3.5 -> 4.0 (nearest)
597            (4.5, 4.0),  // 4.5 -> 4.0 (nearest)
598            (5.0, 4.0),  // 5.0 -> 4.0 (nearest)
599            (5.5, 6.0),  // 5.5 -> 6.0 (nearest)
600            (7.0, 6.0),  // 7.0 -> 6.0 (saturate to max)
601            (10.0, 6.0), // 10.0 -> 6.0 (saturate to max)
602            // Negative values
603            (-0.75, -1.0), // -0.75 -> -1.0
604            (-1.25, -1.0), // -1.25 -> -1.0
605            (-1.75, -2.0), // -1.75 -> -2.0
606            (-2.25, -2.0), // -2.25 -> -2.0
607            (-2.5, -2.0),  // -2.5 -> -2.0
608            (-2.75, -3.0), // -2.75 -> -3.0
609            (-3.25, -3.0), // -3.25 -> -3.0
610            (-3.5, -4.0),  // -3.5 -> -4.0
611            (-4.5, -4.0),  // -4.5 -> -4.0
612            (-5.0, -4.0),  // -5.0 -> -4.0
613            (-5.5, -6.0),  // -5.5 -> -6.0
614            (-7.0, -6.0),  // -7.0 -> -6.0 (saturate)
615        ];
616
617        for &(input, expected) in &test_cases {
618            let fp4 = F4E2M1::from_f64(input);
619            let result = fp4.to_f64();
620            assert_eq!(
621                result, expected,
622                "Rounding failed for {input}: got {result}, expected {expected}"
623            );
624        }
625    }
626
627    #[test]
628    fn test_special_values() {
629        // Test special values: infinities, NaN
630        use std::f64;
631
632        // Positive infinity should saturate to max positive value (6.0)
633        let fp4 = F4E2M1::from_f64(f64::INFINITY);
634        assert_eq!(fp4.to_f64(), 6.0);
635
636        // Negative infinity should saturate to max negative value (-6.0)
637        let fp4 = F4E2M1::from_f64(f64::NEG_INFINITY);
638        assert_eq!(fp4.to_f64(), -6.0);
639
640        // NaN should become positive max (6.0) according to the implementation
641        let fp4 = F4E2M1::from_f64(f64::NAN);
642        assert_eq!(fp4.to_f64(), 6.0);
643    }
644}