1use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
2use cubecl_ir::{Bitwise, Comparison, Operator, Type};
3use half::{bf16, f16};
4
5use crate::{
6 flex32,
7 ir::{Arithmetic, ExpandElement, Scope},
8 prelude::{CubePrimitive, ExpandElementTyped},
9 tf32, unexpanded,
10};
11
12use super::base::{unary_expand, unary_expand_fixed_output};
13
14pub mod not {
15 use super::*;
16
17 pub fn expand(scope: &mut Scope, x: ExpandElementTyped<bool>) -> ExpandElementTyped<bool> {
18 unary_expand(scope, x.into(), Operator::Not).into()
19 }
20}
21
22pub mod neg {
23 use super::*;
24
25 pub fn expand<E: CubePrimitive>(
26 scope: &mut Scope,
27 x: ExpandElementTyped<E>,
28 ) -> ExpandElementTyped<E> {
29 unary_expand(scope, x.into(), Arithmetic::Neg).into()
30 }
31}
32
33macro_rules! impl_unary_func {
34 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => {
35 pub trait $trait_name: CubePrimitive + Sized {
36 #[allow(unused_variables)]
37 fn $method_name(x: Self) -> Self {
38 unexpanded!()
39 }
40
41 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
42 unary_expand(scope, x.into(), $operator).into()
43 }
44 }
45
46 $(impl $trait_name for $type {})*
47 }
48}
49
50impl Exp for f32 {
51 fn exp(x: Self) -> Self {
52 x.exp()
53 }
54}
55
56macro_rules! impl_unary_func_fixed_out_vectorization {
57 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
58 pub trait $trait_name: CubePrimitive + Sized {
59 #[allow(unused_variables)]
60 fn $method_name(x: Self) -> Self {
61 unexpanded!()
62 }
63
64 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<Self> {
65 let expand_element: ExpandElement = x.into();
66 let item = expand_element.ty.line($out_vectorization);
67 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
68 }
69 }
70
71 $(impl $trait_name for $type {})*
72 }
73}
74
75macro_rules! impl_unary_func_fixed_out_ty {
76 ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
77 pub trait $trait_name: CubePrimitive + Sized {
78 #[allow(unused_variables)]
79 fn $method_name(x: Self) -> $out_ty {
80 unexpanded!()
81 }
82
83 fn $method_name_expand(scope: &mut Scope, x: Self::ExpandType) -> ExpandElementTyped<$out_ty> {
84 let expand_element: ExpandElement = x.into();
85 let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
86 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
87 }
88 }
89
90 $(impl $trait_name for $type {})*
91 }
92}
93
94impl_unary_func!(
95 Abs,
96 abs,
97 __expand_abs,
98 Arithmetic::Abs,
99 e2m1,
100 e4m3,
101 e5m2,
102 ue8m0,
103 f16,
104 bf16,
105 flex32,
106 tf32,
107 f32,
108 f64,
109 i8,
110 i16,
111 i32,
112 i64,
113 u8,
114 u16,
115 u32,
116 u64
117);
118impl_unary_func!(
119 Exp,
120 exp,
121 __expand_exp,
122 Arithmetic::Exp,
123 f16,
124 bf16,
125 flex32,
126 tf32,
127 f64
129);
130impl_unary_func!(
131 Log,
132 log,
133 __expand_log,
134 Arithmetic::Log,
135 f16,
136 bf16,
137 flex32,
138 tf32,
139 f32,
140 f64
141);
142impl_unary_func!(
143 Log1p,
144 log1p,
145 __expand_log1p,
146 Arithmetic::Log1p,
147 f16,
148 bf16,
149 flex32,
150 tf32,
151 f32,
152 f64
153);
154impl_unary_func!(
155 Cos,
156 cos,
157 __expand_cos,
158 Arithmetic::Cos,
159 f16,
160 bf16,
161 flex32,
162 tf32,
163 f32,
164 f64
165);
166impl_unary_func!(
167 Sin,
168 sin,
169 __expand_sin,
170 Arithmetic::Sin,
171 f16,
172 bf16,
173 flex32,
174 tf32,
175 f32,
176 f64
177);
178impl_unary_func!(
179 Tanh,
180 tanh,
181 __expand_tanh,
182 Arithmetic::Tanh,
183 f16,
184 bf16,
185 flex32,
186 tf32,
187 f32,
188 f64
189);
190impl_unary_func!(
191 Sqrt,
192 sqrt,
193 __expand_sqrt,
194 Arithmetic::Sqrt,
195 f16,
196 bf16,
197 flex32,
198 tf32,
199 f32,
200 f64
201);
202impl_unary_func!(
203 Round,
204 round,
205 __expand_round,
206 Arithmetic::Round,
207 f16,
208 bf16,
209 flex32,
210 tf32,
211 f32,
212 f64
213);
214impl_unary_func!(
215 Floor,
216 floor,
217 __expand_floor,
218 Arithmetic::Floor,
219 f16,
220 bf16,
221 flex32,
222 tf32,
223 f32,
224 f64
225);
226impl_unary_func!(
227 Ceil,
228 ceil,
229 __expand_ceil,
230 Arithmetic::Ceil,
231 f16,
232 bf16,
233 flex32,
234 tf32,
235 f32,
236 f64
237);
238impl_unary_func!(
239 Trunc,
240 trunc,
241 __expand_trunc,
242 Arithmetic::Trunc,
243 f16,
244 bf16,
245 flex32,
246 tf32,
247 f32,
248 f64
249);
250impl_unary_func!(
251 Erf,
252 erf,
253 __expand_erf,
254 Arithmetic::Erf,
255 f16,
256 bf16,
257 flex32,
258 tf32,
259 f32,
260 f64
261);
262impl_unary_func!(
263 Recip,
264 recip,
265 __expand_recip,
266 Arithmetic::Recip,
267 f16,
268 bf16,
269 flex32,
270 tf32,
271 f32,
272 f64
273);
274impl_unary_func_fixed_out_vectorization!(
275 Magnitude,
276 magnitude,
277 __expand_magnitude,
278 Arithmetic::Magnitude,
279 0,
280 f16,
281 bf16,
282 flex32,
283 tf32,
284 f32,
285 f64
286);
287impl_unary_func!(
288 Normalize,
289 normalize,
290 __expand_normalize,
291 Arithmetic::Normalize,
292 f16,
293 bf16,
294 flex32,
295 tf32,
296 f32,
297 f64
298);
299impl_unary_func_fixed_out_ty!(
300 CountOnes,
301 count_ones,
302 __expand_count_ones,
303 u32,
304 Bitwise::CountOnes,
305 u8,
306 i8,
307 u16,
308 i16,
309 u32,
310 i32,
311 u64,
312 i64
313);
314impl_unary_func!(
315 ReverseBits,
316 reverse_bits,
317 __expand_reverse_bits,
318 Bitwise::ReverseBits,
319 u8,
320 i8,
321 u16,
322 i16,
323 u32,
324 i32,
325 u64,
326 i64
327);
328
329impl_unary_func!(
330 BitwiseNot,
331 bitwise_not,
332 __expand_bitwise_not,
333 Bitwise::BitwiseNot,
334 u8,
335 i8,
336 u16,
337 i16,
338 u32,
339 i32,
340 u64,
341 i64
342);
343impl_unary_func_fixed_out_ty!(
344 LeadingZeros,
345 leading_zeros,
346 __expand_leading_zeros,
347 u32,
348 Bitwise::LeadingZeros,
349 u8,
350 i8,
351 u16,
352 i16,
353 u32,
354 i32,
355 u64,
356 i64
357);
358impl_unary_func_fixed_out_ty!(
359 FindFirstSet,
360 find_first_set,
361 __expand_find_first_set,
362 u32,
363 Bitwise::FindFirstSet,
364 u8,
365 i8,
366 u16,
367 i16,
368 u32,
369 i32,
370 u64,
371 i64
372);
373impl_unary_func_fixed_out_ty!(
374 IsNan,
375 is_nan,
376 __expand_is_nan,
377 bool,
378 Comparison::IsNan,
379 f16,
380 bf16,
381 flex32,
382 tf32,
383 f32,
384 f64
385);
386impl_unary_func_fixed_out_ty!(
387 IsInf,
388 is_inf,
389 __expand_is_inf,
390 bool,
391 Comparison::IsInf,
392 f16,
393 bf16,
394 flex32,
395 tf32,
396 f32,
397 f64
398);