cubecl_core/frontend/element/
uint.rs

1use crate::ir::Elem;
2use crate::prelude::{KernelBuilder, KernelLauncher};
3use crate::Runtime;
4use crate::{
5    frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric},
6    ir::UIntKind,
7};
8
9use super::{
10    init_expand_element, ExpandElementBaseInit, ExpandElementTyped, Init, Int, IntoRuntime,
11    LaunchArgExpand, ScalarArgSettings,
12};
13
14macro_rules! declare_uint {
15    ($primitive:ident, $kind:ident) => {
16        impl CubeType for $primitive {
17            type ExpandType = ExpandElementTyped<Self>;
18        }
19
20        impl ExpandElementBaseInit for $primitive {
21            fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement {
22                init_expand_element(context, elem)
23            }
24        }
25
26        impl CubePrimitive for $primitive {
27            fn as_elem_native() -> Option<Elem> {
28                Some(Elem::UInt(UIntKind::$kind))
29            }
30        }
31
32        impl IntoRuntime for $primitive {
33            fn __expand_runtime_method(
34                self,
35                context: &mut CubeContext,
36            ) -> ExpandElementTyped<Self> {
37                let expand: ExpandElementTyped<Self> = self.into();
38                Init::init(expand, context)
39            }
40        }
41
42        impl LaunchArgExpand for $primitive {
43            type CompilationArg = ();
44
45            fn expand(
46                _: &Self::CompilationArg,
47                builder: &mut KernelBuilder,
48            ) -> ExpandElementTyped<Self> {
49                builder.scalar($primitive::as_elem(&builder.context)).into()
50            }
51        }
52
53        impl Numeric for $primitive {
54            fn min_value() -> Self {
55                $primitive::MIN
56            }
57            fn max_value() -> Self {
58                $primitive::MAX
59            }
60        }
61
62        impl Int for $primitive {
63            const BITS: u32 = $primitive::BITS;
64
65            fn new(val: i64) -> Self {
66                val as $primitive
67            }
68        }
69    };
70}
71
72declare_uint!(u8, U8);
73declare_uint!(u16, U16);
74declare_uint!(u32, U32);
75declare_uint!(u64, U64);
76
77impl ScalarArgSettings for u8 {
78    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
79        settings.register_u8(*self);
80    }
81}
82
83impl ScalarArgSettings for u16 {
84    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
85        settings.register_u16(*self);
86    }
87}
88
89impl ScalarArgSettings for u32 {
90    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
91        settings.register_u32(*self);
92    }
93}
94
95impl ScalarArgSettings for u64 {
96    fn register<R: Runtime>(&self, settings: &mut KernelLauncher<R>) {
97        settings.register_u64(*self);
98    }
99}