cubecl_core/frontend/element/
numeric.rs1use cubecl_ir::{ConstantValue, ManagedVariable};
2use cubecl_runtime::runtime::Runtime;
3use num_traits::{NumCast, One, Zero};
4
5use crate::compute::KernelLauncher;
6use crate::{IntoRuntime, ScalarArgType, compute::KernelBuilder};
7use crate::{
8 frontend::{Abs, Remainder},
9 unexpanded,
10};
11use crate::{
12 frontend::{CubePrimitive, CubeType},
13 prelude::InputScalar,
14};
15use crate::{
16 ir::{Scope, Variable},
17 prelude::Scalar,
18};
19
20use super::{LaunchArg, NativeAssign, NativeExpand};
21
22pub trait Numeric:
25 Copy
26 + Abs
27 + Remainder
28 + Scalar
29 + NativeAssign
30 + Into<NativeExpand<Self>>
31 + Into<ConstantValue>
32 + num_traits::NumCast
33 + num_traits::NumAssign
34 + core::cmp::PartialOrd
35 + core::cmp::PartialEq
36 + core::fmt::Debug
37 + bytemuck::Zeroable
38{
39 fn min_value() -> Self;
40 fn max_value() -> Self;
41
42 fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
43 let elem = Self::as_type(scope).elem_type();
44 let var = elem.min_variable();
45 let expand = ManagedVariable::Plain(var);
46 expand.into()
47 }
48
49 fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
50 let elem = Self::as_type(scope).elem_type();
51 let var = elem.max_variable();
52 let expand = ManagedVariable::Plain(var);
53 expand.into()
54 }
55
56 fn from_int(val: i64) -> Self {
65 <Self as NumCast>::from(val).unwrap()
66 }
67
68 fn from_int_128(val: i128) -> Self {
78 <Self as NumCast>::from(val).unwrap()
79 }
80
81 fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
82 unexpanded!()
83 }
84
85 fn __expand_from_int(
86 scope: &mut Scope,
87 val: NativeExpand<i64>,
88 ) -> <Self as CubeType>::ExpandType {
89 let elem = Self::as_type(scope).elem_type();
90 let var: Variable = elem.constant(val.constant().unwrap());
91
92 ManagedVariable::Plain(var).into()
93 }
94}
95
96pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
99 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
101 fn expand_scalar(builder: &mut KernelBuilder) -> NativeExpand<Self> {
102 builder
103 .scalar(Self::as_type(&builder.scope).storage_type())
104 .into()
105 }
106}
107
108impl<E: ScalarArgType> ScalarArgSettings for E {
109 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
110 launcher.register_scalar(*self);
111 }
112}
113
114impl ScalarArgSettings for usize {
115 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
116 let value = InputScalar::new(*self, launcher.settings.address_type.unsigned_type());
117 InputScalar::register(value, launcher);
118 }
119}
120
121impl ScalarArgSettings for isize {
122 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
123 let value = InputScalar::new(*self, launcher.settings.address_type.signed_type());
124 InputScalar::register(value, launcher);
125 }
126}
127
128impl<T: ScalarArgSettings> LaunchArg for T {
129 type RuntimeArg<R: Runtime> = T;
130 type CompilationArg = ();
131
132 fn register<R: Runtime>(arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) {
133 arg.register(launcher);
134 }
135
136 fn expand(_: &(), builder: &mut KernelBuilder) -> NativeExpand<Self> {
137 T::expand_scalar(builder)
138 }
139}
140
141pub trait ZeroExpand: CubeType + Zero {
142 fn __expand_zero(scope: &mut Scope) -> Self::ExpandType;
143}
144
145pub trait OneExpand: CubeType + One {
146 fn __expand_one(scope: &mut Scope) -> Self::ExpandType;
147}
148
149impl<T: CubeType + Zero + IntoRuntime> ZeroExpand for T {
150 fn __expand_zero(scope: &mut Scope) -> Self::ExpandType {
151 T::zero().__expand_runtime_method(scope)
152 }
153}
154
155impl<T: CubeType + One + IntoRuntime> OneExpand for T {
156 fn __expand_one(scope: &mut Scope) -> Self::ExpandType {
157 T::one().__expand_runtime_method(scope)
158 }
159}