1use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
2use cubecl_ir::{Bitwise, Comparison, Operator, Type};
3use half::{bf16, f16};
4
5use crate::{
6    flex32,
7    ir::{Arithmetic, ExpandElement, Scope},
8    prelude::{CubePrimitive, ExpandElementTyped},
9    tf32, unexpanded,
10};
11
12use super::base::{unary_expand, unary_expand_fixed_output};
13
14pub mod not {
15    use super::*;
16
17    pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
18        unary_expand(scope, x.into(), Operator::Not).into()
19    }
20}
21
22pub mod neg {
23    use super::*;
24
25    pub fn expand<E: CubePrimitive>(
26        scope: &mut Scope,
27        x: ExpandElementTyped<E>,
28    ) -> ExpandElementTyped<E> {
29        unary_expand(scope, x.into(), Arithmetic::Neg).into()
30    }
31}
32
33macro_rules! impl_unary_func {
34    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
35        pub trait $trait_name: CubePrimitive + Sized {
36            #[allow(unused_variables)]
37            fn $method_name(x: Self) -> Self {
38                unexpanded!()
39            }
40
41            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
42                unary_expand(scope, x.into(), $operator).into()
43            }
44        }
45
46        $(impl $trait_name for $type {})*
47    }
48}
49
50impl Exp for f32 {
51    fn exp(x: Self) -> Self {
52        x.exp()
53    }
54}
55
56macro_rules! impl_unary_func_fixed_out_vectorization {
57    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
58        pub trait $trait_name: CubePrimitive + Sized {
59            #[allow(unused_variables)]
60            fn $method_name(x: Self) -> Self {
61                unexpanded!()
62            }
63
64            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
65                let expand_element: ExpandElement = x.into();
66                let item = expand_element.ty.line($out_vectorization);
67                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
68            }
69        }
70
71        $(impl $trait_name for $type {})*
72    }
73}
74
75macro_rules! impl_unary_func_fixed_out_ty {
76    ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
77        pub trait $trait_name: CubePrimitive + Sized {
78            #[allow(unused_variables)]
79            fn $method_name(x: Self) -> $out_ty {
80                unexpanded!()
81            }
82
83            fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
84                let expand_element: ExpandElement = x.into();
85                let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
86                unary_expand_fixed_output(scope, expand_element, item, $operator).into()
87            }
88        }
89
90        $(impl $trait_name for $type {})*
91    }
92}
93
94impl_unary_func!(
95    Abs,
96    abs,
97    __expand_abs,
98    Arithmetic::Abs,
99    e2m1,
100    e4m3,
101    e5m2,
102    ue8m0,
103    f16,
104    bf16,
105    flex32,
106    tf32,
107    f32,
108    f64,
109    i8,
110    i16,
111    i32,
112    i64,
113    u8,
114    u16,
115    u32,
116    u64
117);
118impl_unary_func!(
119    Exp,
120    exp,
121    __expand_exp,
122    Arithmetic::Exp,
123    f16,
124    bf16,
125    flex32,
126    tf32,
127    f64
129);
130impl_unary_func!(
131    Log,
132    log,
133    __expand_log,
134    Arithmetic::Log,
135    f16,
136    bf16,
137    flex32,
138    tf32,
139    f32,
140    f64
141);
142impl_unary_func!(
143    Log1p,
144    log1p,
145    __expand_log1p,
146    Arithmetic::Log1p,
147    f16,
148    bf16,
149    flex32,
150    tf32,
151    f32,
152    f64
153);
154impl_unary_func!(
155    Cos,
156    cos,
157    __expand_cos,
158    Arithmetic::Cos,
159    f16,
160    bf16,
161    flex32,
162    tf32,
163    f32,
164    f64
165);
166impl_unary_func!(
167    Sin,
168    sin,
169    __expand_sin,
170    Arithmetic::Sin,
171    f16,
172    bf16,
173    flex32,
174    tf32,
175    f32,
176    f64
177);
178impl_unary_func!(
179    Tanh,
180    tanh,
181    __expand_tanh,
182    Arithmetic::Tanh,
183    f16,
184    bf16,
185    flex32,
186    tf32,
187    f32,
188    f64
189);
190impl_unary_func!(
191    Sqrt,
192    sqrt,
193    __expand_sqrt,
194    Arithmetic::Sqrt,
195    f16,
196    bf16,
197    flex32,
198    tf32,
199    f32,
200    f64
201);
202impl_unary_func!(
203    Round,
204    round,
205    __expand_round,
206    Arithmetic::Round,
207    f16,
208    bf16,
209    flex32,
210    tf32,
211    f32,
212    f64
213);
214impl_unary_func!(
215    Floor,
216    floor,
217    __expand_floor,
218    Arithmetic::Floor,
219    f16,
220    bf16,
221    flex32,
222    tf32,
223    f32,
224    f64
225);
226impl_unary_func!(
227    Ceil,
228    ceil,
229    __expand_ceil,
230    Arithmetic::Ceil,
231    f16,
232    bf16,
233    flex32,
234    tf32,
235    f32,
236    f64
237);
238impl_unary_func!(
239    Trunc,
240    trunc,
241    __expand_trunc,
242    Arithmetic::Trunc,
243    f16,
244    bf16,
245    flex32,
246    tf32,
247    f32,
248    f64
249);
250impl_unary_func!(
251    Erf,
252    erf,
253    __expand_erf,
254    Arithmetic::Erf,
255    f16,
256    bf16,
257    flex32,
258    tf32,
259    f32,
260    f64
261);
262impl_unary_func!(
263    Recip,
264    recip,
265    __expand_recip,
266    Arithmetic::Recip,
267    f16,
268    bf16,
269    flex32,
270    tf32,
271    f32,
272    f64
273);
274impl_unary_func_fixed_out_vectorization!(
275    Magnitude,
276    magnitude,
277    __expand_magnitude,
278    Arithmetic::Magnitude,
279    0,
280    f16,
281    bf16,
282    flex32,
283    tf32,
284    f32,
285    f64
286);
287impl_unary_func!(
288    Normalize,
289    normalize,
290    __expand_normalize,
291    Arithmetic::Normalize,
292    f16,
293    bf16,
294    flex32,
295    tf32,
296    f32,
297    f64
298);
299impl_unary_func_fixed_out_ty!(
300    CountOnes,
301    count_ones,
302    __expand_count_ones,
303    u32,
304    Bitwise::CountOnes,
305    u8,
306    i8,
307    u16,
308    i16,
309    u32,
310    i32,
311    u64,
312    i64
313);
314impl_unary_func!(
315    ReverseBits,
316    reverse_bits,
317    __expand_reverse_bits,
318    Bitwise::ReverseBits,
319    u8,
320    i8,
321    u16,
322    i16,
323    u32,
324    i32,
325    u64,
326    i64
327);
328
329impl_unary_func!(
330    BitwiseNot,
331    bitwise_not,
332    __expand_bitwise_not,
333    Bitwise::BitwiseNot,
334    u8,
335    i8,
336    u16,
337    i16,
338    u32,
339    i32,
340    u64,
341    i64
342);
343impl_unary_func_fixed_out_ty!(
344    LeadingZeros,
345    leading_zeros,
346    __expand_leading_zeros,
347    u32,
348    Bitwise::LeadingZeros,
349    u8,
350    i8,
351    u16,
352    i16,
353    u32,
354    i32,
355    u64,
356    i64
357);
358impl_unary_func_fixed_out_ty!(
359    FindFirstSet,
360    find_first_set,
361    __expand_find_first_set,
362    u32,
363    Bitwise::FindFirstSet,
364    u8,
365    i8,
366    u16,
367    i16,
368    u32,
369    i32,
370    u64,
371    i64
372);
373impl_unary_func_fixed_out_ty!(
374    IsNan,
375    is_nan,
376    __expand_is_nan,
377    bool,
378    Comparison::IsNan,
379    f16,
380    bf16,
381    flex32,
382    tf32,
383    f32,
384    f64
385);
386impl_unary_func_fixed_out_ty!(
387    IsInf,
388    is_inf,
389    __expand_is_inf,
390    bool,
391    Comparison::IsInf,
392    f16,
393    bf16,
394    flex32,
395    tf32,
396    f32,
397    f64
398);