cubecl_core/frontend/element/
uint.rs

1use cubecl_ir::{ConstantScalarValue, ExpandElement, Scope, StorageType, UIntKind};
2
3use crate::Runtime;
4use crate::frontend::{CubePrimitive, CubeType, Numeric};
5use crate::ir::ElemType;
6use crate::prelude::KernelLauncher;
7
8use super::{
9    ExpandElementIntoMut, ExpandElementTyped, Int, IntoMut, IntoRuntime, ScalarArgSettings,
10    into_mut_expand_element, into_runtime_expand_element,
11};
12
13macro_rules! declare_uint {
14    ($primitive:ident, $kind:ident) => {
15        impl CubeType for $primitive {
16            type ExpandType = ExpandElementTyped<Self>;
17        }
18
19        impl CubePrimitive for $primitive {
20            fn as_type_native() -> Option<StorageType> {
21                Some(ElemType::UInt(UIntKind::$kind).into())
22            }
23
24            fn from_const_value(value: ConstantScalarValue) -> Self {
25                let ConstantScalarValue::UInt(value, _) = value else {
26                    unreachable!()
27                };
28                value as $primitive
29            }
30        }
31
32        impl IntoRuntime for $primitive {
33            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
34                let elem: ExpandElementTyped<Self> = self.into();
35                into_runtime_expand_element(scope, elem).into()
36            }
37        }
38
39        impl IntoMut for $primitive {
40            fn into_mut(self, _scope: &mut Scope) -> Self {
41                self
42            }
43        }
44
45        impl ExpandElementIntoMut for $primitive {
46            fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
47                into_mut_expand_element(scope, elem)
48            }
49        }
50
51        impl Numeric for $primitive {
52            fn min_value() -> Self {
53                $primitive::MIN
54            }
55            fn max_value() -> Self {
56                $primitive::MAX
57            }
58        }
59
60        impl Int for $primitive {
61            const BITS: u32 = $primitive::BITS;
62
63            fn new(val: i64) -> Self {
64                val as $primitive
65            }
66        }
67    };
68}
69
70declare_uint!(u8, U8);
71declare_uint!(u16, U16);
72declare_uint!(u32, U32);
73declare_uint!(u64, U64);
74
75impl ScalarArgSettings for u8 {
76    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
77        settings.register_u8(*self);
78    }
79}
80
81impl ScalarArgSettings for u16 {
82    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
83        settings.register_u16(*self);
84    }
85}
86
87impl ScalarArgSettings for u32 {
88    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
89        settings.register_u32(*self);
90    }
91}
92
93impl ScalarArgSettings for u64 {
94    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
95        settings.register_u64(*self);
96    }
97}