cubecl_core/frontend/element/
base.rs

1use super::{CubePrimitive, Numeric};
2use crate::{
3    ir::{ConstantValue, Operation, Scope, Variable, VariableKind},
4    prelude::{KernelBuilder, KernelLauncher, init_expand},
5};
6use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
7use cubecl_ir::{ExpandElement, LineSize};
8use cubecl_runtime::runtime::Runtime;
9use half::{bf16, f16};
10use std::marker::PhantomData;
11use variadics_please::all_tuples;
12
13/// Types used in a cube function must implement this trait
14///
15/// Variables whose values will be known at runtime must
16/// have ExpandElement as associated type
17/// Variables whose values will be known at compile time
18/// must have the primitive type as associated type
19///
20/// Note: Cube functions should be written using CubeTypes,
21/// so that the code generated uses the associated ExpandType.
22/// This allows Cube code to not necessitate cloning, which is cumbersome
23/// in algorithmic code. The necessary cloning will automatically appear in
24/// the generated code.
25#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
26pub trait CubeType {
27    type ExpandType: Clone + IntoMut + CubeDebug;
28
29    /// Wrapper around the init method, necessary to type inference.
30    fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
31        expand.into_mut(scope)
32    }
33}
34
35/// Trait useful to convert a comptime value into runtime value.
36pub trait IntoRuntime: CubeType + Sized {
37    fn runtime(self) -> Self {
38        self
39    }
40
41    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
42}
43
44/// Trait for marking a function return value as comptime when the compiler can't infer it.
45pub trait IntoComptime: Sized {
46    #[allow(clippy::wrong_self_convention)]
47    fn comptime(self) -> Self {
48        self
49    }
50}
51
52impl<T: Sized> IntoComptime for T {}
53
54/// Convert an expand type to a version with mutable registers when necessary.
55pub trait IntoMut: Sized {
56    fn into_mut(self, scope: &mut Scope) -> Self;
57}
58
59pub trait CubeDebug: Sized {
60    /// Set the debug name of this type's expansion. Should do nothing for types that don't appear
61    /// at runtime
62    #[allow(unused)]
63    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
64}
65
66/// A type that can be used as a kernel comptime argument.
67/// Note that a type doesn't need to implement `CubeComptime` to be used as
68/// a comptime argument. However, this facilitate the declaration of generic cube types.
69///
70/// # Example
71///
72/// ```ignore
73/// #[derive(CubeType)]
74/// pub struct Example<A: CubeType, B: CubeComptime> {
75///     a: A,
76///     #[cube(comptime)]
77///     b: B
78/// }
79/// ```
80pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
81impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
82
83/// Argument used during the compilation of kernels.
84pub trait CompilationArg:
85    Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
86{
87    /// Compilation args should be the same even with different element types. However, it isn't
88    /// possible to enforce it with the type system. So, we make the compilation args serializable
89    /// and dynamically cast them.
90    ///
91    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
92    /// since this is only done once when compiling a kernel for the first time.
93    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
94        // Dynamic cast, unlike transmute it does not require statically proving the types are the
95        // same size. We assert at runtime to avoid undefined behaviour and help the compiler optimize.
96        assert!(size_of::<Arg>() == size_of::<Self>());
97        let this = Box::new(self.clone());
98        unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
99    }
100}
101
102impl CompilationArg for () {}
103
104/// Defines how a [launch argument](LaunchArg) can be expanded.
105///
106/// TODO Verify the accuracy of the next comment.
107///
108/// Normally this type should be implemented two times for an argument.
109/// Once for the reference and the other for the mutable reference. Often time, the reference
110/// should expand the argument as an input while the mutable reference should expand the argument
111/// as an output.
112#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
113pub trait LaunchArg: CubeType + Send + Sync + 'static {
114    /// The runtime argument for the kernel.
115    type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
116    /// Compilation argument.
117    type CompilationArg: CompilationArg;
118
119    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
120
121    /// Register an input variable during compilation that fill the [KernelBuilder].
122    fn expand(
123        arg: &Self::CompilationArg,
124        builder: &mut KernelBuilder,
125    ) -> <Self as CubeType>::ExpandType;
126
127    /// Register an output variable during compilation that fill the [KernelBuilder].
128    fn expand_output(
129        arg: &Self::CompilationArg,
130        builder: &mut KernelBuilder,
131    ) -> <Self as CubeType>::ExpandType {
132        Self::expand(arg, builder)
133    }
134}
135
136/// Defines the argument settings used to launch a kernel.
137pub trait ArgSettings<R: Runtime>: Send + Sync {
138    /// Register the information of an argument to the [KernelLauncher].
139    fn register(&self, launcher: &mut KernelLauncher<R>);
140}
141
142macro_rules! launch_tuple {
143    ($(($T:ident, $t:ident)),*) => {
144        impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
145            type RuntimeArg<'a, R: Runtime> = ($($T::RuntimeArg<'a, R>),*);
146            type CompilationArg = ($($T::CompilationArg),*);
147
148            fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
149                let ($($t),*) = runtime_arg;
150                ($($T::compilation_arg($t)),*)
151            }
152
153            fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
154                let ($($t),*) = arg;
155                ($($T::expand($t, builder)),*)
156            }
157
158            fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
159                let ($($t),*) = arg;
160                ($($T::expand_output($t, builder)),*)
161            }
162        }
163
164        impl<$($T: CompilationArg),*> CompilationArg for ($($T),*) {}
165
166        impl<R: Runtime, $($T: ArgSettings<R>),*> ArgSettings<R> for ($($T),*) {
167            fn register(&self, launcher: &mut KernelLauncher<R>) {
168                let ($($t),*) = self;
169                $($t.register(launcher);)*
170            }
171        }
172    };
173}
174
175all_tuples!(launch_tuple, 2, 12, T, t);
176
177/// Expand type associated with a type.
178#[derive(new)]
179pub struct ExpandElementTyped<T: CubeType> {
180    pub expand: ExpandElement,
181    pub(crate) _type: PhantomData<T>,
182}
183
184impl<T: CubeType> ExpandElementTyped<T> {
185    /// Casts a reference of this expand element to a different type.
186    /// # Safety
187    /// There's no guarantee the new type is valid for the `ExpandElement`
188    pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &ExpandElementTyped<E> {
189        unsafe { core::mem::transmute::<&ExpandElementTyped<T>, &ExpandElementTyped<E>>(self) }
190    }
191
192    /// Casts a mutable reference of this expand element to a different type.
193    /// # Safety
194    /// There's no guarantee the new type is valid for the `ExpandElement`
195    pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut ExpandElementTyped<E> {
196        unsafe {
197            core::mem::transmute::<&mut ExpandElementTyped<T>, &mut ExpandElementTyped<E>>(self)
198        }
199    }
200}
201
202impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
203    fn from(value: &ExpandElementTyped<T>) -> Self {
204        value.clone()
205    }
206}
207
208impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
209    fn from(value: ExpandElementTyped<T>) -> Self {
210        value.expand.into()
211    }
212}
213
214impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
215    fn from(value: &mut ExpandElementTyped<T>) -> Self {
216        value.clone()
217    }
218}
219
220macro_rules! from_const {
221    ($lit:ty) => {
222        impl From<$lit> for ExpandElementTyped<$lit> {
223            fn from(value: $lit) -> Self {
224                let variable: Variable = value.into();
225
226                ExpandElement::Plain(variable).into()
227            }
228        }
229    };
230}
231
232from_const!(u8);
233from_const!(u16);
234from_const!(u32);
235from_const!(u64);
236from_const!(usize);
237from_const!(isize);
238from_const!(i64);
239from_const!(i8);
240from_const!(i16);
241from_const!(i32);
242from_const!(f64);
243from_const!(f16);
244from_const!(bf16);
245from_const!(flex32);
246from_const!(tf32);
247from_const!(f32);
248from_const!(e2m1);
249from_const!(e2m1x2);
250from_const!(e2m3);
251from_const!(e3m2);
252from_const!(e4m3);
253from_const!(e5m2);
254from_const!(ue8m0);
255from_const!(bool);
256
257macro_rules! tuple_cube_type {
258    ($($P:ident),*) => {
259        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
260            type ExpandType = ($($P::ExpandType,)*);
261        }
262    }
263}
264macro_rules! tuple_init {
265    ($($P:ident),*) => {
266        impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
267            #[allow(non_snake_case, unused, clippy::unused_unit)]
268            fn into_mut(self, scope: &mut Scope) -> Self {
269                let ($($P,)*) = self;
270                ($(
271                    $P.into_mut(scope),
272                )*)
273            }
274        }
275    }
276}
277macro_rules! tuple_debug {
278    ($($P:ident),*) => {
279        impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
280    }
281}
282macro_rules! tuple_runtime {
283    ($($P:ident),*) => {
284        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
285            #[allow(non_snake_case, unused, clippy::unused_unit)]
286            fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
287                let ($($P,)*) = self;
288                ($(
289                    $P.__expand_runtime_method(scope),
290                )*)
291            }
292        }
293    }
294}
295
296all_tuples!(tuple_cube_type, 0, 12, P);
297all_tuples!(tuple_debug, 0, 12, P);
298all_tuples!(tuple_init, 0, 12, P);
299all_tuples!(tuple_runtime, 0, 12, P);
300
301impl<P: CubePrimitive> CubeDebug for P {}
302
303pub trait ExpandElementIntoMut: CubeType {
304    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
305}
306
307impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
308    fn into_mut(self, scope: &mut Scope) -> Self {
309        <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
310    }
311}
312
313impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
314    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
315        scope.update_variable_name(*self.expand, name);
316    }
317}
318
319impl<T: CubeType> CubeDebug for &ExpandElementTyped<T> {
320    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
321        scope.update_variable_name(*self.expand, name);
322    }
323}
324
325impl<T: CubeType> CubeDebug for &mut ExpandElementTyped<T> {
326    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
327        scope.update_variable_name(*self.expand, name);
328    }
329}
330
331impl<T: CubeType> ExpandElementTyped<T> {
332    /// Comptime version of [size](Array::line_size).
333    pub fn line_size(&self) -> LineSize {
334        self.expand.ty.line_size()
335    }
336
337    // Expanded version of vectorization factor.
338    pub fn __expand_line_size_method(self, _scope: &mut Scope) -> LineSize {
339        self.expand.ty.line_size()
340    }
341
342    pub fn into_variable(self) -> Variable {
343        self.expand.consume()
344    }
345}
346
347impl<T: CubeType> Clone for ExpandElementTyped<T> {
348    fn clone(&self) -> Self {
349        Self {
350            expand: self.expand.clone(),
351            _type: PhantomData,
352        }
353    }
354}
355
356impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
357    fn from(expand: ExpandElement) -> Self {
358        Self {
359            expand,
360            _type: PhantomData,
361        }
362    }
363}
364
365impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
366    fn from(value: ExpandElementTyped<T>) -> Self {
367        value.expand
368    }
369}
370
371impl<T: CubePrimitive> ExpandElementTyped<T> {
372    /// Create an [ExpandElementTyped] from a value that is normally a literal.
373    pub fn from_lit<L: Into<ConstantValue>>(scope: &Scope, lit: L) -> Self {
374        let variable: ConstantValue = lit.into();
375        let variable = T::as_type(scope).constant(variable);
376
377        ExpandElementTyped::new(ExpandElement::Plain(variable))
378    }
379
380    /// Get the [ConstantScalarValue] from the variable.
381    pub fn constant(&self) -> Option<ConstantValue> {
382        match self.expand.kind {
383            VariableKind::Constant(val) => Some(val),
384            _ => None,
385        }
386    }
387
388    pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
389        let value = self.constant().unwrap();
390        T::from_const_value(value)
391    }
392}
393
394pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
395    scope: &mut Scope,
396    element: E,
397) -> ExpandElement {
398    let elem = element.into();
399
400    match elem.kind {
401        VariableKind::Constant { .. } => init_expand(scope, elem, false, Operation::Copy),
402        _ => elem,
403    }
404}
405
406pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
407    scope: &mut Scope,
408    element: E,
409) -> ExpandElement {
410    let elem = element.into();
411
412    let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
413
414    match elem.kind {
415        VariableKind::GlobalScalar { .. } => init(elem),
416        VariableKind::Constant { .. } => init(elem),
417        VariableKind::LocalMut { .. } => init(elem),
418        VariableKind::Versioned { .. } => init(elem),
419        VariableKind::LocalConst { .. } => init(elem),
420        VariableKind::Builtin(_) => init(elem),
421        VariableKind::Shared { .. }
422        | VariableKind::SharedArray { .. }
423        | VariableKind::GlobalInputArray { .. }
424        | VariableKind::GlobalOutputArray { .. }
425        | VariableKind::LocalArray { .. }
426        | VariableKind::ConstantArray { .. }
427        | VariableKind::Matrix { .. }
428        | VariableKind::BarrierToken { .. }
429        | VariableKind::Pipeline { .. }
430        | VariableKind::TensorMapOutput(_)
431        | VariableKind::TensorMapInput(_) => elem,
432    }
433}
434
435impl IntoMut for ExpandElement {
436    fn into_mut(self, scope: &mut Scope) -> Self {
437        into_mut_expand_element(scope, self)
438    }
439}
440
441impl<T: IntoMut> IntoMut for Option<T> {
442    fn into_mut(self, scope: &mut Scope) -> Self {
443        self.map(|o| IntoMut::into_mut(o, scope))
444    }
445}
446
447impl<T: CubeType> CubeType for Vec<T> {
448    type ExpandType = Vec<T::ExpandType>;
449}
450
451impl<T: CubeType> CubeType for &mut Vec<T> {
452    type ExpandType = Vec<T::ExpandType>;
453}
454
455impl<T: IntoMut> IntoMut for Vec<T> {
456    fn into_mut(self, scope: &mut Scope) -> Self {
457        self.into_iter().map(|e| e.into_mut(scope)).collect()
458    }
459}
460impl<T: CubeDebug> CubeDebug for Vec<T> {}
461
462/// Create a constant element of the correct type during expansion.
463pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
464    scope: &mut Scope,
465    val: C,
466) -> ExpandElementTyped<Out> {
467    let input: ConstantValue = val.into();
468    let var = Out::as_type(scope).constant(input);
469    ExpandElement::Plain(var).into()
470}
471
472impl LaunchArg for () {
473    type RuntimeArg<'a, R: Runtime> = ();
474    type CompilationArg = ();
475
476    fn compilation_arg<'a, R: Runtime>(
477        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
478    ) -> Self::CompilationArg {
479    }
480
481    fn expand(
482        _: &Self::CompilationArg,
483        _builder: &mut KernelBuilder,
484    ) -> <Self as CubeType>::ExpandType {
485    }
486}
487
488impl<R: Runtime> ArgSettings<R> for () {
489    fn register(&self, _launcher: &mut KernelLauncher<R>) {
490        // nothing to do
491    }
492}