1use core::ops::Not;
2use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, ue8m0};
3use cubecl_ir::{Bitwise, Comparison, Operator};
4use half::{bf16, f16};
5
6use crate::{
7 flex32,
8 ir::{Arithmetic, ManagedVariable, Scope},
9 prelude::{CubePrimitive, CubePrimitiveExpand, CubeType, NativeExpand, Reinterpret},
10 tf32, unexpanded,
11};
12
13use super::base::{unary_expand, unary_expand_fixed_output};
14
15pub mod not {
16 use super::*;
17
18 pub fn expand<T: CubeNot>(scope: &mut Scope, x: NativeExpand<T>) -> NativeExpand<T> {
19 if x.expand.ty.is_bool() {
20 unary_expand(scope, x.into(), Operator::Not).into()
21 } else {
22 unary_expand(scope, x.into(), Bitwise::BitwiseNot).into()
23 }
24 }
25}
26
27pub mod neg {
28 use super::*;
29
30 pub fn expand<E: CubePrimitive>(scope: &mut Scope, x: NativeExpand<E>) -> NativeExpand<E> {
31 unary_expand(scope, x.into(), Arithmetic::Neg).into()
32 }
33}
34
35macro_rules! impl_unary_func {
36 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
37 paste::paste! {
38 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
39 #[allow(unused_variables)]
40 fn $method_name(self) -> Self {
41 unexpanded!()
42 }
43
44 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
45 x.[<__expand_ $method_name _method>](scope)
46 }
47 }
48
49 pub trait [<$trait_name Expand>] {
50 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
51 }
52
53 $(impl $trait_name for $type {})*
54 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
55 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
56 unary_expand(scope, self.into(), $operator).into()
57 }
58 }
59 }
60 }
61}
62
63impl Exp for f32 {
64 fn exp(self) -> Self {
65 self.exp()
66 }
67}
68
69macro_rules! impl_unary_func_scalar_out {
70 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
71 paste::paste! {
72 pub trait $trait_name: CubePrimitive
73 + CubeType<ExpandType: [<$trait_name Expand>]
74 + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
75 + Sized {
76 #[allow(unused_variables)]
77 fn $method_name(self) -> Self {
78 unexpanded!()
79 }
80
81 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::Scalar> {
82 x.[<__expand_ $method_name _method>](scope)
83 }
84 }
85
86 pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
87 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar;
88 }
89
90 $(impl $trait_name for $type {})*
91 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
92 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar {
93 let expand_element: ManagedVariable = self.into();
94 let item = expand_element.ty.with_vector_size(0);
95 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
96 }
97 }
98 }
99 }
100}
101
102macro_rules! impl_unary_func_fixed_out_ty {
103 ($trait_name:ident, $method_name:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
104 paste::paste! {
105 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]
106 + CubePrimitiveExpand<WithScalar<$out_ty> = NativeExpand<Self::WithScalar<$out_ty>>>> + Sized {
107 #[allow(unused_variables, clippy::wrong_self_convention)]
108 fn $method_name(self) -> Self::WithScalar<$out_ty> {
109 unexpanded!()
110 }
111
112 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::WithScalar<$out_ty>> {
113 x.[<__expand_ $method_name _method>](scope)
114 }
115 }
116
117 pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
118 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty>;
119 }
120
121 $(impl $trait_name for $type {})*
122 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
123 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty> {
124 let expand_element: ManagedVariable = self.into();
125 let item = <$out_ty as CubePrimitive>::as_type(scope).with_vector_size(expand_element.ty.vector_size());
126 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
127 }
128 }
129 }
130 }
131}
132
133macro_rules! impl_not {
135 ($trait_name:ident, $method_name:ident, $($type:ty),*) => {
136 paste::paste! {
137 pub trait [<Cube $trait_name>]: $trait_name<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> {
138 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
139 x.[<__expand_ $method_name _method>](scope)
140 }
141 }
142
143 pub trait [<$trait_name Expand>] {
144 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
145 }
146
147 $(impl [<Cube $trait_name>] for $type {})*
148 impl<T: [<Cube $trait_name>] + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
149 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
150 not::expand(scope, self.into())
151 }
152 }
153 }
154 }
155}
156
157impl_not!(
158 Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize
159);
160
161impl_unary_func!(
162 Abs,
163 abs,
164 Arithmetic::Abs,
165 e2m1,
166 e4m3,
167 e5m2,
168 ue8m0,
169 f16,
170 bf16,
171 flex32,
172 tf32,
173 f32,
174 f64,
175 i8,
176 i16,
177 i32,
178 i64,
179 u8,
180 u16,
181 u32,
182 u64,
183 usize,
184 isize
185);
186impl_unary_func!(
187 Exp,
188 exp,
189 Arithmetic::Exp,
190 f16,
191 bf16,
192 flex32,
193 tf32,
194 f64
196);
197impl_unary_func!(Log, ln, Arithmetic::Log, f16, bf16, flex32, tf32, f32, f64);
198impl_unary_func!(
199 Log1p,
200 log1p,
201 Arithmetic::Log1p,
202 f16,
203 bf16,
204 flex32,
205 tf32,
206 f32,
207 f64
208);
209impl_unary_func!(Cos, cos, Arithmetic::Cos, f16, bf16, flex32, tf32, f32, f64);
210impl_unary_func!(Sin, sin, Arithmetic::Sin, f16, bf16, flex32, tf32, f32, f64);
211impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64);
212impl_unary_func!(
213 Tanh,
214 tanh,
215 Arithmetic::Tanh,
216 f16,
217 bf16,
218 flex32,
219 tf32,
220 f32,
221 f64
222);
223impl_unary_func!(
224 Sinh,
225 sinh,
226 Arithmetic::Sinh,
227 f16,
228 bf16,
229 flex32,
230 tf32,
231 f32,
232 f64
233);
234impl_unary_func!(
235 Cosh,
236 cosh,
237 Arithmetic::Cosh,
238 f16,
239 bf16,
240 flex32,
241 tf32,
242 f32,
243 f64
244);
245impl_unary_func!(
246 ArcCos,
247 acos,
248 Arithmetic::ArcCos,
249 f16,
250 bf16,
251 flex32,
252 tf32,
253 f32,
254 f64
255);
256impl_unary_func!(
257 ArcSin,
258 asin,
259 Arithmetic::ArcSin,
260 f16,
261 bf16,
262 flex32,
263 tf32,
264 f32,
265 f64
266);
267impl_unary_func!(
268 ArcTan,
269 atan,
270 Arithmetic::ArcTan,
271 f16,
272 bf16,
273 flex32,
274 tf32,
275 f32,
276 f64
277);
278impl_unary_func!(
279 ArcSinh,
280 asinh,
281 Arithmetic::ArcSinh,
282 f16,
283 bf16,
284 flex32,
285 tf32,
286 f32,
287 f64
288);
289impl_unary_func!(
290 ArcCosh,
291 acosh,
292 Arithmetic::ArcCosh,
293 f16,
294 bf16,
295 flex32,
296 tf32,
297 f32,
298 f64
299);
300impl_unary_func!(
301 ArcTanh,
302 atanh,
303 Arithmetic::ArcTanh,
304 f16,
305 bf16,
306 flex32,
307 tf32,
308 f32,
309 f64
310);
311impl_unary_func!(
312 Degrees,
313 to_degrees,
314 Arithmetic::Degrees,
315 f16,
316 bf16,
317 flex32,
318 tf32,
319 f32,
320 f64
321);
322impl_unary_func!(
323 Radians,
324 to_radians,
325 Arithmetic::Radians,
326 f16,
327 bf16,
328 flex32,
329 tf32,
330 f32,
331 f64
332);
333impl_unary_func!(
334 Sqrt,
335 sqrt,
336 Arithmetic::Sqrt,
337 f16,
338 bf16,
339 flex32,
340 tf32,
341 f32,
342 f64
343);
344impl_unary_func!(
345 InverseSqrt,
346 inverse_sqrt,
347 Arithmetic::InverseSqrt,
348 f16,
349 bf16,
350 flex32,
351 tf32,
352 f32,
353 f64
354);
355impl_unary_func!(
356 Round,
357 round,
358 Arithmetic::Round,
359 f16,
360 bf16,
361 flex32,
362 tf32,
363 f32,
364 f64
365);
366impl_unary_func!(
367 Floor,
368 floor,
369 Arithmetic::Floor,
370 f16,
371 bf16,
372 flex32,
373 tf32,
374 f32,
375 f64
376);
377impl_unary_func!(
378 Ceil,
379 ceil,
380 Arithmetic::Ceil,
381 f16,
382 bf16,
383 flex32,
384 tf32,
385 f32,
386 f64
387);
388impl_unary_func!(
389 Trunc,
390 trunc,
391 Arithmetic::Trunc,
392 f16,
393 bf16,
394 flex32,
395 tf32,
396 f32,
397 f64
398);
399impl_unary_func!(Erf, erf, Arithmetic::Erf, f16, bf16, flex32, tf32, f32, f64);
400impl_unary_func!(
401 Recip,
402 recip,
403 Arithmetic::Recip,
404 f16,
405 bf16,
406 flex32,
407 tf32,
408 f32,
409 f64
410);
411impl_unary_func_scalar_out!(
412 Magnitude,
413 magnitude,
414 Arithmetic::Magnitude,
415 f16,
416 bf16,
417 flex32,
418 tf32,
419 f32,
420 f64
421);
422impl_unary_func_scalar_out!(
423 VectorSum,
424 vector_sum,
425 Arithmetic::VectorSum,
426 e2m1,
427 e4m3,
428 e5m2,
429 ue8m0,
430 f16,
431 bf16,
432 flex32,
433 tf32,
434 f32,
435 f64,
436 i8,
437 i16,
438 i32,
439 i64,
440 u8,
441 u16,
442 u32,
443 u64,
444 usize,
445 isize
446);
447impl_unary_func!(
448 Normalize,
449 normalize,
450 Arithmetic::Normalize,
451 f16,
452 bf16,
453 flex32,
454 tf32,
455 f32,
456 f64
457);
458impl_unary_func_fixed_out_ty!(
459 CountOnes,
460 count_ones,
461 u32,
462 Bitwise::CountOnes,
463 u8,
464 i8,
465 u16,
466 i16,
467 u32,
468 i32,
469 u64,
470 i64,
471 usize,
472 isize
473);
474impl_unary_func!(
475 ReverseBits,
476 reverse_bits,
477 Bitwise::ReverseBits,
478 u8,
479 i8,
480 u16,
481 i16,
482 u32,
483 i32,
484 u64,
485 i64,
486 usize,
487 isize
488);
489
490impl_unary_func_fixed_out_ty!(
491 LeadingZeros,
492 leading_zeros,
493 u32,
494 Bitwise::LeadingZeros,
495 u8,
496 i8,
497 u16,
498 i16,
499 u32,
500 i32,
501 u64,
502 i64,
503 usize,
504 isize
505);
506impl_unary_func_fixed_out_ty!(
507 TrailingZeros,
508 trailing_zeros,
509 u32,
510 Bitwise::TrailingZeros,
511 u8,
512 i8,
513 u16,
514 i16,
515 u32,
516 i32,
517 u64,
518 i64,
519 usize,
520 isize
521);
522impl_unary_func_fixed_out_ty!(
523 FindFirstSet,
524 find_first_set,
525 u32,
526 Bitwise::FindFirstSet,
527 u8,
528 i8,
529 u16,
530 i16,
531 u32,
532 i32,
533 u64,
534 i64,
535 usize,
536 isize
537);
538impl_unary_func_fixed_out_ty!(
539 IsNan,
540 is_nan,
541 bool,
542 Comparison::IsNan,
543 f16,
544 bf16,
545 flex32,
546 tf32,
547 f32,
548 f64
549);
550impl_unary_func_fixed_out_ty!(
551 IsInf,
552 is_inf,
553 bool,
554 Comparison::IsInf,
555 f16,
556 bf16,
557 flex32,
558 tf32,
559 f32,
560 f64
561);
562
563pub trait FloatBits:
564 CubePrimitive + CubeType<ExpandType: FloatBitsExpand<Bits = Self::Bits>>
565{
566 type Bits: CubePrimitive;
567
568 fn __expand_from_bits(scope: &mut Scope, bits: NativeExpand<Self::Bits>) -> NativeExpand<Self> {
569 Self::__expand_reinterpret(scope, bits)
570 }
571
572 fn __expand_to_bits(scope: &mut Scope, this: NativeExpand<Self>) -> NativeExpand<Self::Bits> {
573 <Self::Bits as Reinterpret>::__expand_reinterpret(scope, this)
574 }
575}
576
577pub trait FloatBitsExpand: Sized {
578 type Bits: CubePrimitive;
579
580 fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits>;
581}
582
583impl<F: FloatBits> FloatBitsExpand for NativeExpand<F> {
584 type Bits = F::Bits;
585
586 fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits> {
587 <Self::Bits as Reinterpret>::__expand_reinterpret(scope, self)
588 }
589}
590
591impl FloatBits for e2m1x2 {
592 type Bits = u8;
593}
594
595impl FloatBits for e5m2 {
596 type Bits = u8;
597}
598
599impl FloatBits for e4m3 {
600 type Bits = u8;
601}
602
603impl FloatBits for f16 {
604 type Bits = u16;
605}
606
607impl FloatBits for bf16 {
608 type Bits = u16;
609}
610
611impl FloatBits for f32 {
612 type Bits = u32;
613}
614
615impl FloatBits for f64 {
616 type Bits = u64;
617}