Skip to main content

cubecl_core/frontend/element/
numeric.rs

1use cubecl_ir::{ConstantValue, ManagedVariable};
2use cubecl_runtime::runtime::Runtime;
3use num_traits::{NumCast, One, Zero};
4
5use crate::compute::KernelLauncher;
6use crate::{IntoRuntime, ScalarArgType, compute::KernelBuilder};
7use crate::{
8    frontend::{Abs, Remainder},
9    unexpanded,
10};
11use crate::{
12    frontend::{CubePrimitive, CubeType},
13    prelude::InputScalar,
14};
15use crate::{
16    ir::{Scope, Variable},
17    prelude::Scalar,
18};
19
20use super::{LaunchArg, NativeAssign, NativeExpand};
21
22/// Type that encompasses both (unsigned or signed) integers and floats
23/// Used in kernels that should work for both.
24pub trait Numeric:
25    Copy
26    + Abs
27    + Remainder
28    + Scalar
29    + NativeAssign
30    + Into<NativeExpand<Self>>
31    + Into<ConstantValue>
32    + num_traits::NumCast
33    + num_traits::NumAssign
34    + core::cmp::PartialOrd
35    + core::cmp::PartialEq
36    + core::fmt::Debug
37    + bytemuck::Zeroable
38{
39    fn min_value() -> Self;
40    fn max_value() -> Self;
41
42    fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
43        let elem = Self::as_type(scope).elem_type();
44        let var = elem.min_variable();
45        let expand = ManagedVariable::Plain(var);
46        expand.into()
47    }
48
49    fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
50        let elem = Self::as_type(scope).elem_type();
51        let var = elem.max_variable();
52        let expand = ManagedVariable::Plain(var);
53        expand.into()
54    }
55
56    /// Create a new constant numeric.
57    ///
58    /// Note: since this must work for both integer and float
59    /// only the less expressive of both can be created (int)
60    /// If a number with decimals is needed, use `Float::new`.
61    ///
62    /// This method panics when unexpanded. For creating an element
63    /// with a val, use the new method of the sub type.
64    fn from_int(val: i64) -> Self {
65        <Self as NumCast>::from(val).unwrap()
66    }
67
68    /// Create a new constant numeric. Uses `i128` to be able to represent both signed integers, and
69    /// `u64::MAX`.
70    ///
71    /// Note: since this must work for both integer and float
72    /// only the less expressive of both can be created (int)
73    /// If a number with decimals is needed, use `Float::new`.
74    ///
75    /// This method panics when unexpanded. For creating an element
76    /// with a val, use the new method of the sub type.
77    fn from_int_128(val: i128) -> Self {
78        <Self as NumCast>::from(val).unwrap()
79    }
80
81    fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
82        unexpanded!()
83    }
84
85    fn __expand_from_int(
86        scope: &mut Scope,
87        val: NativeExpand<i64>,
88    ) -> <Self as CubeType>::ExpandType {
89        let elem = Self::as_type(scope).elem_type();
90        let var: Variable = elem.constant(val.constant().unwrap());
91
92        ManagedVariable::Plain(var).into()
93    }
94}
95
96/// Similar to [`ArgSettings`], however only for scalar types that don't depend on the [Runtime]
97/// trait.
98pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
99    /// Register the information to the [`KernelLauncher`].
100    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
101    fn expand_scalar(builder: &mut KernelBuilder) -> NativeExpand<Self> {
102        builder
103            .scalar(Self::as_type(&builder.scope).storage_type())
104            .into()
105    }
106}
107
108impl<E: ScalarArgType> ScalarArgSettings for E {
109    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
110        launcher.register_scalar(*self);
111    }
112}
113
114impl ScalarArgSettings for usize {
115    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
116        let value = InputScalar::new(*self, launcher.settings.address_type.unsigned_type());
117        InputScalar::register(value, launcher);
118    }
119}
120
121impl ScalarArgSettings for isize {
122    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
123        let value = InputScalar::new(*self, launcher.settings.address_type.signed_type());
124        InputScalar::register(value, launcher);
125    }
126}
127
128impl<T: ScalarArgSettings> LaunchArg for T {
129    type RuntimeArg<R: Runtime> = T;
130    type CompilationArg = ();
131
132    fn register<R: Runtime>(arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) {
133        arg.register(launcher);
134    }
135
136    fn expand(_: &(), builder: &mut KernelBuilder) -> NativeExpand<Self> {
137        T::expand_scalar(builder)
138    }
139}
140
141pub trait ZeroExpand: CubeType + Zero {
142    fn __expand_zero(scope: &mut Scope) -> Self::ExpandType;
143}
144
145pub trait OneExpand: CubeType + One {
146    fn __expand_one(scope: &mut Scope) -> Self::ExpandType;
147}
148
149impl<T: CubeType + Zero + IntoRuntime> ZeroExpand for T {
150    fn __expand_zero(scope: &mut Scope) -> Self::ExpandType {
151        T::zero().__expand_runtime_method(scope)
152    }
153}
154
155impl<T: CubeType + One + IntoRuntime> OneExpand for T {
156    fn __expand_one(scope: &mut Scope) -> Self::ExpandType {
157        T::one().__expand_runtime_method(scope)
158    }
159}