cubecl_core/frontend/element/
numeric.rs

1use std::marker::PhantomData;
2
3use cubecl_ir::ExpandElement;
4use cubecl_runtime::runtime::Runtime;
5use num_traits::NumCast;
6
7use crate::frontend::{CubePrimitive, CubeType};
8use crate::ir::{Scope, Variable};
9use crate::prelude::Clamp;
10use crate::{CubeScalar, compute::KernelBuilder};
11use crate::{compute::KernelLauncher, prelude::CompilationArg};
12use crate::{
13    frontend::{Abs, Max, Min, Remainder},
14    unexpanded,
15};
16
17use super::{ArgSettings, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, LaunchArg};
18
19/// Type that encompasses both (unsigned or signed) integers and floats
20/// Used in kernels that should work for both.
21pub trait Numeric:
22    Copy
23    + Abs
24    + Max
25    + Min
26    + Clamp
27    + Remainder
28    + CubePrimitive
29    + IntoRuntime
30    + ExpandElementIntoMut
31    + Into<ExpandElementTyped<Self>>
32    + num_traits::NumCast
33    + std::ops::AddAssign
34    + std::ops::SubAssign
35    + std::ops::MulAssign
36    + std::ops::DivAssign
37    + std::ops::Add<Output = Self>
38    + std::ops::Sub<Output = Self>
39    + std::ops::Mul<Output = Self>
40    + std::ops::Div<Output = Self>
41    + std::cmp::PartialOrd
42    + std::cmp::PartialEq
43{
44    fn min_value() -> Self;
45    fn max_value() -> Self;
46
47    fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
48        let elem = Self::as_type(scope).elem_type();
49        let var = elem.min_variable();
50        let expand = ExpandElement::Plain(var);
51        expand.into()
52    }
53
54    fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
55        let elem = Self::as_type(scope).elem_type();
56        let var = elem.max_variable();
57        let expand = ExpandElement::Plain(var);
58        expand.into()
59    }
60
61    /// Create a new constant numeric.
62    ///
63    /// Note: since this must work for both integer and float
64    /// only the less expressive of both can be created (int)
65    /// If a number with decimals is needed, use Float::new.
66    ///
67    /// This method panics when unexpanded. For creating an element
68    /// with a val, use the new method of the sub type.
69    fn from_int(val: i64) -> Self {
70        <Self as NumCast>::from(val).unwrap()
71    }
72
73    fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
74        unexpanded!()
75    }
76
77    fn __expand_from_int(
78        scope: &mut Scope,
79        val: ExpandElementTyped<i64>,
80    ) -> <Self as CubeType>::ExpandType {
81        let elem = Self::as_type(scope).elem_type();
82        let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64());
83
84        ExpandElement::Plain(var).into()
85    }
86}
87
88/// Similar to [ArgSettings], however only for scalar types that don't depend on the [Runtime]
89/// trait.
90pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
91    /// Register the information to the [KernelLauncher].
92    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
93    fn expand_scalar(
94        _: &ScalarCompilationArg<Self>,
95        builder: &mut KernelBuilder,
96    ) -> ExpandElementTyped<Self> {
97        builder.scalar(Self::as_type(&builder.scope)).into()
98    }
99}
100
101impl<E: CubeScalar> ScalarArgSettings for E {
102    fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>) {
103        launcher.register_scalar(*self);
104    }
105}
106
107#[derive(new, Clone, Copy)]
108pub struct ScalarArg<T: ScalarArgSettings> {
109    pub elem: T,
110}
111
112#[derive(new, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
113pub struct ScalarCompilationArg<T: ScalarArgSettings> {
114    _ty: PhantomData<T>,
115}
116
117impl<T: ScalarArgSettings> Eq for ScalarCompilationArg<T> {}
118impl<T: ScalarArgSettings> core::hash::Hash for ScalarCompilationArg<T> {
119    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
120        self._ty.hash(state);
121    }
122}
123impl<T: ScalarArgSettings> core::fmt::Debug for ScalarCompilationArg<T> {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        f.write_str("Scalar")
126    }
127}
128
129impl<T: ScalarArgSettings> CompilationArg for ScalarCompilationArg<T> {}
130
131impl<T: ScalarArgSettings, R: Runtime> ArgSettings<R> for ScalarArg<T> {
132    fn register(&self, launcher: &mut KernelLauncher<R>) {
133        self.elem.register(launcher);
134    }
135}
136
137impl<T: ScalarArgSettings> LaunchArg for T {
138    type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
139    type CompilationArg = ScalarCompilationArg<T>;
140
141    fn compilation_arg<'a, R: Runtime>(
142        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
143    ) -> Self::CompilationArg {
144        ScalarCompilationArg::new()
145    }
146    fn expand(
147        arg: &ScalarCompilationArg<T>,
148        builder: &mut KernelBuilder,
149    ) -> ExpandElementTyped<Self> {
150        T::expand_scalar(arg, builder)
151    }
152}