cubecl_core/frontend/element/
uint.rs

1use cubecl_ir::{ExpandElement, Scope, UIntKind};
2
3use crate::Runtime;
4use crate::frontend::{CubePrimitive, CubeType, Numeric};
5use crate::ir::Elem;
6use crate::prelude::{KernelBuilder, KernelLauncher};
7
8use super::{
9    ExpandElementIntoMut, ExpandElementTyped, Int, IntoMut, IntoRuntime, LaunchArgExpand,
10    ScalarArgSettings, into_mut_expand_element, into_runtime_expand_element,
11};
12
13macro_rules! declare_uint {
14    ($primitive:ident, $kind:ident) => {
15        impl CubeType for $primitive {
16            type ExpandType = ExpandElementTyped<Self>;
17        }
18
19        impl CubePrimitive for $primitive {
20            fn as_elem_native() -> Option<Elem> {
21                Some(Elem::UInt(UIntKind::$kind))
22            }
23        }
24
25        impl IntoRuntime for $primitive {
26            fn __expand_runtime_method(self, scope: &mut Scope) -> ExpandElementTyped<Self> {
27                let elem: ExpandElementTyped<Self> = self.into();
28                into_runtime_expand_element(scope, elem).into()
29            }
30        }
31
32        impl IntoMut for $primitive {
33            fn into_mut(self, _scope: &mut Scope) -> Self {
34                self
35            }
36        }
37
38        impl ExpandElementIntoMut for $primitive {
39            fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement {
40                into_mut_expand_element(scope, elem)
41            }
42        }
43
44        impl LaunchArgExpand for $primitive {
45            type CompilationArg = ();
46
47            fn expand(
48                _: &Self::CompilationArg,
49                builder: &mut KernelBuilder,
50            ) -> ExpandElementTyped<Self> {
51                builder.scalar($primitive::as_elem(&builder.scope)).into()
52            }
53        }
54
55        impl Numeric for $primitive {
56            fn min_value() -> Self {
57                $primitive::MIN
58            }
59            fn max_value() -> Self {
60                $primitive::MAX
61            }
62        }
63
64        impl Int for $primitive {
65            const BITS: u32 = $primitive::BITS;
66
67            fn new(val: i64) -> Self {
68                val as $primitive
69            }
70        }
71    };
72}
73
74declare_uint!(u8, U8);
75declare_uint!(u16, U16);
76declare_uint!(u32, U32);
77declare_uint!(u64, U64);
78
79impl ScalarArgSettings for u8 {
80    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
81        settings.register_u8(*self);
82    }
83}
84
85impl ScalarArgSettings for u16 {
86    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
87        settings.register_u16(*self);
88    }
89}
90
91impl ScalarArgSettings for u32 {
92    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
93        settings.register_u32(*self);
94    }
95}
96
97impl ScalarArgSettings for u64 {
98    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
99        settings.register_u64(*self);
100    }
101}