cubecl_core/frontend/element/
float.rs1use cubecl_ir::{ConstantScalarValue, Scope, StorageType};
2use half::{bf16, f16};
3
4use crate::{
5 ir::{ElemType, ExpandElement, FloatKind},
6 prelude::*,
7};
8
9use super::Numeric;
10
11mod fp4;
12mod fp6;
13mod fp8;
14mod relaxed;
15mod tensor_float;
16mod typemap;
17
18pub use typemap::*;
19
20pub trait Float:
22 Numeric
23 + Exp
24 + Log
25 + Log1p
26 + Cos
27 + Sin
28 + Tan
29 + Tanh
30 + Sinh
31 + Cosh
32 + ArcCos
33 + ArcSin
34 + ArcTan
35 + ArcSinh
36 + ArcCosh
37 + ArcTanh
38 + Degrees
39 + Radians
40 + ArcTan2
41 + Powf
42 + Powi<i32>
43 + Hypot
44 + Rhypot
45 + Sqrt
46 + InverseSqrt
47 + Round
48 + Floor
49 + Ceil
50 + Trunc
51 + Erf
52 + Recip
53 + Magnitude
54 + Normalize
55 + Dot
56 + IsNan
57 + IsInf
58 + Into<Self::ExpandType>
59 + core::ops::Neg<Output = Self>
60 + core::ops::Add<Output = Self>
61 + core::ops::Sub<Output = Self>
62 + core::ops::Mul<Output = Self>
63 + core::ops::Div<Output = Self>
64 + std::ops::AddAssign
65 + std::ops::SubAssign
66 + std::ops::MulAssign
67 + std::ops::DivAssign
68 + std::cmp::PartialOrd
69 + std::cmp::PartialEq
70{
71 const DIGITS: u32;
72 const EPSILON: Self;
73 const INFINITY: Self;
74 const MANTISSA_DIGITS: u32;
75 const MAX_10_EXP: i32;
76 const MAX_EXP: i32;
77 const MIN_10_EXP: i32;
78 const MIN_EXP: i32;
79 const MIN_POSITIVE: Self;
80 const NAN: Self;
81 const NEG_INFINITY: Self;
82 const RADIX: u32;
83
84 fn new(val: f32) -> Self;
85 fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
86 __expand_new(scope, val)
87 }
88}
89
90macro_rules! impl_float {
91 (half $primitive:ident, $kind:ident) => {
92 impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
93 };
94 ($primitive:ident, $kind:ident) => {
95 impl_float!($primitive, $kind, |val| val as $primitive);
96 };
97 ($primitive:ident, $kind:ident, $new:expr) => {
98 impl CubeType for $primitive {
99 type ExpandType = ExpandElementTyped<$primitive>;
100 }
101
102 impl CubePrimitive for $primitive {
103 fn as_type_native() -> Option<StorageType> {
105 Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)))
106 }
107
108 fn from_const_value(value: ConstantScalarValue) -> Self {
109 let ConstantScalarValue::Float(value, _) = value else {
110 unreachable!()
111 };
112 $new(value)
113 }
114 }
115
116 impl IntoRuntime for $primitive {
117 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
118 let elem: ExpandElementTyped<Self> = self.into();
119 into_runtime_expand_element(scope, elem).into()
120 }
121 }
122
123 impl Numeric for $primitive {
124 fn min_value() -> Self {
125 <Self as num_traits::Float>::min_value()
126 }
127 fn max_value() -> Self {
128 <Self as num_traits::Float>::max_value()
129 }
130 }
131
132 impl ExpandElementIntoMut for $primitive {
133 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
134 into_mut_expand_element(scope, elem)
135 }
136 }
137
138 impl IntoMut for $primitive {
139 fn into_mut(self, _scope: &mut Scope) -> Self {
140 self
141 }
142 }
143
144 impl Float for $primitive {
145 const DIGITS: u32 = $primitive::DIGITS;
146 const EPSILON: Self = $primitive::EPSILON;
147 const INFINITY: Self = $primitive::INFINITY;
148 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
149 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
150 const MAX_EXP: i32 = $primitive::MAX_EXP;
151 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
152 const MIN_EXP: i32 = $primitive::MIN_EXP;
153 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
154 const NAN: Self = $primitive::NAN;
155 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
156 const RADIX: u32 = $primitive::RADIX;
157
158 fn new(val: f32) -> Self {
159 $new(val as f64)
160 }
161 }
162 };
163}
164
165impl_float!(half f16, F16);
166impl_float!(half bf16, BF16);
167impl_float!(f32, F32);
168impl_float!(f64, F64);