cubecl_core/frontend/element/
int.rs

1use cubecl_ir::ExpandElement;
2
3use crate::Runtime;
4use crate::frontend::{CubeType, Numeric};
5use crate::ir::{Elem, IntKind, Scope};
6use crate::prelude::BitwiseNot;
7use crate::prelude::{FindFirstSet, LeadingZeros};
8use crate::{
9    compute::{KernelBuilder, KernelLauncher},
10    prelude::{CountOnes, ReverseBits},
11};
12
13use super::{
14    __expand_new, CubePrimitive, ExpandElementIntoMut, ExpandElementTyped, IntoMut, IntoRuntime,
15    LaunchArgExpand, ScalarArgSettings, into_mut_expand_element, into_runtime_expand_element,
16};
17
18mod typemap;
19pub use typemap::*;
20
21/// Signed or unsigned integer. Used as input in int kernels
22pub trait Int:
23    Numeric
24    + CountOnes
25    + ReverseBits
26    + BitwiseNot
27    + LeadingZeros
28    + FindFirstSet
29    + std::ops::Rem<Output = Self>
30    + core::ops::Add<Output = Self>
31    + core::ops::Sub<Output = Self>
32    + core::ops::Mul<Output = Self>
33    + core::ops::Div<Output = Self>
34    + core::ops::BitOr<Output = Self>
35    + core::ops::BitAnd<Output = Self>
36    + core::ops::BitXor<Output = Self>
37    + core::ops::Shl<Output = Self>
38    + core::ops::Shr<Output = Self>
39    + core::ops::Not<Output = Self>
40    + std::ops::RemAssign
41    + std::ops::AddAssign
42    + std::ops::SubAssign
43    + std::ops::MulAssign
44    + std::ops::DivAssign
45    + std::ops::BitOrAssign
46    + std::ops::BitAndAssign
47    + std::ops::BitXorAssign
48    + std::ops::ShlAssign<u32>
49    + std::ops::ShrAssign<u32>
50    + std::cmp::PartialOrd
51    + std::cmp::PartialEq
52{
53    const BITS: u32;
54
55    fn new(val: i64) -> Self;
56    fn __expand_new(scope: &mut Scope, val: i64) -> <Self as CubeType>::ExpandType {
57        __expand_new(scope, val)
58    }
59}
60
61macro_rules! impl_int {
62    ($type:ident, $kind:ident) => {
63        impl CubeType for $type {
64            type ExpandType = ExpandElementTyped<Self>;
65        }
66
67        impl CubePrimitive for $type {
68            fn as_elem_native() -> Option<Elem> {
69                Some(Elem::Int(IntKind::$kind))
70            }
71        }
72
73        impl IntoRuntime for $type {
74            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
75                let elem: ExpandElementTyped<Self> = self.into();
76                into_runtime_expand_element(scope, elem).into()
77            }
78        }
79
80        impl IntoMut for $type {
81            fn into_mut(self, _scope: &mut Scope) -> Self {
82                self
83            }
84        }
85
86        impl Numeric for $type {
87            fn min_value() -> Self {
88                $type::MIN
89            }
90            fn max_value() -> Self {
91                $type::MAX
92            }
93        }
94
95        impl ExpandElementIntoMut for $type {
96            fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
97                into_mut_expand_element(scope, elem)
98            }
99        }
100
101        impl Int for $type {
102            const BITS: u32 = $type::BITS;
103
104            fn new(val: i64) -> Self {
105                val as $type
106            }
107        }
108
109        impl LaunchArgExpand for $type {
110            type CompilationArg = ();
111
112            fn expand(
113                _: &Self::CompilationArg,
114                builder: &mut KernelBuilder,
115            ) -> ExpandElementTyped<Self> {
116                builder.scalar($type::as_elem(&builder.scope)).into()
117            }
118        }
119    };
120}
121
122impl_int!(i8, I8);
123impl_int!(i16, I16);
124impl_int!(i32, I32);
125impl_int!(i64, I64);
126
127impl ScalarArgSettings for i8 {
128    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
129        settings.register_i8(*self);
130    }
131}
132
133impl ScalarArgSettings for i16 {
134    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
135        settings.register_i16(*self);
136    }
137}
138
139impl ScalarArgSettings for i32 {
140    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
141        settings.register_i32(*self);
142    }
143}
144
145impl ScalarArgSettings for i64 {
146    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
147        settings.register_i64(*self);
148    }
149}