1use core::ops::Not;
2use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, ue8m0};
3use cubecl_ir::{Bitwise, Comparison, Operator, Type};
4use half::{bf16, f16};
5
6use crate::{
7 flex32,
8 ir::{Arithmetic, ExpandElement, Scope},
9 prelude::{CubePrimitive, CubeType, ExpandElementTyped, Reinterpret},
10 tf32, unexpanded,
11};
12
13use super::base::{unary_expand, unary_expand_fixed_output};
14
15pub mod not {
16 use super::*;
17
18 pub fn expand<T: CubeNot>(
19 scope: &mut Scope,
20 x: ExpandElementTyped<T>,
21 ) -> ExpandElementTyped<T> {
22 if x.expand.ty.is_bool() {
23 unary_expand(scope, x.into(), Operator::Not).into()
24 } else {
25 unary_expand(scope, x.into(), Bitwise::BitwiseNot).into()
26 }
27 }
28}
29
30pub mod neg {
31 use super::*;
32
33 pub fn expand<E: CubePrimitive>(
34 scope: &mut Scope,
35 x: ExpandElementTyped<E>,
36 ) -> ExpandElementTyped<E> {
37 unary_expand(scope, x.into(), Arithmetic::Neg).into()
38 }
39}
40
41macro_rules! impl_unary_func {
42 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
43 paste::paste! {
44 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
45 #[allow(unused_variables)]
46 fn $method_name(self) -> Self {
47 unexpanded!()
48 }
49
50 fn [<__expand_ $method_name>](scope: &mut Scope, x: ExpandElementTyped<Self>) -> ExpandElementTyped<Self> {
51 x.[<__expand_ $method_name _method>](scope)
52 }
53 }
54
55 pub trait [<$trait_name Expand>] {
56 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
57 }
58
59 $(impl $trait_name for $type {})*
60 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for ExpandElementTyped<T> {
61 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
62 unary_expand(scope, self.into(), $operator).into()
63 }
64 }
65 }
66 }
67}
68
69impl Exp for f32 {
70 fn exp(self) -> Self {
71 self.exp()
72 }
73}
74
75macro_rules! impl_unary_func_fixed_out_vectorization {
76 ($trait_name:ident, $method_name:ident, $operator:expr, $out_vectorization: expr, $($type:ty),*) => {
77 paste::paste! {
78 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
79 #[allow(unused_variables)]
80 fn $method_name(self) -> Self {
81 unexpanded!()
82 }
83
84 fn [<__expand_ $method_name>](scope: &mut Scope, x: ExpandElementTyped<Self>) -> ExpandElementTyped<Self> {
85 x.[<__expand_ $method_name _method>](scope)
86 }
87 }
88
89 pub trait [<$trait_name Expand>] {
90 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
91 }
92
93 $(impl $trait_name for $type {})*
94 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for ExpandElementTyped<T> {
95 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
96 let expand_element: ExpandElement = self.into();
97 let item = expand_element.ty.line($out_vectorization);
98 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
99 }
100 }
101 }
102 }
103}
104
105macro_rules! impl_unary_func_fixed_out_ty {
106 ($trait_name:ident, $method_name:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
107 paste::paste! {
108 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
109 #[allow(unused_variables, clippy::wrong_self_convention)]
110 fn $method_name(self) -> $out_ty {
111 unexpanded!()
112 }
113
114 fn [<__expand_ $method_name>](scope: &mut Scope, x: ExpandElementTyped<Self>) -> ExpandElementTyped<$out_ty> {
115 x.[<__expand_ $method_name _method>](scope)
116 }
117 }
118
119 pub trait [<$trait_name Expand>] {
120 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> ExpandElementTyped<$out_ty>;
121 }
122
123 $(impl $trait_name for $type {})*
124 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for ExpandElementTyped<T> {
125 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> ExpandElementTyped<$out_ty> {
126 let expand_element: ExpandElement = self.into();
127 let item = Type::new(<$out_ty as CubePrimitive>::as_type(scope)).line(expand_element.ty.line_size());
128 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
129 }
130 }
131 }
132 }
133}
134
135macro_rules! impl_not {
137 ($trait_name:ident, $method_name:ident, $($type:ty),*) => {
138 paste::paste! {
139 pub trait [<Cube $trait_name>]: $trait_name<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> {
140 fn [<__expand_ $method_name>](scope: &mut Scope, x: ExpandElementTyped<Self>) -> ExpandElementTyped<Self> {
141 x.[<__expand_ $method_name _method>](scope)
142 }
143 }
144
145 pub trait [<$trait_name Expand>] {
146 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
147 }
148
149 $(impl [<Cube $trait_name>] for $type {})*
150 impl<T: [<Cube $trait_name>] + CubePrimitive> [<$trait_name Expand>] for ExpandElementTyped<T> {
151 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
152 not::expand(scope, self.into())
153 }
154 }
155 }
156 }
157}
158
159impl_not!(
160 Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize
161);
162
163impl_unary_func!(
164 Abs,
165 abs,
166 Arithmetic::Abs,
167 e2m1,
168 e4m3,
169 e5m2,
170 ue8m0,
171 f16,
172 bf16,
173 flex32,
174 tf32,
175 f32,
176 f64,
177 i8,
178 i16,
179 i32,
180 i64,
181 u8,
182 u16,
183 u32,
184 u64,
185 usize,
186 isize
187);
188impl_unary_func!(
189 Exp,
190 exp,
191 Arithmetic::Exp,
192 f16,
193 bf16,
194 flex32,
195 tf32,
196 f64
198);
199impl_unary_func!(Log, ln, Arithmetic::Log, f16, bf16, flex32, tf32, f32, f64);
200impl_unary_func!(
201 Log1p,
202 log1p,
203 Arithmetic::Log1p,
204 f16,
205 bf16,
206 flex32,
207 tf32,
208 f32,
209 f64
210);
211impl_unary_func!(Cos, cos, Arithmetic::Cos, f16, bf16, flex32, tf32, f32, f64);
212impl_unary_func!(Sin, sin, Arithmetic::Sin, f16, bf16, flex32, tf32, f32, f64);
213impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64);
214impl_unary_func!(
215 Tanh,
216 tanh,
217 Arithmetic::Tanh,
218 f16,
219 bf16,
220 flex32,
221 tf32,
222 f32,
223 f64
224);
225impl_unary_func!(
226 Sinh,
227 sinh,
228 Arithmetic::Sinh,
229 f16,
230 bf16,
231 flex32,
232 tf32,
233 f32,
234 f64
235);
236impl_unary_func!(
237 Cosh,
238 cosh,
239 Arithmetic::Cosh,
240 f16,
241 bf16,
242 flex32,
243 tf32,
244 f32,
245 f64
246);
247impl_unary_func!(
248 ArcCos,
249 acos,
250 Arithmetic::ArcCos,
251 f16,
252 bf16,
253 flex32,
254 tf32,
255 f32,
256 f64
257);
258impl_unary_func!(
259 ArcSin,
260 asin,
261 Arithmetic::ArcSin,
262 f16,
263 bf16,
264 flex32,
265 tf32,
266 f32,
267 f64
268);
269impl_unary_func!(
270 ArcTan,
271 atan,
272 Arithmetic::ArcTan,
273 f16,
274 bf16,
275 flex32,
276 tf32,
277 f32,
278 f64
279);
280impl_unary_func!(
281 ArcSinh,
282 asinh,
283 Arithmetic::ArcSinh,
284 f16,
285 bf16,
286 flex32,
287 tf32,
288 f32,
289 f64
290);
291impl_unary_func!(
292 ArcCosh,
293 acosh,
294 Arithmetic::ArcCosh,
295 f16,
296 bf16,
297 flex32,
298 tf32,
299 f32,
300 f64
301);
302impl_unary_func!(
303 ArcTanh,
304 atanh,
305 Arithmetic::ArcTanh,
306 f16,
307 bf16,
308 flex32,
309 tf32,
310 f32,
311 f64
312);
313impl_unary_func!(
314 Degrees,
315 to_degrees,
316 Arithmetic::Degrees,
317 f16,
318 bf16,
319 flex32,
320 tf32,
321 f32,
322 f64
323);
324impl_unary_func!(
325 Radians,
326 to_radians,
327 Arithmetic::Radians,
328 f16,
329 bf16,
330 flex32,
331 tf32,
332 f32,
333 f64
334);
335impl_unary_func!(
336 Sqrt,
337 sqrt,
338 Arithmetic::Sqrt,
339 f16,
340 bf16,
341 flex32,
342 tf32,
343 f32,
344 f64
345);
346impl_unary_func!(
347 InverseSqrt,
348 inverse_sqrt,
349 Arithmetic::InverseSqrt,
350 f16,
351 bf16,
352 flex32,
353 tf32,
354 f32,
355 f64
356);
357impl_unary_func!(
358 Round,
359 round,
360 Arithmetic::Round,
361 f16,
362 bf16,
363 flex32,
364 tf32,
365 f32,
366 f64
367);
368impl_unary_func!(
369 Floor,
370 floor,
371 Arithmetic::Floor,
372 f16,
373 bf16,
374 flex32,
375 tf32,
376 f32,
377 f64
378);
379impl_unary_func!(
380 Ceil,
381 ceil,
382 Arithmetic::Ceil,
383 f16,
384 bf16,
385 flex32,
386 tf32,
387 f32,
388 f64
389);
390impl_unary_func!(
391 Trunc,
392 trunc,
393 Arithmetic::Trunc,
394 f16,
395 bf16,
396 flex32,
397 tf32,
398 f32,
399 f64
400);
401impl_unary_func!(Erf, erf, Arithmetic::Erf, f16, bf16, flex32, tf32, f32, f64);
402impl_unary_func!(
403 Recip,
404 recip,
405 Arithmetic::Recip,
406 f16,
407 bf16,
408 flex32,
409 tf32,
410 f32,
411 f64
412);
413impl_unary_func_fixed_out_vectorization!(
414 Magnitude,
415 magnitude,
416 Arithmetic::Magnitude,
417 0,
418 f16,
419 bf16,
420 flex32,
421 tf32,
422 f32,
423 f64
424);
425impl_unary_func!(
426 Normalize,
427 normalize,
428 Arithmetic::Normalize,
429 f16,
430 bf16,
431 flex32,
432 tf32,
433 f32,
434 f64
435);
436impl_unary_func_fixed_out_ty!(
437 CountOnes,
438 count_ones,
439 u32,
440 Bitwise::CountOnes,
441 u8,
442 i8,
443 u16,
444 i16,
445 u32,
446 i32,
447 u64,
448 i64,
449 usize,
450 isize
451);
452impl_unary_func!(
453 ReverseBits,
454 reverse_bits,
455 Bitwise::ReverseBits,
456 u8,
457 i8,
458 u16,
459 i16,
460 u32,
461 i32,
462 u64,
463 i64,
464 usize,
465 isize
466);
467
468impl_unary_func_fixed_out_ty!(
469 LeadingZeros,
470 leading_zeros,
471 u32,
472 Bitwise::LeadingZeros,
473 u8,
474 i8,
475 u16,
476 i16,
477 u32,
478 i32,
479 u64,
480 i64,
481 usize,
482 isize
483);
484impl_unary_func_fixed_out_ty!(
485 FindFirstSet,
486 find_first_set,
487 u32,
488 Bitwise::FindFirstSet,
489 u8,
490 i8,
491 u16,
492 i16,
493 u32,
494 i32,
495 u64,
496 i64,
497 usize,
498 isize
499);
500impl_unary_func_fixed_out_ty!(
501 IsNan,
502 is_nan,
503 bool,
504 Comparison::IsNan,
505 f16,
506 bf16,
507 flex32,
508 tf32,
509 f32,
510 f64
511);
512impl_unary_func_fixed_out_ty!(
513 IsInf,
514 is_inf,
515 bool,
516 Comparison::IsInf,
517 f16,
518 bf16,
519 flex32,
520 tf32,
521 f32,
522 f64
523);
524
525pub trait FloatBits:
526 CubePrimitive + CubeType<ExpandType: FloatBitsExpand<Bits = Self::Bits>>
527{
528 type Bits: CubePrimitive;
529
530 fn __expand_from_bits(
531 scope: &mut Scope,
532 bits: ExpandElementTyped<Self::Bits>,
533 ) -> ExpandElementTyped<Self> {
534 Self::__expand_reinterpret(scope, bits)
535 }
536
537 fn __expand_to_bits(
538 scope: &mut Scope,
539 this: ExpandElementTyped<Self>,
540 ) -> ExpandElementTyped<Self::Bits> {
541 <Self::Bits as Reinterpret>::__expand_reinterpret(scope, this)
542 }
543}
544
545pub trait FloatBitsExpand: Sized {
546 type Bits: CubePrimitive;
547
548 fn __expand_to_bits_method(self, scope: &mut Scope) -> ExpandElementTyped<Self::Bits>;
549}
550
551impl<F: FloatBits> FloatBitsExpand for ExpandElementTyped<F> {
552 type Bits = F::Bits;
553
554 fn __expand_to_bits_method(self, scope: &mut Scope) -> ExpandElementTyped<Self::Bits> {
555 <Self::Bits as Reinterpret>::__expand_reinterpret(scope, self)
556 }
557}
558
559impl FloatBits for e2m1x2 {
560 type Bits = u8;
561}
562
563impl FloatBits for e5m2 {
564 type Bits = u8;
565}
566
567impl FloatBits for e4m3 {
568 type Bits = u8;
569}
570
571impl FloatBits for f16 {
572 type Bits = u16;
573}
574
575impl FloatBits for bf16 {
576 type Bits = u16;
577}
578
579impl FloatBits for f32 {
580 type Bits = u32;
581}
582
583impl FloatBits for f64 {
584 type Bits = u64;
585}