cubecl_core/frontend/element/
int.rs

1use cubecl_ir::{ConstantScalarValue, ExpandElement, StorageType};
2
3use crate::Runtime;
4use crate::frontend::{CubeType, Numeric};
5use crate::ir::{ElemType, IntKind, Scope};
6use crate::prelude::BitwiseNot;
7use crate::prelude::{FindFirstSet, LeadingZeros, SaturatingAdd, SaturatingSub};
8use crate::{
9    compute::KernelLauncher,
10    prelude::{CountOnes, ReverseBits},
11};
12
13use super::{
14    __expand_new, CubePrimitive, ExpandElementIntoMut, ExpandElementTyped, IntoMut, IntoRuntime,
15    ScalarArgSettings, into_mut_expand_element, into_runtime_expand_element,
16};
17
18mod typemap;
19pub use typemap::*;
20
21/// Signed or unsigned integer. Used as input in int kernels
22pub trait Int:
23    Numeric
24    + CountOnes
25    + ReverseBits
26    + BitwiseNot
27    + LeadingZeros
28    + FindFirstSet
29    + SaturatingAdd
30    + SaturatingSub
31    + std::ops::Rem<Output = Self>
32    + core::ops::Add<Output = Self>
33    + core::ops::Sub<Output = Self>
34    + core::ops::Mul<Output = Self>
35    + core::ops::Div<Output = Self>
36    + core::ops::BitOr<Output = Self>
37    + core::ops::BitAnd<Output = Self>
38    + core::ops::BitXor<Output = Self>
39    + core::ops::Shl<Output = Self>
40    + core::ops::Shr<Output = Self>
41    + core::ops::Not<Output = Self>
42    + std::ops::RemAssign
43    + std::ops::AddAssign
44    + std::ops::SubAssign
45    + std::ops::MulAssign
46    + std::ops::DivAssign
47    + std::ops::BitOrAssign
48    + std::ops::BitAndAssign
49    + std::ops::BitXorAssign
50    + std::ops::ShlAssign<u32>
51    + std::ops::ShrAssign<u32>
52    + std::cmp::PartialOrd
53    + std::cmp::PartialEq
54{
55    const BITS: u32;
56
57    fn new(val: i64) -> Self;
58    fn __expand_new(scope: &mut Scope, val: i64) -> <Self as CubeType>::ExpandType {
59        __expand_new(scope, val)
60    }
61}
62
63macro_rules! impl_int {
64    ($type:ident, $kind:ident) => {
65        impl CubeType for $type {
66            type ExpandType = ExpandElementTyped<Self>;
67        }
68
69        impl CubePrimitive for $type {
70            fn as_type_native() -> Option<StorageType> {
71                Some(ElemType::Int(IntKind::$kind).into())
72            }
73
74            fn from_const_value(value: ConstantScalarValue) -> Self {
75                let ConstantScalarValue::Int(value, _) = value else {
76                    unreachable!()
77                };
78                value as $type
79            }
80        }
81
82        impl IntoRuntime for $type {
83            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
84                let elem: ExpandElementTyped<Self> = self.into();
85                into_runtime_expand_element(scope, elem).into()
86            }
87        }
88
89        impl IntoMut for $type {
90            fn into_mut(self, _scope: &mut Scope) -> Self {
91                self
92            }
93        }
94
95        impl Numeric for $type {
96            fn min_value() -> Self {
97                $type::MIN
98            }
99            fn max_value() -> Self {
100                $type::MAX
101            }
102        }
103
104        impl ExpandElementIntoMut for $type {
105            fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
106                into_mut_expand_element(scope, elem)
107            }
108        }
109
110        impl Int for $type {
111            const BITS: u32 = $type::BITS;
112
113            fn new(val: i64) -> Self {
114                val as $type
115            }
116        }
117    };
118}
119
120impl_int!(i8, I8);
121impl_int!(i16, I16);
122impl_int!(i32, I32);
123impl_int!(i64, I64);
124
125impl ScalarArgSettings for i8 {
126    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
127        settings.register_i8(*self);
128    }
129}
130
131impl ScalarArgSettings for i16 {
132    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
133        settings.register_i16(*self);
134    }
135}
136
137impl ScalarArgSettings for i32 {
138    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
139        settings.register_i32(*self);
140    }
141}
142
143impl ScalarArgSettings for i64 {
144    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
145        settings.register_i64(*self);
146    }
147}