cubecl_core/frontend/element/
float.rs1use cubecl_ir::Scope;
2use half::{bf16, f16};
3
4use crate::{
5 ir::{Elem, ExpandElement, FloatKind},
6 prelude::*,
7};
8
9use super::Numeric;
10
11mod fp4;
12mod fp6;
13mod fp8;
14mod relaxed;
15mod tensor_float;
16mod typemap;
17
18pub use typemap::*;
19
20pub trait Float:
22 Numeric
23 + Exp
24 + Log
25 + Log1p
26 + Cos
27 + Sin
28 + Tanh
29 + Powf
30 + Sqrt
31 + Round
32 + Floor
33 + Ceil
34 + Erf
35 + Recip
36 + Magnitude
37 + Normalize
38 + Dot
39 + Into<Self::ExpandType>
40 + core::ops::Neg<Output = Self>
41 + core::ops::Add<Output = Self>
42 + core::ops::Sub<Output = Self>
43 + core::ops::Mul<Output = Self>
44 + core::ops::Div<Output = Self>
45 + std::ops::AddAssign
46 + std::ops::SubAssign
47 + std::ops::MulAssign
48 + std::ops::DivAssign
49 + std::cmp::PartialOrd
50 + std::cmp::PartialEq
51{
52 const DIGITS: u32;
53 const EPSILON: Self;
54 const INFINITY: Self;
55 const MANTISSA_DIGITS: u32;
56 const MAX_10_EXP: i32;
57 const MAX_EXP: i32;
58 const MIN_10_EXP: i32;
59 const MIN_EXP: i32;
60 const MIN_POSITIVE: Self;
61 const NAN: Self;
62 const NEG_INFINITY: Self;
63 const RADIX: u32;
64
65 fn new(val: f32) -> Self;
66 fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
67 __expand_new(scope, val)
68 }
69}
70
71macro_rules! impl_float {
72 (half $primitive:ident, $kind:ident) => {
73 impl_float!($primitive, $kind, |val| $primitive::from_f32(val));
74 };
75 ($primitive:ident, $kind:ident) => {
76 impl_float!($primitive, $kind, |val| val as $primitive);
77 };
78 ($primitive:ident, $kind:ident, $new:expr) => {
79 impl CubeType for $primitive {
80 type ExpandType = ExpandElementTyped<$primitive>;
81 }
82
83 impl CubePrimitive for $primitive {
84 fn as_elem_native() -> Option<Elem> {
86 Some(Elem::Float(FloatKind::$kind))
87 }
88 }
89
90 impl IntoRuntime for $primitive {
91 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
92 let elem: ExpandElementTyped<Self> = self.into();
93 into_runtime_expand_element(scope, elem).into()
94 }
95 }
96
97 impl Numeric for $primitive {
98 fn min_value() -> Self {
99 <Self as num_traits::Float>::min_value()
100 }
101 fn max_value() -> Self {
102 <Self as num_traits::Float>::max_value()
103 }
104 }
105
106 impl ExpandElementIntoMut for $primitive {
107 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
108 into_mut_expand_element(scope, elem)
109 }
110 }
111
112 impl IntoMut for $primitive {
113 fn into_mut(self, _scope: &mut Scope) -> Self {
114 self
115 }
116 }
117
118 impl Float for $primitive {
119 const DIGITS: u32 = $primitive::DIGITS;
120 const EPSILON: Self = $primitive::EPSILON;
121 const INFINITY: Self = $primitive::INFINITY;
122 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
123 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
124 const MAX_EXP: i32 = $primitive::MAX_EXP;
125 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
126 const MIN_EXP: i32 = $primitive::MIN_EXP;
127 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
128 const NAN: Self = $primitive::NAN;
129 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
130 const RADIX: u32 = $primitive::RADIX;
131
132 fn new(val: f32) -> Self {
133 $new(val)
134 }
135 }
136
137 impl LaunchArgExpand for $primitive {
138 type CompilationArg = ();
139
140 fn expand(
141 _: &Self::CompilationArg,
142 builder: &mut KernelBuilder,
143 ) -> ExpandElementTyped<Self> {
144 builder.scalar($primitive::as_elem(&builder.scope)).into()
145 }
146 }
147 };
148}
149
150impl_float!(half f16, F16);
151impl_float!(half bf16, BF16);
152impl_float!(f32, F32);
153impl_float!(f64, F64);
154
155impl ScalarArgSettings for f16 {
156 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
157 settings.register_f16(*self);
158 }
159}
160
161impl ScalarArgSettings for bf16 {
162 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
163 settings.register_bf16(*self);
164 }
165}
166
167impl ScalarArgSettings for f32 {
168 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
169 settings.register_f32(*self);
170 }
171}
172
173impl ScalarArgSettings for f64 {
174 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
175 settings.register_f64(*self);
176 }
177}