cubecl_core/frontend/element/
base.rs

1use super::{CubePrimitive, Numeric};
2use crate::{
3    ir::{ConstantScalarValue, Operation, Scope, Variable, VariableKind},
4    prelude::{KernelBuilder, KernelLauncher, init_expand},
5};
6use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
7use cubecl_ir::ExpandElement;
8use cubecl_runtime::runtime::Runtime;
9use half::{bf16, f16};
10use std::marker::PhantomData;
11use variadics_please::all_tuples;
12
13/// Types used in a cube function must implement this trait
14///
15/// Variables whose values will be known at runtime must
16/// have ExpandElement as associated type
17/// Variables whose values will be known at compile time
18/// must have the primitive type as associated type
19///
20/// Note: Cube functions should be written using CubeTypes,
21/// so that the code generated uses the associated ExpandType.
22/// This allows Cube code to not necessitate cloning, which is cumbersome
23/// in algorithmic code. The necessary cloning will automatically appear in
24/// the generated code.
25#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
26pub trait CubeType {
27    type ExpandType: Clone + IntoMut + CubeDebug;
28
29    /// Wrapper around the init method, necessary to type inference.
30    fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
31        expand.into_mut(scope)
32    }
33}
34
35/// Trait useful to convert a comptime value into runtime value.
36pub trait IntoRuntime: CubeType + Sized {
37    fn runtime(self) -> Self {
38        self
39    }
40
41    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
42}
43
44/// Convert an expand type to a version with mutable registers when necessary.
45pub trait IntoMut: Sized {
46    fn into_mut(self, scope: &mut Scope) -> Self;
47}
48
49pub trait CubeDebug: Sized {
50    /// Set the debug name of this type's expansion. Should do nothing for types that don't appear
51    /// at runtime
52    #[allow(unused)]
53    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
54}
55
56/// A type that can be used as a kernel comptime argument.
57/// Note that a type doesn't need to implement `CubeComptime` to be used as
58/// a comptime argument. However, this facilitate the declaration of generic cube types.
59///
60/// # Example
61///
62/// ```ignore
63/// #[derive(CubeType)]
64/// pub struct Example<A: CubeType, B: CubeComptime> {
65///     a: A,
66///     #[cube(comptime)]
67///     b: B
68/// }
69/// ```
70pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
71impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
72
73/// Argument used during the compilation of kernels.
74pub trait CompilationArg:
75    Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
76{
77    /// Compilation args should be the same even with different element types. However, it isn't
78    /// possible to enforce it with the type system. So, we make the compilation args serializable
79    /// and dynamically cast them.
80    ///
81    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
82    /// since this is only done once when compiling a kernel for the first time.
83    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
84        // Dynamic cast, unlike transmute it does not require statically proving the types are the
85        // same size. We assert at runtime to avoid undefined behaviour and help the compiler optimize.
86        assert!(size_of::<Arg>() == size_of::<Self>());
87        let this = Box::new(self.clone());
88        unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
89    }
90}
91
92impl CompilationArg for () {}
93
94/// Defines how a [launch argument](LaunchArg) can be expanded.
95///
96/// TODO Verify the accuracy of the next comment.
97///
98/// Normally this type should be implemented two times for an argument.
99/// Once for the reference and the other for the mutable reference. Often time, the reference
100/// should expand the argument as an input while the mutable reference should expand the argument
101/// as an output.
102#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
103pub trait LaunchArg: CubeType + Send + Sync + 'static {
104    /// The runtime argument for the kernel.
105    type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
106    /// Compilation argument.
107    type CompilationArg: CompilationArg;
108
109    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
110
111    /// Register an input variable during compilation that fill the [KernelBuilder].
112    fn expand(
113        arg: &Self::CompilationArg,
114        builder: &mut KernelBuilder,
115    ) -> <Self as CubeType>::ExpandType;
116
117    /// Register an output variable during compilation that fill the [KernelBuilder].
118    fn expand_output(
119        arg: &Self::CompilationArg,
120        builder: &mut KernelBuilder,
121    ) -> <Self as CubeType>::ExpandType {
122        Self::expand(arg, builder)
123    }
124}
125
126/// Defines the argument settings used to launch a kernel.
127pub trait ArgSettings<R: Runtime>: Send + Sync {
128    /// Register the information of an argument to the [KernelLauncher].
129    fn register(&self, launcher: &mut KernelLauncher<R>);
130}
131
132macro_rules! launch_tuple {
133    ($(($T:ident, $t:ident)),*) => {
134        impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
135            type RuntimeArg<'a, R: Runtime> = ($($T::RuntimeArg<'a, R>),*);
136            type CompilationArg = ($($T::CompilationArg),*);
137
138            fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
139                let ($($t),*) = runtime_arg;
140                ($($T::compilation_arg($t)),*)
141            }
142
143            fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
144                let ($($t),*) = arg;
145                ($($T::expand($t, builder)),*)
146            }
147
148            fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
149                let ($($t),*) = arg;
150                ($($T::expand_output($t, builder)),*)
151            }
152        }
153
154        impl<$($T: CompilationArg),*> CompilationArg for ($($T),*) {}
155
156        impl<R: Runtime, $($T: ArgSettings<R>),*> ArgSettings<R> for ($($T),*) {
157            fn register(&self, launcher: &mut KernelLauncher<R>) {
158                let ($($t),*) = self;
159                $($t.register(launcher);)*
160            }
161        }
162    };
163}
164
165all_tuples!(launch_tuple, 2, 12, T, t);
166
167/// Expand type associated with a type.
168#[derive(new)]
169pub struct ExpandElementTyped<T: CubeType> {
170    pub(crate) expand: ExpandElement,
171    pub(crate) _type: PhantomData<T>,
172}
173
174impl<T: CubeType> ExpandElementTyped<T> {
175    /// Casts a reference of this expand element to a different type.
176    /// # Safety
177    /// There's no guarantee the new type is valid for the `ExpandElement`
178    pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &ExpandElementTyped<E> {
179        unsafe { core::mem::transmute::<&ExpandElementTyped<T>, &ExpandElementTyped<E>>(self) }
180    }
181
182    /// Casts a mutable reference of this expand element to a different type.
183    /// # Safety
184    /// There's no guarantee the new type is valid for the `ExpandElement`
185    pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut ExpandElementTyped<E> {
186        unsafe {
187            core::mem::transmute::<&mut ExpandElementTyped<T>, &mut ExpandElementTyped<E>>(self)
188        }
189    }
190}
191
192impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
193    fn from(value: &ExpandElementTyped<T>) -> Self {
194        value.clone()
195    }
196}
197
198impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
199    fn from(value: ExpandElementTyped<T>) -> Self {
200        value.expand.into()
201    }
202}
203
204impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
205    fn from(value: &mut ExpandElementTyped<T>) -> Self {
206        value.clone()
207    }
208}
209
210macro_rules! from_const {
211    ($lit:ty) => {
212        impl From<$lit> for ExpandElementTyped<$lit> {
213            fn from(value: $lit) -> Self {
214                let variable: Variable = value.into();
215
216                ExpandElement::Plain(variable).into()
217            }
218        }
219    };
220}
221
222from_const!(u8);
223from_const!(u16);
224from_const!(u32);
225from_const!(u64);
226from_const!(i64);
227from_const!(i8);
228from_const!(i16);
229from_const!(i32);
230from_const!(f64);
231from_const!(f16);
232from_const!(bf16);
233from_const!(flex32);
234from_const!(tf32);
235from_const!(f32);
236from_const!(e2m1);
237from_const!(e2m1x2);
238from_const!(e2m3);
239from_const!(e3m2);
240from_const!(e4m3);
241from_const!(e5m2);
242from_const!(ue8m0);
243from_const!(bool);
244
245macro_rules! tuple_cube_type {
246    ($($P:ident),*) => {
247        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
248            type ExpandType = ($($P::ExpandType,)*);
249        }
250    }
251}
252macro_rules! tuple_init {
253    ($($P:ident),*) => {
254        impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
255            #[allow(non_snake_case, unused, clippy::unused_unit)]
256            fn into_mut(self, scope: &mut Scope) -> Self {
257                let ($($P,)*) = self;
258                ($(
259                    $P.into_mut(scope),
260                )*)
261            }
262        }
263    }
264}
265macro_rules! tuple_debug {
266    ($($P:ident),*) => {
267        impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
268    }
269}
270macro_rules! tuple_runtime {
271    ($($P:ident),*) => {
272        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
273            #[allow(non_snake_case, unused, clippy::unused_unit)]
274            fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
275                let ($($P,)*) = self;
276                ($(
277                    $P.__expand_runtime_method(scope),
278                )*)
279            }
280        }
281    }
282}
283
284all_tuples!(tuple_cube_type, 0, 12, P);
285all_tuples!(tuple_debug, 0, 12, P);
286all_tuples!(tuple_init, 0, 12, P);
287all_tuples!(tuple_runtime, 0, 12, P);
288
289impl<P: CubePrimitive> CubeDebug for P {}
290
291pub trait ExpandElementIntoMut: CubeType {
292    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
293}
294
295impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
296    fn into_mut(self, scope: &mut Scope) -> Self {
297        <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
298    }
299}
300
301impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
302    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
303        scope.update_variable_name(*self.expand, name);
304    }
305}
306
307impl<T: CubeType> ExpandElementTyped<T> {
308    /// Comptime version of [size](Array::line_size).
309    pub fn line_size(&self) -> u32 {
310        self.expand.ty.line_size()
311    }
312
313    // Expanded version of vectorization factor.
314    pub fn __expand_line_size_method(self, _scope: &mut Scope) -> u32 {
315        self.expand.ty.line_size()
316    }
317
318    pub fn into_variable(self) -> Variable {
319        self.expand.consume()
320    }
321}
322
323impl<T: CubeType> Clone for ExpandElementTyped<T> {
324    fn clone(&self) -> Self {
325        Self {
326            expand: self.expand.clone(),
327            _type: PhantomData,
328        }
329    }
330}
331
332impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
333    fn from(expand: ExpandElement) -> Self {
334        Self {
335            expand,
336            _type: PhantomData,
337        }
338    }
339}
340
341impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
342    fn from(value: ExpandElementTyped<T>) -> Self {
343        value.expand
344    }
345}
346
347impl<T: CubePrimitive> ExpandElementTyped<T> {
348    /// Create an [ExpandElementTyped] from a value that is normally a literal.
349    pub fn from_lit<L: Into<Variable>>(scope: &Scope, lit: L) -> Self {
350        let variable: Variable = lit.into();
351        let variable = T::as_type(scope).from_constant(variable);
352
353        ExpandElementTyped::new(ExpandElement::Plain(variable))
354    }
355
356    /// Get the [ConstantScalarValue] from the variable.
357    pub fn constant(&self) -> Option<ConstantScalarValue> {
358        match self.expand.kind {
359            VariableKind::ConstantScalar(val) => Some(val),
360            _ => None,
361        }
362    }
363
364    pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
365        let value = self.constant().unwrap();
366        T::from_const_value(value)
367    }
368}
369
370pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
371    scope: &mut Scope,
372    element: E,
373) -> ExpandElement {
374    let elem = element.into();
375
376    match elem.kind {
377        VariableKind::ConstantScalar { .. } => init_expand(scope, elem, false, Operation::Copy),
378        _ => elem,
379    }
380}
381
382pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
383    scope: &mut Scope,
384    element: E,
385) -> ExpandElement {
386    let elem = element.into();
387
388    let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
389
390    match elem.kind {
391        VariableKind::GlobalScalar { .. } => init(elem),
392        VariableKind::ConstantScalar { .. } => init(elem),
393        VariableKind::LocalMut { .. } => init(elem),
394        VariableKind::Versioned { .. } => init(elem),
395        VariableKind::LocalConst { .. } => init(elem),
396        VariableKind::Builtin(_) => init(elem),
397        VariableKind::Shared { .. }
398        | VariableKind::SharedArray { .. }
399        | VariableKind::GlobalInputArray { .. }
400        | VariableKind::GlobalOutputArray { .. }
401        | VariableKind::LocalArray { .. }
402        | VariableKind::ConstantArray { .. }
403        | VariableKind::Matrix { .. }
404        | VariableKind::BarrierToken { .. }
405        | VariableKind::Pipeline { .. }
406        | VariableKind::TensorMapOutput(_)
407        | VariableKind::TensorMapInput(_) => elem,
408    }
409}
410
411impl IntoMut for ExpandElement {
412    fn into_mut(self, scope: &mut Scope) -> Self {
413        into_mut_expand_element(scope, self)
414    }
415}
416
417impl<T: IntoMut> IntoMut for Option<T> {
418    fn into_mut(self, scope: &mut Scope) -> Self {
419        self.map(|o| IntoMut::into_mut(o, scope))
420    }
421}
422
423impl<T: CubeType> CubeType for Vec<T> {
424    type ExpandType = Vec<T::ExpandType>;
425}
426
427impl<T: CubeType> CubeType for &mut Vec<T> {
428    type ExpandType = Vec<T::ExpandType>;
429}
430
431impl<T: IntoMut> IntoMut for Vec<T> {
432    fn into_mut(self, scope: &mut Scope) -> Self {
433        self.into_iter().map(|e| e.into_mut(scope)).collect()
434    }
435}
436impl<T: CubeDebug> CubeDebug for Vec<T> {}
437
438/// Create a constant element of the correct type during expansion.
439pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
440    scope: &mut Scope,
441    val: C,
442) -> ExpandElementTyped<Out> {
443    let input: ExpandElementTyped<C> = val.into();
444    let const_val = input.expand.as_const().unwrap();
445    let var = Variable::constant(const_val.cast_to(Out::as_type(scope)));
446    ExpandElement::Plain(var).into()
447}
448
449impl LaunchArg for () {
450    type RuntimeArg<'a, R: Runtime> = ();
451    type CompilationArg = ();
452
453    fn compilation_arg<'a, R: Runtime>(
454        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
455    ) -> Self::CompilationArg {
456    }
457
458    fn expand(
459        _: &Self::CompilationArg,
460        _builder: &mut KernelBuilder,
461    ) -> <Self as CubeType>::ExpandType {
462    }
463}
464
465impl<R: Runtime> ArgSettings<R> for () {
466    fn register(&self, _launcher: &mut KernelLauncher<R>) {
467        // nothing to do
468    }
469}