cubecl_core/frontend/element/
base.rs1use super::{CubePrimitive, Numeric};
2use crate::{
3 ir::{ConstantValue, Operation, Scope, Variable, VariableKind},
4 prelude::{KernelBuilder, KernelLauncher, init_expand},
5};
6use alloc::{boxed::Box, vec::Vec};
7use core::marker::PhantomData;
8use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
9use cubecl_ir::{ExpandElement, LineSize};
10use cubecl_runtime::runtime::Runtime;
11use half::{bf16, f16};
12use variadics_please::all_tuples;
13
14#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
27pub trait CubeType {
28 type ExpandType: Clone + IntoMut + CubeDebug;
29
30 fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
32 expand.into_mut(scope)
33 }
34}
35
36pub trait CloneExpand {
37 fn __expand_clone_method(&self, scope: &mut Scope) -> Self;
38}
39
40impl<C: Clone> CloneExpand for C {
41 fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
42 self.clone()
43 }
44}
45
46pub trait IntoRuntime: CubeType + Sized {
48 fn runtime(self) -> Self {
49 self
50 }
51
52 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
53}
54
55pub trait IntoComptime: Sized {
57 #[allow(clippy::wrong_self_convention)]
58 fn comptime(self) -> Self {
59 self
60 }
61}
62
63impl<T: Sized> IntoComptime for T {}
64
65pub trait IntoMut: Sized {
67 fn into_mut(self, scope: &mut Scope) -> Self;
68}
69
70pub trait CubeDebug: Sized {
71 #[allow(unused)]
74 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
75}
76
77pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
92impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
93
94pub trait CompilationArg:
96 Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
97{
98 fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
105 assert!(size_of::<Arg>() == size_of::<Self>());
108 let this = Box::new(self.clone());
109 unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
110 }
111}
112
113impl CompilationArg for () {}
114
115#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
124pub trait LaunchArg: CubeType + Send + Sync + 'static {
125 type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
127 type CompilationArg: CompilationArg;
129
130 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
131
132 fn expand(
134 arg: &Self::CompilationArg,
135 builder: &mut KernelBuilder,
136 ) -> <Self as CubeType>::ExpandType;
137
138 fn expand_output(
140 arg: &Self::CompilationArg,
141 builder: &mut KernelBuilder,
142 ) -> <Self as CubeType>::ExpandType {
143 Self::expand(arg, builder)
144 }
145}
146
147pub trait ArgSettings<R: Runtime>: Send + Sync {
149 fn register(&self, launcher: &mut KernelLauncher<R>);
151}
152
153macro_rules! launch_tuple {
154 ($(($T:ident, $t:ident)),*) => {
155 impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
156 type RuntimeArg<'a, R: Runtime> = ($($T::RuntimeArg<'a, R>),*);
157 type CompilationArg = ($($T::CompilationArg),*);
158
159 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
160 let ($($t),*) = runtime_arg;
161 ($($T::compilation_arg($t)),*)
162 }
163
164 fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
165 let ($($t),*) = arg;
166 ($($T::expand($t, builder)),*)
167 }
168
169 fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
170 let ($($t),*) = arg;
171 ($($T::expand_output($t, builder)),*)
172 }
173 }
174
175 impl<$($T: CompilationArg),*> CompilationArg for ($($T),*) {}
176
177 impl<R: Runtime, $($T: ArgSettings<R>),*> ArgSettings<R> for ($($T),*) {
178 fn register(&self, launcher: &mut KernelLauncher<R>) {
179 let ($($t),*) = self;
180 $($t.register(launcher);)*
181 }
182 }
183 };
184}
185
186all_tuples!(launch_tuple, 2, 12, T, t);
187
188#[derive(new)]
190pub struct ExpandElementTyped<T: CubeType> {
191 pub expand: ExpandElement,
192 pub(crate) _type: PhantomData<T>,
193}
194
195impl<T: CubeType> ExpandElementTyped<T> {
196 pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &ExpandElementTyped<E> {
200 unsafe { core::mem::transmute::<&ExpandElementTyped<T>, &ExpandElementTyped<E>>(self) }
201 }
202
203 pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut ExpandElementTyped<E> {
207 unsafe {
208 core::mem::transmute::<&mut ExpandElementTyped<T>, &mut ExpandElementTyped<E>>(self)
209 }
210 }
211}
212
213impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
214 fn from(value: &ExpandElementTyped<T>) -> Self {
215 value.clone()
216 }
217}
218
219impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
220 fn from(value: ExpandElementTyped<T>) -> Self {
221 value.expand.into()
222 }
223}
224
225impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
226 fn from(value: &mut ExpandElementTyped<T>) -> Self {
227 value.clone()
228 }
229}
230
231macro_rules! from_const {
232 ($lit:ty) => {
233 impl From<$lit> for ExpandElementTyped<$lit> {
234 fn from(value: $lit) -> Self {
235 let variable: Variable = value.into();
236
237 ExpandElement::Plain(variable).into()
238 }
239 }
240 };
241}
242
243from_const!(u8);
244from_const!(u16);
245from_const!(u32);
246from_const!(u64);
247from_const!(usize);
248from_const!(isize);
249from_const!(i64);
250from_const!(i8);
251from_const!(i16);
252from_const!(i32);
253from_const!(f64);
254from_const!(f16);
255from_const!(bf16);
256from_const!(flex32);
257from_const!(tf32);
258from_const!(f32);
259from_const!(e2m1);
260from_const!(e2m1x2);
261from_const!(e2m3);
262from_const!(e3m2);
263from_const!(e4m3);
264from_const!(e5m2);
265from_const!(ue8m0);
266from_const!(bool);
267
268macro_rules! tuple_cube_type {
269 ($($P:ident),*) => {
270 impl<$($P: CubeType),*> CubeType for ($($P,)*) {
271 type ExpandType = ($($P::ExpandType,)*);
272 }
273 }
274}
275macro_rules! tuple_init {
276 ($($P:ident),*) => {
277 impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
278 #[allow(non_snake_case, unused, clippy::unused_unit)]
279 fn into_mut(self, scope: &mut Scope) -> Self {
280 let ($($P,)*) = self;
281 ($(
282 $P.into_mut(scope),
283 )*)
284 }
285 }
286 }
287}
288macro_rules! tuple_debug {
289 ($($P:ident),*) => {
290 impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
291 }
292}
293macro_rules! tuple_runtime {
294 ($($P:ident),*) => {
295 impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
296 #[allow(non_snake_case, unused, clippy::unused_unit)]
297 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
298 let ($($P,)*) = self;
299 ($(
300 $P.__expand_runtime_method(scope),
301 )*)
302 }
303 }
304 }
305}
306
307all_tuples!(tuple_cube_type, 0, 12, P);
308all_tuples!(tuple_debug, 0, 12, P);
309all_tuples!(tuple_init, 0, 12, P);
310all_tuples!(tuple_runtime, 0, 12, P);
311
312impl<P: CubePrimitive> CubeDebug for P {}
313
314pub trait ExpandElementIntoMut: CubeType {
315 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
316}
317
318impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
319 fn into_mut(self, scope: &mut Scope) -> Self {
320 <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
321 }
322}
323
324impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
325 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
326 scope.update_variable_name(*self.expand, name);
327 }
328}
329
330impl<T: CubeType> CubeDebug for &ExpandElementTyped<T> {
331 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
332 scope.update_variable_name(*self.expand, name);
333 }
334}
335
336impl<T: CubeType> CubeDebug for &mut ExpandElementTyped<T> {
337 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
338 scope.update_variable_name(*self.expand, name);
339 }
340}
341
342impl<T: CubeType> ExpandElementTyped<T> {
343 pub fn line_size(&self) -> LineSize {
345 self.expand.ty.line_size()
346 }
347
348 pub fn __expand_line_size_method(self, _scope: &mut Scope) -> LineSize {
350 self.expand.ty.line_size()
351 }
352
353 pub fn into_variable(self) -> Variable {
354 self.expand.consume()
355 }
356}
357
358impl<T: CubeType> Clone for ExpandElementTyped<T> {
359 fn clone(&self) -> Self {
360 Self {
361 expand: self.expand.clone(),
362 _type: PhantomData,
363 }
364 }
365}
366
367impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
368 fn from(expand: ExpandElement) -> Self {
369 Self {
370 expand,
371 _type: PhantomData,
372 }
373 }
374}
375
376impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
377 fn from(value: ExpandElementTyped<T>) -> Self {
378 value.expand
379 }
380}
381
382impl<T: CubePrimitive> ExpandElementTyped<T> {
383 pub fn from_lit<L: Into<ConstantValue>>(scope: &Scope, lit: L) -> Self {
385 let variable: ConstantValue = lit.into();
386 let variable = T::as_type(scope).constant(variable);
387
388 ExpandElementTyped::new(ExpandElement::Plain(variable))
389 }
390
391 pub fn constant(&self) -> Option<ConstantValue> {
393 match self.expand.kind {
394 VariableKind::Constant(val) => Some(val),
395 _ => None,
396 }
397 }
398
399 pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
400 let value = self.constant().unwrap();
401 T::from_const_value(value)
402 }
403}
404
405pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
406 scope: &mut Scope,
407 element: E,
408) -> ExpandElement {
409 let elem = element.into();
410
411 match elem.kind {
412 VariableKind::Constant { .. } => init_expand(scope, elem, false, Operation::Copy),
413 _ => elem,
414 }
415}
416
417pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
418 scope: &mut Scope,
419 element: E,
420) -> ExpandElement {
421 let elem = element.into();
422
423 let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
424
425 match elem.kind {
426 VariableKind::GlobalScalar { .. } => init(elem),
427 VariableKind::Constant { .. } => init(elem),
428 VariableKind::LocalMut { .. } => init(elem),
429 VariableKind::Versioned { .. } => init(elem),
430 VariableKind::LocalConst { .. } => init(elem),
431 VariableKind::Builtin(_) => init(elem),
432 VariableKind::Shared { .. }
433 | VariableKind::SharedArray { .. }
434 | VariableKind::GlobalInputArray { .. }
435 | VariableKind::GlobalOutputArray { .. }
436 | VariableKind::LocalArray { .. }
437 | VariableKind::ConstantArray { .. }
438 | VariableKind::Matrix { .. }
439 | VariableKind::BarrierToken { .. }
440 | VariableKind::Pipeline { .. }
441 | VariableKind::TensorMapOutput(_)
442 | VariableKind::TensorMapInput(_) => elem,
443 }
444}
445
446impl IntoMut for ExpandElement {
447 fn into_mut(self, scope: &mut Scope) -> Self {
448 into_mut_expand_element(scope, self)
449 }
450}
451
452impl<T: IntoMut> IntoMut for Option<T> {
453 fn into_mut(self, scope: &mut Scope) -> Self {
454 self.map(|o| IntoMut::into_mut(o, scope))
455 }
456}
457
458impl<T: CubeType> CubeType for Vec<T> {
459 type ExpandType = Vec<T::ExpandType>;
460}
461
462impl<T: CubeType> CubeType for &mut Vec<T> {
463 type ExpandType = Vec<T::ExpandType>;
464}
465
466impl<T: IntoMut> IntoMut for Vec<T> {
467 fn into_mut(self, scope: &mut Scope) -> Self {
468 self.into_iter().map(|e| e.into_mut(scope)).collect()
469 }
470}
471impl<T: CubeDebug> CubeDebug for Vec<T> {}
472
473pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
475 scope: &mut Scope,
476 val: C,
477) -> ExpandElementTyped<Out> {
478 let input: ConstantValue = val.into();
479 let var = Out::as_type(scope).constant(input);
480 ExpandElement::Plain(var).into()
481}
482
483impl LaunchArg for () {
484 type RuntimeArg<'a, R: Runtime> = ();
485 type CompilationArg = ();
486
487 fn compilation_arg<'a, R: Runtime>(
488 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
489 ) -> Self::CompilationArg {
490 }
491
492 fn expand(
493 _: &Self::CompilationArg,
494 _builder: &mut KernelBuilder,
495 ) -> <Self as CubeType>::ExpandType {
496 }
497}
498
499impl<R: Runtime> ArgSettings<R> for () {
500 fn register(&self, _launcher: &mut KernelLauncher<R>) {
501 }
503}