Skip to main content

cubecl_core/frontend/element/
base.rs

1use super::{CubePrimitive, Numeric};
2use crate::{
3    ir::{ConstantValue, Scope, Variable, VariableKind},
4    prelude::{DynamicSize, KernelBuilder, KernelLauncher, assign},
5    unexpanded,
6};
7use alloc::{boxed::Box, vec::Vec};
8use core::marker::PhantomData;
9use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
10use cubecl_ir::{ManagedVariable, VectorSize};
11use cubecl_runtime::runtime::Runtime;
12use half::{bf16, f16};
13use variadics_please::{all_tuples, all_tuples_enumerated};
14
15/// Types used in a cube function must implement this trait
16///
17/// Variables whose values will be known at runtime must
18/// have `ManagedVariable` as associated type
19/// Variables whose values will be known at compile time
20/// must have the primitive type as associated type
21///
22/// Note: Cube functions should be written using `CubeTypes`,
23/// so that the code generated uses the associated `ExpandType`.
24/// This allows Cube code to not necessitate cloning, which is cumbersome
25/// in algorithmic code. The necessary cloning will automatically appear in
26/// the generated code.
27#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
28pub trait CubeType {
29    type ExpandType: Clone + IntoMut + CubeDebug;
30}
31
32pub trait CubeEnum: Sized {
33    type RuntimeValue: Clone + CubeDebug;
34
35    fn discriminant(&self) -> NativeExpand<i32>;
36
37    /// Return the runtime value of this enum, if only one variant has a value.
38    /// Should return () for all other cases.
39    fn runtime_value(self) -> Self::RuntimeValue;
40
41    fn discriminant_of_value(&self, variant_name: &'static str) -> i32 {
42        Self::discriminant_of(variant_name)
43    }
44
45    fn discriminant_of(variant_name: &'static str) -> i32;
46}
47
48pub trait Assign {
49    /// Assign `value` to `self` in `scope`.
50    fn expand_assign(&mut self, scope: &mut Scope, value: Self);
51    /// Create a new mutable variable of this type in `scope`.
52    fn init_mut(&self, scope: &mut Scope) -> Self;
53}
54
55impl<T: CubePrimitive> Assign for T {
56    fn expand_assign(&mut self, _scope: &mut Scope, value: Self) {
57        *self = value;
58    }
59    fn init_mut(&self, _scope: &mut Scope) -> Self {
60        *self
61    }
62}
63
64impl<T: NativeAssign> Assign for NativeExpand<T> {
65    fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
66        assign::expand(scope, value, self.clone());
67    }
68    fn init_mut(&self, scope: &mut Scope) -> Self {
69        T::elem_init_mut(scope, self.expand.clone()).into()
70    }
71}
72
73impl<T: Assign> Assign for Option<T> {
74    fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
75        match (self, value) {
76            (Some(this), Some(other)) => this.expand_assign(scope, other),
77            (None, None) => {}
78            _ => panic!("Can't assign mismatched enum variants"),
79        }
80    }
81    fn init_mut(&self, scope: &mut Scope) -> Self {
82        self.as_ref().map(|value| value.init_mut(scope))
83    }
84}
85
86impl<T: Assign> Assign for Vec<T> {
87    fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
88        assert!(
89            self.len() == value.len(),
90            "Can't assign mismatched vector lengths"
91        );
92        for (this, other) in self.iter_mut().zip(value) {
93            this.expand_assign(scope, other);
94        }
95    }
96    fn init_mut(&self, scope: &mut Scope) -> Self {
97        self.iter().map(|it| it.init_mut(scope)).collect()
98    }
99}
100
101pub trait CloneExpand {
102    fn __expand_clone_method(&self, scope: &mut Scope) -> Self;
103}
104
105impl<C: Clone> CloneExpand for C {
106    fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
107        self.clone()
108    }
109}
110
111/// Trait useful to convert a comptime value into runtime value.
112pub trait IntoRuntime: CubeType + Sized {
113    fn runtime(self) -> Self {
114        self
115    }
116
117    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
118}
119
120/// Trait for marking a function return value as comptime when the compiler can't infer it.
121pub trait IntoComptime: Sized {
122    #[allow(clippy::wrong_self_convention)]
123    fn comptime(self) -> Self {
124        self
125    }
126}
127
128impl<T: Sized> IntoComptime for T {}
129
130/// Convert an expand type to a version with mutable registers when necessary.
131pub trait IntoMut: Sized {
132    /// Convert the variable into a potentially new mutable variable in `scope`, copying if needed.
133    fn into_mut(self, scope: &mut Scope) -> Self;
134}
135
136pub fn into_mut_assign<T: Assign>(value: T, scope: &mut Scope) -> T {
137    let mut out = value.init_mut(scope);
138    out.expand_assign(scope, value);
139    out
140}
141
142pub trait CubeDebug: Sized {
143    /// Set the debug name of this type's expansion. Should do nothing for types that don't appear
144    /// at runtime
145    #[allow(unused)]
146    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
147}
148
149/// A type that can be used as a kernel comptime argument.
150/// Note that a type doesn't need to implement `CubeComptime` to be used as
151/// a comptime argument. However, this facilitate the declaration of generic cube types.
152///
153/// # Example
154///
155/// ```ignore
156/// #[derive(CubeType)]
157/// pub struct Example<A: CubeType, B: CubeComptime> {
158///     a: A,
159///     #[cube(comptime)]
160///     b: B
161/// }
162/// ```
163pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
164impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
165
166/// Argument used during the compilation of kernels.
167pub trait CompilationArg:
168    Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
169{
170    /// Compilation args should be the same even with different element types. However, it isn't
171    /// possible to enforce it with the type system. So, we make the compilation args serializable
172    /// and dynamically cast them.
173    ///
174    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
175    /// since this is only done once when compiling a kernel for the first time.
176    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
177        // Dynamic cast, unlike transmute it does not require statically proving the types are the
178        // same size. We assert at runtime to avoid undefined behaviour and help the compiler optimize.
179        assert!(size_of::<Arg>() == size_of::<Self>());
180        let this = Box::new(self.clone());
181        unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
182    }
183}
184
185impl<T: Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static>
186    CompilationArg for T
187{
188}
189
190/// Defines how a [launch argument](LaunchArg) can be expanded.
191///
192/// TODO Verify the accuracy of the next comment.
193///
194/// Normally this type should be implemented two times for an argument.
195/// Once for the reference and the other for the mutable reference. Often time, the reference
196/// should expand the argument as an input while the mutable reference should expand the argument
197/// as an output.
198#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
199pub trait LaunchArg: CubeType + Send + Sync + 'static {
200    /// The runtime argument for the kernel.
201    type RuntimeArg<R: Runtime>: Send + Sync;
202    /// Compilation argument.
203    type CompilationArg: CompilationArg;
204
205    fn register<R: Runtime>(
206        arg: Self::RuntimeArg<R>,
207        launcher: &mut KernelLauncher<R>,
208    ) -> Self::CompilationArg;
209
210    /// Register an input variable during compilation that fill the [`KernelBuilder`].
211    fn expand(
212        arg: &Self::CompilationArg,
213        builder: &mut KernelBuilder,
214    ) -> <Self as CubeType>::ExpandType;
215
216    /// Register an output variable during compilation that fill the [`KernelBuilder`].
217    fn expand_output(
218        arg: &Self::CompilationArg,
219        builder: &mut KernelBuilder,
220    ) -> <Self as CubeType>::ExpandType {
221        Self::expand(arg, builder)
222    }
223}
224
225macro_rules! launch_tuple {
226    ($(($T:ident, $t:ident)),*) => {
227        impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
228            type RuntimeArg<R: Runtime> = ($($T::RuntimeArg<R>),*);
229            type CompilationArg = ($($T::CompilationArg),*);
230
231            fn register<R: Runtime>(runtime_arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) -> Self::CompilationArg {
232                let ($($t),*) = runtime_arg;
233                ($($T::register($t, launcher)),*)
234            }
235
236            fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
237                let ($($t),*) = arg;
238                ($($T::expand($t, builder)),*)
239            }
240
241            fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
242                let ($($t),*) = arg;
243                ($($T::expand_output($t, builder)),*)
244            }
245        }
246    };
247}
248
249all_tuples!(launch_tuple, 2, 12, T, t);
250
251/// Expand type of a native GPU type, i.e. scalar primitives, arrays, shared memory.
252#[derive(new)]
253pub struct NativeExpand<T: CubeType> {
254    pub expand: ManagedVariable,
255    pub(crate) _type: PhantomData<T>,
256}
257
258impl<T: CubeType> NativeExpand<T> {
259    /// Casts a reference of this expand element to a different type.
260    /// # Safety
261    /// There's no guarantee the new type is valid for the `ManagedVariable`
262    pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &NativeExpand<E> {
263        unsafe { core::mem::transmute::<&NativeExpand<T>, &NativeExpand<E>>(self) }
264    }
265
266    /// Casts a mutable reference of this expand element to a different type.
267    /// # Safety
268    /// There's no guarantee the new type is valid for the `ManagedVariable`
269    pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut NativeExpand<E> {
270        unsafe { core::mem::transmute::<&mut NativeExpand<T>, &mut NativeExpand<E>>(self) }
271    }
272}
273
274impl<T: CubeType> From<&NativeExpand<T>> for NativeExpand<T> {
275    fn from(value: &NativeExpand<T>) -> Self {
276        value.clone()
277    }
278}
279
280impl<T: CubeType> From<NativeExpand<T>> for Variable {
281    fn from(value: NativeExpand<T>) -> Self {
282        value.expand.into()
283    }
284}
285
286impl<T: CubeType> From<&mut NativeExpand<T>> for NativeExpand<T> {
287    fn from(value: &mut NativeExpand<T>) -> Self {
288        value.clone()
289    }
290}
291
292macro_rules! from_const {
293    ($lit:ty) => {
294        impl From<$lit> for NativeExpand<$lit> {
295            fn from(value: $lit) -> Self {
296                let variable: Variable = value.into();
297
298                ManagedVariable::Plain(variable).into()
299            }
300        }
301    };
302}
303
304from_const!(u8);
305from_const!(u16);
306from_const!(u32);
307from_const!(u64);
308from_const!(usize);
309from_const!(isize);
310from_const!(i64);
311from_const!(i8);
312from_const!(i16);
313from_const!(i32);
314from_const!(f64);
315from_const!(f16);
316from_const!(bf16);
317from_const!(flex32);
318from_const!(tf32);
319from_const!(f32);
320from_const!(e2m1);
321from_const!(e2m1x2);
322from_const!(e2m3);
323from_const!(e3m2);
324from_const!(e4m3);
325from_const!(e5m2);
326from_const!(ue8m0);
327from_const!(bool);
328
329macro_rules! tuple_cube_type {
330    ($($P:ident),*) => {
331        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
332            type ExpandType = ($($P::ExpandType,)*);
333        }
334    }
335}
336macro_rules! tuple_init {
337    ($($P:ident),*) => {
338        impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
339            #[allow(non_snake_case, unused, clippy::unused_unit)]
340            fn into_mut(self, scope: &mut Scope) -> Self {
341                let ($($P,)*) = self;
342                ($(
343                    $P.into_mut(scope),
344                )*)
345            }
346        }
347    }
348}
349macro_rules! tuple_debug {
350    ($($P:ident),*) => {
351        impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
352    }
353}
354macro_rules! tuple_runtime {
355    ($($P:ident),*) => {
356        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
357            #[allow(non_snake_case, unused, clippy::unused_unit)]
358            fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
359                let ($($P,)*) = self;
360                ($(
361                    $P.__expand_runtime_method(scope),
362                )*)
363            }
364        }
365    }
366}
367macro_rules! tuple_assign {
368    ($(($n: tt, $P:ident)),*) => {
369        impl<$($P: Assign),*> Assign for ($($P,)*) {
370            #[allow(non_snake_case, unused, clippy::unused_unit)]
371            fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
372                let ($($P,)*) = self;
373                $(
374                    $P.expand_assign(scope, value.$n);
375                )*
376            }
377            #[allow(non_snake_case, unused, clippy::unused_unit)]
378            fn init_mut(&self, scope: &mut Scope) -> Self {
379                let ($($P,)*) = self;
380                ($(
381                    $P.init_mut(scope),
382                )*)
383            }
384        }
385    }
386}
387
388all_tuples!(tuple_cube_type, 0, 12, P);
389all_tuples!(tuple_debug, 0, 12, P);
390all_tuples!(tuple_init, 0, 12, P);
391all_tuples!(tuple_runtime, 0, 12, P);
392all_tuples_enumerated!(tuple_assign, 0, 12, P);
393
394impl<P: CubePrimitive> CubeDebug for P {}
395
396/// Trait for native types that can be assigned. For non-native composites, use the normal [`Assign`].
397pub trait NativeAssign: CubeType {
398    fn elem_init_mut(scope: &mut Scope, elem: ManagedVariable) -> ManagedVariable {
399        init_mut_expand_element(scope, &elem)
400    }
401}
402
403impl<T: NativeAssign> IntoMut for NativeExpand<T> {
404    fn into_mut(self, scope: &mut Scope) -> Self {
405        into_mut_assign(self, scope)
406    }
407}
408
409impl<T: CubeType> CubeDebug for NativeExpand<T> {
410    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
411        scope.update_variable_name(*self.expand, name);
412    }
413}
414
415impl<T: CubeType> CubeDebug for &NativeExpand<T> {
416    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
417        scope.update_variable_name(*self.expand, name);
418    }
419}
420
421impl<T: CubeType> CubeDebug for &mut NativeExpand<T> {
422    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
423        scope.update_variable_name(*self.expand, name);
424    }
425}
426
427impl<T: CubeType> NativeExpand<T> {
428    /// Comptime version of [`crate::frontend::Array::vector_size`].
429    pub fn vector_size(&self) -> VectorSize {
430        self.expand.ty.vector_size()
431    }
432
433    // Expanded version of vectorization factor.
434    pub fn __expand_vector_size_method(self, _scope: &mut Scope) -> VectorSize {
435        self.expand.ty.vector_size()
436    }
437
438    pub fn into_variable(self) -> Variable {
439        self.expand.consume()
440    }
441}
442
443impl<T: CubeType> Clone for NativeExpand<T> {
444    fn clone(&self) -> Self {
445        Self {
446            expand: self.expand.clone(),
447            _type: PhantomData,
448        }
449    }
450}
451
452impl<T: CubeType> From<ManagedVariable> for NativeExpand<T> {
453    fn from(expand: ManagedVariable) -> Self {
454        Self {
455            expand,
456            _type: PhantomData,
457        }
458    }
459}
460
461impl<T: CubeType> From<NativeExpand<T>> for ManagedVariable {
462    fn from(value: NativeExpand<T>) -> Self {
463        value.expand
464    }
465}
466
467impl<T: CubePrimitive> NativeExpand<T> {
468    /// Create an [`NativeExpand`] from a value that is normally a literal.
469    pub fn from_lit<L: Into<ConstantValue>>(scope: &Scope, lit: L) -> Self {
470        let variable: ConstantValue = lit.into();
471        let variable = T::as_type(scope).constant(variable);
472
473        NativeExpand::new(ManagedVariable::Plain(variable))
474    }
475
476    /// Get the [`ConstantValue`] from the variable.
477    pub fn constant(&self) -> Option<ConstantValue> {
478        match self.expand.kind {
479            VariableKind::Constant(val) => Some(val),
480            _ => None,
481        }
482    }
483
484    pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
485        let value = self.constant().unwrap();
486        T::from_const_value(value)
487    }
488}
489
490pub(crate) fn init_mut_expand_element(
491    scope: &mut Scope,
492    element: &ManagedVariable,
493) -> ManagedVariable {
494    scope.create_local_mut(element.ty)
495}
496
497impl<T: IntoMut> IntoMut for Option<T> {
498    fn into_mut(self, scope: &mut Scope) -> Self {
499        self.map(|o| IntoMut::into_mut(o, scope))
500    }
501}
502
503impl<T: CubeType> CubeType for Vec<T> {
504    type ExpandType = Vec<T::ExpandType>;
505}
506
507impl<T: CubeType> CubeType for &mut Vec<T> {
508    type ExpandType = Vec<T::ExpandType>;
509}
510
511impl<T: IntoMut> IntoMut for Vec<T> {
512    fn into_mut(self, scope: &mut Scope) -> Self {
513        self.into_iter().map(|e| e.into_mut(scope)).collect()
514    }
515}
516impl<T: CubeDebug> CubeDebug for Vec<T> {}
517
518/// Create a constant element of the correct type during expansion.
519pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
520    scope: &mut Scope,
521    val: C,
522) -> NativeExpand<Out> {
523    let input: ConstantValue = val.into();
524    let var = Out::as_type(scope).constant(input);
525    ManagedVariable::Plain(var).into()
526}
527
528impl LaunchArg for () {
529    type RuntimeArg<R: Runtime> = ();
530    type CompilationArg = ();
531
532    fn register<R: Runtime>(_runtime_arg: Self::RuntimeArg<R>, _launcher: &mut KernelLauncher<R>) {
533        // nothing to do
534    }
535
536    fn expand(
537        _: &Self::CompilationArg,
538        _builder: &mut KernelBuilder,
539    ) -> <Self as CubeType>::ExpandType {
540    }
541}
542
543pub trait DefaultExpand: CubeType {
544    fn __expand_default(scope: &mut Scope) -> Self::ExpandType;
545}
546
547impl<T: CubeType + Default + IntoRuntime> DefaultExpand for T {
548    fn __expand_default(scope: &mut Scope) -> T::ExpandType {
549        T::default().__expand_runtime_method(scope)
550    }
551}
552
553#[derive(Clone, Copy, Debug)]
554pub struct Const<const N: usize>;
555
556pub trait Size: core::fmt::Debug + Clone + Copy + Send + Sync + 'static {
557    fn __expand_value(scope: &Scope) -> usize;
558    fn value() -> usize {
559        unexpanded!()
560    }
561    fn try_value_const() -> Option<usize> {
562        None
563    }
564}
565
566impl<const VALUE: usize> Size for Const<VALUE> {
567    fn __expand_value(_scope: &Scope) -> usize {
568        VALUE
569    }
570    fn value() -> usize {
571        VALUE
572    }
573    fn try_value_const() -> Option<usize> {
574        Some(VALUE)
575    }
576}
577
578impl<Marker: 'static> Size for DynamicSize<Marker> {
579    fn __expand_value(scope: &Scope) -> usize {
580        scope.resolve_size::<Self>().expect("Size to be registered")
581    }
582    fn value() -> usize {
583        unexpanded!()
584    }
585}
586
587/// Define a custom type to be used for a comptime scalar type.
588/// Useful for cases where generics can't work.
589#[macro_export]
590macro_rules! define_scalar {
591    ($vis: vis $name: ident) => {
592        $crate::__private::paste! {
593            $vis struct [<__ $name>];
594            $vis type $name = $crate::prelude::DynamicScalar<[<__ $name>]>;
595        }
596    };
597}
598
599/// Define a custom type to be used for a comptime size. Useful for cases where generics can't work.
600#[macro_export]
601macro_rules! define_size {
602    ($vis: vis $name: ident) => {
603        $crate::__private::paste! {
604            $vis struct [<__ $name>];
605            $vis type $name = $crate::prelude::DynamicSize<[<__ $name>]>;
606        }
607    };
608}