Skip to main content

cubecl_core/frontend/element/
numeric.rs

1use cubecl_ir::{ConstantValue, ManagedVariable};
2use cubecl_runtime::runtime::Runtime;
3use num_traits::{NumCast, One, Zero};
4
5use crate::compute::KernelLauncher;
6use crate::{IntoRuntime, ScalarArgType, compute::KernelBuilder};
7use crate::{
8    frontend::{Abs, Remainder, VectorSum},
9    unexpanded,
10};
11use crate::{
12    frontend::{CubePrimitive, CubeType},
13    prelude::InputScalar,
14};
15use crate::{
16    ir::{Scope, Variable},
17    prelude::Scalar,
18};
19
20use super::{LaunchArg, NativeAssign, NativeExpand};
21
22/// Type that encompasses both (unsigned or signed) integers and floats
23/// Used in kernels that should work for both.
24pub trait Numeric:
25    Copy
26    + Abs
27    + VectorSum
28    + Remainder
29    + Scalar
30    + NativeAssign
31    + Into<NativeExpand<Self>>
32    + Into<ConstantValue>
33    + num_traits::NumCast
34    + num_traits::NumAssign
35    + core::cmp::PartialOrd
36    + core::cmp::PartialEq
37    + core::fmt::Debug
38    + bytemuck::Zeroable
39{
40    fn min_value() -> Self;
41    fn max_value() -> Self;
42
43    fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
44        let elem = Self::as_type(scope).elem_type();
45        let var = elem.min_variable();
46        let expand = ManagedVariable::Plain(var);
47        expand.into()
48    }
49
50    fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
51        let elem = Self::as_type(scope).elem_type();
52        let var = elem.max_variable();
53        let expand = ManagedVariable::Plain(var);
54        expand.into()
55    }
56
57    /// Create a new constant numeric.
58    ///
59    /// Note: since this must work for both integer and float
60    /// only the less expressive of both can be created (int)
61    /// If a number with decimals is needed, use `Float::new`.
62    ///
63    /// This method panics when unexpanded. For creating an element
64    /// with a val, use the new method of the sub type.
65    fn from_int(val: i64) -> Self {
66        <Self as NumCast>::from(val).unwrap()
67    }
68
69    /// Create a new constant numeric. Uses `i128` to be able to represent both signed integers, and
70    /// `u64::MAX`.
71    ///
72    /// Note: since this must work for both integer and float
73    /// only the less expressive of both can be created (int)
74    /// If a number with decimals is needed, use `Float::new`.
75    ///
76    /// This method panics when unexpanded. For creating an element
77    /// with a val, use the new method of the sub type.
78    fn from_int_128(val: i128) -> Self {
79        <Self as NumCast>::from(val).unwrap()
80    }
81
82    fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
83        unexpanded!()
84    }
85
86    fn __expand_from_int(
87        scope: &mut Scope,
88        val: NativeExpand<i64>,
89    ) -> <Self as CubeType>::ExpandType {
90        let elem = Self::as_type(scope).elem_type();
91        let var: Variable = elem.constant(val.constant().unwrap());
92
93        ManagedVariable::Plain(var).into()
94    }
95}
96
97/// Similar to [`ArgSettings`], however only for scalar types that don't depend on the [Runtime]
98/// trait.
99pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
100    /// Register the information to the [`KernelLauncher`].
101    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
102    fn expand_scalar(builder: &mut KernelBuilder) -> NativeExpand<Self> {
103        builder
104            .scalar(Self::as_type(&builder.scope).storage_type())
105            .into()
106    }
107}
108
109impl<E: ScalarArgType> ScalarArgSettings for E {
110    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
111        launcher.register_scalar(*self);
112    }
113}
114
115impl ScalarArgSettings for usize {
116    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
117        let value = InputScalar::new(*self, launcher.settings.address_type.unsigned_type());
118        InputScalar::register(value, launcher);
119    }
120}
121
122impl ScalarArgSettings for isize {
123    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
124        let value = InputScalar::new(*self, launcher.settings.address_type.signed_type());
125        InputScalar::register(value, launcher);
126    }
127}
128
129impl<T: ScalarArgSettings> LaunchArg for T {
130    type RuntimeArg<R: Runtime> = T;
131    type CompilationArg = ();
132
133    fn register<R: Runtime>(arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) {
134        arg.register(launcher);
135    }
136
137    fn expand(_: &(), builder: &mut KernelBuilder) -> NativeExpand<Self> {
138        T::expand_scalar(builder)
139    }
140}
141
142pub trait ZeroExpand: CubeType + Zero {
143    fn __expand_zero(scope: &mut Scope) -> Self::ExpandType;
144}
145
146pub trait OneExpand: CubeType + One {
147    fn __expand_one(scope: &mut Scope) -> Self::ExpandType;
148}
149
150impl<T: CubeType + Zero + IntoRuntime> ZeroExpand for T {
151    fn __expand_zero(scope: &mut Scope) -> Self::ExpandType {
152        T::zero().__expand_runtime_method(scope)
153    }
154}
155
156impl<T: CubeType + One + IntoRuntime> OneExpand for T {
157    fn __expand_one(scope: &mut Scope) -> Self::ExpandType {
158        T::one().__expand_runtime_method(scope)
159    }
160}