cubecl_common/float/
tensor_float.rs

1#![allow(clippy::transmute_int_to_float)] // Not yet stable in previous version. To be removed when
2#![allow(clippy::transmute_float_to_int)] // prev=1.83.
3
4use bytemuck::{Pod, Zeroable};
5use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
6use core::{fmt::Display, mem::transmute};
7use num_traits::{NumCast, ToPrimitive};
8
9/// A 19-bit floating point type implementing the [`tfloat32`] format.
10///
11/// The [`tfloat32`] floating point format is a truncated 19-bit version of the IEEE 754 standard
12/// `binary32`, a.k.a [`f32`]. [`bf16`] has approximately the same dynamic range as [`f32`] but a
13/// a lower precision equal to [`f16`][half::f16].
14///
15/// [`tfloat32`]: https://en.wikipedia.org/wiki/TensorFloat-32
16#[allow(non_camel_case_types)]
17#[repr(transparent)]
18#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
19#[derive(Clone, Copy, Default, Zeroable, Pod, PartialEq, PartialOrd)]
20pub struct tf32(f32);
21
22impl tf32 {
23    /// Constructs a [`tf32`] value from the raw bits.
24    #[inline]
25    #[must_use]
26    pub const fn from_bits(bits: u32) -> tf32 {
27        tf32(unsafe { transmute::<u32, f32>(bits) })
28    }
29
30    /// Constructs a [`tf32`] value from a 32-bit floating point value.
31    ///
32    /// This operation is lossy. If the 32-bit value is too large to fit, ±∞ will result. NaN values
33    /// are preserved. Subnormal values that are too tiny to be represented will result in ±0. All
34    /// other values are truncated and rounded to the nearest representable value.
35    #[inline]
36    #[must_use]
37    pub const fn from_f32(value: f32) -> tf32 {
38        tf32(value)
39    }
40
41    /// Constructs a [`tf32`] value from a 64-bit floating point value.
42    ///
43    /// This operation is lossy. If the 64-bit value is to large to fit, ±∞ will result. NaN values
44    /// are preserved. 64-bit subnormal values are too tiny to be represented and result in ±0.
45    /// Exponents that underflow the minimum exponent will result in subnormals or ±0. All other
46    /// values are truncated and rounded to the nearest representable value.
47    #[inline]
48    #[must_use]
49    pub const fn from_f64(value: f64) -> tf32 {
50        tf32(value as f32)
51    }
52
53    /// Converts a [`tf32`] into the underlying bit representation.
54    #[inline]
55    #[must_use]
56    pub const fn to_bits(self) -> u32 {
57        unsafe { transmute(self.0) }
58    }
59
60    /// Converts a [`tf32`] value into an [`f32`] value.
61    ///
62    /// This conversion is lossless as all values can be represented exactly in [`f32`].
63    #[inline]
64    #[must_use]
65    pub const fn to_f32(self) -> f32 {
66        self.0
67    }
68
69    /// Converts a [`tf32`] value into an [`f64`] value.
70    ///
71    /// This conversion is lossless as all values can be represented exactly in [`f64`].
72    #[inline]
73    #[must_use]
74    pub const fn to_f64(self) -> f64 {
75        self.0 as f64
76    }
77}
78
79impl Neg for tf32 {
80    type Output = Self;
81
82    fn neg(self) -> Self::Output {
83        Self::from_f32(self.to_f32().neg())
84    }
85}
86
87impl Mul for tf32 {
88    type Output = Self;
89
90    fn mul(self, rhs: Self) -> Self::Output {
91        Self::from_f32(self.to_f32() * rhs.to_f32())
92    }
93}
94
95impl MulAssign for tf32 {
96    fn mul_assign(&mut self, rhs: Self) {
97        *self = *self * rhs;
98    }
99}
100
101impl Div for tf32 {
102    type Output = Self;
103
104    fn div(self, rhs: Self) -> Self::Output {
105        Self::from_f32(self.to_f32() / rhs.to_f32())
106    }
107}
108
109impl DivAssign for tf32 {
110    fn div_assign(&mut self, rhs: Self) {
111        *self = *self / rhs;
112    }
113}
114
115impl Add for tf32 {
116    type Output = Self;
117
118    fn add(self, rhs: Self) -> Self::Output {
119        Self::from_f32(self.to_f32() + rhs.to_f32())
120    }
121}
122
123impl AddAssign for tf32 {
124    fn add_assign(&mut self, rhs: Self) {
125        *self = *self + rhs;
126    }
127}
128
129impl Sub for tf32 {
130    type Output = Self;
131
132    fn sub(self, rhs: Self) -> Self::Output {
133        Self::from_f32(self.to_f32() - rhs.to_f32())
134    }
135}
136
137impl SubAssign for tf32 {
138    fn sub_assign(&mut self, rhs: Self) {
139        *self = *self - rhs;
140    }
141}
142
143impl ToPrimitive for tf32 {
144    fn to_i64(&self) -> Option<i64> {
145        Some(tf32::to_f32(*self) as i64)
146    }
147
148    fn to_u64(&self) -> Option<u64> {
149        Some(tf32::to_f64(*self) as u64)
150    }
151
152    fn to_f32(&self) -> Option<f32> {
153        Some(tf32::to_f32(*self))
154    }
155
156    fn to_f64(&self) -> Option<f64> {
157        Some(tf32::to_f64(*self))
158    }
159}
160
161impl NumCast for tf32 {
162    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
163        Some(Self::from_f32(n.to_f32()?))
164    }
165}
166
167impl Display for tf32 {
168    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
169        write!(f, "{}", self.0)
170    }
171}