cubecl_core/frontend/element/
int.rs

1use crate::frontend::{
2    CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, ExpandElementTyped,
3    Numeric,
4};
5use crate::ir::{Elem, IntKind};
6use crate::Runtime;
7use crate::{
8    compute::{KernelBuilder, KernelLauncher},
9    prelude::{CountOnes, ReverseBits},
10};
11
12use super::{
13    init_expand_element, Init, IntoRuntime, LaunchArgExpand, ScalarArgSettings, __expand_new,
14};
15
16/// Signed or unsigned integer. Used as input in int kernels
17pub trait Int:
18    Numeric
19    + CountOnes
20    + ReverseBits
21    + std::ops::Rem<Output = Self>
22    + core::ops::Add<Output = Self>
23    + core::ops::Sub<Output = Self>
24    + core::ops::Mul<Output = Self>
25    + core::ops::Div<Output = Self>
26    + core::ops::BitOr<Output = Self>
27    + core::ops::BitAnd<Output = Self>
28    + core::ops::BitXor<Output = Self>
29    + core::ops::Shl<Output = Self>
30    + core::ops::Shr<Output = Self>
31    + core::ops::Not<Output = Self>
32    + std::ops::RemAssign
33    + std::ops::AddAssign
34    + std::ops::SubAssign
35    + std::ops::MulAssign
36    + std::ops::DivAssign
37    + std::ops::BitOrAssign
38    + std::ops::BitAndAssign
39    + std::ops::BitXorAssign
40    + std::ops::ShlAssign<u32>
41    + std::ops::ShrAssign<u32>
42    + std::cmp::PartialOrd
43    + std::cmp::PartialEq
44{
45    const BITS: u32;
46
47    fn new(val: i64) -> Self;
48    fn __expand_new(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
49        __expand_new(context, val)
50    }
51}
52
53macro_rules! impl_int {
54    ($type:ident, $kind:ident) => {
55        impl CubeType for $type {
56            type ExpandType = ExpandElementTyped<Self>;
57        }
58
59        impl CubePrimitive for $type {
60            fn as_elem_native() -> Option<Elem> {
61                Some(Elem::Int(IntKind::$kind))
62            }
63        }
64
65        impl IntoRuntime for $type {
66            fn __expand_runtime_method(
67                self,
68                context: &mut CubeContext,
69            ) -> ExpandElementTyped<Self> {
70                let expand: ExpandElementTyped<Self> = self.into();
71                Init::init(expand, context)
72            }
73        }
74
75        impl Numeric for $type {
76            fn min_value() -> Self {
77                $type::MIN
78            }
79            fn max_value() -> Self {
80                $type::MAX
81            }
82        }
83
84        impl ExpandElementBaseInit for $type {
85            fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
86                init_expand_element(context, elem)
87            }
88        }
89
90        impl Int for $type {
91            const BITS: u32 = $type::BITS;
92
93            fn new(val: i64) -> Self {
94                val as $type
95            }
96        }
97
98        impl LaunchArgExpand for $type {
99            type CompilationArg = ();
100
101            fn expand(
102                _: &Self::CompilationArg,
103                builder: &mut KernelBuilder,
104            ) -> ExpandElementTyped<Self> {
105                builder.scalar($type::as_elem(&builder.context)).into()
106            }
107        }
108    };
109}
110
111impl_int!(i8, I8);
112impl_int!(i16, I16);
113impl_int!(i32, I32);
114impl_int!(i64, I64);
115
116impl ScalarArgSettings for i8 {
117    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
118        settings.register_i8(*self);
119    }
120}
121
122impl ScalarArgSettings for i16 {
123    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
124        settings.register_i16(*self);
125    }
126}
127
128impl ScalarArgSettings for i32 {
129    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
130        settings.register_i32(*self);
131    }
132}
133
134impl ScalarArgSettings for i64 {
135    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
136        settings.register_i64(*self);
137    }
138}