cubecl_core/frontend/element/
numeric.rs1use std::num::NonZero;
2
3use cubecl_ir::ExpandElement;
4use num_traits::NumCast;
5
6use crate::Runtime;
7use crate::compute::KernelLauncher;
8use crate::ir::{Item, Scope, Variable};
9use crate::prelude::Clamp;
10use crate::{
11 frontend::{Abs, Max, Min, Remainder, index_assign},
12 unexpanded,
13};
14use crate::{
15 frontend::{CubePrimitive, CubeType},
16 prelude::CubeIndexMut,
17};
18
19use super::{
20 ArgSettings, ExpandElementBaseInit, ExpandElementTyped, IntoRuntime, LaunchArg, LaunchArgExpand,
21};
22
23pub trait Numeric:
26 Copy
27 + Abs
28 + Max
29 + Min
30 + Clamp
31 + Remainder
32 + CubePrimitive
33 + IntoRuntime
34 + LaunchArgExpand<CompilationArg = ()>
35 + ScalarArgSettings
36 + ExpandElementBaseInit
37 + Into<ExpandElementTyped<Self>>
38 + CubeIndexMut<u32, Output = Self>
39 + CubeIndexMut<ExpandElementTyped<u32>, Output = Self>
40 + num_traits::NumCast
41 + std::ops::AddAssign
42 + std::ops::SubAssign
43 + std::ops::MulAssign
44 + std::ops::DivAssign
45 + std::ops::Add<Output = Self>
46 + std::ops::Sub<Output = Self>
47 + std::ops::Mul<Output = Self>
48 + std::ops::Div<Output = Self>
49 + std::cmp::PartialOrd
50 + std::cmp::PartialEq
51{
52 fn min_value() -> Self;
53 fn max_value() -> Self;
54
55 fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
56 let elem = Self::as_elem(scope);
57 let var = elem.min_variable();
58 let expand = ExpandElement::Plain(var);
59 expand.into()
60 }
61
62 fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
63 let elem = Self::as_elem(scope);
64 let var = elem.max_variable();
65 let expand = ExpandElement::Plain(var);
66 expand.into()
67 }
68
69 fn from_int(val: i64) -> Self {
78 <Self as NumCast>::from(val).unwrap()
79 }
80
81 fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
82 unexpanded!()
83 }
84
85 fn __expand_from_int(
86 scope: &mut Scope,
87 val: ExpandElementTyped<i64>,
88 ) -> <Self as CubeType>::ExpandType {
89 let elem = Self::as_elem(scope);
90 let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64());
91
92 ExpandElement::Plain(var).into()
93 }
94
95 fn __expand_from_vec<const D: usize>(
96 scope: &mut Scope,
97 vec: [u32; D],
98 ) -> <Self as CubeType>::ExpandType {
99 let new_var = scope.create_local(Item::vectorized(
100 Self::as_elem(scope),
101 NonZero::new(vec.len() as u8),
102 ));
103 let elem = Self::as_elem(scope);
104
105 for (i, element) in vec.iter().enumerate() {
106 let var: Variable = elem.constant_from_i64(*element as i64);
107 let expand = ExpandElement::Plain(var);
108
109 index_assign::expand::<u32>(
110 scope,
111 new_var.clone().into(),
112 ExpandElementTyped::from_lit(scope, i),
113 expand.into(),
114 );
115 }
116
117 new_var.into()
118 }
119}
120
121pub trait ScalarArgSettings: Send + Sync {
124 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
126}
127
128#[derive(new)]
129pub struct ScalarArg<T: Numeric> {
130 pub elem: T,
131}
132
133impl<T: Numeric, R: Runtime> ArgSettings<R> for ScalarArg<T> {
134 fn register(&self, launcher: &mut KernelLauncher<R>) {
135 self.elem.register(launcher);
136 }
137}
138
139impl<T: Numeric> LaunchArg for T {
140 type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
141
142 fn compilation_arg<'a, R: Runtime>(
143 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
144 ) -> Self::CompilationArg {
145 }
146}