cubecl_core/frontend/element/
numeric.rs1use cubecl_ir::ExpandElement;
2use num_traits::NumCast;
3
4use crate::Runtime;
5use crate::compute::KernelLauncher;
6use crate::frontend::{CubePrimitive, CubeType};
7use crate::ir::{Scope, Variable};
8use crate::prelude::Clamp;
9use crate::{
10 frontend::{Abs, Max, Min, Remainder},
11 unexpanded,
12};
13
14use super::{
15 ArgSettings, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, LaunchArg, LaunchArgExpand,
16};
17
18pub trait Numeric:
21 Copy
22 + Abs
23 + Max
24 + Min
25 + Clamp
26 + Remainder
27 + CubePrimitive
28 + IntoRuntime
29 + LaunchArgExpand<CompilationArg = ()>
30 + ScalarArgSettings
31 + ExpandElementIntoMut
32 + Into<ExpandElementTyped<Self>>
33 + num_traits::NumCast
34 + std::ops::AddAssign
35 + std::ops::SubAssign
36 + std::ops::MulAssign
37 + std::ops::DivAssign
38 + std::ops::Add<Output = Self>
39 + std::ops::Sub<Output = Self>
40 + std::ops::Mul<Output = Self>
41 + std::ops::Div<Output = Self>
42 + std::cmp::PartialOrd
43 + std::cmp::PartialEq
44{
45 fn min_value() -> Self;
46 fn max_value() -> Self;
47
48 fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
49 let elem = Self::as_elem(scope);
50 let var = elem.min_variable();
51 let expand = ExpandElement::Plain(var);
52 expand.into()
53 }
54
55 fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
56 let elem = Self::as_elem(scope);
57 let var = elem.max_variable();
58 let expand = ExpandElement::Plain(var);
59 expand.into()
60 }
61
62 fn from_int(val: i64) -> Self {
71 <Self as NumCast>::from(val).unwrap()
72 }
73
74 fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
75 unexpanded!()
76 }
77
78 fn __expand_from_int(
79 scope: &mut Scope,
80 val: ExpandElementTyped<i64>,
81 ) -> <Self as CubeType>::ExpandType {
82 let elem = Self::as_elem(scope);
83 let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64());
84
85 ExpandElement::Plain(var).into()
86 }
87}
88
89pub trait ScalarArgSettings: Send + Sync {
92 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
94}
95
96#[derive(new)]
97pub struct ScalarArg<T: Numeric> {
98 pub elem: T,
99}
100
101impl<T: Numeric, R: Runtime> ArgSettings<R> for ScalarArg<T> {
102 fn register(&self, launcher: &mut KernelLauncher<R>) {
103 self.elem.register(launcher);
104 }
105}
106
107impl<T: Numeric> LaunchArg for T {
108 type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
109
110 fn compilation_arg<'a, R: Runtime>(
111 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
112 ) -> Self::CompilationArg {
113 }
114}