cubecl_core/frontend/element/
int.rs

1use cubecl_ir::ExpandElement;
2
3use crate::Runtime;
4use crate::frontend::{CubeType, Numeric};
5use crate::ir::{Elem, IntKind, Scope};
6use crate::prelude::BitwiseNot;
7use crate::prelude::{FindFirstSet, LeadingZeros};
8use crate::{
9    compute::{KernelBuilder, KernelLauncher},
10    prelude::{CountOnes, ReverseBits},
11};
12
13use super::{
14    __expand_new, CubePrimitive, ExpandElementBaseInit, ExpandElementTyped, Init, IntoRuntime,
15    LaunchArgExpand, ScalarArgSettings, init_expand_element,
16};
17
18mod typemap;
19pub use typemap::*;
20
21/// Signed or unsigned integer. Used as input in int kernels
22pub trait Int:
23    Numeric
24    + CountOnes
25    + ReverseBits
26    + BitwiseNot
27    + LeadingZeros
28    + FindFirstSet
29    + std::ops::Rem<Output = Self>
30    + core::ops::Add<Output = Self>
31    + core::ops::Sub<Output = Self>
32    + core::ops::Mul<Output = Self>
33    + core::ops::Div<Output = Self>
34    + core::ops::BitOr<Output = Self>
35    + core::ops::BitAnd<Output = Self>
36    + core::ops::BitXor<Output = Self>
37    + core::ops::Shl<Output = Self>
38    + core::ops::Shr<Output = Self>
39    + core::ops::Not<Output = Self>
40    + std::ops::RemAssign
41    + std::ops::AddAssign
42    + std::ops::SubAssign
43    + std::ops::MulAssign
44    + std::ops::DivAssign
45    + std::ops::BitOrAssign
46    + std::ops::BitAndAssign
47    + std::ops::BitXorAssign
48    + std::ops::ShlAssign<u32>
49    + std::ops::ShrAssign<u32>
50    + std::cmp::PartialOrd
51    + std::cmp::PartialEq
52{
53    const BITS: u32;
54
55    fn new(val: i64) -> Self;
56    fn __expand_new(scope: &mut Scope, val: i64) -> <Self as CubeType>::ExpandType {
57        __expand_new(scope, val)
58    }
59}
60
61macro_rules! impl_int {
62    ($type:ident, $kind:ident) => {
63        impl CubeType for $type {
64            type ExpandType = ExpandElementTyped<Self>;
65        }
66
67        impl CubePrimitive for $type {
68            fn as_elem_native() -> Option<Elem> {
69                Some(Elem::Int(IntKind::$kind))
70            }
71        }
72
73        impl IntoRuntime for $type {
74            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
75                let expand: ExpandElementTyped<Self> = self.into();
76                Init::init(expand, scope)
77            }
78        }
79
80        impl Numeric for $type {
81            fn min_value() -> Self {
82                $type::MIN
83            }
84            fn max_value() -> Self {
85                $type::MAX
86            }
87        }
88
89        impl ExpandElementBaseInit for $type {
90            fn init_elem(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
91                init_expand_element(scope, elem)
92            }
93        }
94
95        impl Int for $type {
96            const BITS: u32 = $type::BITS;
97
98            fn new(val: i64) -> Self {
99                val as $type
100            }
101        }
102
103        impl LaunchArgExpand for $type {
104            type CompilationArg = ();
105
106            fn expand(
107                _: &Self::CompilationArg,
108                builder: &mut KernelBuilder,
109            ) -> ExpandElementTyped<Self> {
110                builder.scalar($type::as_elem(&builder.context)).into()
111            }
112        }
113    };
114}
115
116impl_int!(i8, I8);
117impl_int!(i16, I16);
118impl_int!(i32, I32);
119impl_int!(i64, I64);
120
121impl ScalarArgSettings for i8 {
122    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
123        settings.register_i8(*self);
124    }
125}
126
127impl ScalarArgSettings for i16 {
128    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
129        settings.register_i16(*self);
130    }
131}
132
133impl ScalarArgSettings for i32 {
134    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
135        settings.register_i32(*self);
136    }
137}
138
139impl ScalarArgSettings for i64 {
140    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
141        settings.register_i64(*self);
142    }
143}