cubecl_common/float/
tensor_float.rs1use bytemuck::{Pod, Zeroable};
2use core::fmt::Display;
3use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
4use num_traits::{NumCast, ToPrimitive};
5
6#[allow(non_camel_case_types)]
14#[repr(transparent)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16#[derive(Clone, Copy, Default, Zeroable, Pod, PartialEq, PartialOrd)]
17pub struct tf32(f32);
18
19impl tf32 {
20 #[inline]
22 #[must_use]
23 pub const fn from_bits(bits: u32) -> tf32 {
24 tf32(f32::from_bits(bits))
25 }
26
27 #[inline]
33 #[must_use]
34 pub const fn from_f32(value: f32) -> tf32 {
35 tf32(value)
36 }
37
38 #[inline]
45 #[must_use]
46 pub const fn from_f64(value: f64) -> tf32 {
47 tf32(value as f32)
48 }
49
50 #[inline]
52 #[must_use]
53 pub const fn to_bits(self) -> u32 {
54 f32::to_bits(self.0)
55 }
56
57 #[inline]
61 #[must_use]
62 pub const fn to_f32(self) -> f32 {
63 self.0
64 }
65
66 #[inline]
70 #[must_use]
71 pub const fn to_f64(self) -> f64 {
72 self.0 as f64
73 }
74}
75
76impl Neg for tf32 {
77 type Output = Self;
78
79 fn neg(self) -> Self::Output {
80 Self::from_f32(self.to_f32().neg())
81 }
82}
83
84impl Mul for tf32 {
85 type Output = Self;
86
87 fn mul(self, rhs: Self) -> Self::Output {
88 Self::from_f32(self.to_f32() * rhs.to_f32())
89 }
90}
91
92impl MulAssign for tf32 {
93 fn mul_assign(&mut self, rhs: Self) {
94 *self = *self * rhs;
95 }
96}
97
98impl Div for tf32 {
99 type Output = Self;
100
101 fn div(self, rhs: Self) -> Self::Output {
102 Self::from_f32(self.to_f32() / rhs.to_f32())
103 }
104}
105
106impl DivAssign for tf32 {
107 fn div_assign(&mut self, rhs: Self) {
108 *self = *self / rhs;
109 }
110}
111
112impl Add for tf32 {
113 type Output = Self;
114
115 fn add(self, rhs: Self) -> Self::Output {
116 Self::from_f32(self.to_f32() + rhs.to_f32())
117 }
118}
119
120impl AddAssign for tf32 {
121 fn add_assign(&mut self, rhs: Self) {
122 *self = *self + rhs;
123 }
124}
125
126impl Sub for tf32 {
127 type Output = Self;
128
129 fn sub(self, rhs: Self) -> Self::Output {
130 Self::from_f32(self.to_f32() - rhs.to_f32())
131 }
132}
133
134impl SubAssign for tf32 {
135 fn sub_assign(&mut self, rhs: Self) {
136 *self = *self - rhs;
137 }
138}
139
140impl ToPrimitive for tf32 {
141 fn to_i64(&self) -> Option<i64> {
142 Some(tf32::to_f32(*self) as i64)
143 }
144
145 fn to_u64(&self) -> Option<u64> {
146 Some(tf32::to_f64(*self) as u64)
147 }
148
149 fn to_f32(&self) -> Option<f32> {
150 Some(tf32::to_f32(*self))
151 }
152
153 fn to_f64(&self) -> Option<f64> {
154 Some(tf32::to_f64(*self))
155 }
156}
157
158impl NumCast for tf32 {
159 fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
160 Some(Self::from_f32(n.to_f32()?))
161 }
162}
163
164impl Display for tf32 {
165 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
166 write!(f, "{}", self.0)
167 }
168}