cubecl_common/float/
tensor_float.rs

1use bytemuck::{Pod, Zeroable};
2use core::fmt::Display;
3use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
4use num_traits::{NumCast, ToPrimitive};
5
6/// A 19-bit floating point type implementing the [`tfloat32`] format.
7///
8/// The [`tfloat32`] floating point format is a truncated 19-bit version of the IEEE 754 standard
9/// `binary32`, a.k.a [`f32`]. [`bf16`] has approximately the same dynamic range as [`f32`] but a
10/// a lower precision equal to [`f16`][half::f16].
11///
12/// [`tfloat32`]: https://en.wikipedia.org/wiki/TensorFloat-32
13#[allow(non_camel_case_types)]
14#[repr(transparent)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16#[derive(Clone, Copy, Default, Zeroable, Pod, PartialEq, PartialOrd)]
17pub struct tf32(f32);
18
19impl tf32 {
20    /// Constructs a [`tf32`] value from the raw bits.
21    #[inline]
22    #[must_use]
23    pub const fn from_bits(bits: u32) -> tf32 {
24        tf32(f32::from_bits(bits))
25    }
26
27    /// Constructs a [`tf32`] value from a 32-bit floating point value.
28    ///
29    /// This operation is lossy. If the 32-bit value is too large to fit, ±∞ will result. NaN values
30    /// are preserved. Subnormal values that are too tiny to be represented will result in ±0. All
31    /// other values are truncated and rounded to the nearest representable value.
32    #[inline]
33    #[must_use]
34    pub const fn from_f32(value: f32) -> tf32 {
35        tf32(value)
36    }
37
38    /// Constructs a [`tf32`] value from a 64-bit floating point value.
39    ///
40    /// This operation is lossy. If the 64-bit value is to large to fit, ±∞ will result. NaN values
41    /// are preserved. 64-bit subnormal values are too tiny to be represented and result in ±0.
42    /// Exponents that underflow the minimum exponent will result in subnormals or ±0. All other
43    /// values are truncated and rounded to the nearest representable value.
44    #[inline]
45    #[must_use]
46    pub const fn from_f64(value: f64) -> tf32 {
47        tf32(value as f32)
48    }
49
50    /// Converts a [`tf32`] into the underlying bit representation.
51    #[inline]
52    #[must_use]
53    pub const fn to_bits(self) -> u32 {
54        f32::to_bits(self.0)
55    }
56
57    /// Converts a [`tf32`] value into an [`f32`] value.
58    ///
59    /// This conversion is lossless as all values can be represented exactly in [`f32`].
60    #[inline]
61    #[must_use]
62    pub const fn to_f32(self) -> f32 {
63        self.0
64    }
65
66    /// Converts a [`tf32`] value into an [`f64`] value.
67    ///
68    /// This conversion is lossless as all values can be represented exactly in [`f64`].
69    #[inline]
70    #[must_use]
71    pub const fn to_f64(self) -> f64 {
72        self.0 as f64
73    }
74}
75
76impl Neg for tf32 {
77    type Output = Self;
78
79    fn neg(self) -> Self::Output {
80        Self::from_f32(self.to_f32().neg())
81    }
82}
83
84impl Mul for tf32 {
85    type Output = Self;
86
87    fn mul(self, rhs: Self) -> Self::Output {
88        Self::from_f32(self.to_f32() * rhs.to_f32())
89    }
90}
91
92impl MulAssign for tf32 {
93    fn mul_assign(&mut self, rhs: Self) {
94        *self = *self * rhs;
95    }
96}
97
98impl Div for tf32 {
99    type Output = Self;
100
101    fn div(self, rhs: Self) -> Self::Output {
102        Self::from_f32(self.to_f32() / rhs.to_f32())
103    }
104}
105
106impl DivAssign for tf32 {
107    fn div_assign(&mut self, rhs: Self) {
108        *self = *self / rhs;
109    }
110}
111
112impl Add for tf32 {
113    type Output = Self;
114
115    fn add(self, rhs: Self) -> Self::Output {
116        Self::from_f32(self.to_f32() + rhs.to_f32())
117    }
118}
119
120impl AddAssign for tf32 {
121    fn add_assign(&mut self, rhs: Self) {
122        *self = *self + rhs;
123    }
124}
125
126impl Sub for tf32 {
127    type Output = Self;
128
129    fn sub(self, rhs: Self) -> Self::Output {
130        Self::from_f32(self.to_f32() - rhs.to_f32())
131    }
132}
133
134impl SubAssign for tf32 {
135    fn sub_assign(&mut self, rhs: Self) {
136        *self = *self - rhs;
137    }
138}
139
140impl ToPrimitive for tf32 {
141    fn to_i64(&self) -> Option<i64> {
142        Some(tf32::to_f32(*self) as i64)
143    }
144
145    fn to_u64(&self) -> Option<u64> {
146        Some(tf32::to_f64(*self) as u64)
147    }
148
149    fn to_f32(&self) -> Option<f32> {
150        Some(tf32::to_f32(*self))
151    }
152
153    fn to_f64(&self) -> Option<f64> {
154        Some(tf32::to_f64(*self))
155    }
156}
157
158impl NumCast for tf32 {
159    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
160        Some(Self::from_f32(n.to_f32()?))
161    }
162}
163
164impl Display for tf32 {
165    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
166        write!(f, "{}", self.0)
167    }
168}