cubecl_core/frontend/element/
int.rs1use crate::frontend::{
2 CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, ExpandElementTyped,
3 Numeric,
4};
5use crate::ir::{Elem, IntKind};
6use crate::Runtime;
7use crate::{
8 compute::{KernelBuilder, KernelLauncher},
9 prelude::{CountOnes, ReverseBits},
10};
11
12use super::{
13 init_expand_element, Init, IntoRuntime, LaunchArgExpand, ScalarArgSettings, __expand_new,
14};
15
16pub trait Int:
18 Numeric
19 + CountOnes
20 + ReverseBits
21 + std::ops::Rem<Output = Self>
22 + core::ops::Add<Output = Self>
23 + core::ops::Sub<Output = Self>
24 + core::ops::Mul<Output = Self>
25 + core::ops::Div<Output = Self>
26 + core::ops::BitOr<Output = Self>
27 + core::ops::BitAnd<Output = Self>
28 + core::ops::BitXor<Output = Self>
29 + core::ops::Shl<Output = Self>
30 + core::ops::Shr<Output = Self>
31 + core::ops::Not<Output = Self>
32 + std::ops::RemAssign
33 + std::ops::AddAssign
34 + std::ops::SubAssign
35 + std::ops::MulAssign
36 + std::ops::DivAssign
37 + std::ops::BitOrAssign
38 + std::ops::BitAndAssign
39 + std::ops::BitXorAssign
40 + std::ops::ShlAssign<u32>
41 + std::ops::ShrAssign<u32>
42 + std::cmp::PartialOrd
43 + std::cmp::PartialEq
44{
45 const BITS: u32;
46
47 fn new(val: i64) -> Self;
48 fn __expand_new(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
49 __expand_new(context, val)
50 }
51}
52
53macro_rules! impl_int {
54 ($type:ident, $kind:ident) => {
55 impl CubeType for $type {
56 type ExpandType = ExpandElementTyped<Self>;
57 }
58
59 impl CubePrimitive for $type {
60 fn as_elem_native() -> Option<Elem> {
61 Some(Elem::Int(IntKind::$kind))
62 }
63 }
64
65 impl IntoRuntime for $type {
66 fn __expand_runtime_method(
67 self,
68 context: &mut CubeContext,
69 ) -> ExpandElementTyped<Self> {
70 let expand: ExpandElementTyped<Self> = self.into();
71 Init::init(expand, context)
72 }
73 }
74
75 impl Numeric for $type {
76 fn min_value() -> Self {
77 $type::MIN
78 }
79 fn max_value() -> Self {
80 $type::MAX
81 }
82 }
83
84 impl ExpandElementBaseInit for $type {
85 fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
86 init_expand_element(context, elem)
87 }
88 }
89
90 impl Int for $type {
91 const BITS: u32 = $type::BITS;
92
93 fn new(val: i64) -> Self {
94 val as $type
95 }
96 }
97
98 impl LaunchArgExpand for $type {
99 type CompilationArg = ();
100
101 fn expand(
102 _: &Self::CompilationArg,
103 builder: &mut KernelBuilder,
104 ) -> ExpandElementTyped<Self> {
105 builder.scalar($type::as_elem(&builder.context)).into()
106 }
107 }
108 };
109}
110
111impl_int!(i8, I8);
112impl_int!(i16, I16);
113impl_int!(i32, I32);
114impl_int!(i64, I64);
115
116impl ScalarArgSettings for i8 {
117 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
118 settings.register_i8(*self);
119 }
120}
121
122impl ScalarArgSettings for i16 {
123 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
124 settings.register_i16(*self);
125 }
126}
127
128impl ScalarArgSettings for i32 {
129 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
130 settings.register_i32(*self);
131 }
132}
133
134impl ScalarArgSettings for i64 {
135 fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
136 settings.register_i64(*self);
137 }
138}