cubecl_core/frontend/element/
numeric.rs

1use std::num::NonZero;
2
3use num_traits::NumCast;
4
5use crate::compute::KernelLauncher;
6use crate::ir::{Item, Variable};
7use crate::prelude::Clamp;
8use crate::Runtime;
9use crate::{
10    frontend::{index_assign, Abs, Max, Min, Remainder},
11    unexpanded,
12};
13use crate::{
14    frontend::{CubeContext, CubePrimitive, CubeType},
15    prelude::CubeIndexMut,
16};
17
18use super::{
19    ArgSettings, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, LaunchArg,
20    LaunchArgExpand,
21};
22
23/// Type that encompasses both (unsigned or signed) integers and floats
24/// Used in kernels that should work for both.
25pub trait Numeric:
26    Copy
27    + Abs
28    + Max
29    + Min
30    + Clamp
31    + Remainder
32    + CubePrimitive
33    + LaunchArgExpand<CompilationArg = ()>
34    + ScalarArgSettings
35    + ExpandElementBaseInit
36    + Into<ExpandElementTyped<Self>>
37    + CubeIndexMut<u32, Output = Self>
38    + CubeIndexMut<ExpandElementTyped<u32>, Output = Self>
39    + num_traits::NumCast
40    + std::ops::AddAssign
41    + std::ops::SubAssign
42    + std::ops::MulAssign
43    + std::ops::DivAssign
44    + std::ops::Add<Output = Self>
45    + std::ops::Sub<Output = Self>
46    + std::ops::Mul<Output = Self>
47    + std::ops::Div<Output = Self>
48    + std::cmp::PartialOrd
49    + std::cmp::PartialEq
50{
51    fn min_value() -> Self;
52    fn max_value() -> Self;
53
54    fn __expand_min_value(context: &mut CubeContext) -> <Self as CubeType>::ExpandType {
55        let elem = Self::as_elem(context);
56        let var = elem.min_variable();
57        let expand = ExpandElement::Plain(var);
58        expand.into()
59    }
60
61    fn __expand_max_value(context: &mut CubeContext) -> <Self as CubeType>::ExpandType {
62        let elem = Self::as_elem(context);
63        let var = elem.max_variable();
64        let expand = ExpandElement::Plain(var);
65        expand.into()
66    }
67
68    /// Create a new constant numeric.
69    ///
70    /// Note: since this must work for both integer and float
71    /// only the less expressive of both can be created (int)
72    /// If a number with decimals is needed, use Float::new.
73    ///
74    /// This method panics when unexpanded. For creating an element
75    /// with a val, use the new method of the sub type.
76    fn from_int(val: i64) -> Self {
77        <Self as NumCast>::from(val).unwrap()
78    }
79
80    fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
81        unexpanded!()
82    }
83
84    fn __expand_from_int(
85        context: &mut CubeContext,
86        val: ExpandElementTyped<i64>,
87    ) -> <Self as CubeType>::ExpandType {
88        let elem = Self::as_elem(context);
89        let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64());
90
91        ExpandElement::Plain(var).into()
92    }
93
94    fn __expand_from_vec<const D: usize>(
95        context: &mut CubeContext,
96        vec: [u32; D],
97    ) -> <Self as CubeType>::ExpandType {
98        let new_var = context.create_local(Item::vectorized(
99            Self::as_elem(context),
100            NonZero::new(vec.len() as u8),
101        ));
102        let elem = Self::as_elem(context);
103
104        for (i, element) in vec.iter().enumerate() {
105            let var: Variable = elem.constant_from_i64(*element as i64);
106            let expand = ExpandElement::Plain(var);
107
108            index_assign::expand::<u32>(
109                context,
110                new_var.clone().into(),
111                ExpandElementTyped::from_lit(context, i),
112                expand.into(),
113            );
114        }
115
116        new_var.into()
117    }
118}
119
120/// Similar to [ArgSettings], however only for scalar types that don't depend on the [Runtime]
121/// trait.
122pub trait ScalarArgSettings: Send + Sync {
123    /// Register the information to the [KernelLauncher].
124    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
125}
126
127#[derive(new)]
128pub struct ScalarArg<T: Numeric> {
129    elem: T,
130}
131
132impl<T: Numeric, R: Runtime> ArgSettings<R> for ScalarArg<T> {
133    fn register(&self, launcher: &mut crate::compute::KernelLauncher<R>) {
134        self.elem.register(launcher);
135    }
136}
137
138impl<T: Numeric> LaunchArg for T {
139    type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
140
141    fn compilation_arg<'a, R: Runtime>(
142        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
143    ) -> Self::CompilationArg {
144    }
145}