cubecl_core/frontend/element/
numeric.rs

1use std::marker::PhantomData;
2
3use cubecl_ir::{ConstantValue, ExpandElement};
4use cubecl_runtime::runtime::Runtime;
5use num_traits::NumCast;
6
7use crate::ir::{Scope, Variable};
8use crate::{CubeScalar, compute::KernelBuilder};
9use crate::{compute::KernelLauncher, prelude::CompilationArg};
10use crate::{
11    frontend::{Abs, Remainder},
12    unexpanded,
13};
14use crate::{
15    frontend::{CubePrimitive, CubeType},
16    prelude::InputScalar,
17};
18
19use super::{ArgSettings, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, LaunchArg};
20
21/// Type that encompasses both (unsigned or signed) integers and floats
22/// Used in kernels that should work for both.
23pub trait Numeric:
24    Copy
25    + Abs
26    + Remainder
27    + CubePrimitive
28    + IntoRuntime
29    + ExpandElementIntoMut
30    + Into<ExpandElementTyped<Self>>
31    + Into<ConstantValue>
32    + num_traits::NumCast
33    + num_traits::NumAssign
34    + std::cmp::PartialOrd
35    + std::cmp::PartialEq
36    + std::fmt::Debug
37{
38    fn min_value() -> Self;
39    fn max_value() -> Self;
40
41    fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
42        let elem = Self::as_type(scope).elem_type();
43        let var = elem.min_variable();
44        let expand = ExpandElement::Plain(var);
45        expand.into()
46    }
47
48    fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
49        let elem = Self::as_type(scope).elem_type();
50        let var = elem.max_variable();
51        let expand = ExpandElement::Plain(var);
52        expand.into()
53    }
54
55    /// Create a new constant numeric.
56    ///
57    /// Note: since this must work for both integer and float
58    /// only the less expressive of both can be created (int)
59    /// If a number with decimals is needed, use Float::new.
60    ///
61    /// This method panics when unexpanded. For creating an element
62    /// with a val, use the new method of the sub type.
63    fn from_int(val: i64) -> Self {
64        <Self as NumCast>::from(val).unwrap()
65    }
66
67    /// Create a new constant numeric. Uses `i128` to be able to represent both signed integers, and
68    /// u64::MAX.
69    ///
70    /// Note: since this must work for both integer and float
71    /// only the less expressive of both can be created (int)
72    /// If a number with decimals is needed, use Float::new.
73    ///
74    /// This method panics when unexpanded. For creating an element
75    /// with a val, use the new method of the sub type.
76    fn from_int_128(val: i128) -> Self {
77        <Self as NumCast>::from(val).unwrap()
78    }
79
80    fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
81        unexpanded!()
82    }
83
84    fn __expand_from_int(
85        scope: &mut Scope,
86        val: ExpandElementTyped<i64>,
87    ) -> <Self as CubeType>::ExpandType {
88        let elem = Self::as_type(scope).elem_type();
89        let var: Variable = elem.constant(val.constant().unwrap());
90
91        ExpandElement::Plain(var).into()
92    }
93}
94
95/// Similar to [ArgSettings], however only for scalar types that don't depend on the [Runtime]
96/// trait.
97pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
98    /// Register the information to the [KernelLauncher].
99    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
100    fn expand_scalar(
101        _: &ScalarCompilationArg<Self>,
102        builder: &mut KernelBuilder,
103    ) -> ExpandElementTyped<Self> {
104        builder.scalar(Self::as_type(&builder.scope)).into()
105    }
106}
107
108impl<E: CubeScalar> ScalarArgSettings for E {
109    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
110        launcher.register_scalar(*self);
111    }
112}
113
114impl ScalarArgSettings for usize {
115    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
116        InputScalar::new(*self, launcher.settings.address_type.unsigned_type()).register(launcher);
117    }
118}
119
120impl ScalarArgSettings for isize {
121    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
122        InputScalar::new(*self, launcher.settings.address_type.signed_type()).register(launcher);
123    }
124}
125
126#[derive(new, Clone, Copy)]
127pub struct ScalarArg<T: ScalarArgSettings> {
128    pub elem: T,
129}
130
131#[derive(new, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
132pub struct ScalarCompilationArg<T: ScalarArgSettings> {
133    _ty: PhantomData<T>,
134}
135
136impl<T: ScalarArgSettings> Eq for ScalarCompilationArg<T> {}
137impl<T: ScalarArgSettings> core::hash::Hash for ScalarCompilationArg<T> {
138    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139        self._ty.hash(state);
140    }
141}
142impl<T: ScalarArgSettings> core::fmt::Debug for ScalarCompilationArg<T> {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        f.write_str("Scalar")
145    }
146}
147
148impl<T: ScalarArgSettings> CompilationArg for ScalarCompilationArg<T> {}
149
150impl<T: ScalarArgSettings, R: Runtime> ArgSettings<R> for ScalarArg<T> {
151    fn register(&self, launcher: &mut KernelLauncher<R>) {
152        self.elem.register(launcher);
153    }
154}
155
156impl<T: ScalarArgSettings> LaunchArg for T {
157    type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
158    type CompilationArg = ScalarCompilationArg<T>;
159
160    fn compilation_arg<'a, R: Runtime>(
161        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
162    ) -> Self::CompilationArg {
163        ScalarCompilationArg::new()
164    }
165    fn expand(
166        arg: &ScalarCompilationArg<T>,
167        builder: &mut KernelBuilder,
168    ) -> ExpandElementTyped<Self> {
169        T::expand_scalar(arg, builder)
170    }
171}