float8/
lib.rs

1//! Eight bit floating point types in Rust.
2//!
3//! This crate provides 2 types:
4//! - [`F8E4M3`]: Sign + 4-bit exponent + 3-bit mantissa. More precise but less dynamic range.
5//! - [`F8E5M2`]: Sign + 5-bit exponent + 2-bit mantissa. Less precise but more dynamic range (same exponent as [`struct@f16`]).
6//!
7//! Generally, this crate is modelled after the [`half`] crate, so it can be
8//! used alongside and with minimal code changes.
9//!
10//! # Serialization
11//!
12//! When the `serde` feature is enabled, [`F8E4M3`] and [`F8E5M2`] will be serialized as a newtype of
13//! [`u16`] by default. In binary formats this is ideal, as it will generally use just two bytes for
14//! storage. For string formats like JSON, however, this isn't as useful, and due to design
15//! limitations of serde, it's not possible for the default `Serialize` implementation to support
16//! different serialization for different formats.
17//!
18//! It is up to the container type of the floats to control how it is serialized. This can
19//! easily be controlled when using the derive macros using `#[serde(serialize_with="")]`
20//! attributes. For both [`F8E4M3`] and [`F8E5M2`], a `serialize_as_f32` and `serialize_as_string` are
21//! provided for use with this attribute.
22//!
23//! Deserialization of both float types supports deserializing from the default serialization,
24//! strings, and `f32`/`f64` values, so no additional work is required.
25//!
26//! # Cargo Features
27//!
28//! This crate supports a number of optional cargo features. None of these features are enabled by
29//! default, even `std`.
30//!
31//! - **`std`** — Enable features that depend on the Rust [`std`] library.
32//!
33//! - **`serde`** — Adds support for the [`serde`] crate by implementing [`Serialize`] and
34//!   [`Deserialize`] traits for both [`F8E4M3`] and [`F8E5M2`].
35//!
36//! - **`num-traits`** — Adds support for the [`num-traits`] crate by implementing [`ToPrimitive`],
37//!   [`FromPrimitive`], [`AsPrimitive`], [`Num`], [`Float`], [`FloatCore`], and [`Bounded`] traits
38//!   for both [`F8E4M3`] and [`F8E5M2`].
39//!
40//! - **`bytemuck`** — Adds support for the [`bytemuck`] crate by implementing [`Zeroable`] and
41//!   [`Pod`] traits for both [`F8E4M3`] and [`F8E5M2`].
42//!
43//! - **`zerocopy`** — Adds support for the [`zerocopy`] crate by implementing [`AsBytes`] and
44//!   [`FromBytes`] traits for both [`F8E4M3`] and [`F8E5M2`].
45//!
46//! - **`rand_distr`** — Adds support for the [`rand_distr`] crate by implementing [`Distribution`]
47//!   and other traits for both [`F8E4M3`] and [`F8E5M2`].
48//!
49//! - **`rkyv`** -- Enable zero-copy deserialization with [`rkyv`] crate.
50//!
51//! [`alloc`]: https://doc.rust-lang.org/alloc/
52//! [`std`]: https://doc.rust-lang.org/std/
53//! [`binary16`]: https://en.wikipedia.org/wiki/Half-precision_floating-point_format
54//! [`bfloat16`]: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
55//! [`serde`]: https://crates.io/crates/serde
56//! [`bytemuck`]: https://crates.io/crates/bytemuck
57//! [`num-traits`]: https://crates.io/crates/num-traits
58//! [`zerocopy`]: https://crates.io/crates/zerocopy
59//! [`rand_distr`]: https://crates.io/crates/rand_distr
60//! [`rkyv`]: https://crates.io/crates/rkyv
61//! [`FromBytes`]: https://docs.rs/zerocopy/latest/zerocopy/trait.FromBytes.html
62//! [`Distribution`]: https://docs.rs/rand/latest/rand/distributions/trait.Distribution.html
63//! [`AsBytes`]: https://docs.rs/zerocopy/0.6.6/zerocopy/trait.AsBytes.html
64//! [`Pod`]: https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html
65//! [`Zeroable`]: https://docs.rs/bytemuck/latest/bytemuck/trait.Zeroable.html
66//! [`Bounded`]: https://docs.rs/num-traits/latest/num_traits/bounds/trait.Bounded.html
67//! [`FloatCore`]: https://docs.rs/num-traits/latest/num_traits/float/trait.FloatCore.html
68//! [`Float`]: https://docs.rs/num-traits/latest/num_traits/float/trait.Float.html
69//! [`Num`]: https://docs.rs/num-traits/latest/num_traits/trait.Num.html
70//! [`AsPrimitive`]: https://docs.rs/num-traits/latest/num_traits/cast/trait.AsPrimitive.html
71//! [`ToPrimitive`]: https://docs.rs/num-traits/latest/num_traits/cast/trait.ToPrimitive.html
72//! [`FromPrimitive`]: https://docs.rs/num-traits/latest/num_traits/cast/trait.FromPrimitive.html
73//! [`Deserialize`]: https://docs.rs/serde/latest/serde/trait.Deserialize.html
74//! [`Serialize`]: https://docs.rs/serde/latest/serde/trait.Serialize.html
75
76#![no_std]
77
78#[cfg(feature = "num-traits")]
79mod num_traits;
80#[cfg(feature = "rand_distr")]
81mod rand_distr;
82
83use core::{
84    cmp::Ordering,
85    f64,
86    fmt::{self, Debug, Display, LowerExp, LowerHex, UpperExp, UpperHex},
87    mem,
88    num::{FpCategory, ParseFloatError},
89    ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Rem, RemAssign, Sub, SubAssign},
90    str::FromStr,
91};
92use half::f16;
93
94#[cfg(feature = "bytemuck")]
95use bytemuck::{Pod, Zeroable};
96#[cfg(feature = "serde")]
97use serde::{Deserialize, Serialize};
98#[cfg(feature = "zerocopy")]
99use zerocopy::{AsBytes, FromBytes};
100
101#[derive(Clone, Copy, PartialEq)]
102enum Kind {
103    E4M3,
104    E5M2,
105}
106
107#[allow(dead_code)]
108#[derive(Clone, Copy, PartialEq, Default)]
109/// Saturation type. If `NoSat`, allow NaN and inf.
110enum SaturationType {
111    NoSat,
112    #[default]
113    SatFinite,
114}
115
116// https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp?ref_type=heads#L97
117const fn convert_to_fp8(x: f64, saturate: SaturationType, fp8_interpretation: Kind) -> u8 {
118    // TODO: use x.to_bits() with MSRV 1.83
119    #[allow(unknown_lints, unnecessary_transmutes)]
120    let xbits: u64 = unsafe { mem::transmute::<f64, u64>(x) };
121
122    let (
123        fp8_maxnorm,
124        fp8_mantissa_mask,
125        fp8_exp_bias,
126        fp8_significand_bits,
127        fp8_mindenorm_o2,
128        fp8_overflow_threshold,
129        fp8_minnorm,
130    ) = match fp8_interpretation {
131        Kind::E4M3 => (
132            0x7E_u8,
133            0x7_u8,
134            7_u16,
135            4_u64,
136            0x3F50000000000000_u64,
137            0x407D000000000000_u64,
138            0x3F90000000000000_u64,
139        ),
140        Kind::E5M2 => (
141            0x7B_u8,
142            0x3_u8,
143            15_u16,
144            3_u64,
145            0x3EE0000000000000_u64,
146            0x40EE000000000000_u64 - 1,
147            0x3F10000000000000_u64,
148        ),
149    };
150
151    const DP_INF_BITS: u64 = 0x7FF0000000000000;
152    let fp8_dp_half_ulp: u64 = 1 << (53 - fp8_significand_bits - 1);
153    let sign: u8 = ((xbits >> 63) << 7) as u8;
154    let exp: u8 = ((((xbits >> 52) as u16) & 0x7FF)
155        .wrapping_sub(1023)
156        .wrapping_add(fp8_exp_bias)) as u8;
157    let mantissa: u8 = ((xbits >> (53 - fp8_significand_bits)) & (fp8_mantissa_mask as u64)) as u8;
158    let absx: u64 = xbits & 0x7FFFFFFFFFFFFFFF;
159
160    let res = if absx <= fp8_mindenorm_o2 {
161        // Zero or underflow
162        0
163    } else if absx > DP_INF_BITS {
164        // Preserve NaNs
165        match fp8_interpretation {
166            Kind::E4M3 => 0x7F,
167            Kind::E5M2 => 0x7E | mantissa,
168        }
169    } else if absx > fp8_overflow_threshold {
170        // Saturate
171        match saturate {
172            SaturationType::SatFinite => fp8_maxnorm,
173            SaturationType::NoSat => match fp8_interpretation {
174                Kind::E4M3 => 0x7F, // NaN
175                Kind::E5M2 => 0x7C, // Inf in E5M2
176            },
177        }
178    } else if absx >= fp8_minnorm {
179        // Round, normal range
180        let mut res = (exp << (fp8_significand_bits - 1)) | mantissa;
181
182        // Round off bits and round-to-nearest-even adjustment
183        let round = xbits & ((fp8_dp_half_ulp << 1) - 1);
184        if (round > fp8_dp_half_ulp) || ((round == fp8_dp_half_ulp) && (mantissa & 1 != 0)) {
185            res = res.wrapping_add(1);
186        }
187        res
188    } else {
189        // Denormal numbers
190        let shift = 1_u8.wrapping_sub(exp);
191        let mantissa = mantissa | (1 << (fp8_significand_bits - 1));
192        let mut res = mantissa >> shift;
193
194        // Round off bits and round-to-nearest-even adjustment
195        let round = (xbits | (1 << (53 - 1))) & ((fp8_dp_half_ulp << (shift as u64 + 1)) - 1);
196        if (round > (fp8_dp_half_ulp << shift as u64))
197            || ((round == (fp8_dp_half_ulp << shift as u64)) && (res & 1 != 0))
198        {
199            res = res.wrapping_add(1);
200        }
201        res
202    };
203
204    res | sign
205}
206
207// https://gitlab.com/nvidia/headers/cuda-individual/cudart/-/blob/main/cuda_fp8.hpp?ref_type=heads#L463
208const fn convert_fp8_to_fp16(x: u8, fp8_interpretation: Kind) -> u16 {
209    let mut ur = (x as u16) << 8;
210
211    match fp8_interpretation {
212        Kind::E5M2 => {
213            if (ur & 0x7FFF) > 0x7C00 {
214                // If NaN, return canonical NaN
215                ur = 0x7FFF;
216            }
217        }
218        Kind::E4M3 => {
219            let sign = ur & 0x8000;
220            let mut exponent = ((ur & 0x7800) >> 1).wrapping_add(0x2000);
221            let mut mantissa = (ur & 0x0700) >> 1;
222            let absx = 0x7F & x;
223
224            if absx == 0x7F {
225                // FP16 canonical NaN, discard sign
226                ur = 0x7FFF;
227            } else if exponent == 0x2000 {
228                // Zero or denormal
229                if mantissa != 0 {
230                    // Normalize
231                    mantissa <<= 1;
232                    while (mantissa & 0x0400) == 0 {
233                        mantissa <<= 1;
234                        exponent = exponent.wrapping_sub(0x0400);
235                    }
236                    // Discard implicit leading bit
237                    mantissa &= 0x03FF;
238                } else {
239                    // Zero
240                    exponent = 0;
241                }
242                ur = sign | exponent | mantissa;
243            } else {
244                ur = sign | exponent | mantissa;
245            }
246        }
247    };
248
249    ur
250}
251
252#[derive(Clone, Copy, Default)]
253#[cfg_attr(feature = "serde", derive(Serialize))]
254#[cfg_attr(
255    feature = "rkyv",
256    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
257)]
258#[cfg_attr(feature = "rkyv", archive(resolver = "F8E4M3Resolver"))]
259#[cfg_attr(feature = "bytemuck", derive(Zeroable, Pod))]
260#[cfg_attr(feature = "zerocopy", derive(AsBytes, FromBytes))]
261#[repr(transparent)]
262/// Eight bit floating point type with 4-bit exponent and 3-bit mantissa.
263pub struct F8E4M3(u8);
264
265impl F8E4M3 {
266    const INTERPRETATION: Kind = Kind::E4M3;
267
268    /// Construct an 8-bit floating point value from the raw bits.
269    pub const fn from_bits(bits: u8) -> Self {
270        Self(bits)
271    }
272
273    /// Return the raw bits.
274    pub const fn to_bits(&self) -> u8 {
275        self.0
276    }
277
278    /// Convert a [`prim@f64`] type into [`F8E4M3`].
279    ///
280    /// This operation is lossy.
281    ///
282    /// - If the 64-bit value is to large to fit in 8-bits, ±∞ will result.
283    /// - NaN values are preserved.
284    /// - 64-bit subnormal values are too tiny to be represented in 8-bits and result in ±0.
285    /// - Exponents that underflow the minimum 8-bit exponent will result in 8-bit subnormals or ±0.
286    /// - All other values are truncated and rounded to the nearest representable  8-bit value.
287    pub const fn from_f64(x: f64) -> Self {
288        Self(convert_to_fp8(
289            x,
290            SaturationType::SatFinite,
291            Self::INTERPRETATION,
292        ))
293    }
294
295    /// Convert a [`f32`] type into [`F8E4M3`].
296    ///
297    /// This operation is lossy.
298    ///
299    /// - If the 32-bit value is to large to fit in 8-bits, ±∞ will result.
300    /// - NaN values are preserved.
301    /// - 32-bit subnormal values are too tiny to be represented in 8-bits and result in ±0.
302    /// - Exponents that underflow the minimum 8-bit exponent will result in 8-bit subnormals or ±0.
303    /// - All other values are truncated and rounded to the nearest representable  8-bit value.
304    pub const fn from_f32(x: f32) -> Self {
305        Self::from_f64(x as f64)
306    }
307
308    /// Convert this [`F8E4M3`] type into a [`struct@f16`] type.
309    ///
310    /// This operation may be lossy.
311    ///
312    /// - NaN and zero values are preserved.
313    /// - Subnormal values are normalized.
314    /// - Otherwise, the values are mapped to the appropriate 16-bit value.
315    pub const fn to_f16(&self) -> f16 {
316        f16::from_bits(convert_fp8_to_fp16(self.0, Self::INTERPRETATION))
317    }
318
319    /// Convert this [`F8E4M3`] type into a [`f32`] type.
320    ///
321    /// This operation may be lossy.
322    ///
323    /// - NaN and zero values are preserved.
324    /// - Subnormal values are normalized.
325    /// - Otherwise, the values are mapped to the appropriate 16-bit value.
326    pub const fn to_f32(&self) -> f32 {
327        self.to_f16().to_f32_const()
328    }
329
330    /// Convert this [`F8E4M3`] type into a [`prim@f64`] type.
331    ///
332    /// This operation may be lossy.
333    ///
334    /// - NaN and zero values are preserved.
335    /// - Subnormal values are normalized.
336    /// - Otherwise, the values are mapped to the appropriate 16-bit value.
337    pub const fn to_f64(&self) -> f64 {
338        self.to_f16().to_f64_const()
339    }
340
341    /// Returns the ordering between `self` and `other`.
342    ///
343    /// - negative quiet NaN
344    /// - negative signaling NaN
345    /// - negative infinity
346    /// - negative numbers
347    /// - negative subnormal numbers
348    /// - negative zero
349    /// - positive zero
350    /// - positive subnormal numbers
351    /// - positive numbers
352    /// - positive infinity
353    /// - positive signaling NaN
354    /// - positive quiet NaN.
355    ///
356    /// The ordering established by this function does not always agree with the
357    /// [`PartialOrd`] and [`PartialEq`] implementations. For example,
358    /// they consider negative and positive zero equal, while `total_cmp`
359    /// doesn't.
360    ///
361    /// # Example
362    /// ```
363    /// # use float8::F8E4M3;
364    ///
365    /// let mut v: Vec<F8E4M3> = vec![];
366    /// v.push(F8E4M3::ONE);
367    /// v.push(F8E4M3::INFINITY);
368    /// v.push(F8E4M3::NEG_INFINITY);
369    /// v.push(F8E4M3::NAN);
370    /// v.push(F8E4M3::MAX_SUBNORMAL);
371    /// v.push(-F8E4M3::MAX_SUBNORMAL);
372    /// v.push(F8E4M3::ZERO);
373    /// v.push(F8E4M3::NEG_ZERO);
374    /// v.push(F8E4M3::NEG_ONE);
375    /// v.push(F8E4M3::MIN_POSITIVE);
376    ///
377    /// v.sort_by(|a, b| a.total_cmp(&b));
378    ///
379    /// assert!(v
380    ///     .into_iter()
381    ///     .zip(
382    ///         [
383    ///             F8E4M3::NEG_INFINITY,
384    ///             F8E4M3::NEG_ONE,
385    ///             -F8E4M3::MAX_SUBNORMAL,
386    ///             F8E4M3::NEG_ZERO,
387    ///             F8E4M3::ZERO,
388    ///             F8E4M3::MAX_SUBNORMAL,
389    ///             F8E4M3::MIN_POSITIVE,
390    ///             F8E4M3::ONE,
391    ///             F8E4M3::INFINITY,
392    ///             F8E4M3::NAN
393    ///         ]
394    ///         .iter()
395    ///     )
396    ///     .all(|(a, b)| a.to_bits() == b.to_bits()));
397    /// ```
398    pub fn total_cmp(&self, other: &Self) -> Ordering {
399        let mut left = self.to_bits() as i8;
400        let mut right = other.to_bits() as i8;
401        left ^= (((left >> 7) as u8) >> 1) as i8;
402        right ^= (((right >> 7) as u8) >> 1) as i8;
403        left.cmp(&right)
404    }
405
406    /// Returns `true` if and only if `self` has a positive sign, including +0.0, NaNs with a
407    /// positive sign bit and +∞.
408    pub const fn is_sign_positive(&self) -> bool {
409        self.0 & 0x80u8 == 0
410    }
411
412    /// Returns `true` if and only if `self` has a negative sign, including −0.0, NaNs with a
413    /// negative sign bit and −∞.
414    pub const fn is_sign_negative(&self) -> bool {
415        self.0 & 0x80u8 != 0
416    }
417
418    /// Returns `true` if this value is NaN and `false` otherwise.
419    ///
420    /// # Examples
421    ///
422    /// ```rust
423    /// # use float8::*;
424    ///
425    /// let nan = F8E4M3::NAN;
426    /// let f = F8E4M3::from_f32(7.0_f32);
427    ///
428    /// assert!(nan.is_nan());
429    /// assert!(!f.is_nan());
430    /// ```
431    pub const fn is_nan(&self) -> bool {
432        self.0 == 0x7Fu8 || self.0 == 0xFFu8
433    }
434
435    /// Returns `true` if this value is ±∞ and `false` otherwise.
436    ///
437    /// # Examples
438    ///
439    /// ```rust
440    /// # use float8::*;
441    ///
442    /// let f = F8E4M3::from_f32(7.0f32);
443    /// let inf = F8E4M3::INFINITY;
444    /// let neg_inf = F8E4M3::NEG_INFINITY;
445    /// let nan = F8E4M3::NAN;
446    ///
447    /// assert!(!f.is_infinite());
448    /// assert!(!nan.is_infinite());
449    ///
450    /// assert!(inf.is_infinite());
451    /// assert!(neg_inf.is_infinite());
452    /// ```
453    pub const fn is_infinite(&self) -> bool {
454        self.0 & 0x7Fu8 == 0x7Eu8
455    }
456
457    /// Returns true if this number is neither infinite nor NaN.
458    ///
459    /// # Examples
460    ///
461    /// ```rust
462    /// # use float8::*;
463    ///
464    /// let f = F8E4M3::from_f32(7.0f32);
465    /// let inf = F8E4M3::INFINITY;
466    /// let neg_inf = F8E4M3::NEG_INFINITY;
467    /// let nan = F8E4M3::NAN;
468    ///
469    /// assert!(f.is_finite());
470    ///
471    /// assert!(!nan.is_finite());
472    /// assert!(!inf.is_finite());
473    /// assert!(!neg_inf.is_finite());
474    /// ```
475    pub const fn is_finite(&self) -> bool {
476        !(self.is_infinite() || self.is_nan())
477    }
478
479    /// Returns `true` if the number is neither zero, infinite, subnormal, or `NaN` and `false` otherwise.
480    ///
481    /// # Examples
482    ///
483    /// ```rust
484    /// # use float8::*;
485    ///
486    /// let min = F8E4M3::MIN_POSITIVE;
487    /// let max = F8E4M3::MAX;
488    /// let lower_than_min = F8E4M3::from_f32(1.0e-10_f32);
489    /// let zero = F8E4M3::from_f32(0.0_f32);
490    ///
491    /// assert!(min.is_normal());
492    /// assert!(max.is_normal());
493    ///
494    /// assert!(!zero.is_normal());
495    /// assert!(!F8E4M3::NAN.is_normal());
496    /// assert!(!F8E4M3::INFINITY.is_normal());
497    /// // Values between `0` and `min` are Subnormal.
498    /// assert!(!lower_than_min.is_normal());
499    /// ```
500    pub const fn is_normal(&self) -> bool {
501        #[allow(clippy::unusual_byte_groupings)]
502        let exp = self.0 & 0b0_1111_000;
503        exp != 0 && self.is_finite()
504    }
505
506    /// Returns the minimum of the two numbers.
507    ///
508    /// If one of the arguments is NaN, then the other argument is returned.
509    ///
510    /// # Examples
511    ///
512    /// ```
513    /// # use float8::*;
514    /// let x = F8E4M3::from_f32(1.0);
515    /// let y = F8E4M3::from_f32(2.0);
516    ///
517    /// assert_eq!(x.min(y), x);
518    /// ```
519    pub fn min(self, other: Self) -> Self {
520        if other < self && !other.is_nan() {
521            other
522        } else {
523            self
524        }
525    }
526
527    /// Returns the minimum of the two numbers.
528    ///
529    /// If one of the arguments is NaN, then the other argument is returned.
530    ///
531    /// # Examples
532    ///
533    /// ```
534    /// # use float8::*;
535    /// let x = F8E4M3::from_f32(1.0);
536    /// let y = F8E4M3::from_f32(2.0);
537    ///
538    /// assert_eq!(x.min(y), x);
539    /// ```
540    pub fn max(self, other: Self) -> Self {
541        if other > self && !other.is_nan() {
542            other
543        } else {
544            self
545        }
546    }
547
548    /// Restrict a value to a certain interval unless it is NaN.
549    ///
550    /// Returns `max` if `self` is greater than `max`, and `min` if `self` is less than `min`.
551    /// Otherwise this returns `self`.
552    ///
553    /// Note that this function returns NaN if the initial value was NaN as well.
554    ///
555    /// # Panics
556    /// Panics if `min > max`, `min` is NaN, or `max` is NaN.
557    ///
558    /// # Examples
559    ///
560    /// ```
561    /// # use float8::*;
562    /// assert!(F8E4M3::from_f32(-3.0).clamp(F8E4M3::from_f32(-2.0), F8E4M3::from_f32(1.0)) == F8E4M3::from_f32(-2.0));
563    /// assert!(F8E4M3::from_f32(0.0).clamp(F8E4M3::from_f32(-2.0), F8E4M3::from_f32(1.0)) == F8E4M3::from_f32(0.0));
564    /// assert!(F8E4M3::from_f32(2.0).clamp(F8E4M3::from_f32(-2.0), F8E4M3::from_f32(1.0)) == F8E4M3::from_f32(1.0));
565    /// assert!(F8E4M3::NAN.clamp(F8E4M3::from_f32(-2.0), F8E4M3::from_f32(1.0)).is_nan());
566    /// ```
567    pub fn clamp(self, min: Self, max: Self) -> Self {
568        assert!(min <= max);
569        let mut x = self;
570        if x < min {
571            x = min;
572        }
573        if x > max {
574            x = max;
575        }
576        x
577    }
578
579    /// Returns a number composed of the magnitude of `self` and the sign of `sign`.
580    ///
581    /// Equal to `self` if the sign of `self` and `sign` are the same, otherwise equal to `-self`.
582    /// If `self` is NaN, then NaN with the sign of `sign` is returned.
583    ///
584    /// # Examples
585    ///
586    /// ```
587    /// # use float8::*;
588    /// let f = F8E4M3::from_f32(3.5);
589    ///
590    /// assert_eq!(f.copysign(F8E4M3::from_f32(0.42)), F8E4M3::from_f32(3.5));
591    /// assert_eq!(f.copysign(F8E4M3::from_f32(-0.42)), F8E4M3::from_f32(-3.5));
592    /// assert_eq!((-f).copysign(F8E4M3::from_f32(0.42)), F8E4M3::from_f32(3.5));
593    /// assert_eq!((-f).copysign(F8E4M3::from_f32(-0.42)), F8E4M3::from_f32(-3.5));
594    ///
595    /// assert!(F8E4M3::NAN.copysign(F8E4M3::from_f32(1.0)).is_nan());
596    /// ```
597    pub const fn copysign(self, sign: Self) -> Self {
598        Self((sign.0 & 0x80u8) | (self.0 & 0x7Fu8))
599    }
600
601    /// Returns a number that represents the sign of `self`.
602    ///
603    /// * `1.0` if the number is positive, `+0.0` or [`INFINITY`][Self::INFINITY]
604    /// * `-1.0` if the number is negative, `-0.0` or [`NEG_INFINITY`][Self::NEG_INFINITY]
605    /// * [`NAN`][Self::NAN] if the number is `NaN`
606    ///
607    /// # Examples
608    ///
609    /// ```rust
610    /// # use float8::*;
611    ///
612    /// let f = F8E4M3::from_f32(3.5_f32);
613    ///
614    /// assert_eq!(f.signum(), F8E4M3::from_f32(1.0));
615    /// assert_eq!(F8E4M3::NEG_INFINITY.signum(), F8E4M3::from_f32(-1.0));
616    ///
617    /// assert!(F8E4M3::NAN.signum().is_nan());
618    /// ```
619    pub const fn signum(self) -> Self {
620        if self.is_nan() {
621            self
622        } else if self.0 & 0x80u8 != 0 {
623            Self::NEG_ONE
624        } else {
625            Self::ONE
626        }
627    }
628
629    /// Returns the floating point category of the number.
630    ///
631    /// If only one property is going to be tested, it is generally faster to use the specific
632    /// predicate instead.
633    ///
634    /// # Examples
635    ///
636    /// ```rust
637    /// use std::num::FpCategory;
638    /// # use float8::*;
639    ///
640    /// let num = F8E4M3::from_f32(12.4_f32);
641    /// let inf = F8E4M3::INFINITY;
642    ///
643    /// assert_eq!(num.classify(), FpCategory::Normal);
644    /// assert_eq!(inf.classify(), FpCategory::Infinite);
645    /// ```
646    pub const fn classify(&self) -> FpCategory {
647        if self.is_infinite() {
648            FpCategory::Infinite
649        } else if !self.is_normal() {
650            FpCategory::Subnormal
651        } else if self.is_nan() {
652            FpCategory::Nan
653        } else if self.0 & 0x7Fu8 == 0 {
654            FpCategory::Zero
655        } else {
656            FpCategory::Normal
657        }
658    }
659}
660
661#[cfg(feature = "serde")]
662struct VisitorF8E4M3;
663
664#[cfg(feature = "serde")]
665impl<'de> Deserialize<'de> for F8E4M3 {
666    fn deserialize<D>(deserializer: D) -> Result<F8E4M3, D::Error>
667    where
668        D: serde::de::Deserializer<'de>,
669    {
670        deserializer.deserialize_newtype_struct("f8e4m3", VisitorF8E4M3)
671    }
672}
673
674#[cfg(feature = "serde")]
675impl<'de> serde::de::Visitor<'de> for VisitorF8E4M3 {
676    type Value = F8E4M3;
677
678    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
679        write!(formatter, "tuple struct f8e4m3")
680    }
681
682    fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
683    where
684        D: serde::Deserializer<'de>,
685    {
686        Ok(F8E4M3(<u8 as Deserialize>::deserialize(deserializer)?))
687    }
688
689    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
690    where
691        E: serde::de::Error,
692    {
693        v.parse().map_err(|_| {
694            serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"a float string")
695        })
696    }
697
698    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
699    where
700        E: serde::de::Error,
701    {
702        Ok(F8E4M3::from_f32(v))
703    }
704
705    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
706    where
707        E: serde::de::Error,
708    {
709        Ok(F8E4M3::from_f64(v))
710    }
711}
712
713#[derive(Clone, Copy, Default)]
714#[cfg_attr(feature = "serde", derive(Serialize))]
715#[cfg_attr(
716    feature = "rkyv",
717    derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)
718)]
719#[cfg_attr(feature = "rkyv", archive(resolver = "F8E5M2Resolver"))]
720#[cfg_attr(feature = "bytemuck", derive(Zeroable, Pod))]
721#[cfg_attr(feature = "zerocopy", derive(AsBytes, FromBytes))]
722#[repr(transparent)]
723/// Eight bit floating point type with 5-bit exponent and 2-bit mantissa.
724pub struct F8E5M2(u8);
725
726impl F8E5M2 {
727    const INTERPRETATION: Kind = Kind::E5M2;
728
729    /// Construct an 8-bit floating point value from the raw bits.
730    pub const fn from_bits(bits: u8) -> Self {
731        Self(bits)
732    }
733
734    /// Return the raw bits.
735    pub const fn to_bits(&self) -> u8 {
736        self.0
737    }
738
739    /// Convert a [`prim@f64`] type into [`F8E5M2`].
740    ///
741    /// This operation is lossy.
742    ///
743    /// - If the 64-bit value is to large to fit in 8-bits, ±∞ will result.
744    /// - NaN values are preserved.
745    /// - 64-bit subnormal values are too tiny to be represented in 8-bits and result in ±0.
746    /// - Exponents that underflow the minimum 8-bit exponent will result in 8-bit subnormals or ±0.
747    /// - All other values are truncated and rounded to the nearest representable  8-bit value.
748    pub const fn from_f64(x: f64) -> Self {
749        Self(convert_to_fp8(
750            x,
751            SaturationType::SatFinite,
752            Self::INTERPRETATION,
753        ))
754    }
755
756    /// Convert a [`f32`] type into [`F8E5M2`].
757    ///
758    /// This operation is lossy.
759    ///
760    /// - If the 32-bit value is to large to fit in 8-bits, ±∞ will result.
761    /// - NaN values are preserved.
762    /// - 32-bit subnormal values are too tiny to be represented in 8-bits and result in ±0.
763    /// - Exponents that underflow the minimum 8-bit exponent will result in 8-bit subnormals or ±0.
764    /// - All other values are truncated and rounded to the nearest representable  8-bit value.
765    pub const fn from_f32(x: f32) -> Self {
766        Self::from_f64(x as f64)
767    }
768
769    /// Convert this [`F8E5M2`] type into a [`struct@f16`] type.
770    ///
771    /// This operation may be lossy.
772    ///
773    /// - NaN and zero values are preserved.
774    /// - Subnormal values are normalized.
775    /// - Otherwise, the values are mapped to the appropriate 16-bit value.
776    pub const fn to_f16(&self) -> f16 {
777        f16::from_bits(convert_fp8_to_fp16(self.0, Self::INTERPRETATION))
778    }
779
780    /// Convert this [`F8E5M2`] type into a [`prim@f32`] type.
781    ///
782    /// This operation may be lossy.
783    ///
784    /// - NaN and zero values are preserved.
785    /// - Subnormal values are normalized.
786    /// - Otherwise, the values are mapped to the appropriate 16-bit value.
787    pub const fn to_f32(&self) -> f32 {
788        self.to_f16().to_f32_const()
789    }
790
791    /// Convert this [`F8E5M2`] type into a [`prim@f64`] type.
792    ///
793    /// This operation may be lossy.
794    ///
795    /// - NaN and zero values are preserved.
796    /// - Subnormal values are normalized.
797    /// - Otherwise, the values are mapped to the appropriate 16-bit value.
798    pub const fn to_f64(&self) -> f64 {
799        self.to_f16().to_f64_const()
800    }
801
802    /// Returns the ordering between `self` and `other`.
803    ///
804    /// - negative quiet NaN
805    /// - negative signaling NaN
806    /// - negative infinity
807    /// - negative numbers
808    /// - negative subnormal numbers
809    /// - negative zero
810    /// - positive zero
811    /// - positive subnormal numbers
812    /// - positive numbers
813    /// - positive infinity
814    /// - positive signaling NaN
815    /// - positive quiet NaN.
816    ///
817    /// The ordering established by this function does not always agree with the
818    /// [`PartialOrd`] and [`PartialEq`] implementations. For example,
819    /// they consider negative and positive zero equal, while `total_cmp`
820    /// doesn't.
821    ///
822    /// # Example
823    /// ```
824    /// # use float8::F8E5M2;
825    ///
826    /// let mut v: Vec<F8E5M2> = vec![];
827    /// v.push(F8E5M2::ONE);
828    /// v.push(F8E5M2::INFINITY);
829    /// v.push(F8E5M2::NEG_INFINITY);
830    /// v.push(F8E5M2::NAN);
831    /// v.push(F8E5M2::MAX_SUBNORMAL);
832    /// v.push(-F8E5M2::MAX_SUBNORMAL);
833    /// v.push(F8E5M2::ZERO);
834    /// v.push(F8E5M2::NEG_ZERO);
835    /// v.push(F8E5M2::NEG_ONE);
836    /// v.push(F8E5M2::MIN_POSITIVE);
837    ///
838    /// v.sort_by(|a, b| a.total_cmp(&b));
839    ///
840    /// assert!(v
841    ///     .into_iter()
842    ///     .zip(
843    ///         [
844    ///             F8E5M2::NEG_INFINITY,
845    ///             F8E5M2::NEG_ONE,
846    ///             -F8E5M2::MAX_SUBNORMAL,
847    ///             F8E5M2::NEG_ZERO,
848    ///             F8E5M2::ZERO,
849    ///             F8E5M2::MAX_SUBNORMAL,
850    ///             F8E5M2::MIN_POSITIVE,
851    ///             F8E5M2::ONE,
852    ///             F8E5M2::INFINITY,
853    ///             F8E5M2::NAN
854    ///         ]
855    ///         .iter()
856    ///     )
857    ///     .all(|(a, b)| a.to_bits() == b.to_bits()));
858    /// ```
859    pub fn total_cmp(&self, other: &Self) -> Ordering {
860        let mut left = self.to_bits() as i8;
861        let mut right = other.to_bits() as i8;
862        left ^= (((left >> 7) as u8) >> 1) as i8;
863        right ^= (((right >> 7) as u8) >> 1) as i8;
864        left.cmp(&right)
865    }
866
867    /// Returns `true` if and only if `self` has a positive sign, including +0.0, NaNs with a
868    /// positive sign bit and +∞.
869    pub const fn is_sign_positive(&self) -> bool {
870        self.0 & 0x80u8 == 0
871    }
872
873    /// Returns `true` if and only if `self` has a negative sign, including −0.0, NaNs with a
874    /// negative sign bit and −∞.
875    pub const fn is_sign_negative(&self) -> bool {
876        self.0 & 0x80u8 != 0
877    }
878
879    /// Returns `true` if this value is NaN and `false` otherwise.
880    ///
881    /// # Examples
882    ///
883    /// ```rust
884    /// # use float8::*;
885    ///
886    /// let nan = F8E5M2::NAN;
887    /// let f = F8E5M2::from_f32(7.0_f32);
888    ///
889    /// assert!(nan.is_nan());
890    /// assert!(!f.is_nan());
891    /// ```
892    pub const fn is_nan(&self) -> bool {
893        self.0 == 0x7Eu8 || self.0 == 0xFEu8
894    }
895
896    /// Returns `true` if this value is ±∞ and `false` otherwise.
897    ///
898    /// # Examples
899    ///
900    /// ```rust
901    /// # use float8::*;
902    ///
903    /// let f = F8E5M2::from_f32(7.0f32);
904    /// let inf = F8E5M2::INFINITY;
905    /// let neg_inf = F8E5M2::NEG_INFINITY;
906    /// let nan = F8E5M2::NAN;
907    ///
908    /// assert!(!f.is_infinite());
909    /// assert!(!nan.is_infinite());
910    ///
911    /// assert!(inf.is_infinite());
912    /// assert!(neg_inf.is_infinite());
913    /// ```
914    pub const fn is_infinite(&self) -> bool {
915        self.0 & 0x7Fu8 == 0x7Bu8
916    }
917
918    /// Returns true if this number is neither infinite nor NaN.
919    ///
920    /// # Examples
921    ///
922    /// ```rust
923    /// # use float8::*;
924    ///
925    /// let f = F8E5M2::from_f32(7.0f32);
926    /// let inf = F8E5M2::INFINITY;
927    /// let neg_inf = F8E5M2::NEG_INFINITY;
928    /// let nan = F8E5M2::NAN;
929    ///
930    /// assert!(f.is_finite());
931    ///
932    /// assert!(!nan.is_finite());
933    /// assert!(!inf.is_finite());
934    /// assert!(!neg_inf.is_finite());
935    /// ```
936    pub const fn is_finite(&self) -> bool {
937        !(self.is_infinite() || self.is_nan())
938    }
939
940    /// Returns `true` if the number is neither zero, infinite, subnormal, or `NaN` and `false` otherwise.
941    ///
942    /// # Examples
943    ///
944    /// ```rust
945    /// # use float8::*;
946    ///
947    /// let min = F8E5M2::MIN_POSITIVE;
948    /// let max = F8E5M2::MAX;
949    /// let lower_than_min = F8E5M2::from_f32(1.0e-10_f32);
950    /// let zero = F8E5M2::from_f32(0.0_f32);
951    ///
952    /// assert!(min.is_normal());
953    /// assert!(max.is_normal());
954    ///
955    /// assert!(!zero.is_normal());
956    /// assert!(!F8E5M2::NAN.is_normal());
957    /// assert!(!F8E5M2::INFINITY.is_normal());
958    /// // Values between `0` and `min` are Subnormal.
959    /// assert!(!lower_than_min.is_normal());
960    /// ```
961    pub const fn is_normal(&self) -> bool {
962        #[allow(clippy::unusual_byte_groupings)]
963        let exp = self.0 & 0b0_11111_00;
964        exp != 0 && self.is_finite()
965    }
966
967    /// Returns the minimum of the two numbers.
968    ///
969    /// If one of the arguments is NaN, then the other argument is returned.
970    ///
971    /// # Examples
972    ///
973    /// ```
974    /// # use float8::*;
975    /// let x = F8E5M2::from_f32(1.0);
976    /// let y = F8E5M2::from_f32(2.0);
977    ///
978    /// assert_eq!(x.min(y), x);
979    /// ```
980    pub fn min(self, other: Self) -> Self {
981        if other < self && !other.is_nan() {
982            other
983        } else {
984            self
985        }
986    }
987
988    /// Returns the minimum of the two numbers.
989    ///
990    /// If one of the arguments is NaN, then the other argument is returned.
991    ///
992    /// # Examples
993    ///
994    /// ```
995    /// # use float8::*;
996    /// let x = F8E5M2::from_f32(1.0);
997    /// let y = F8E5M2::from_f32(2.0);
998    ///
999    /// assert_eq!(x.min(y), x);
1000    /// ```
1001    pub fn max(self, other: Self) -> Self {
1002        if other > self && !other.is_nan() {
1003            other
1004        } else {
1005            self
1006        }
1007    }
1008
1009    /// Restrict a value to a certain interval unless it is NaN.
1010    ///
1011    /// Returns `max` if `self` is greater than `max`, and `min` if `self` is less than `min`.
1012    /// Otherwise this returns `self`.
1013    ///
1014    /// Note that this function returns NaN if the initial value was NaN as well.
1015    ///
1016    /// # Panics
1017    /// Panics if `min > max`, `min` is NaN, or `max` is NaN.
1018    ///
1019    /// # Examples
1020    ///
1021    /// ```
1022    /// # use float8::*;
1023    /// assert!(F8E5M2::from_f32(-3.0).clamp(F8E5M2::from_f32(-2.0), F8E5M2::from_f32(1.0)) == F8E5M2::from_f32(-2.0));
1024    /// assert!(F8E5M2::from_f32(0.0).clamp(F8E5M2::from_f32(-2.0), F8E5M2::from_f32(1.0)) == F8E5M2::from_f32(0.0));
1025    /// assert!(F8E5M2::from_f32(2.0).clamp(F8E5M2::from_f32(-2.0), F8E5M2::from_f32(1.0)) == F8E5M2::from_f32(1.0));
1026    /// assert!(F8E5M2::NAN.clamp(F8E5M2::from_f32(-2.0), F8E5M2::from_f32(1.0)).is_nan());
1027    /// ```
1028    pub fn clamp(self, min: Self, max: Self) -> Self {
1029        assert!(min <= max);
1030        let mut x = self;
1031        if x < min {
1032            x = min;
1033        }
1034        if x > max {
1035            x = max;
1036        }
1037        x
1038    }
1039
1040    /// Returns a number composed of the magnitude of `self` and the sign of `sign`.
1041    ///
1042    /// Equal to `self` if the sign of `self` and `sign` are the same, otherwise equal to `-self`.
1043    /// If `self` is NaN, then NaN with the sign of `sign` is returned.
1044    ///
1045    /// # Examples
1046    ///
1047    /// ```
1048    /// # use float8::*;
1049    /// let f = F8E5M2::from_f32(3.5);
1050    ///
1051    /// assert_eq!(f.copysign(F8E5M2::from_f32(0.42)), F8E5M2::from_f32(3.5));
1052    /// assert_eq!(f.copysign(F8E5M2::from_f32(-0.42)), F8E5M2::from_f32(-3.5));
1053    /// assert_eq!((-f).copysign(F8E5M2::from_f32(0.42)), F8E5M2::from_f32(3.5));
1054    /// assert_eq!((-f).copysign(F8E5M2::from_f32(-0.42)), F8E5M2::from_f32(-3.5));
1055    ///
1056    /// assert!(F8E5M2::NAN.copysign(F8E5M2::from_f32(1.0)).is_nan());
1057    /// ```
1058    pub const fn copysign(self, sign: Self) -> Self {
1059        Self((sign.0 & 0x80u8) | (self.0 & 0x7Fu8))
1060    }
1061
1062    /// Returns a number that represents the sign of `self`.
1063    ///
1064    /// * `1.0` if the number is positive, `+0.0` or [`INFINITY`][Self::INFINITY]
1065    /// * `-1.0` if the number is negative, `-0.0` or [`NEG_INFINITY`][Self::NEG_INFINITY]
1066    /// * [`NAN`][Self::NAN] if the number is `NaN`
1067    ///
1068    /// # Examples
1069    ///
1070    /// ```rust
1071    /// # use float8::*;
1072    ///
1073    /// let f = F8E5M2::from_f32(3.5_f32);
1074    ///
1075    /// assert_eq!(f.signum(), F8E5M2::from_f32(1.0));
1076    /// assert_eq!(F8E5M2::NEG_INFINITY.signum(), F8E5M2::from_f32(-1.0));
1077    ///
1078    /// assert!(F8E5M2::NAN.signum().is_nan());
1079    /// ```
1080    pub const fn signum(self) -> Self {
1081        if self.is_nan() {
1082            self
1083        } else if self.0 & 0x80u8 != 0 {
1084            Self::NEG_ONE
1085        } else {
1086            Self::ONE
1087        }
1088    }
1089
1090    /// Returns the floating point category of the number.
1091    ///
1092    /// If only one property is going to be tested, it is generally faster to use the specific
1093    /// predicate instead.
1094    ///
1095    /// # Examples
1096    ///
1097    /// ```rust
1098    /// use std::num::FpCategory;
1099    /// # use float8::*;
1100    ///
1101    /// let num = F8E5M2::from_f32(12.4_f32);
1102    /// let inf = F8E5M2::INFINITY;
1103    ///
1104    /// assert_eq!(num.classify(), FpCategory::Normal);
1105    /// assert_eq!(inf.classify(), FpCategory::Infinite);
1106    /// ```
1107    pub const fn classify(&self) -> FpCategory {
1108        if self.is_infinite() {
1109            FpCategory::Infinite
1110        } else if !self.is_normal() {
1111            FpCategory::Subnormal
1112        } else if self.is_nan() {
1113            FpCategory::Nan
1114        } else if self.0 & 0x7Fu8 == 0 {
1115            FpCategory::Zero
1116        } else {
1117            FpCategory::Normal
1118        }
1119    }
1120}
1121
1122#[cfg(feature = "serde")]
1123struct VisitorF8E5M2;
1124
1125#[cfg(feature = "serde")]
1126impl<'de> Deserialize<'de> for F8E5M2 {
1127    fn deserialize<D>(deserializer: D) -> Result<F8E5M2, D::Error>
1128    where
1129        D: serde::de::Deserializer<'de>,
1130    {
1131        deserializer.deserialize_newtype_struct("f8e5m2", VisitorF8E5M2)
1132    }
1133}
1134
1135#[cfg(feature = "serde")]
1136impl<'de> serde::de::Visitor<'de> for VisitorF8E5M2 {
1137    type Value = F8E5M2;
1138
1139    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
1140        write!(formatter, "tuple struct f8e5m2")
1141    }
1142
1143    fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
1144    where
1145        D: serde::Deserializer<'de>,
1146    {
1147        Ok(F8E5M2(<u8 as Deserialize>::deserialize(deserializer)?))
1148    }
1149
1150    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
1151    where
1152        E: serde::de::Error,
1153    {
1154        v.parse().map_err(|_| {
1155            serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"a float string")
1156        })
1157    }
1158
1159    fn visit_f32<E>(self, v: f32) -> Result<Self::Value, E>
1160    where
1161        E: serde::de::Error,
1162    {
1163        Ok(F8E5M2::from_f32(v))
1164    }
1165
1166    fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
1167    where
1168        E: serde::de::Error,
1169    {
1170        Ok(F8E5M2::from_f64(v))
1171    }
1172}
1173
1174macro_rules! comparison {
1175    ($t:ident) => {
1176        impl PartialEq for $t {
1177            fn eq(&self, other: &Self) -> bool {
1178                if self.is_nan() || other.is_nan() {
1179                    false
1180                } else {
1181                    (self.0 == other.0) || ((self.0 | other.0) & 0x7Fu8 == 0)
1182                }
1183            }
1184        }
1185
1186        impl PartialOrd for $t {
1187            fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1188                if self.is_nan() || other.is_nan() {
1189                    None
1190                } else {
1191                    let neg = self.0 & 0x80u8 != 0;
1192                    let other_neg = other.0 & 0x80u8 != 0;
1193                    match (neg, other_neg) {
1194                        (false, false) => Some(self.0.cmp(&other.0)),
1195                        (false, true) => {
1196                            if (self.0 | other.0) & 0x7Fu8 == 0 {
1197                                Some(Ordering::Equal)
1198                            } else {
1199                                Some(Ordering::Greater)
1200                            }
1201                        }
1202                        (true, false) => {
1203                            if (self.0 | other.0) & 0x7Fu8 == 0 {
1204                                Some(Ordering::Equal)
1205                            } else {
1206                                Some(Ordering::Less)
1207                            }
1208                        }
1209                        (true, true) => Some(other.0.cmp(&self.0)),
1210                    }
1211                }
1212            }
1213
1214            fn lt(&self, other: &Self) -> bool {
1215                if self.is_nan() || other.is_nan() {
1216                    false
1217                } else {
1218                    let neg = self.0 & 0x80u8 != 0;
1219                    let other_neg = other.0 & 0x80u8 != 0;
1220                    match (neg, other_neg) {
1221                        (false, false) => self.0 < other.0,
1222                        (false, true) => false,
1223                        (true, false) => (self.0 | other.0) & 0x7Fu8 != 0,
1224                        (true, true) => self.0 > other.0,
1225                    }
1226                }
1227            }
1228
1229            fn le(&self, other: &Self) -> bool {
1230                if self.is_nan() || other.is_nan() {
1231                    false
1232                } else {
1233                    let neg = self.0 & 0x80u8 != 0;
1234                    let other_neg = other.0 & 0x80u8 != 0;
1235                    match (neg, other_neg) {
1236                        (false, false) => self.0 <= other.0,
1237                        (false, true) => (self.0 | other.0) & 0x7Fu8 == 0,
1238                        (true, false) => true,
1239                        (true, true) => self.0 >= other.0,
1240                    }
1241                }
1242            }
1243
1244            fn gt(&self, other: &Self) -> bool {
1245                if self.is_nan() || other.is_nan() {
1246                    false
1247                } else {
1248                    let neg = self.0 & 0x80u8 != 0;
1249                    let other_neg = other.0 & 0x80u8 != 0;
1250                    match (neg, other_neg) {
1251                        (false, false) => self.0 > other.0,
1252                        (false, true) => (self.0 | other.0) & 0x7Fu8 != 0,
1253                        (true, false) => false,
1254                        (true, true) => self.0 < other.0,
1255                    }
1256                }
1257            }
1258
1259            fn ge(&self, other: &Self) -> bool {
1260                if self.is_nan() || other.is_nan() {
1261                    false
1262                } else {
1263                    let neg = self.0 & 0x80u8 != 0;
1264                    let other_neg = other.0 & 0x80u8 != 0;
1265                    match (neg, other_neg) {
1266                        (false, false) => self.0 >= other.0,
1267                        (false, true) => true,
1268                        (true, false) => (self.0 | other.0) & 0x7Fu8 == 0,
1269                        (true, true) => self.0 <= other.0,
1270                    }
1271                }
1272            }
1273        }
1274    };
1275}
1276
1277comparison!(F8E4M3);
1278comparison!(F8E5M2);
1279
1280macro_rules! constants {
1281    ($t:ident) => {
1282        impl $t {
1283            /// π
1284            pub const PI: Self = Self::from_f64(f64::consts::PI);
1285
1286            /// The full circle constant (τ)
1287            ///
1288            /// Equal to 2π.
1289            pub const TAU: Self = Self::from_f64(f64::consts::TAU);
1290
1291            /// π/2
1292            pub const FRAC_PI_2: Self = Self::from_f64(f64::consts::FRAC_PI_2);
1293
1294            /// π/3
1295            pub const FRAC_PI_3: Self = Self::from_f64(f64::consts::FRAC_PI_3);
1296
1297            /// π/4
1298            pub const FRAC_PI_4: Self = Self::from_f64(f64::consts::FRAC_PI_4);
1299
1300            /// π/6
1301            pub const FRAC_PI_6: Self = Self::from_f64(f64::consts::FRAC_PI_6);
1302
1303            /// π/8
1304            pub const FRAC_PI_8: Self = Self::from_f64(f64::consts::FRAC_PI_8);
1305
1306            /// 1/π
1307            pub const FRAC_1_PI: Self = Self::from_f64(f64::consts::FRAC_1_PI);
1308
1309            /// 2/π
1310            pub const FRAC_2_PI: Self = Self::from_f64(f64::consts::FRAC_2_PI);
1311
1312            /// 2/sqrt(π)
1313            pub const FRAC_2_SQRT_PI: Self = Self::from_f64(f64::consts::FRAC_2_SQRT_PI);
1314
1315            /// sqrt(2)
1316            pub const SQRT_2: Self = Self::from_f64(f64::consts::SQRT_2);
1317
1318            /// 1/sqrt(2)
1319            pub const FRAC_1_SQRT_2: Self = Self::from_f64(f64::consts::FRAC_1_SQRT_2);
1320
1321            /// Euler's number (e)
1322            pub const E: Self = Self::from_f64(f64::consts::E);
1323
1324            /// log<sub>2</sub>(10)
1325            pub const LOG2_10: Self = Self::from_f64(f64::consts::LOG2_10);
1326
1327            /// log<sub>2</sub>(e)
1328            pub const LOG2_E: Self = Self::from_f64(f64::consts::LOG2_E);
1329
1330            /// log<sub>10</sub>(2)
1331            pub const LOG10_2: Self = Self::from_f64(f64::consts::LOG10_2);
1332
1333            /// log<sub>10</sub>(e)
1334            pub const LOG10_E: Self = Self::from_f64(f64::consts::LOG10_E);
1335
1336            /// ln(2)
1337            pub const LN_2: Self = Self::from_f64(f64::consts::LN_2);
1338
1339            /// ln(10)
1340            pub const LN_10: Self = Self::from_f64(f64::consts::LN_10);
1341        }
1342    };
1343}
1344
1345constants!(F8E4M3);
1346constants!(F8E5M2);
1347
1348#[allow(clippy::unusual_byte_groupings)]
1349impl F8E4M3 {
1350    /// Number of mantissa digits
1351    pub const MANTISSA_DIGITS: u32 = 3;
1352    /// Maximum possible value
1353    pub const MAX: Self = Self::from_bits(0x7E - 1);
1354    /// Minimum possible value
1355    pub const MIN: Self = Self::from_bits(0xFE - 1);
1356    /// Positive infinity ∞
1357    pub const INFINITY: Self = Self::from_bits(0x7E);
1358    /// Negative infinity -∞
1359    pub const NEG_INFINITY: Self = Self::from_bits(0xFE);
1360    /// Smallest possible normal value
1361    pub const MIN_POSITIVE: Self = Self::from_bits(0b0_0001_000);
1362    /// Smallest possible subnormal value
1363    pub const MIN_POSITIVE_SUBNORMAL: Self = Self::from_bits(0b0_0000_001);
1364    /// Smallest possible subnormal value
1365    pub const MAX_SUBNORMAL: Self = Self::from_bits(0b0_0000_111);
1366    /// This is the difference between 1.0 and the next largest representable number.
1367    pub const EPSILON: Self = Self::from_bits(0b0_0100_000);
1368    /// NaN value
1369    pub const NAN: Self = Self::from_bits(0x7F);
1370    /// 1
1371    pub const ONE: Self = Self::from_bits(0b0_0111_000);
1372    /// 0
1373    pub const ZERO: Self = Self::from_bits(0b0_0000_000);
1374    /// -1
1375    pub const NEG_ONE: Self = Self::from_bits(0b1_0111_000);
1376    /// -0
1377    pub const NEG_ZERO: Self = Self::from_bits(0b1_0000_000);
1378    /// One greater than the minimum possible normal power of 2 exponent
1379    pub const MIN_EXP: i32 = -5;
1380    /// Minimum possible normal power of 10 exponent
1381    pub const MIN_10_EXP: i32 = -1;
1382    /// Maximum possible normal power of 2 exponent
1383    pub const MAX_EXP: i32 = 7;
1384    /// Maximum possible normal power of 10 exponent
1385    pub const MAX_10_EXP: i32 = 2;
1386    /// Approximate number of significant digits in base 10
1387    pub const DIGITS: u32 = 0;
1388}
1389
1390#[allow(clippy::unusual_byte_groupings)]
1391impl F8E5M2 {
1392    /// Number of mantissa digits
1393    pub const MANTISSA_DIGITS: u32 = 2;
1394    /// Maximum possible value
1395    pub const MAX: Self = Self::from_bits(0x7B - 1);
1396    /// Minimum possible value
1397    pub const MIN: Self = Self::from_bits(0xFB - 1);
1398    /// Positive infinity ∞
1399    pub const INFINITY: Self = Self::from_bits(0x7B);
1400    /// Negative infinity -∞
1401    pub const NEG_INFINITY: Self = Self::from_bits(0xFB);
1402    /// Smallest possible normal value
1403    pub const MIN_POSITIVE: Self = Self::from_bits(0b0_00001_00);
1404    /// Smallest possible subnormal value
1405    pub const MIN_POSITIVE_SUBNORMAL: Self = Self::from_bits(0b0_00000_01);
1406    /// Smallest possible subnormal value
1407    pub const MAX_SUBNORMAL: Self = Self::from_bits(0b0_00000_11);
1408    /// This is the difference between 1.0 and the next largest representable number.
1409    pub const EPSILON: Self = Self::from_bits(0b0_01101_00);
1410    /// NaN value
1411    pub const NAN: Self = Self::from_bits(0x7E);
1412    /// 1
1413    pub const ONE: Self = Self::from_bits(0b0_01111_00);
1414    /// 0
1415    pub const ZERO: Self = Self::from_bits(0b0_00000_00);
1416    /// -1
1417    pub const NEG_ONE: Self = Self::from_bits(0b1_01111_00);
1418    /// -0
1419    pub const NEG_ZERO: Self = Self::from_bits(0b1_00000_00);
1420    /// One greater than the minimum possible normal power of 2 exponent
1421    pub const MIN_EXP: i32 = -13;
1422    /// Minimum possible normal power of 10 exponent
1423    pub const MIN_10_EXP: i32 = -4;
1424    /// Maximum possible normal power of 2 exponent
1425    pub const MAX_EXP: i32 = 15;
1426    /// Maximum possible normal power of 10 exponent
1427    pub const MAX_10_EXP: i32 = 4;
1428    /// Approximate number of significant digits in base 10
1429    pub const DIGITS: u32 = 0;
1430}
1431
1432macro_rules! io {
1433    ($t:ident) => {
1434        impl Display for $t {
1435            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1436                Display::fmt(&self.to_f32(), f)
1437            }
1438        }
1439        impl Debug for $t {
1440            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1441                Debug::fmt(&self.to_f32(), f)
1442            }
1443        }
1444        impl FromStr for $t {
1445            type Err = ParseFloatError;
1446            fn from_str(src: &str) -> Result<$t, ParseFloatError> {
1447                f32::from_str(src).map($t::from_f32)
1448            }
1449        }
1450        impl From<f16> for $t {
1451            fn from(x: f16) -> $t {
1452                Self::from_f32(x.to_f32())
1453            }
1454        }
1455        impl From<f32> for $t {
1456            fn from(x: f32) -> $t {
1457                Self::from_f32(x)
1458            }
1459        }
1460        impl From<f64> for $t {
1461            fn from(x: f64) -> $t {
1462                Self::from_f64(x)
1463            }
1464        }
1465        impl LowerExp for $t {
1466            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1467                write!(f, "{:e}", self.to_f32())
1468            }
1469        }
1470        impl LowerHex for $t {
1471            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1472                write!(f, "{:x}", self.0)
1473            }
1474        }
1475        impl UpperExp for $t {
1476            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1477                write!(f, "{:E}", self.to_f32())
1478            }
1479        }
1480        impl UpperHex for $t {
1481            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1482                write!(f, "{:X}", self.0)
1483            }
1484        }
1485    };
1486}
1487
1488io!(F8E4M3);
1489io!(F8E5M2);
1490
1491macro_rules! binary {
1492    ($trait:ident, $fn_name:ident, $t:ident, $op:tt) => {
1493        impl $trait for $t {
1494            type Output = Self;
1495
1496            fn $fn_name(self, rhs: Self) -> Self::Output {
1497                Self::from_f32(self.to_f32() $op rhs.to_f32())
1498            }
1499        }
1500    };
1501}
1502
1503macro_rules! assign_binary {
1504    ($trait:ident, $fn_name:ident, $t:ident, $op:tt) => {
1505        impl $trait for $t {
1506            fn $fn_name(&mut self, rhs: Self) {
1507                *self = Self::from_f32(self.to_f32() $op rhs.to_f32())
1508            }
1509        }
1510    };
1511}
1512
1513macro_rules! unary {
1514    ($trait:ident, $fn_name:ident, $t:ident, $op:tt) => {
1515        impl $trait for $t {
1516            type Output = Self;
1517
1518            fn $fn_name(self) -> Self::Output {
1519                Self::from_f32($op self.to_f32())
1520            }
1521        }
1522    };
1523}
1524
1525binary!(Add, add, F8E4M3, +);
1526binary!(Sub, sub, F8E4M3, -);
1527binary!(Mul, mul, F8E4M3, *);
1528binary!(Div, div, F8E4M3, /);
1529binary!(Rem, rem, F8E4M3, %);
1530assign_binary!(AddAssign, add_assign, F8E4M3, +);
1531assign_binary!(SubAssign, sub_assign, F8E4M3, -);
1532assign_binary!(MulAssign, mul_assign, F8E4M3, *);
1533assign_binary!(DivAssign, div_assign, F8E4M3, /);
1534assign_binary!(RemAssign, rem_assign, F8E4M3, %);
1535unary!(Neg, neg, F8E4M3, -);
1536
1537binary!(Add, add, F8E5M2, +);
1538binary!(Sub, sub, F8E5M2, -);
1539binary!(Mul, mul, F8E5M2, *);
1540binary!(Div, div, F8E5M2, /);
1541binary!(Rem, rem, F8E5M2, %);
1542assign_binary!(AddAssign, add_assign, F8E5M2, +);
1543assign_binary!(SubAssign, sub_assign, F8E5M2, -);
1544assign_binary!(MulAssign, mul_assign, F8E5M2, *);
1545assign_binary!(DivAssign, div_assign, F8E5M2, /);
1546assign_binary!(RemAssign, rem_assign, F8E5M2, %);
1547unary!(Neg, neg, F8E5M2, -);
1548
1549macro_rules! from_t {
1550    ($t:ident) => {
1551        impl From<$t> for f64 {
1552            fn from(value: $t) -> Self {
1553                value.to_f64()
1554            }
1555        }
1556
1557        impl From<$t> for f32 {
1558            fn from(value: $t) -> Self {
1559                value.to_f32()
1560            }
1561        }
1562
1563        impl From<$t> for f16 {
1564            fn from(value: $t) -> Self {
1565                value.to_f16()
1566            }
1567        }
1568    };
1569}
1570
1571from_t!(F8E4M3);
1572from_t!(F8E5M2);
1573
1574#[cfg(feature = "cuda")]
1575unsafe impl cudarc::driver::DeviceRepr for F8E4M3 {}
1576#[cfg(feature = "cuda")]
1577unsafe impl cudarc::driver::ValidAsZeroBits for F8E4M3 {}
1578
1579#[cfg(feature = "cuda")]
1580unsafe impl cudarc::driver::safe::DeviceRepr for F8E5M2 {}
1581#[cfg(feature = "cuda")]
1582unsafe impl cudarc::driver::ValidAsZeroBits for F8E5M2 {}