cubecl_core/frontend/element/
numeric.rs1use std::num::NonZero;
2
3use num_traits::NumCast;
4
5use crate::compute::KernelLauncher;
6use crate::ir::{Item, Variable};
7use crate::prelude::Clamp;
8use crate::Runtime;
9use crate::{
10 frontend::{index_assign, Abs, Max, Min, Remainder},
11 unexpanded,
12};
13use crate::{
14 frontend::{CubeContext, CubePrimitive, CubeType},
15 prelude::CubeIndexMut,
16};
17
18use super::{
19 ArgSettings, ExpandElement, ExpandElementBaseInit, ExpandElementTyped, LaunchArg,
20 LaunchArgExpand,
21};
22
23pub trait Numeric:
26 Copy
27 + Abs
28 + Max
29 + Min
30 + Clamp
31 + Remainder
32 + CubePrimitive
33 + LaunchArgExpand<CompilationArg = ()>
34 + ScalarArgSettings
35 + ExpandElementBaseInit
36 + Into<ExpandElementTyped<Self>>
37 + CubeIndexMut<u32, Output = Self>
38 + CubeIndexMut<ExpandElementTyped<u32>, Output = Self>
39 + num_traits::NumCast
40 + std::ops::AddAssign
41 + std::ops::SubAssign
42 + std::ops::MulAssign
43 + std::ops::DivAssign
44 + std::ops::Add<Output = Self>
45 + std::ops::Sub<Output = Self>
46 + std::ops::Mul<Output = Self>
47 + std::ops::Div<Output = Self>
48 + std::cmp::PartialOrd
49 + std::cmp::PartialEq
50{
51 fn min_value() -> Self;
52 fn max_value() -> Self;
53
54 fn __expand_min_value(context: &mut CubeContext) -> <Self as CubeType>::ExpandType {
55 let elem = Self::as_elem(context);
56 let var = elem.min_variable();
57 let expand = ExpandElement::Plain(var);
58 expand.into()
59 }
60
61 fn __expand_max_value(context: &mut CubeContext) -> <Self as CubeType>::ExpandType {
62 let elem = Self::as_elem(context);
63 let var = elem.max_variable();
64 let expand = ExpandElement::Plain(var);
65 expand.into()
66 }
67
68 fn from_int(val: i64) -> Self {
77 <Self as NumCast>::from(val).unwrap()
78 }
79
80 fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
81 unexpanded!()
82 }
83
84 fn __expand_from_int(
85 context: &mut CubeContext,
86 val: ExpandElementTyped<i64>,
87 ) -> <Self as CubeType>::ExpandType {
88 let elem = Self::as_elem(context);
89 let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64());
90
91 ExpandElement::Plain(var).into()
92 }
93
94 fn __expand_from_vec<const D: usize>(
95 context: &mut CubeContext,
96 vec: [u32; D],
97 ) -> <Self as CubeType>::ExpandType {
98 let new_var = context.create_local(Item::vectorized(
99 Self::as_elem(context),
100 NonZero::new(vec.len() as u8),
101 ));
102 let elem = Self::as_elem(context);
103
104 for (i, element) in vec.iter().enumerate() {
105 let var: Variable = elem.constant_from_i64(*element as i64);
106 let expand = ExpandElement::Plain(var);
107
108 index_assign::expand::<u32>(
109 context,
110 new_var.clone().into(),
111 ExpandElementTyped::from_lit(context, i),
112 expand.into(),
113 );
114 }
115
116 new_var.into()
117 }
118}
119
120pub trait ScalarArgSettings: Send + Sync {
123 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
125}
126
127#[derive(new)]
128pub struct ScalarArg<T: Numeric> {
129 elem: T,
130}
131
132impl<T: Numeric, R: Runtime> ArgSettings<R> for ScalarArg<T> {
133 fn register(&self, launcher: &mut crate::compute::KernelLauncher<R>) {
134 self.elem.register(launcher);
135 }
136}
137
138impl<T: Numeric> LaunchArg for T {
139 type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
140
141 fn compilation_arg<'a, R: Runtime>(
142 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
143 ) -> Self::CompilationArg {
144 }
145}