cubecl_core/frontend/element/
numeric.rs1use core::marker::PhantomData;
2
3use cubecl_ir::{ConstantValue, ExpandElement};
4use cubecl_runtime::runtime::Runtime;
5use num_traits::NumCast;
6
7use crate::ir::{Scope, Variable};
8use crate::{CubeScalar, compute::KernelBuilder};
9use crate::{compute::KernelLauncher, prelude::CompilationArg};
10use crate::{
11 frontend::{Abs, Remainder},
12 unexpanded,
13};
14use crate::{
15 frontend::{CubePrimitive, CubeType},
16 prelude::InputScalar,
17};
18
19use super::{ArgSettings, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, LaunchArg};
20
21pub trait Numeric:
24 Copy
25 + Abs
26 + Remainder
27 + CubePrimitive
28 + IntoRuntime
29 + ExpandElementIntoMut
30 + Into<ExpandElementTyped<Self>>
31 + Into<ConstantValue>
32 + num_traits::NumCast
33 + num_traits::NumAssign
34 + core::cmp::PartialOrd
35 + core::cmp::PartialEq
36 + core::fmt::Debug
37 + Default
38{
39 fn min_value() -> Self;
40 fn max_value() -> Self;
41
42 fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
43 let elem = Self::as_type(scope).elem_type();
44 let var = elem.min_variable();
45 let expand = ExpandElement::Plain(var);
46 expand.into()
47 }
48
49 fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
50 let elem = Self::as_type(scope).elem_type();
51 let var = elem.max_variable();
52 let expand = ExpandElement::Plain(var);
53 expand.into()
54 }
55
56 fn from_int(val: i64) -> Self {
65 <Self as NumCast>::from(val).unwrap()
66 }
67
68 fn from_int_128(val: i128) -> Self {
78 <Self as NumCast>::from(val).unwrap()
79 }
80
81 fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
82 unexpanded!()
83 }
84
85 fn __expand_from_int(
86 scope: &mut Scope,
87 val: ExpandElementTyped<i64>,
88 ) -> <Self as CubeType>::ExpandType {
89 let elem = Self::as_type(scope).elem_type();
90 let var: Variable = elem.constant(val.constant().unwrap());
91
92 ExpandElement::Plain(var).into()
93 }
94}
95
96pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
99 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
101 fn expand_scalar(
102 _: &ScalarCompilationArg<Self>,
103 builder: &mut KernelBuilder,
104 ) -> ExpandElementTyped<Self> {
105 builder.scalar(Self::as_type(&builder.scope)).into()
106 }
107}
108
109impl<E: CubeScalar> ScalarArgSettings for E {
110 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
111 launcher.register_scalar(*self);
112 }
113}
114
115impl ScalarArgSettings for usize {
116 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
117 InputScalar::new(*self, launcher.settings.address_type.unsigned_type()).register(launcher);
118 }
119}
120
121impl ScalarArgSettings for isize {
122 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
123 InputScalar::new(*self, launcher.settings.address_type.signed_type()).register(launcher);
124 }
125}
126
127#[derive(new, Clone, Copy)]
128pub struct ScalarArg<T: ScalarArgSettings> {
129 pub elem: T,
130}
131
132#[derive(new, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
133pub struct ScalarCompilationArg<T: ScalarArgSettings> {
134 _ty: PhantomData<T>,
135}
136
137impl<T: ScalarArgSettings> Eq for ScalarCompilationArg<T> {}
138impl<T: ScalarArgSettings> core::hash::Hash for ScalarCompilationArg<T> {
139 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
140 self._ty.hash(state);
141 }
142}
143impl<T: ScalarArgSettings> core::fmt::Debug for ScalarCompilationArg<T> {
144 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145 f.write_str("Scalar")
146 }
147}
148
149impl<T: ScalarArgSettings> CompilationArg for ScalarCompilationArg<T> {}
150
151impl<T: ScalarArgSettings, R: Runtime> ArgSettings<R> for ScalarArg<T> {
152 fn register(&self, launcher: &mut KernelLauncher<R>) {
153 self.elem.register(launcher);
154 }
155}
156
157impl<T: ScalarArgSettings> LaunchArg for T {
158 type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
159 type CompilationArg = ScalarCompilationArg<T>;
160
161 fn compilation_arg<'a, R: Runtime>(
162 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
163 ) -> Self::CompilationArg {
164 ScalarCompilationArg::new()
165 }
166 fn expand(
167 arg: &ScalarCompilationArg<T>,
168 builder: &mut KernelBuilder,
169 ) -> ExpandElementTyped<Self> {
170 T::expand_scalar(arg, builder)
171 }
172}