cubecl_core/frontend/element/
float.rs1use cubecl_ir::Scope;
2use half::{bf16, f16};
3
4use crate::{
5 ir::{Elem, ExpandElement, FloatKind},
6 prelude::*,
7};
8
9use super::Numeric;
10
11mod relaxed;
12mod tensor_float;
13mod typemap;
14
15pub use typemap::*;
16
17pub trait Float:
19 Numeric
20 + Exp
21 + Log
22 + Log1p
23 + Cos
24 + Sin
25 + Tanh
26 + Powf
27 + Sqrt
28 + Round
29 + Floor
30 + Ceil
31 + Erf
32 + Recip
33 + Magnitude
34 + Normalize
35 + Dot
36 + Into<Self::ExpandType>
37 + core::ops::Neg<Output = Self>
38 + core::ops::Add<Output = Self>
39 + core::ops::Sub<Output = Self>
40 + core::ops::Mul<Output = Self>
41 + core::ops::Div<Output = Self>
42 + std::ops::AddAssign
43 + std::ops::SubAssign
44 + std::ops::MulAssign
45 + std::ops::DivAssign
46 + std::cmp::PartialOrd
47 + std::cmp::PartialEq
48{
49 const DIGITS: u32;
50 const EPSILON: Self;
51 const INFINITY: Self;
52 const MANTISSA_DIGITS: u32;
53 const MAX_10_EXP: i32;
54 const MAX_EXP: i32;
55 const MIN_10_EXP: i32;
56 const MIN_EXP: i32;
57 const MIN_POSITIVE: Self;
58 const NAN: Self;
59 const NEG_INFINITY: Self;
60 const RADIX: u32;
61
62 fn new(val: f32) -> Self;
63 fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
64 __expand_new(scope, val)
65 }
66}
67
68macro_rules! impl_float {
69 (half $primitive:ident, $kind:ident) => {
70 impl_float!($primitive, $kind, |val| $primitive::from_f32(val));
71 };
72 ($primitive:ident, $kind:ident) => {
73 impl_float!($primitive, $kind, |val| val as $primitive);
74 };
75 ($primitive:ident, $kind:ident, $new:expr) => {
76 impl CubeType for $primitive {
77 type ExpandType = ExpandElementTyped<$primitive>;
78 }
79
80 impl CubePrimitive for $primitive {
81 fn as_elem_native() -> Option<Elem> {
83 Some(Elem::Float(FloatKind::$kind))
84 }
85 }
86
87 impl IntoRuntime for $primitive {
88 fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
89 let expand: ExpandElementTyped<Self> = self.into();
90 Init::init(expand, scope)
91 }
92 }
93
94 impl Numeric for $primitive {
95 fn min_value() -> Self {
96 <Self as num_traits::Float>::min_value()
97 }
98 fn max_value() -> Self {
99 <Self as num_traits::Float>::max_value()
100 }
101 }
102
103 impl ExpandElementBaseInit for $primitive {
104 fn init_elem(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
105 init_expand_element(scope, elem)
106 }
107 }
108
109 impl Float for $primitive {
110 const DIGITS: u32 = $primitive::DIGITS;
111 const EPSILON: Self = $primitive::EPSILON;
112 const INFINITY: Self = $primitive::INFINITY;
113 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
114 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
115 const MAX_EXP: i32 = $primitive::MAX_EXP;
116 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
117 const MIN_EXP: i32 = $primitive::MIN_EXP;
118 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
119 const NAN: Self = $primitive::NAN;
120 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
121 const RADIX: u32 = $primitive::RADIX;
122
123 fn new(val: f32) -> Self {
124 $new(val)
125 }
126 }
127
128 impl LaunchArgExpand for $primitive {
129 type CompilationArg = ();
130
131 fn expand(
132 _: &Self::CompilationArg,
133 builder: &mut KernelBuilder,
134 ) -> ExpandElementTyped<Self> {
135 builder.scalar($primitive::as_elem(&builder.context)).into()
136 }
137 }
138 };
139}
140
141impl_float!(half f16, F16);
142impl_float!(half bf16, BF16);
143impl_float!(f32, F32);
144impl_float!(f64, F64);
145
146impl ScalarArgSettings for f16 {
147 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
148 settings.register_f16(*self);
149 }
150}
151
152impl ScalarArgSettings for bf16 {
153 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
154 settings.register_bf16(*self);
155 }
156}
157
158impl ScalarArgSettings for f32 {
159 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
160 settings.register_f32(*self);
161 }
162}
163
164impl ScalarArgSettings for f64 {
165 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
166 settings.register_f64(*self);
167 }
168}