Skip to main content

cubecl_core/frontend/element/
numeric.rs

1use core::marker::PhantomData;
2
3use cubecl_ir::{ConstantValue, ExpandElement};
4use cubecl_runtime::runtime::Runtime;
5use num_traits::NumCast;
6
7use crate::ir::{Scope, Variable};
8use crate::{CubeScalar, compute::KernelBuilder};
9use crate::{compute::KernelLauncher, prelude::CompilationArg};
10use crate::{
11    frontend::{Abs, Remainder},
12    unexpanded,
13};
14use crate::{
15    frontend::{CubePrimitive, CubeType},
16    prelude::InputScalar,
17};
18
19use super::{ArgSettings, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, LaunchArg};
20
21/// Type that encompasses both (unsigned or signed) integers and floats
22/// Used in kernels that should work for both.
23pub trait Numeric:
24    Copy
25    + Abs
26    + Remainder
27    + CubePrimitive
28    + IntoRuntime
29    + ExpandElementIntoMut
30    + Into<ExpandElementTyped<Self>>
31    + Into<ConstantValue>
32    + num_traits::NumCast
33    + num_traits::NumAssign
34    + core::cmp::PartialOrd
35    + core::cmp::PartialEq
36    + core::fmt::Debug
37    + Default
38{
39    fn min_value() -> Self;
40    fn max_value() -> Self;
41
42    fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
43        let elem = Self::as_type(scope).elem_type();
44        let var = elem.min_variable();
45        let expand = ExpandElement::Plain(var);
46        expand.into()
47    }
48
49    fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
50        let elem = Self::as_type(scope).elem_type();
51        let var = elem.max_variable();
52        let expand = ExpandElement::Plain(var);
53        expand.into()
54    }
55
56    /// Create a new constant numeric.
57    ///
58    /// Note: since this must work for both integer and float
59    /// only the less expressive of both can be created (int)
60    /// If a number with decimals is needed, use `Float::new`.
61    ///
62    /// This method panics when unexpanded. For creating an element
63    /// with a val, use the new method of the sub type.
64    fn from_int(val: i64) -> Self {
65        <Self as NumCast>::from(val).unwrap()
66    }
67
68    /// Create a new constant numeric. Uses `i128` to be able to represent both signed integers, and
69    /// `u64::MAX`.
70    ///
71    /// Note: since this must work for both integer and float
72    /// only the less expressive of both can be created (int)
73    /// If a number with decimals is needed, use `Float::new`.
74    ///
75    /// This method panics when unexpanded. For creating an element
76    /// with a val, use the new method of the sub type.
77    fn from_int_128(val: i128) -> Self {
78        <Self as NumCast>::from(val).unwrap()
79    }
80
81    fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
82        unexpanded!()
83    }
84
85    fn __expand_from_int(
86        scope: &mut Scope,
87        val: ExpandElementTyped<i64>,
88    ) -> <Self as CubeType>::ExpandType {
89        let elem = Self::as_type(scope).elem_type();
90        let var: Variable = elem.constant(val.constant().unwrap());
91
92        ExpandElement::Plain(var).into()
93    }
94}
95
96/// Similar to [`ArgSettings`], however only for scalar types that don't depend on the [Runtime]
97/// trait.
98pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
99    /// Register the information to the [`KernelLauncher`].
100    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
101    fn expand_scalar(
102        _: &ScalarCompilationArg<Self>,
103        builder: &mut KernelBuilder,
104    ) -> ExpandElementTyped<Self> {
105        builder.scalar(Self::as_type(&builder.scope)).into()
106    }
107}
108
109impl<E: CubeScalar> ScalarArgSettings for E {
110    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
111        launcher.register_scalar(*self);
112    }
113}
114
115impl ScalarArgSettings for usize {
116    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
117        InputScalar::new(*self, launcher.settings.address_type.unsigned_type()).register(launcher);
118    }
119}
120
121impl ScalarArgSettings for isize {
122    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
123        InputScalar::new(*self, launcher.settings.address_type.signed_type()).register(launcher);
124    }
125}
126
127#[derive(new, Clone, Copy)]
128pub struct ScalarArg<T: ScalarArgSettings> {
129    pub elem: T,
130}
131
132#[derive(new, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
133pub struct ScalarCompilationArg<T: ScalarArgSettings> {
134    _ty: PhantomData<T>,
135}
136
137impl<T: ScalarArgSettings> Eq for ScalarCompilationArg<T> {}
138impl<T: ScalarArgSettings> core::hash::Hash for ScalarCompilationArg<T> {
139    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
140        self._ty.hash(state);
141    }
142}
143impl<T: ScalarArgSettings> core::fmt::Debug for ScalarCompilationArg<T> {
144    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145        f.write_str("Scalar")
146    }
147}
148
149impl<T: ScalarArgSettings> CompilationArg for ScalarCompilationArg<T> {}
150
151impl<T: ScalarArgSettings, R: Runtime> ArgSettings<R> for ScalarArg<T> {
152    fn register(&self, launcher: &mut KernelLauncher<R>) {
153        self.elem.register(launcher);
154    }
155}
156
157impl<T: ScalarArgSettings> LaunchArg for T {
158    type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
159    type CompilationArg = ScalarCompilationArg<T>;
160
161    fn compilation_arg<'a, R: Runtime>(
162        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
163    ) -> Self::CompilationArg {
164        ScalarCompilationArg::new()
165    }
166    fn expand(
167        arg: &ScalarCompilationArg<T>,
168        builder: &mut KernelBuilder,
169    ) -> ExpandElementTyped<Self> {
170        T::expand_scalar(arg, builder)
171    }
172}