cubecl_core/frontend/element/
numeric.rs

1use cubecl_ir::ExpandElement;
2use num_traits::NumCast;
3
4use crate::Runtime;
5use crate::compute::KernelLauncher;
6use crate::frontend::{CubePrimitive, CubeType};
7use crate::ir::{Scope, Variable};
8use crate::prelude::Clamp;
9use crate::{
10    frontend::{Abs, Max, Min, Remainder},
11    unexpanded,
12};
13
14use super::{
15    ArgSettings, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, LaunchArg, LaunchArgExpand,
16};
17
18/// Type that encompasses both (unsigned or signed) integers and floats
19/// Used in kernels that should work for both.
20pub trait Numeric:
21    Copy
22    + Abs
23    + Max
24    + Min
25    + Clamp
26    + Remainder
27    + CubePrimitive
28    + IntoRuntime
29    + LaunchArgExpand<CompilationArg = ()>
30    + ScalarArgSettings
31    + ExpandElementIntoMut
32    + Into<ExpandElementTyped<Self>>
33    + num_traits::NumCast
34    + std::ops::AddAssign
35    + std::ops::SubAssign
36    + std::ops::MulAssign
37    + std::ops::DivAssign
38    + std::ops::Add<Output = Self>
39    + std::ops::Sub<Output = Self>
40    + std::ops::Mul<Output = Self>
41    + std::ops::Div<Output = Self>
42    + std::cmp::PartialOrd
43    + std::cmp::PartialEq
44{
45    fn min_value() -> Self;
46    fn max_value() -> Self;
47
48    fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
49        let elem = Self::as_elem(scope);
50        let var = elem.min_variable();
51        let expand = ExpandElement::Plain(var);
52        expand.into()
53    }
54
55    fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
56        let elem = Self::as_elem(scope);
57        let var = elem.max_variable();
58        let expand = ExpandElement::Plain(var);
59        expand.into()
60    }
61
62    /// Create a new constant numeric.
63    ///
64    /// Note: since this must work for both integer and float
65    /// only the less expressive of both can be created (int)
66    /// If a number with decimals is needed, use Float::new.
67    ///
68    /// This method panics when unexpanded. For creating an element
69    /// with a val, use the new method of the sub type.
70    fn from_int(val: i64) -> Self {
71        <Self as NumCast>::from(val).unwrap()
72    }
73
74    fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
75        unexpanded!()
76    }
77
78    fn __expand_from_int(
79        scope: &mut Scope,
80        val: ExpandElementTyped<i64>,
81    ) -> <Self as CubeType>::ExpandType {
82        let elem = Self::as_elem(scope);
83        let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64());
84
85        ExpandElement::Plain(var).into()
86    }
87}
88
89/// Similar to [ArgSettings], however only for scalar types that don't depend on the [Runtime]
90/// trait.
91pub trait ScalarArgSettings: Send + Sync {
92    /// Register the information to the [KernelLauncher].
93    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
94}
95
96#[derive(new)]
97pub struct ScalarArg<T: Numeric> {
98    pub elem: T,
99}
100
101impl<T: Numeric, R: Runtime> ArgSettings<R> for ScalarArg<T> {
102    fn register(&self, launcher: &mut KernelLauncher<R>) {
103        self.elem.register(launcher);
104    }
105}
106
107impl<T: Numeric> LaunchArg for T {
108    type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
109
110    fn compilation_arg<'a, R: Runtime>(
111        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
112    ) -> Self::CompilationArg {
113    }
114}