cubecl_core/frontend/element/
numeric.rs

1use std::num::NonZero;
2
3use cubecl_ir::ExpandElement;
4use num_traits::NumCast;
5
6use crate::Runtime;
7use crate::compute::KernelLauncher;
8use crate::ir::{Item, Scope, Variable};
9use crate::prelude::Clamp;
10use crate::{
11    frontend::{Abs, Max, Min, Remainder, index_assign},
12    unexpanded,
13};
14use crate::{
15    frontend::{CubePrimitive, CubeType},
16    prelude::CubeIndexMut,
17};
18
19use super::{
20    ArgSettings, ExpandElementBaseInit, ExpandElementTyped, IntoRuntime, LaunchArg, LaunchArgExpand,
21};
22
23/// Type that encompasses both (unsigned or signed) integers and floats
24/// Used in kernels that should work for both.
25pub trait Numeric:
26    Copy
27    + Abs
28    + Max
29    + Min
30    + Clamp
31    + Remainder
32    + CubePrimitive
33    + IntoRuntime
34    + LaunchArgExpand<CompilationArg = ()>
35    + ScalarArgSettings
36    + ExpandElementBaseInit
37    + Into<ExpandElementTyped<Self>>
38    + CubeIndexMut<u32, Output = Self>
39    + CubeIndexMut<ExpandElementTyped<u32>, Output = Self>
40    + num_traits::NumCast
41    + std::ops::AddAssign
42    + std::ops::SubAssign
43    + std::ops::MulAssign
44    + std::ops::DivAssign
45    + std::ops::Add<Output = Self>
46    + std::ops::Sub<Output = Self>
47    + std::ops::Mul<Output = Self>
48    + std::ops::Div<Output = Self>
49    + std::cmp::PartialOrd
50    + std::cmp::PartialEq
51{
52    fn min_value() -> Self;
53    fn max_value() -> Self;
54
55    fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
56        let elem = Self::as_elem(scope);
57        let var = elem.min_variable();
58        let expand = ExpandElement::Plain(var);
59        expand.into()
60    }
61
62    fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
63        let elem = Self::as_elem(scope);
64        let var = elem.max_variable();
65        let expand = ExpandElement::Plain(var);
66        expand.into()
67    }
68
69    /// Create a new constant numeric.
70    ///
71    /// Note: since this must work for both integer and float
72    /// only the less expressive of both can be created (int)
73    /// If a number with decimals is needed, use Float::new.
74    ///
75    /// This method panics when unexpanded. For creating an element
76    /// with a val, use the new method of the sub type.
77    fn from_int(val: i64) -> Self {
78        <Self as NumCast>::from(val).unwrap()
79    }
80
81    fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
82        unexpanded!()
83    }
84
85    fn __expand_from_int(
86        scope: &mut Scope,
87        val: ExpandElementTyped<i64>,
88    ) -> <Self as CubeType>::ExpandType {
89        let elem = Self::as_elem(scope);
90        let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64());
91
92        ExpandElement::Plain(var).into()
93    }
94
95    fn __expand_from_vec<const D: usize>(
96        scope: &mut Scope,
97        vec: [u32; D],
98    ) -> <Self as CubeType>::ExpandType {
99        let new_var = scope.create_local(Item::vectorized(
100            Self::as_elem(scope),
101            NonZero::new(vec.len() as u8),
102        ));
103        let elem = Self::as_elem(scope);
104
105        for (i, element) in vec.iter().enumerate() {
106            let var: Variable = elem.constant_from_i64(*element as i64);
107            let expand = ExpandElement::Plain(var);
108
109            index_assign::expand::<u32>(
110                scope,
111                new_var.clone().into(),
112                ExpandElementTyped::from_lit(scope, i),
113                expand.into(),
114            );
115        }
116
117        new_var.into()
118    }
119}
120
121/// Similar to [ArgSettings], however only for scalar types that don't depend on the [Runtime]
122/// trait.
123pub trait ScalarArgSettings: Send + Sync {
124    /// Register the information to the [KernelLauncher].
125    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
126}
127
128#[derive(new)]
129pub struct ScalarArg<T: Numeric> {
130    pub elem: T,
131}
132
133impl<T: Numeric, R: Runtime> ArgSettings<R> for ScalarArg<T> {
134    fn register(&self, launcher: &mut KernelLauncher<R>) {
135        self.elem.register(launcher);
136    }
137}
138
139impl<T: Numeric> LaunchArg for T {
140    type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
141
142    fn compilation_arg<'a, R: Runtime>(
143        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
144    ) -> Self::CompilationArg {
145    }
146}