1use core::ops::Not;
2use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, ue8m0};
3use cubecl_ir::{Bitwise, Comparison, Operator};
4use half::{bf16, f16};
5
6use crate::{
7 flex32,
8 ir::{Arithmetic, ManagedVariable, Scope},
9 prelude::{CubePrimitive, CubePrimitiveExpand, CubeType, NativeExpand, Reinterpret},
10 tf32, unexpanded,
11};
12
13use super::base::{unary_expand, unary_expand_fixed_output};
14
15pub mod not {
16 use super::*;
17
18 pub fn expand<T: CubeNot>(scope: &mut Scope, x: NativeExpand<T>) -> NativeExpand<T> {
19 if x.expand.ty.is_bool() {
20 unary_expand(scope, x.into(), Operator::Not).into()
21 } else {
22 unary_expand(scope, x.into(), Bitwise::BitwiseNot).into()
23 }
24 }
25}
26
27pub mod neg {
28 use super::*;
29
30 pub fn expand<E: CubePrimitive>(scope: &mut Scope, x: NativeExpand<E>) -> NativeExpand<E> {
31 unary_expand(scope, x.into(), Arithmetic::Neg).into()
32 }
33}
34
35macro_rules! impl_unary_func {
36 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
37 paste::paste! {
38 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
39 #[allow(unused_variables)]
40 fn $method_name(self) -> Self {
41 unexpanded!()
42 }
43
44 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
45 x.[<__expand_ $method_name _method>](scope)
46 }
47 }
48
49 pub trait [<$trait_name Expand>] {
50 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
51 }
52
53 $(impl $trait_name for $type {})*
54 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
55 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
56 unary_expand(scope, self.into(), $operator).into()
57 }
58 }
59 }
60 }
61}
62
63impl Exp for f32 {
64 fn exp(self) -> Self {
65 self.exp()
66 }
67}
68
69macro_rules! impl_unary_func_scalar_out {
70 ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
71 paste::paste! {
72 pub trait $trait_name: CubePrimitive
73 + CubeType<ExpandType: [<$trait_name Expand>]
74 + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
75 + Sized {
76 #[allow(unused_variables)]
77 fn $method_name(self) -> Self {
78 unexpanded!()
79 }
80
81 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::Scalar> {
82 x.[<__expand_ $method_name _method>](scope)
83 }
84 }
85
86 pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
87 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar;
88 }
89
90 $(impl $trait_name for $type {})*
91 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
92 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar {
93 let expand_element: ManagedVariable = self.into();
94 let item = expand_element.ty.with_vector_size(0);
95 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
96 }
97 }
98 }
99 }
100}
101
102macro_rules! impl_unary_func_fixed_out_ty {
103 ($trait_name:ident, $method_name:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
104 paste::paste! {
105 pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]
106 + CubePrimitiveExpand<WithScalar<$out_ty> = NativeExpand<Self::WithScalar<$out_ty>>>> + Sized {
107 #[allow(unused_variables, clippy::wrong_self_convention)]
108 fn $method_name(self) -> Self::WithScalar<$out_ty> {
109 unexpanded!()
110 }
111
112 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::WithScalar<$out_ty>> {
113 x.[<__expand_ $method_name _method>](scope)
114 }
115 }
116
117 pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
118 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty>;
119 }
120
121 $(impl $trait_name for $type {})*
122 impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
123 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty> {
124 let expand_element: ManagedVariable = self.into();
125 let item = <$out_ty as CubePrimitive>::as_type(scope).with_vector_size(expand_element.ty.vector_size());
126 unary_expand_fixed_output(scope, expand_element, item, $operator).into()
127 }
128 }
129 }
130 }
131}
132
133macro_rules! impl_not {
135 ($trait_name:ident, $method_name:ident, $($type:ty),*) => {
136 paste::paste! {
137 pub trait [<Cube $trait_name>]: $trait_name<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> {
138 fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
139 x.[<__expand_ $method_name _method>](scope)
140 }
141 }
142
143 pub trait [<$trait_name Expand>] {
144 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
145 }
146
147 $(impl [<Cube $trait_name>] for $type {})*
148 impl<T: [<Cube $trait_name>] + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
149 fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
150 not::expand(scope, self.into())
151 }
152 }
153 }
154 }
155}
156
157impl_not!(
158 Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize
159);
160
161impl_unary_func!(
162 Abs,
163 abs,
164 Arithmetic::Abs,
165 e2m1,
166 e4m3,
167 e5m2,
168 ue8m0,
169 f16,
170 bf16,
171 flex32,
172 tf32,
173 f32,
174 f64,
175 i8,
176 i16,
177 i32,
178 i64,
179 u8,
180 u16,
181 u32,
182 u64,
183 usize,
184 isize
185);
186impl_unary_func!(
187 Exp,
188 exp,
189 Arithmetic::Exp,
190 f16,
191 bf16,
192 flex32,
193 tf32,
194 f64
196);
197impl_unary_func!(Log, ln, Arithmetic::Log, f16, bf16, flex32, tf32, f32, f64);
198impl_unary_func!(
199 Log1p,
200 log1p,
201 Arithmetic::Log1p,
202 f16,
203 bf16,
204 flex32,
205 tf32,
206 f32,
207 f64
208);
209impl_unary_func!(Cos, cos, Arithmetic::Cos, f16, bf16, flex32, tf32, f32, f64);
210impl_unary_func!(Sin, sin, Arithmetic::Sin, f16, bf16, flex32, tf32, f32, f64);
211impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64);
212impl_unary_func!(
213 Tanh,
214 tanh,
215 Arithmetic::Tanh,
216 f16,
217 bf16,
218 flex32,
219 tf32,
220 f32,
221 f64
222);
223impl_unary_func!(
224 Sinh,
225 sinh,
226 Arithmetic::Sinh,
227 f16,
228 bf16,
229 flex32,
230 tf32,
231 f32,
232 f64
233);
234impl_unary_func!(
235 Cosh,
236 cosh,
237 Arithmetic::Cosh,
238 f16,
239 bf16,
240 flex32,
241 tf32,
242 f32,
243 f64
244);
245impl_unary_func!(
246 ArcCos,
247 acos,
248 Arithmetic::ArcCos,
249 f16,
250 bf16,
251 flex32,
252 tf32,
253 f32,
254 f64
255);
256impl_unary_func!(
257 ArcSin,
258 asin,
259 Arithmetic::ArcSin,
260 f16,
261 bf16,
262 flex32,
263 tf32,
264 f32,
265 f64
266);
267impl_unary_func!(
268 ArcTan,
269 atan,
270 Arithmetic::ArcTan,
271 f16,
272 bf16,
273 flex32,
274 tf32,
275 f32,
276 f64
277);
278impl_unary_func!(
279 ArcSinh,
280 asinh,
281 Arithmetic::ArcSinh,
282 f16,
283 bf16,
284 flex32,
285 tf32,
286 f32,
287 f64
288);
289impl_unary_func!(
290 ArcCosh,
291 acosh,
292 Arithmetic::ArcCosh,
293 f16,
294 bf16,
295 flex32,
296 tf32,
297 f32,
298 f64
299);
300impl_unary_func!(
301 ArcTanh,
302 atanh,
303 Arithmetic::ArcTanh,
304 f16,
305 bf16,
306 flex32,
307 tf32,
308 f32,
309 f64
310);
311impl_unary_func!(
312 Degrees,
313 to_degrees,
314 Arithmetic::Degrees,
315 f16,
316 bf16,
317 flex32,
318 tf32,
319 f32,
320 f64
321);
322impl_unary_func!(
323 Radians,
324 to_radians,
325 Arithmetic::Radians,
326 f16,
327 bf16,
328 flex32,
329 tf32,
330 f32,
331 f64
332);
333impl_unary_func!(
334 Sqrt,
335 sqrt,
336 Arithmetic::Sqrt,
337 f16,
338 bf16,
339 flex32,
340 tf32,
341 f32,
342 f64
343);
344impl_unary_func!(
345 InverseSqrt,
346 inverse_sqrt,
347 Arithmetic::InverseSqrt,
348 f16,
349 bf16,
350 flex32,
351 tf32,
352 f32,
353 f64
354);
355impl_unary_func!(
356 Round,
357 round,
358 Arithmetic::Round,
359 f16,
360 bf16,
361 flex32,
362 tf32,
363 f32,
364 f64
365);
366impl_unary_func!(
367 Floor,
368 floor,
369 Arithmetic::Floor,
370 f16,
371 bf16,
372 flex32,
373 tf32,
374 f32,
375 f64
376);
377impl_unary_func!(
378 Ceil,
379 ceil,
380 Arithmetic::Ceil,
381 f16,
382 bf16,
383 flex32,
384 tf32,
385 f32,
386 f64
387);
388impl_unary_func!(
389 Trunc,
390 trunc,
391 Arithmetic::Trunc,
392 f16,
393 bf16,
394 flex32,
395 tf32,
396 f32,
397 f64
398);
399impl_unary_func!(Erf, erf, Arithmetic::Erf, f16, bf16, flex32, tf32, f32, f64);
400impl_unary_func!(
401 Recip,
402 recip,
403 Arithmetic::Recip,
404 f16,
405 bf16,
406 flex32,
407 tf32,
408 f32,
409 f64
410);
411impl_unary_func_scalar_out!(
412 Magnitude,
413 magnitude,
414 Arithmetic::Magnitude,
415 f16,
416 bf16,
417 flex32,
418 tf32,
419 f32,
420 f64
421);
422impl_unary_func!(
423 Normalize,
424 normalize,
425 Arithmetic::Normalize,
426 f16,
427 bf16,
428 flex32,
429 tf32,
430 f32,
431 f64
432);
433impl_unary_func_fixed_out_ty!(
434 CountOnes,
435 count_ones,
436 u32,
437 Bitwise::CountOnes,
438 u8,
439 i8,
440 u16,
441 i16,
442 u32,
443 i32,
444 u64,
445 i64,
446 usize,
447 isize
448);
449impl_unary_func!(
450 ReverseBits,
451 reverse_bits,
452 Bitwise::ReverseBits,
453 u8,
454 i8,
455 u16,
456 i16,
457 u32,
458 i32,
459 u64,
460 i64,
461 usize,
462 isize
463);
464
465impl_unary_func_fixed_out_ty!(
466 LeadingZeros,
467 leading_zeros,
468 u32,
469 Bitwise::LeadingZeros,
470 u8,
471 i8,
472 u16,
473 i16,
474 u32,
475 i32,
476 u64,
477 i64,
478 usize,
479 isize
480);
481impl_unary_func_fixed_out_ty!(
482 TrailingZeros,
483 trailing_zeros,
484 u32,
485 Bitwise::TrailingZeros,
486 u8,
487 i8,
488 u16,
489 i16,
490 u32,
491 i32,
492 u64,
493 i64,
494 usize,
495 isize
496);
497impl_unary_func_fixed_out_ty!(
498 FindFirstSet,
499 find_first_set,
500 u32,
501 Bitwise::FindFirstSet,
502 u8,
503 i8,
504 u16,
505 i16,
506 u32,
507 i32,
508 u64,
509 i64,
510 usize,
511 isize
512);
513impl_unary_func_fixed_out_ty!(
514 IsNan,
515 is_nan,
516 bool,
517 Comparison::IsNan,
518 f16,
519 bf16,
520 flex32,
521 tf32,
522 f32,
523 f64
524);
525impl_unary_func_fixed_out_ty!(
526 IsInf,
527 is_inf,
528 bool,
529 Comparison::IsInf,
530 f16,
531 bf16,
532 flex32,
533 tf32,
534 f32,
535 f64
536);
537
538pub trait FloatBits:
539 CubePrimitive + CubeType<ExpandType: FloatBitsExpand<Bits = Self::Bits>>
540{
541 type Bits: CubePrimitive;
542
543 fn __expand_from_bits(scope: &mut Scope, bits: NativeExpand<Self::Bits>) -> NativeExpand<Self> {
544 Self::__expand_reinterpret(scope, bits)
545 }
546
547 fn __expand_to_bits(scope: &mut Scope, this: NativeExpand<Self>) -> NativeExpand<Self::Bits> {
548 <Self::Bits as Reinterpret>::__expand_reinterpret(scope, this)
549 }
550}
551
552pub trait FloatBitsExpand: Sized {
553 type Bits: CubePrimitive;
554
555 fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits>;
556}
557
558impl<F: FloatBits> FloatBitsExpand for NativeExpand<F> {
559 type Bits = F::Bits;
560
561 fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits> {
562 <Self::Bits as Reinterpret>::__expand_reinterpret(scope, self)
563 }
564}
565
566impl FloatBits for e2m1x2 {
567 type Bits = u8;
568}
569
570impl FloatBits for e5m2 {
571 type Bits = u8;
572}
573
574impl FloatBits for e4m3 {
575 type Bits = u8;
576}
577
578impl FloatBits for f16 {
579 type Bits = u16;
580}
581
582impl FloatBits for bf16 {
583 type Bits = u16;
584}
585
586impl FloatBits for f32 {
587 type Bits = u32;
588}
589
590impl FloatBits for f64 {
591 type Bits = u64;
592}