cubecl_core/frontend/operation/
unary.rs

1use half::{bf16, f16};
2
3use crate::{
4    flex32,
5    frontend::CubeContext,
6    ir::Operator,
7    prelude::{CubePrimitive, ExpandElement, ExpandElementTyped},
8    tf32, unexpanded,
9};
10
11use super::base::{unary_expand, unary_expand_fixed_output};
12
13pub mod not {
14    use super::*;
15
16    pub fn expand(
17        context: &mut CubeContext,
18        x: ExpandElementTyped<bool>,
19    ) -> ExpandElementTyped<bool> {
20        unary_expand(context, x.into(), Operator::Not).into()
21    }
22}
23
24pub mod neg {
25    use crate::prelude::Numeric;
26
27    use super::*;
28
29    pub fn expand<N: Numeric>(
30        context: &mut CubeContext,
31        x: ExpandElementTyped<N>,
32    ) -> ExpandElementTyped<N> {
33        unary_expand(context, x.into(), Operator::Neg).into()
34    }
35}
36
37macro_rules! impl_unary_func {
38    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
39        pub trait $trait_name: CubePrimitive + Sized {
40            #[allow(unused_variables)]
41            fn $method_name(x: Self) -> Self {
42                unexpanded!()
43            }
44
45            fn $method_name_expand(context: &mut CubeContext, x: Self::ExpandType) -> ExpandElementTyped<Self> {
46                unary_expand(context, x.into(), $operator).into()
47            }
48        }
49
50        $(impl $trait_name for $type {})*
51    }
52}
53
54macro_rules! impl_unary_func_fixed_out_vectorization {
55    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
56        pub trait $trait_name: CubePrimitive + Sized {
57            #[allow(unused_variables)]
58            fn $method_name(x: Self) -> Self {
59                unexpanded!()
60            }
61
62            fn $method_name_expand(context: &mut CubeContext, x: Self::ExpandType) -> ExpandElementTyped<Self> {
63                let expand_element: ExpandElement = x.into();
64                let mut item = expand_element.item;
65                item.vectorization = $out_vectorization;
66                unary_expand_fixed_output(context, expand_element, item, $operator).into()
67            }
68        }
69
70        $(impl $trait_name for $type {})*
71    }
72}
73
74macro_rules! impl_unary_func_fixed_out_ty {
75    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
76        pub trait $trait_name: CubePrimitive + Sized {
77            #[allow(unused_variables)]
78            fn $method_name(x: Self) -> $out_ty {
79                unexpanded!()
80            }
81
82            fn $method_name_expand(context: &mut CubeContext, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
83                let expand_element: ExpandElement = x.into();
84                let mut item = expand_element.item;
85                item.elem = <$out_ty as CubePrimitive>::as_elem(context);
86                unary_expand_fixed_output(context, expand_element, item, $operator).into()
87            }
88        }
89
90        $(impl $trait_name for $type {})*
91    }
92}
93
94impl_unary_func!(
95    Abs,
96    abs,
97    __expand_abs,
98    Operator::Abs,
99    f16,
100    bf16,
101    flex32,
102    tf32,
103    f32,
104    f64,
105    i8,
106    i16,
107    i32,
108    i64,
109    u8,
110    u16,
111    u32,
112    u64
113);
114impl_unary_func!(
115    Exp,
116    exp,
117    __expand_exp,
118    Operator::Exp,
119    f16,
120    bf16,
121    flex32,
122    tf32,
123    f32,
124    f64
125);
126impl_unary_func!(
127    Log,
128    log,
129    __expand_log,
130    Operator::Log,
131    f16,
132    bf16,
133    flex32,
134    tf32,
135    f32,
136    f64
137);
138impl_unary_func!(
139    Log1p,
140    log1p,
141    __expand_log1p,
142    Operator::Log1p,
143    f16,
144    bf16,
145    flex32,
146    tf32,
147    f32,
148    f64
149);
150impl_unary_func!(
151    Cos,
152    cos,
153    __expand_cos,
154    Operator::Cos,
155    f16,
156    bf16,
157    flex32,
158    tf32,
159    f32,
160    f64
161);
162impl_unary_func!(
163    Sin,
164    sin,
165    __expand_sin,
166    Operator::Sin,
167    f16,
168    bf16,
169    flex32,
170    tf32,
171    f32,
172    f64
173);
174impl_unary_func!(
175    Tanh,
176    tanh,
177    __expand_tanh,
178    Operator::Tanh,
179    f16,
180    bf16,
181    flex32,
182    tf32,
183    f32,
184    f64
185);
186impl_unary_func!(
187    Sqrt,
188    sqrt,
189    __expand_sqrt,
190    Operator::Sqrt,
191    f16,
192    bf16,
193    flex32,
194    tf32,
195    f32,
196    f64
197);
198impl_unary_func!(
199    Round,
200    round,
201    __expand_round,
202    Operator::Round,
203    f16,
204    bf16,
205    flex32,
206    tf32,
207    f32,
208    f64
209);
210impl_unary_func!(
211    Floor,
212    floor,
213    __expand_floor,
214    Operator::Floor,
215    f16,
216    bf16,
217    flex32,
218    tf32,
219    f32,
220    f64
221);
222impl_unary_func!(
223    Ceil,
224    ceil,
225    __expand_ceil,
226    Operator::Ceil,
227    f16,
228    bf16,
229    flex32,
230    tf32,
231    f32,
232    f64
233);
234impl_unary_func!(
235    Erf,
236    erf,
237    __expand_erf,
238    Operator::Erf,
239    f16,
240    bf16,
241    flex32,
242    tf32,
243    f32,
244    f64
245);
246impl_unary_func!(
247    Recip,
248    recip,
249    __expand_recip,
250    Operator::Recip,
251    f16,
252    bf16,
253    flex32,
254    tf32,
255    f32,
256    f64
257);
258impl_unary_func_fixed_out_vectorization!(
259    Magnitude,
260    magnitude,
261    __expand_magnitude,
262    Operator::Magnitude,
263    None,
264    f16,
265    bf16,
266    flex32,
267    tf32,
268    f32,
269    f64
270);
271impl_unary_func!(
272    Normalize,
273    normalize,
274    __expand_normalize,
275    Operator::Normalize,
276    f16,
277    bf16,
278    flex32,
279    tf32,
280    f32,
281    f64
282);
283impl_unary_func_fixed_out_ty!(
284    CountOnes,
285    count_ones,
286    __expand_count_ones,
287    u32,
288    Operator::CountOnes,
289    u8,
290    i8,
291    u16,
292    i16,
293    u32,
294    i32,
295    u64,
296    i64
297);
298impl_unary_func!(
299    ReverseBits,
300    reverse_bits,
301    __expand_reverse_bits,
302    Operator::ReverseBits,
303    u8,
304    i8,
305    u16,
306    i16,
307    u32,
308    i32,
309    u64,
310    i64
311);
312impl_unary_func!(
313    Not,
314    not,
315    __expand_not,
316    Operator::Not,
317    u8,
318    i8,
319    u16,
320    i16,
321    u32,
322    i32,
323    u64,
324    i64
325);