cubecl_core/frontend/element/
numeric.rs1use cubecl_ir::{ConstantValue, ManagedVariable};
2use cubecl_runtime::runtime::Runtime;
3use num_traits::{NumCast, One, Zero};
4
5use crate::compute::KernelLauncher;
6use crate::{IntoRuntime, ScalarArgType, compute::KernelBuilder};
7use crate::{
8 frontend::{Abs, Remainder, VectorSum},
9 unexpanded,
10};
11use crate::{
12 frontend::{CubePrimitive, CubeType},
13 prelude::InputScalar,
14};
15use crate::{
16 ir::{Scope, Variable},
17 prelude::Scalar,
18};
19
20use super::{LaunchArg, NativeAssign, NativeExpand};
21
22pub trait Numeric:
25 Copy
26 + Abs
27 + VectorSum
28 + Remainder
29 + Scalar
30 + NativeAssign
31 + Into<NativeExpand<Self>>
32 + Into<ConstantValue>
33 + num_traits::NumCast
34 + num_traits::NumAssign
35 + core::cmp::PartialOrd
36 + core::cmp::PartialEq
37 + core::fmt::Debug
38 + bytemuck::Zeroable
39{
40 fn min_value() -> Self;
41 fn max_value() -> Self;
42
43 fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
44 let elem = Self::as_type(scope).elem_type();
45 let var = elem.min_variable();
46 let expand = ManagedVariable::Plain(var);
47 expand.into()
48 }
49
50 fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
51 let elem = Self::as_type(scope).elem_type();
52 let var = elem.max_variable();
53 let expand = ManagedVariable::Plain(var);
54 expand.into()
55 }
56
57 fn from_int(val: i64) -> Self {
66 <Self as NumCast>::from(val).unwrap()
67 }
68
69 fn from_int_128(val: i128) -> Self {
79 <Self as NumCast>::from(val).unwrap()
80 }
81
82 fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
83 unexpanded!()
84 }
85
86 fn __expand_from_int(
87 scope: &mut Scope,
88 val: NativeExpand<i64>,
89 ) -> <Self as CubeType>::ExpandType {
90 let elem = Self::as_type(scope).elem_type();
91 let var: Variable = elem.constant(val.constant().unwrap());
92
93 ManagedVariable::Plain(var).into()
94 }
95}
96
97pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
100 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
102 fn expand_scalar(builder: &mut KernelBuilder) -> NativeExpand<Self> {
103 builder
104 .scalar(Self::as_type(&builder.scope).storage_type())
105 .into()
106 }
107}
108
109impl<E: ScalarArgType> ScalarArgSettings for E {
110 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
111 launcher.register_scalar(*self);
112 }
113}
114
115impl ScalarArgSettings for usize {
116 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
117 let value = InputScalar::new(*self, launcher.settings.address_type.unsigned_type());
118 InputScalar::register(value, launcher);
119 }
120}
121
122impl ScalarArgSettings for isize {
123 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
124 let value = InputScalar::new(*self, launcher.settings.address_type.signed_type());
125 InputScalar::register(value, launcher);
126 }
127}
128
129impl<T: ScalarArgSettings> LaunchArg for T {
130 type RuntimeArg<R: Runtime> = T;
131 type CompilationArg = ();
132
133 fn register<R: Runtime>(arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) {
134 arg.register(launcher);
135 }
136
137 fn expand(_: &(), builder: &mut KernelBuilder) -> NativeExpand<Self> {
138 T::expand_scalar(builder)
139 }
140}
141
142pub trait ZeroExpand: CubeType + Zero {
143 fn __expand_zero(scope: &mut Scope) -> Self::ExpandType;
144}
145
146pub trait OneExpand: CubeType + One {
147 fn __expand_one(scope: &mut Scope) -> Self::ExpandType;
148}
149
150impl<T: CubeType + Zero + IntoRuntime> ZeroExpand for T {
151 fn __expand_zero(scope: &mut Scope) -> Self::ExpandType {
152 T::zero().__expand_runtime_method(scope)
153 }
154}
155
156impl<T: CubeType + One + IntoRuntime> OneExpand for T {
157 fn __expand_one(scope: &mut Scope) -> Self::ExpandType {
158 T::one().__expand_runtime_method(scope)
159 }
160}