cubecl_core/frontend/operation/
unary.rs

1use cubecl_ir::{Bitwise, Operator};
2use half::{bf16, f16};
3
4use crate::{
5    flex32,
6    ir::{Arithmetic, ExpandElement, Scope},
7    prelude::{CubePrimitive, ExpandElementTyped},
8    tf32, unexpanded,
9};
10
11use super::base::{unary_expand, unary_expand_fixed_output};
12
13pub mod not {
14    use super::*;
15
16    pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
17        unary_expand(scope, x.into(), Operator::Not).into()
18    }
19}
20
21pub mod neg {
22    use super::*;
23
24    pub fn expand<E: CubePrimitive>(
25        scope: &mut Scope,
26        x: ExpandElementTyped<E>,
27    ) -> ExpandElementTyped<E> {
28        unary_expand(scope, x.into(), Arithmetic::Neg).into()
29    }
30}
31
32macro_rules! impl_unary_func {
33    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
34        pub trait $trait_name: CubePrimitive + Sized {
35            #[allow(unused_variables)]
36            fn $method_name(x: Self) -> Self {
37                unexpanded!()
38            }
39
40            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
41                unary_expand(scope, x.into(), $operator).into()
42            }
43        }
44
45        $(impl $trait_name for $type {})*
46    }
47}
48
49macro_rules! impl_unary_func_fixed_out_vectorization {
50    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
51        pub trait $trait_name: CubePrimitive + Sized {
52            #[allow(unused_variables)]
53            fn $method_name(x: Self) -> Self {
54                unexpanded!()
55            }
56
57            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
58                let expand_element: ExpandElement = x.into();
59                let mut item = expand_element.item;
60                item.vectorization = $out_vectorization;
61                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
62            }
63        }
64
65        $(impl $trait_name for $type {})*
66    }
67}
68
69macro_rules! impl_unary_func_fixed_out_ty {
70    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
71        pub trait $trait_name: CubePrimitive + Sized {
72            #[allow(unused_variables)]
73            fn $method_name(x: Self) -> $out_ty {
74                unexpanded!()
75            }
76
77            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
78                let expand_element: ExpandElement = x.into();
79                let mut item = expand_element.item;
80                item.elem = <$out_ty as CubePrimitive>::as_elem(scope);
81                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
82            }
83        }
84
85        $(impl $trait_name for $type {})*
86    }
87}
88
89impl_unary_func!(
90    Abs,
91    abs,
92    __expand_abs,
93    Arithmetic::Abs,
94    f16,
95    bf16,
96    flex32,
97    tf32,
98    f32,
99    f64,
100    i8,
101    i16,
102    i32,
103    i64,
104    u8,
105    u16,
106    u32,
107    u64
108);
109impl_unary_func!(
110    Exp,
111    exp,
112    __expand_exp,
113    Arithmetic::Exp,
114    f16,
115    bf16,
116    flex32,
117    tf32,
118    f32,
119    f64
120);
121impl_unary_func!(
122    Log,
123    log,
124    __expand_log,
125    Arithmetic::Log,
126    f16,
127    bf16,
128    flex32,
129    tf32,
130    f32,
131    f64
132);
133impl_unary_func!(
134    Log1p,
135    log1p,
136    __expand_log1p,
137    Arithmetic::Log1p,
138    f16,
139    bf16,
140    flex32,
141    tf32,
142    f32,
143    f64
144);
145impl_unary_func!(
146    Cos,
147    cos,
148    __expand_cos,
149    Arithmetic::Cos,
150    f16,
151    bf16,
152    flex32,
153    tf32,
154    f32,
155    f64
156);
157impl_unary_func!(
158    Sin,
159    sin,
160    __expand_sin,
161    Arithmetic::Sin,
162    f16,
163    bf16,
164    flex32,
165    tf32,
166    f32,
167    f64
168);
169impl_unary_func!(
170    Tanh,
171    tanh,
172    __expand_tanh,
173    Arithmetic::Tanh,
174    f16,
175    bf16,
176    flex32,
177    tf32,
178    f32,
179    f64
180);
181impl_unary_func!(
182    Sqrt,
183    sqrt,
184    __expand_sqrt,
185    Arithmetic::Sqrt,
186    f16,
187    bf16,
188    flex32,
189    tf32,
190    f32,
191    f64
192);
193impl_unary_func!(
194    Round,
195    round,
196    __expand_round,
197    Arithmetic::Round,
198    f16,
199    bf16,
200    flex32,
201    tf32,
202    f32,
203    f64
204);
205impl_unary_func!(
206    Floor,
207    floor,
208    __expand_floor,
209    Arithmetic::Floor,
210    f16,
211    bf16,
212    flex32,
213    tf32,
214    f32,
215    f64
216);
217impl_unary_func!(
218    Ceil,
219    ceil,
220    __expand_ceil,
221    Arithmetic::Ceil,
222    f16,
223    bf16,
224    flex32,
225    tf32,
226    f32,
227    f64
228);
229impl_unary_func!(
230    Erf,
231    erf,
232    __expand_erf,
233    Arithmetic::Erf,
234    f16,
235    bf16,
236    flex32,
237    tf32,
238    f32,
239    f64
240);
241impl_unary_func!(
242    Recip,
243    recip,
244    __expand_recip,
245    Arithmetic::Recip,
246    f16,
247    bf16,
248    flex32,
249    tf32,
250    f32,
251    f64
252);
253impl_unary_func_fixed_out_vectorization!(
254    Magnitude,
255    magnitude,
256    __expand_magnitude,
257    Arithmetic::Magnitude,
258    None,
259    f16,
260    bf16,
261    flex32,
262    tf32,
263    f32,
264    f64
265);
266impl_unary_func!(
267    Normalize,
268    normalize,
269    __expand_normalize,
270    Arithmetic::Normalize,
271    f16,
272    bf16,
273    flex32,
274    tf32,
275    f32,
276    f64
277);
278impl_unary_func_fixed_out_ty!(
279    CountOnes,
280    count_ones,
281    __expand_count_ones,
282    u32,
283    Bitwise::CountOnes,
284    u8,
285    i8,
286    u16,
287    i16,
288    u32,
289    i32,
290    u64,
291    i64
292);
293impl_unary_func!(
294    ReverseBits,
295    reverse_bits,
296    __expand_reverse_bits,
297    Bitwise::ReverseBits,
298    u8,
299    i8,
300    u16,
301    i16,
302    u32,
303    i32,
304    u64,
305    i64
306);
307
308impl_unary_func!(
309    BitwiseNot,
310    bitwise_not,
311    __expand_bitwise_not,
312    Bitwise::BitwiseNot,
313    u8,
314    i8,
315    u16,
316    i16,
317    u32,
318    i32,
319    u64,
320    i64
321);
322impl_unary_func_fixed_out_ty!(
323    LeadingZeros,
324    leading_zeros,
325    __expand_leading_zeros,
326    u32,
327    Bitwise::LeadingZeros,
328    u8,
329    i8,
330    u16,
331    i16,
332    u32,
333    i32,
334    u64,
335    i64
336);
337impl_unary_func_fixed_out_ty!(
338    FindFirstSet,
339    find_first_set,
340    __expand_find_first_set,
341    u32,
342    Bitwise::FindFirstSet,
343    u8,
344    i8,
345    u16,
346    i16,
347    u32,
348    i32,
349    u64,
350    i64
351);