Skip to main content

cubecl_core/frontend/element/
base.rs

1use super::{CubePrimitive, Numeric};
2use crate::{
3    ir::{ConstantValue, Operation, Scope, Variable, VariableKind},
4    prelude::{KernelBuilder, KernelLauncher, init_expand},
5};
6use alloc::{boxed::Box, vec::Vec};
7use core::marker::PhantomData;
8use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
9use cubecl_ir::{ExpandElement, LineSize};
10use cubecl_runtime::runtime::Runtime;
11use half::{bf16, f16};
12use variadics_please::all_tuples;
13
14/// Types used in a cube function must implement this trait
15///
16/// Variables whose values will be known at runtime must
17/// have `ExpandElement` as associated type
18/// Variables whose values will be known at compile time
19/// must have the primitive type as associated type
20///
21/// Note: Cube functions should be written using `CubeTypes`,
22/// so that the code generated uses the associated `ExpandType`.
23/// This allows Cube code to not necessitate cloning, which is cumbersome
24/// in algorithmic code. The necessary cloning will automatically appear in
25/// the generated code.
26#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
27pub trait CubeType {
28    type ExpandType: Clone + IntoMut + CubeDebug;
29
30    /// Wrapper around the init method, necessary to type inference.
31    fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
32        expand.into_mut(scope)
33    }
34}
35
36pub trait CloneExpand {
37    fn __expand_clone_method(&self, scope: &mut Scope) -> Self;
38}
39
40impl<C: Clone> CloneExpand for C {
41    fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
42        self.clone()
43    }
44}
45
46/// Trait useful to convert a comptime value into runtime value.
47pub trait IntoRuntime: CubeType + Sized {
48    fn runtime(self) -> Self {
49        self
50    }
51
52    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
53}
54
55/// Trait for marking a function return value as comptime when the compiler can't infer it.
56pub trait IntoComptime: Sized {
57    #[allow(clippy::wrong_self_convention)]
58    fn comptime(self) -> Self {
59        self
60    }
61}
62
63impl<T: Sized> IntoComptime for T {}
64
65/// Convert an expand type to a version with mutable registers when necessary.
66pub trait IntoMut: Sized {
67    fn into_mut(self, scope: &mut Scope) -> Self;
68}
69
70pub trait CubeDebug: Sized {
71    /// Set the debug name of this type's expansion. Should do nothing for types that don't appear
72    /// at runtime
73    #[allow(unused)]
74    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
75}
76
77/// A type that can be used as a kernel comptime argument.
78/// Note that a type doesn't need to implement `CubeComptime` to be used as
79/// a comptime argument. However, this facilitate the declaration of generic cube types.
80///
81/// # Example
82///
83/// ```ignore
84/// #[derive(CubeType)]
85/// pub struct Example<A: CubeType, B: CubeComptime> {
86///     a: A,
87///     #[cube(comptime)]
88///     b: B
89/// }
90/// ```
91pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
92impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
93
94/// Argument used during the compilation of kernels.
95pub trait CompilationArg:
96    Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
97{
98    /// Compilation args should be the same even with different element types. However, it isn't
99    /// possible to enforce it with the type system. So, we make the compilation args serializable
100    /// and dynamically cast them.
101    ///
102    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
103    /// since this is only done once when compiling a kernel for the first time.
104    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
105        // Dynamic cast, unlike transmute it does not require statically proving the types are the
106        // same size. We assert at runtime to avoid undefined behaviour and help the compiler optimize.
107        assert!(size_of::<Arg>() == size_of::<Self>());
108        let this = Box::new(self.clone());
109        unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
110    }
111}
112
113impl CompilationArg for () {}
114
115/// Defines how a [launch argument](LaunchArg) can be expanded.
116///
117/// TODO Verify the accuracy of the next comment.
118///
119/// Normally this type should be implemented two times for an argument.
120/// Once for the reference and the other for the mutable reference. Often time, the reference
121/// should expand the argument as an input while the mutable reference should expand the argument
122/// as an output.
123#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
124pub trait LaunchArg: CubeType + Send + Sync + 'static {
125    /// The runtime argument for the kernel.
126    type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
127    /// Compilation argument.
128    type CompilationArg: CompilationArg;
129
130    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
131
132    /// Register an input variable during compilation that fill the [`KernelBuilder`].
133    fn expand(
134        arg: &Self::CompilationArg,
135        builder: &mut KernelBuilder,
136    ) -> <Self as CubeType>::ExpandType;
137
138    /// Register an output variable during compilation that fill the [`KernelBuilder`].
139    fn expand_output(
140        arg: &Self::CompilationArg,
141        builder: &mut KernelBuilder,
142    ) -> <Self as CubeType>::ExpandType {
143        Self::expand(arg, builder)
144    }
145}
146
147/// Defines the argument settings used to launch a kernel.
148pub trait ArgSettings<R: Runtime>: Send + Sync {
149    /// Register the information of an argument to the [`KernelLauncher`].
150    fn register(&self, launcher: &mut KernelLauncher<R>);
151}
152
153macro_rules! launch_tuple {
154    ($(($T:ident, $t:ident)),*) => {
155        impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
156            type RuntimeArg<'a, R: Runtime> = ($($T::RuntimeArg<'a, R>),*);
157            type CompilationArg = ($($T::CompilationArg),*);
158
159            fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
160                let ($($t),*) = runtime_arg;
161                ($($T::compilation_arg($t)),*)
162            }
163
164            fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
165                let ($($t),*) = arg;
166                ($($T::expand($t, builder)),*)
167            }
168
169            fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
170                let ($($t),*) = arg;
171                ($($T::expand_output($t, builder)),*)
172            }
173        }
174
175        impl<$($T: CompilationArg),*> CompilationArg for ($($T),*) {}
176
177        impl<R: Runtime, $($T: ArgSettings<R>),*> ArgSettings<R> for ($($T),*) {
178            fn register(&self, launcher: &mut KernelLauncher<R>) {
179                let ($($t),*) = self;
180                $($t.register(launcher);)*
181            }
182        }
183    };
184}
185
186all_tuples!(launch_tuple, 2, 12, T, t);
187
188/// Expand type associated with a type.
189#[derive(new)]
190pub struct ExpandElementTyped<T: CubeType> {
191    pub expand: ExpandElement,
192    pub(crate) _type: PhantomData<T>,
193}
194
195impl<T: CubeType> ExpandElementTyped<T> {
196    /// Casts a reference of this expand element to a different type.
197    /// # Safety
198    /// There's no guarantee the new type is valid for the `ExpandElement`
199    pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &ExpandElementTyped<E> {
200        unsafe { core::mem::transmute::<&ExpandElementTyped<T>, &ExpandElementTyped<E>>(self) }
201    }
202
203    /// Casts a mutable reference of this expand element to a different type.
204    /// # Safety
205    /// There's no guarantee the new type is valid for the `ExpandElement`
206    pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut ExpandElementTyped<E> {
207        unsafe {
208            core::mem::transmute::<&mut ExpandElementTyped<T>, &mut ExpandElementTyped<E>>(self)
209        }
210    }
211}
212
213impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
214    fn from(value: &ExpandElementTyped<T>) -> Self {
215        value.clone()
216    }
217}
218
219impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
220    fn from(value: ExpandElementTyped<T>) -> Self {
221        value.expand.into()
222    }
223}
224
225impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
226    fn from(value: &mut ExpandElementTyped<T>) -> Self {
227        value.clone()
228    }
229}
230
231macro_rules! from_const {
232    ($lit:ty) => {
233        impl From<$lit> for ExpandElementTyped<$lit> {
234            fn from(value: $lit) -> Self {
235                let variable: Variable = value.into();
236
237                ExpandElement::Plain(variable).into()
238            }
239        }
240    };
241}
242
243from_const!(u8);
244from_const!(u16);
245from_const!(u32);
246from_const!(u64);
247from_const!(usize);
248from_const!(isize);
249from_const!(i64);
250from_const!(i8);
251from_const!(i16);
252from_const!(i32);
253from_const!(f64);
254from_const!(f16);
255from_const!(bf16);
256from_const!(flex32);
257from_const!(tf32);
258from_const!(f32);
259from_const!(e2m1);
260from_const!(e2m1x2);
261from_const!(e2m3);
262from_const!(e3m2);
263from_const!(e4m3);
264from_const!(e5m2);
265from_const!(ue8m0);
266from_const!(bool);
267
268macro_rules! tuple_cube_type {
269    ($($P:ident),*) => {
270        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
271            type ExpandType = ($($P::ExpandType,)*);
272        }
273    }
274}
275macro_rules! tuple_init {
276    ($($P:ident),*) => {
277        impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
278            #[allow(non_snake_case, unused, clippy::unused_unit)]
279            fn into_mut(self, scope: &mut Scope) -> Self {
280                let ($($P,)*) = self;
281                ($(
282                    $P.into_mut(scope),
283                )*)
284            }
285        }
286    }
287}
288macro_rules! tuple_debug {
289    ($($P:ident),*) => {
290        impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
291    }
292}
293macro_rules! tuple_runtime {
294    ($($P:ident),*) => {
295        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
296            #[allow(non_snake_case, unused, clippy::unused_unit)]
297            fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
298                let ($($P,)*) = self;
299                ($(
300                    $P.__expand_runtime_method(scope),
301                )*)
302            }
303        }
304    }
305}
306
307all_tuples!(tuple_cube_type, 0, 12, P);
308all_tuples!(tuple_debug, 0, 12, P);
309all_tuples!(tuple_init, 0, 12, P);
310all_tuples!(tuple_runtime, 0, 12, P);
311
312impl<P: CubePrimitive> CubeDebug for P {}
313
314pub trait ExpandElementIntoMut: CubeType {
315    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
316}
317
318impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
319    fn into_mut(self, scope: &mut Scope) -> Self {
320        <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
321    }
322}
323
324impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
325    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
326        scope.update_variable_name(*self.expand, name);
327    }
328}
329
330impl<T: CubeType> CubeDebug for &ExpandElementTyped<T> {
331    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
332        scope.update_variable_name(*self.expand, name);
333    }
334}
335
336impl<T: CubeType> CubeDebug for &mut ExpandElementTyped<T> {
337    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
338        scope.update_variable_name(*self.expand, name);
339    }
340}
341
342impl<T: CubeType> ExpandElementTyped<T> {
343    /// Comptime version of [`crate::frontend::Array::line_size`].
344    pub fn line_size(&self) -> LineSize {
345        self.expand.ty.line_size()
346    }
347
348    // Expanded version of vectorization factor.
349    pub fn __expand_line_size_method(self, _scope: &mut Scope) -> LineSize {
350        self.expand.ty.line_size()
351    }
352
353    pub fn into_variable(self) -> Variable {
354        self.expand.consume()
355    }
356}
357
358impl<T: CubeType> Clone for ExpandElementTyped<T> {
359    fn clone(&self) -> Self {
360        Self {
361            expand: self.expand.clone(),
362            _type: PhantomData,
363        }
364    }
365}
366
367impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
368    fn from(expand: ExpandElement) -> Self {
369        Self {
370            expand,
371            _type: PhantomData,
372        }
373    }
374}
375
376impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
377    fn from(value: ExpandElementTyped<T>) -> Self {
378        value.expand
379    }
380}
381
382impl<T: CubePrimitive> ExpandElementTyped<T> {
383    /// Create an [`ExpandElementTyped`] from a value that is normally a literal.
384    pub fn from_lit<L: Into<ConstantValue>>(scope: &Scope, lit: L) -> Self {
385        let variable: ConstantValue = lit.into();
386        let variable = T::as_type(scope).constant(variable);
387
388        ExpandElementTyped::new(ExpandElement::Plain(variable))
389    }
390
391    /// Get the [`ConstantValue`] from the variable.
392    pub fn constant(&self) -> Option<ConstantValue> {
393        match self.expand.kind {
394            VariableKind::Constant(val) => Some(val),
395            _ => None,
396        }
397    }
398
399    pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
400        let value = self.constant().unwrap();
401        T::from_const_value(value)
402    }
403}
404
405pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
406    scope: &mut Scope,
407    element: E,
408) -> ExpandElement {
409    let elem = element.into();
410
411    match elem.kind {
412        VariableKind::Constant { .. } => init_expand(scope, elem, false, Operation::Copy),
413        _ => elem,
414    }
415}
416
417pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
418    scope: &mut Scope,
419    element: E,
420) -> ExpandElement {
421    let elem = element.into();
422
423    let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
424
425    match elem.kind {
426        VariableKind::GlobalScalar { .. } => init(elem),
427        VariableKind::Constant { .. } => init(elem),
428        VariableKind::LocalMut { .. } => init(elem),
429        VariableKind::Versioned { .. } => init(elem),
430        VariableKind::LocalConst { .. } => init(elem),
431        VariableKind::Builtin(_) => init(elem),
432        VariableKind::Shared { .. }
433        | VariableKind::SharedArray { .. }
434        | VariableKind::GlobalInputArray { .. }
435        | VariableKind::GlobalOutputArray { .. }
436        | VariableKind::LocalArray { .. }
437        | VariableKind::ConstantArray { .. }
438        | VariableKind::Matrix { .. }
439        | VariableKind::BarrierToken { .. }
440        | VariableKind::Pipeline { .. }
441        | VariableKind::TensorMapOutput(_)
442        | VariableKind::TensorMapInput(_) => elem,
443    }
444}
445
446impl IntoMut for ExpandElement {
447    fn into_mut(self, scope: &mut Scope) -> Self {
448        into_mut_expand_element(scope, self)
449    }
450}
451
452impl<T: IntoMut> IntoMut for Option<T> {
453    fn into_mut(self, scope: &mut Scope) -> Self {
454        self.map(|o| IntoMut::into_mut(o, scope))
455    }
456}
457
458impl<T: CubeType> CubeType for Vec<T> {
459    type ExpandType = Vec<T::ExpandType>;
460}
461
462impl<T: CubeType> CubeType for &mut Vec<T> {
463    type ExpandType = Vec<T::ExpandType>;
464}
465
466impl<T: IntoMut> IntoMut for Vec<T> {
467    fn into_mut(self, scope: &mut Scope) -> Self {
468        self.into_iter().map(|e| e.into_mut(scope)).collect()
469    }
470}
471impl<T: CubeDebug> CubeDebug for Vec<T> {}
472
473/// Create a constant element of the correct type during expansion.
474pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
475    scope: &mut Scope,
476    val: C,
477) -> ExpandElementTyped<Out> {
478    let input: ConstantValue = val.into();
479    let var = Out::as_type(scope).constant(input);
480    ExpandElement::Plain(var).into()
481}
482
483impl LaunchArg for () {
484    type RuntimeArg<'a, R: Runtime> = ();
485    type CompilationArg = ();
486
487    fn compilation_arg<'a, R: Runtime>(
488        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
489    ) -> Self::CompilationArg {
490    }
491
492    fn expand(
493        _: &Self::CompilationArg,
494        _builder: &mut KernelBuilder,
495    ) -> <Self as CubeType>::ExpandType {
496    }
497}
498
499impl<R: Runtime> ArgSettings<R> for () {
500    fn register(&self, _launcher: &mut KernelLauncher<R>) {
501        // nothing to do
502    }
503}