cubecl_core/frontend/operation/
unary.rs

1use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
2use cubecl_ir::{Bitwise, Comparison, Operator, Type};
3use half::{bf16, f16};
4
5use crate::{
6    flex32,
7    ir::{Arithmetic, ExpandElement, Scope},
8    prelude::{CubePrimitive, ExpandElementTyped},
9    tf32, unexpanded,
10};
11
12use super::base::{unary_expand, unary_expand_fixed_output};
13
14pub mod not {
15    use super::*;
16
17    pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
18        unary_expand(scope, x.into(), Operator::Not).into()
19    }
20}
21
22pub mod neg {
23    use super::*;
24
25    pub fn expand<E: CubePrimitive>(
26        scope: &mut Scope,
27        x: ExpandElementTyped<E>,
28    ) -> ExpandElementTyped<E> {
29        unary_expand(scope, x.into(), Arithmetic::Neg).into()
30    }
31}
32
33macro_rules! impl_unary_func {
34    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
35        pub trait $trait_name: CubePrimitive + Sized {
36            #[allow(unused_variables)]
37            fn $method_name(x: Self) -> Self {
38                unexpanded!()
39            }
40
41            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
42                unary_expand(scope, x.into(), $operator).into()
43            }
44        }
45
46        $(impl $trait_name for $type {})*
47    }
48}
49
50impl Exp for f32 {
51    fn exp(x: Self) -> Self {
52        x.exp()
53    }
54}
55
56macro_rules! impl_unary_func_fixed_out_vectorization {
57    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
58        pub trait $trait_name: CubePrimitive + Sized {
59            #[allow(unused_variables)]
60            fn $method_name(x: Self) -> Self {
61                unexpanded!()
62            }
63
64            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
65                let expand_element: ExpandElement = x.into();
66                let item = expand_element.ty.line($out_vectorization);
67                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
68            }
69        }
70
71        $(impl $trait_name for $type {})*
72    }
73}
74
75macro_rules! impl_unary_func_fixed_out_ty {
76    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
77        pub trait $trait_name: CubePrimitive + Sized {
78            #[allow(unused_variables)]
79            fn $method_name(x: Self) -> $out_ty {
80                unexpanded!()
81            }
82
83            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
84                let expand_element: ExpandElement = x.into();
85                let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
86                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
87            }
88        }
89
90        $(impl $trait_name for $type {})*
91    }
92}
93
94impl_unary_func!(
95    Abs,
96    abs,
97    __expand_abs,
98    Arithmetic::Abs,
99    e2m1,
100    e4m3,
101    e5m2,
102    ue8m0,
103    f16,
104    bf16,
105    flex32,
106    tf32,
107    f32,
108    f64,
109    i8,
110    i16,
111    i32,
112    i64,
113    u8,
114    u16,
115    u32,
116    u64
117);
118impl_unary_func!(
119    Exp,
120    exp,
121    __expand_exp,
122    Arithmetic::Exp,
123    f16,
124    bf16,
125    flex32,
126    tf32,
127    // f32,
128    f64
129);
130impl_unary_func!(
131    Log,
132    log,
133    __expand_log,
134    Arithmetic::Log,
135    f16,
136    bf16,
137    flex32,
138    tf32,
139    f32,
140    f64
141);
142impl_unary_func!(
143    Log1p,
144    log1p,
145    __expand_log1p,
146    Arithmetic::Log1p,
147    f16,
148    bf16,
149    flex32,
150    tf32,
151    f32,
152    f64
153);
154impl_unary_func!(
155    Cos,
156    cos,
157    __expand_cos,
158    Arithmetic::Cos,
159    f16,
160    bf16,
161    flex32,
162    tf32,
163    f32,
164    f64
165);
166impl_unary_func!(
167    Sin,
168    sin,
169    __expand_sin,
170    Arithmetic::Sin,
171    f16,
172    bf16,
173    flex32,
174    tf32,
175    f32,
176    f64
177);
178impl_unary_func!(
179    Tanh,
180    tanh,
181    __expand_tanh,
182    Arithmetic::Tanh,
183    f16,
184    bf16,
185    flex32,
186    tf32,
187    f32,
188    f64
189);
190impl_unary_func!(
191    Sqrt,
192    sqrt,
193    __expand_sqrt,
194    Arithmetic::Sqrt,
195    f16,
196    bf16,
197    flex32,
198    tf32,
199    f32,
200    f64
201);
202impl_unary_func!(
203    InverseSqrt,
204    inverse_sqrt,
205    __expand_inverse_sqrt,
206    Arithmetic::InverseSqrt,
207    f16,
208    bf16,
209    flex32,
210    tf32,
211    f32,
212    f64
213);
214impl_unary_func!(
215    Round,
216    round,
217    __expand_round,
218    Arithmetic::Round,
219    f16,
220    bf16,
221    flex32,
222    tf32,
223    f32,
224    f64
225);
226impl_unary_func!(
227    Floor,
228    floor,
229    __expand_floor,
230    Arithmetic::Floor,
231    f16,
232    bf16,
233    flex32,
234    tf32,
235    f32,
236    f64
237);
238impl_unary_func!(
239    Ceil,
240    ceil,
241    __expand_ceil,
242    Arithmetic::Ceil,
243    f16,
244    bf16,
245    flex32,
246    tf32,
247    f32,
248    f64
249);
250impl_unary_func!(
251    Trunc,
252    trunc,
253    __expand_trunc,
254    Arithmetic::Trunc,
255    f16,
256    bf16,
257    flex32,
258    tf32,
259    f32,
260    f64
261);
262impl_unary_func!(
263    Erf,
264    erf,
265    __expand_erf,
266    Arithmetic::Erf,
267    f16,
268    bf16,
269    flex32,
270    tf32,
271    f32,
272    f64
273);
274impl_unary_func!(
275    Recip,
276    recip,
277    __expand_recip,
278    Arithmetic::Recip,
279    f16,
280    bf16,
281    flex32,
282    tf32,
283    f32,
284    f64
285);
286impl_unary_func_fixed_out_vectorization!(
287    Magnitude,
288    magnitude,
289    __expand_magnitude,
290    Arithmetic::Magnitude,
291    0,
292    f16,
293    bf16,
294    flex32,
295    tf32,
296    f32,
297    f64
298);
299impl_unary_func!(
300    Normalize,
301    normalize,
302    __expand_normalize,
303    Arithmetic::Normalize,
304    f16,
305    bf16,
306    flex32,
307    tf32,
308    f32,
309    f64
310);
311impl_unary_func_fixed_out_ty!(
312    CountOnes,
313    count_ones,
314    __expand_count_ones,
315    u32,
316    Bitwise::CountOnes,
317    u8,
318    i8,
319    u16,
320    i16,
321    u32,
322    i32,
323    u64,
324    i64
325);
326impl_unary_func!(
327    ReverseBits,
328    reverse_bits,
329    __expand_reverse_bits,
330    Bitwise::ReverseBits,
331    u8,
332    i8,
333    u16,
334    i16,
335    u32,
336    i32,
337    u64,
338    i64
339);
340
341impl_unary_func!(
342    BitwiseNot,
343    bitwise_not,
344    __expand_bitwise_not,
345    Bitwise::BitwiseNot,
346    u8,
347    i8,
348    u16,
349    i16,
350    u32,
351    i32,
352    u64,
353    i64
354);
355impl_unary_func_fixed_out_ty!(
356    LeadingZeros,
357    leading_zeros,
358    __expand_leading_zeros,
359    u32,
360    Bitwise::LeadingZeros,
361    u8,
362    i8,
363    u16,
364    i16,
365    u32,
366    i32,
367    u64,
368    i64
369);
370impl_unary_func_fixed_out_ty!(
371    FindFirstSet,
372    find_first_set,
373    __expand_find_first_set,
374    u32,
375    Bitwise::FindFirstSet,
376    u8,
377    i8,
378    u16,
379    i16,
380    u32,
381    i32,
382    u64,
383    i64
384);
385impl_unary_func_fixed_out_ty!(
386    IsNan,
387    is_nan,
388    __expand_is_nan,
389    bool,
390    Comparison::IsNan,
391    f16,
392    bf16,
393    flex32,
394    tf32,
395    f32,
396    f64
397);
398impl_unary_func_fixed_out_ty!(
399    IsInf,
400    is_inf,
401    __expand_is_inf,
402    bool,
403    Comparison::IsInf,
404    f16,
405    bf16,
406    flex32,
407    tf32,
408    f32,
409    f64
410);