cubecl_core/frontend/element/
float.rs1use half::{bf16, f16};
2
3use crate::{
4 ir::{Elem, FloatKind},
5 prelude::*,
6};
7
8use super::Numeric;
9
10mod relaxed;
11mod tensor_float;
12mod typemap;
13
14pub use relaxed::*;
15pub use tensor_float::*;
16pub use typemap::*;
17
18pub trait Float:
20 Numeric
21 + Exp
22 + Log
23 + Log1p
24 + Cos
25 + Sin
26 + Tanh
27 + Powf
28 + Sqrt
29 + Round
30 + Floor
31 + Ceil
32 + Erf
33 + Recip
34 + Magnitude
35 + Normalize
36 + Dot
37 + Into<Self::ExpandType>
38 + core::ops::Neg<Output = Self>
39 + core::ops::Add<Output = Self>
40 + core::ops::Sub<Output = Self>
41 + core::ops::Mul<Output = Self>
42 + core::ops::Div<Output = Self>
43 + std::ops::AddAssign
44 + std::ops::SubAssign
45 + std::ops::MulAssign
46 + std::ops::DivAssign
47 + std::cmp::PartialOrd
48 + std::cmp::PartialEq
49{
50 const DIGITS: u32;
51 const EPSILON: Self;
52 const INFINITY: Self;
53 const MANTISSA_DIGITS: u32;
54 const MAX_10_EXP: i32;
55 const MAX_EXP: i32;
56 const MIN_10_EXP: i32;
57 const MIN_EXP: i32;
58 const MIN_POSITIVE: Self;
59 const NAN: Self;
60 const NEG_INFINITY: Self;
61 const RADIX: u32;
62
63 fn new(val: f32) -> Self;
64 fn __expand_new(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType {
65 __expand_new(context, val)
66 }
67}
68
69macro_rules! impl_float {
70 (half $primitive:ident, $kind:ident) => {
71 impl_float!($primitive, $kind, |val| $primitive::from_f32(val));
72 };
73 ($primitive:ident, $kind:ident) => {
74 impl_float!($primitive, $kind, |val| val as $primitive);
75 };
76 ($primitive:ident, $kind:ident, $new:expr) => {
77 impl CubeType for $primitive {
78 type ExpandType = ExpandElementTyped<$primitive>;
79 }
80
81 impl CubePrimitive for $primitive {
82 fn as_elem_native() -> Option<Elem> {
84 Some(Elem::Float(FloatKind::$kind))
85 }
86 }
87
88 impl IntoRuntime for $primitive {
89 fn __expand_runtime_method(
90 self,
91 context: &mut CubeContext,
92 ) -> ExpandElementTyped<Self> {
93 let expand: ExpandElementTyped<Self> = self.into();
94 Init::init(expand, context)
95 }
96 }
97
98 impl Numeric for $primitive {
99 fn min_value() -> Self {
100 <Self as num_traits::Float>::min_value()
101 }
102 fn max_value() -> Self {
103 <Self as num_traits::Float>::max_value()
104 }
105 }
106
107 impl ExpandElementBaseInit for $primitive {
108 fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
109 init_expand_element(context, elem)
110 }
111 }
112
113 impl Float for $primitive {
114 const DIGITS: u32 = $primitive::DIGITS;
115 const EPSILON: Self = $primitive::EPSILON;
116 const INFINITY: Self = $primitive::INFINITY;
117 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
118 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
119 const MAX_EXP: i32 = $primitive::MAX_EXP;
120 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
121 const MIN_EXP: i32 = $primitive::MIN_EXP;
122 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
123 const NAN: Self = $primitive::NAN;
124 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
125 const RADIX: u32 = $primitive::RADIX;
126
127 fn new(val: f32) -> Self {
128 $new(val)
129 }
130 }
131
132 impl LaunchArgExpand for $primitive {
133 type CompilationArg = ();
134
135 fn expand(
136 _: &Self::CompilationArg,
137 builder: &mut KernelBuilder,
138 ) -> ExpandElementTyped<Self> {
139 builder.scalar($primitive::as_elem(&builder.context)).into()
140 }
141 }
142 };
143}
144
145impl_float!(half f16, F16);
146impl_float!(half bf16, BF16);
147impl_float!(f32, F32);
148impl_float!(f64, F64);
149
150impl ScalarArgSettings for f16 {
151 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
152 settings.register_f16(*self);
153 }
154}
155
156impl ScalarArgSettings for bf16 {
157 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
158 settings.register_bf16(*self);
159 }
160}
161
162impl ScalarArgSettings for f32 {
163 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
164 settings.register_f32(*self);
165 }
166}
167
168impl ScalarArgSettings for f64 {
169 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
170 settings.register_f64(*self);
171 }
172}