cubecl_core/frontend/element/
numeric.rs1use std::marker::PhantomData;
2
3use cubecl_ir::ExpandElement;
4use num_traits::NumCast;
5
6use crate::frontend::{CubePrimitive, CubeType};
7use crate::ir::{Scope, Variable};
8use crate::prelude::Clamp;
9use crate::{Runtime, compute::KernelBuilder};
10use crate::{compute::KernelLauncher, prelude::CompilationArg};
11use crate::{
12 frontend::{Abs, Max, Min, Remainder},
13 unexpanded,
14};
15
16use super::{ArgSettings, ExpandElementIntoMut, ExpandElementTyped, IntoRuntime, LaunchArg};
17
18pub trait Numeric:
21 Copy
22 + Abs
23 + Max
24 + Min
25 + Clamp
26 + Remainder
27 + CubePrimitive
28 + IntoRuntime
29 + ScalarArgSettings
30 + ExpandElementIntoMut
31 + Into<ExpandElementTyped<Self>>
32 + num_traits::NumCast
33 + std::ops::AddAssign
34 + std::ops::SubAssign
35 + std::ops::MulAssign
36 + std::ops::DivAssign
37 + std::ops::Add<Output = Self>
38 + std::ops::Sub<Output = Self>
39 + std::ops::Mul<Output = Self>
40 + std::ops::Div<Output = Self>
41 + std::cmp::PartialOrd
42 + std::cmp::PartialEq
43{
44 fn min_value() -> Self;
45 fn max_value() -> Self;
46
47 fn __expand_min_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
48 let elem = Self::as_type(scope).elem_type();
49 let var = elem.min_variable();
50 let expand = ExpandElement::Plain(var);
51 expand.into()
52 }
53
54 fn __expand_max_value(scope: &mut Scope) -> <Self as CubeType>::ExpandType {
55 let elem = Self::as_type(scope).elem_type();
56 let var = elem.max_variable();
57 let expand = ExpandElement::Plain(var);
58 expand.into()
59 }
60
61 fn from_int(val: i64) -> Self {
70 <Self as NumCast>::from(val).unwrap()
71 }
72
73 fn from_vec<const D: usize>(_vec: [u32; D]) -> Self {
74 unexpanded!()
75 }
76
77 fn __expand_from_int(
78 scope: &mut Scope,
79 val: ExpandElementTyped<i64>,
80 ) -> <Self as CubeType>::ExpandType {
81 let elem = Self::as_type(scope).elem_type();
82 let var: Variable = elem.constant_from_i64(val.constant().unwrap().as_i64());
83
84 ExpandElement::Plain(var).into()
85 }
86}
87
88pub trait ScalarArgSettings: Send + Sync + CubePrimitive {
91 fn register<R: Runtime>(&self, launcher: &mut KernelLauncher<R>);
93 fn expand_scalar(
94 _: &ScalarCompilationArg<Self>,
95 builder: &mut KernelBuilder,
96 ) -> ExpandElementTyped<Self> {
97 builder.scalar(Self::as_type(&builder.scope)).into()
98 }
99}
100
101#[derive(new, Clone, Copy)]
102pub struct ScalarArg<T: ScalarArgSettings> {
103 pub elem: T,
104}
105
106#[derive(new, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
107pub struct ScalarCompilationArg<T: ScalarArgSettings> {
108 _ty: PhantomData<T>,
109}
110
111impl<T: ScalarArgSettings> Eq for ScalarCompilationArg<T> {}
112impl<T: ScalarArgSettings> core::hash::Hash for ScalarCompilationArg<T> {
113 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
114 self._ty.hash(state);
115 }
116}
117impl<T: ScalarArgSettings> core::fmt::Debug for ScalarCompilationArg<T> {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 f.write_str("Scalar")
120 }
121}
122
123impl<T: ScalarArgSettings> CompilationArg for ScalarCompilationArg<T> {}
124
125impl<T: ScalarArgSettings, R: Runtime> ArgSettings<R> for ScalarArg<T> {
126 fn register(&self, launcher: &mut KernelLauncher<R>) {
127 self.elem.register(launcher);
128 }
129}
130
131impl<T: ScalarArgSettings> LaunchArg for T {
132 type RuntimeArg<'a, R: Runtime> = ScalarArg<T>;
133 type CompilationArg = ScalarCompilationArg<T>;
134
135 fn compilation_arg<'a, R: Runtime>(
136 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
137 ) -> Self::CompilationArg {
138 ScalarCompilationArg::new()
139 }
140 fn expand(
141 arg: &ScalarCompilationArg<T>,
142 builder: &mut KernelBuilder,
143 ) -> ExpandElementTyped<Self> {
144 T::expand_scalar(arg, builder)
145 }
146}