cubecl_core/frontend/element/
base.rs1use super::{CubePrimitive, Numeric};
2use crate::{
3 Runtime,
4 ir::{ConstantScalarValue, Operation, Scope, Variable, VariableKind},
5 prelude::{KernelBuilder, KernelLauncher, init_expand},
6};
7use cubecl_common::{e2m1, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
8use cubecl_ir::ExpandElement;
9use half::{bf16, f16};
10use std::{
11 any::{Any, TypeId},
12 marker::PhantomData,
13};
14use variadics_please::all_tuples;
15
16#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
29pub trait CubeType {
30 type ExpandType: Clone + IntoMut + CubeDebug;
31
32 fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
34 expand.into_mut(scope)
35 }
36}
37
38pub trait IntoRuntime: CubeType + Sized {
40 fn runtime(self) -> Self {
41 self
42 }
43
44 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
45}
46
47pub trait IntoMut: Sized {
49 fn into_mut(self, scope: &mut Scope) -> Self;
50}
51
52pub trait CubeDebug: Sized {
53 #[allow(unused)]
56 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
57}
58
59pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
74impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
75
76pub trait CubeLaunch: CubeType + LaunchArg + LaunchArgExpand {}
78impl<T: CubeType + LaunchArg + LaunchArgExpand> CubeLaunch for T {}
79
80pub trait CompilationArg:
82 serde::Serialize
83 + serde::de::DeserializeOwned
84 + Clone
85 + PartialEq
86 + Eq
87 + core::hash::Hash
88 + core::fmt::Debug
89 + Send
90 + Sync
91 + 'static
92{
93 fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
100 if TypeId::of::<Arg>() == TypeId::of::<Self>() {
101 let tmp: Box<dyn Any> = Box::new(self.clone());
102 *tmp.downcast().unwrap()
103 } else {
104 let val = serde_json::to_string(self).unwrap();
105 serde_json::from_str(&val)
106 .expect("Compilation argument should be the same even with different element types")
107 }
108 }
109}
110
111impl CompilationArg for () {}
112
113#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
122pub trait LaunchArgExpand: CubeType {
123 type CompilationArg: CompilationArg;
125
126 fn expand(
128 arg: &Self::CompilationArg,
129 builder: &mut KernelBuilder,
130 ) -> <Self as CubeType>::ExpandType;
131
132 fn expand_output(
134 arg: &Self::CompilationArg,
135 builder: &mut KernelBuilder,
136 ) -> <Self as CubeType>::ExpandType {
137 Self::expand(arg, builder)
138 }
139}
140
141pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
143 type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
145
146 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
147}
148
149pub trait ArgSettings<R: Runtime>: Send + Sync {
151 fn register(&self, launcher: &mut KernelLauncher<R>);
153}
154
155#[derive(new)]
157pub struct ExpandElementTyped<T: CubeType> {
158 pub(crate) expand: ExpandElement,
159 pub(crate) _type: PhantomData<T>,
160}
161
162impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
163 fn from(value: &ExpandElementTyped<T>) -> Self {
164 value.clone()
165 }
166}
167
168impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
169 fn from(value: ExpandElementTyped<T>) -> Self {
170 value.expand.into()
171 }
172}
173
174impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
175 fn from(value: &mut ExpandElementTyped<T>) -> Self {
176 value.clone()
177 }
178}
179
180macro_rules! from_const {
181 ($lit:ty) => {
182 impl From<$lit> for ExpandElementTyped<$lit> {
183 fn from(value: $lit) -> Self {
184 let variable: Variable = value.into();
185
186 ExpandElement::Plain(variable).into()
187 }
188 }
189 };
190}
191
192from_const!(u8);
193from_const!(u16);
194from_const!(u32);
195from_const!(u64);
196from_const!(i64);
197from_const!(i8);
198from_const!(i16);
199from_const!(i32);
200from_const!(f64);
201from_const!(f16);
202from_const!(bf16);
203from_const!(flex32);
204from_const!(tf32);
205from_const!(f32);
206from_const!(e2m1);
207from_const!(e2m3);
208from_const!(e3m2);
209from_const!(e4m3);
210from_const!(e5m2);
211from_const!(ue8m0);
212from_const!(bool);
213
214macro_rules! tuple_cube_type {
215 ($($P:ident),*) => {
216 impl<$($P: CubeType),*> CubeType for ($($P,)*) {
217 type ExpandType = ($($P::ExpandType,)*);
218 }
219 }
220}
221macro_rules! tuple_init {
222 ($($P:ident),*) => {
223 impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
224 #[allow(non_snake_case, unused, clippy::unused_unit)]
225 fn into_mut(self, scope: &mut Scope) -> Self {
226 let ($($P,)*) = self;
227 ($(
228 $P.into_mut(scope),
229 )*)
230 }
231 }
232 }
233}
234macro_rules! tuple_debug {
235 ($($P:ident),*) => {
236 impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
237 }
238}
239macro_rules! tuple_runtime {
240 ($($P:ident),*) => {
241 impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
242 #[allow(non_snake_case, unused, clippy::unused_unit)]
243 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
244 let ($($P,)*) = self;
245 ($(
246 $P.__expand_runtime_method(scope),
247 )*)
248 }
249 }
250 }
251}
252
253all_tuples!(tuple_cube_type, 0, 12, P);
254all_tuples!(tuple_debug, 0, 12, P);
255all_tuples!(tuple_init, 0, 12, P);
256all_tuples!(tuple_runtime, 0, 12, P);
257
258impl<P: CubePrimitive> CubeDebug for P {}
259
260pub trait ExpandElementIntoMut: CubeType {
261 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
262}
263
264impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
265 fn into_mut(self, scope: &mut Scope) -> Self {
266 <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
267 }
268}
269
270impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
271 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
272 scope.update_variable_name(*self.expand, name);
273 }
274}
275
276impl<T: CubeType> ExpandElementTyped<T> {
277 pub fn __expand_vectorization_factor_method(self, _scope: &mut Scope) -> u32 {
279 self.expand
280 .item
281 .vectorization
282 .map(|it| it.get())
283 .unwrap_or(1) as u32
284 }
285
286 pub fn into_variable(self) -> Variable {
287 self.expand.consume()
288 }
289}
290
291impl<T: CubeType> Clone for ExpandElementTyped<T> {
292 fn clone(&self) -> Self {
293 Self {
294 expand: self.expand.clone(),
295 _type: PhantomData,
296 }
297 }
298}
299
300impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
301 fn from(expand: ExpandElement) -> Self {
302 Self {
303 expand,
304 _type: PhantomData,
305 }
306 }
307}
308
309impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
310 fn from(value: ExpandElementTyped<T>) -> Self {
311 value.expand
312 }
313}
314
315impl<T: CubePrimitive> ExpandElementTyped<T> {
316 pub fn from_lit<L: Into<Variable>>(scope: &Scope, lit: L) -> Self {
318 let variable: Variable = lit.into();
319 let variable = T::as_elem(scope).from_constant(variable);
320
321 ExpandElementTyped::new(ExpandElement::Plain(variable))
322 }
323
324 pub fn constant(&self) -> Option<ConstantScalarValue> {
326 match self.expand.kind {
327 VariableKind::ConstantScalar(val) => Some(val),
328 _ => None,
329 }
330 }
331}
332
333pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
334 scope: &mut Scope,
335 element: E,
336) -> ExpandElement {
337 let elem = element.into();
338
339 match elem.kind {
340 VariableKind::ConstantScalar { .. } => init_expand(scope, elem, false, Operation::Copy),
341 _ => elem,
342 }
343}
344
345pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
346 scope: &mut Scope,
347 element: E,
348) -> ExpandElement {
349 let elem = element.into();
350
351 let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
352
353 match elem.kind {
354 VariableKind::GlobalScalar { .. } => init(elem),
355 VariableKind::ConstantScalar { .. } => init(elem),
356 VariableKind::LocalMut { .. } => init(elem),
357 VariableKind::Versioned { .. } => init(elem),
358 VariableKind::LocalConst { .. } => init(elem),
359 VariableKind::Builtin(_) => init(elem),
360 VariableKind::SharedMemory { .. }
361 | VariableKind::GlobalInputArray { .. }
362 | VariableKind::GlobalOutputArray { .. }
363 | VariableKind::LocalArray { .. }
364 | VariableKind::ConstantArray { .. }
365 | VariableKind::Matrix { .. }
366 | VariableKind::Barrier { .. }
367 | VariableKind::Pipeline { .. }
368 | VariableKind::TensorMap(_) => elem,
369 }
370}
371
372impl IntoMut for ExpandElement {
373 fn into_mut(self, scope: &mut Scope) -> Self {
374 into_mut_expand_element(scope, self)
375 }
376}
377
378impl<T: IntoMut> IntoMut for Option<T> {
379 fn into_mut(self, scope: &mut Scope) -> Self {
380 self.map(|o| IntoMut::into_mut(o, scope))
381 }
382}
383
384impl<T: CubeType> CubeType for Vec<T> {
385 type ExpandType = Vec<T::ExpandType>;
386}
387
388impl<T: CubeType> CubeType for &mut Vec<T> {
389 type ExpandType = Vec<T::ExpandType>;
390}
391
392impl<T: IntoMut> IntoMut for Vec<T> {
393 fn into_mut(self, scope: &mut Scope) -> Self {
394 self.into_iter().map(|e| e.into_mut(scope)).collect()
395 }
396}
397impl<T: CubeDebug> CubeDebug for Vec<T> {}
398
399pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
401 scope: &mut Scope,
402 val: C,
403) -> ExpandElementTyped<Out> {
404 let input: ExpandElementTyped<C> = val.into();
405 let const_val = input.expand.as_const().unwrap();
406 let var = Variable::constant(const_val.cast_to(Out::as_elem(scope)));
407 ExpandElement::Plain(var).into()
408}
409
410impl LaunchArg for () {
411 type RuntimeArg<'a, R: Runtime> = ();
412
413 fn compilation_arg<'a, R: Runtime>(
414 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
415 ) -> Self::CompilationArg {
416 }
417}
418
419impl<R: Runtime> ArgSettings<R> for () {
420 fn register(&self, _launcher: &mut KernelLauncher<R>) {
421 }
423}
424
425impl LaunchArgExpand for () {
426 type CompilationArg = ();
427
428 fn expand(
429 _: &Self::CompilationArg,
430 _builder: &mut KernelBuilder,
431 ) -> <Self as CubeType>::ExpandType {
432 }
433}