Skip to main content

cubecl_core/frontend/element/
base.rs

1use super::{CubePrimitive, Numeric};
2use crate::{
3    ir::{ConstantValue, Operation, Scope, Variable, VariableKind},
4    prelude::{KernelBuilder, KernelLauncher, init_expand},
5};
6use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
7use cubecl_ir::{ExpandElement, LineSize};
8use cubecl_runtime::runtime::Runtime;
9use half::{bf16, f16};
10use std::marker::PhantomData;
11use variadics_please::all_tuples;
12
13/// Types used in a cube function must implement this trait
14///
15/// Variables whose values will be known at runtime must
16/// have `ExpandElement` as associated type
17/// Variables whose values will be known at compile time
18/// must have the primitive type as associated type
19///
20/// Note: Cube functions should be written using `CubeTypes`,
21/// so that the code generated uses the associated `ExpandType`.
22/// This allows Cube code to not necessitate cloning, which is cumbersome
23/// in algorithmic code. The necessary cloning will automatically appear in
24/// the generated code.
25#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
26pub trait CubeType {
27    type ExpandType: Clone + IntoMut + CubeDebug;
28
29    /// Wrapper around the init method, necessary to type inference.
30    fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
31        expand.into_mut(scope)
32    }
33}
34
35pub trait CloneExpand {
36    fn __expand_clone_method(&self, scope: &mut Scope) -> Self;
37}
38
39impl<C: Clone> CloneExpand for C {
40    fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
41        self.clone()
42    }
43}
44
45/// Trait useful to convert a comptime value into runtime value.
46pub trait IntoRuntime: CubeType + Sized {
47    fn runtime(self) -> Self {
48        self
49    }
50
51    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
52}
53
54/// Trait for marking a function return value as comptime when the compiler can't infer it.
55pub trait IntoComptime: Sized {
56    #[allow(clippy::wrong_self_convention)]
57    fn comptime(self) -> Self {
58        self
59    }
60}
61
62impl<T: Sized> IntoComptime for T {}
63
64/// Convert an expand type to a version with mutable registers when necessary.
65pub trait IntoMut: Sized {
66    fn into_mut(self, scope: &mut Scope) -> Self;
67}
68
69pub trait CubeDebug: Sized {
70    /// Set the debug name of this type's expansion. Should do nothing for types that don't appear
71    /// at runtime
72    #[allow(unused)]
73    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
74}
75
76/// A type that can be used as a kernel comptime argument.
77/// Note that a type doesn't need to implement `CubeComptime` to be used as
78/// a comptime argument. However, this facilitate the declaration of generic cube types.
79///
80/// # Example
81///
82/// ```ignore
83/// #[derive(CubeType)]
84/// pub struct Example<A: CubeType, B: CubeComptime> {
85///     a: A,
86///     #[cube(comptime)]
87///     b: B
88/// }
89/// ```
90pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
91impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
92
93/// Argument used during the compilation of kernels.
94pub trait CompilationArg:
95    Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
96{
97    /// Compilation args should be the same even with different element types. However, it isn't
98    /// possible to enforce it with the type system. So, we make the compilation args serializable
99    /// and dynamically cast them.
100    ///
101    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
102    /// since this is only done once when compiling a kernel for the first time.
103    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
104        // Dynamic cast, unlike transmute it does not require statically proving the types are the
105        // same size. We assert at runtime to avoid undefined behaviour and help the compiler optimize.
106        assert!(size_of::<Arg>() == size_of::<Self>());
107        let this = Box::new(self.clone());
108        unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
109    }
110}
111
112impl CompilationArg for () {}
113
114/// Defines how a [launch argument](LaunchArg) can be expanded.
115///
116/// TODO Verify the accuracy of the next comment.
117///
118/// Normally this type should be implemented two times for an argument.
119/// Once for the reference and the other for the mutable reference. Often time, the reference
120/// should expand the argument as an input while the mutable reference should expand the argument
121/// as an output.
122#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
123pub trait LaunchArg: CubeType + Send + Sync + 'static {
124    /// The runtime argument for the kernel.
125    type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
126    /// Compilation argument.
127    type CompilationArg: CompilationArg;
128
129    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
130
131    /// Register an input variable during compilation that fill the [`KernelBuilder`].
132    fn expand(
133        arg: &Self::CompilationArg,
134        builder: &mut KernelBuilder,
135    ) -> <Self as CubeType>::ExpandType;
136
137    /// Register an output variable during compilation that fill the [`KernelBuilder`].
138    fn expand_output(
139        arg: &Self::CompilationArg,
140        builder: &mut KernelBuilder,
141    ) -> <Self as CubeType>::ExpandType {
142        Self::expand(arg, builder)
143    }
144}
145
146/// Defines the argument settings used to launch a kernel.
147pub trait ArgSettings<R: Runtime>: Send + Sync {
148    /// Register the information of an argument to the [`KernelLauncher`].
149    fn register(&self, launcher: &mut KernelLauncher<R>);
150}
151
152macro_rules! launch_tuple {
153    ($(($T:ident, $t:ident)),*) => {
154        impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
155            type RuntimeArg<'a, R: Runtime> = ($($T::RuntimeArg<'a, R>),*);
156            type CompilationArg = ($($T::CompilationArg),*);
157
158            fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
159                let ($($t),*) = runtime_arg;
160                ($($T::compilation_arg($t)),*)
161            }
162
163            fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
164                let ($($t),*) = arg;
165                ($($T::expand($t, builder)),*)
166            }
167
168            fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
169                let ($($t),*) = arg;
170                ($($T::expand_output($t, builder)),*)
171            }
172        }
173
174        impl<$($T: CompilationArg),*> CompilationArg for ($($T),*) {}
175
176        impl<R: Runtime, $($T: ArgSettings<R>),*> ArgSettings<R> for ($($T),*) {
177            fn register(&self, launcher: &mut KernelLauncher<R>) {
178                let ($($t),*) = self;
179                $($t.register(launcher);)*
180            }
181        }
182    };
183}
184
185all_tuples!(launch_tuple, 2, 12, T, t);
186
187/// Expand type associated with a type.
188#[derive(new)]
189pub struct ExpandElementTyped<T: CubeType> {
190    pub expand: ExpandElement,
191    pub(crate) _type: PhantomData<T>,
192}
193
194impl<T: CubeType> ExpandElementTyped<T> {
195    /// Casts a reference of this expand element to a different type.
196    /// # Safety
197    /// There's no guarantee the new type is valid for the `ExpandElement`
198    pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &ExpandElementTyped<E> {
199        unsafe { core::mem::transmute::<&ExpandElementTyped<T>, &ExpandElementTyped<E>>(self) }
200    }
201
202    /// Casts a mutable reference of this expand element to a different type.
203    /// # Safety
204    /// There's no guarantee the new type is valid for the `ExpandElement`
205    pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut ExpandElementTyped<E> {
206        unsafe {
207            core::mem::transmute::<&mut ExpandElementTyped<T>, &mut ExpandElementTyped<E>>(self)
208        }
209    }
210}
211
212impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
213    fn from(value: &ExpandElementTyped<T>) -> Self {
214        value.clone()
215    }
216}
217
218impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
219    fn from(value: ExpandElementTyped<T>) -> Self {
220        value.expand.into()
221    }
222}
223
224impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
225    fn from(value: &mut ExpandElementTyped<T>) -> Self {
226        value.clone()
227    }
228}
229
230macro_rules! from_const {
231    ($lit:ty) => {
232        impl From<$lit> for ExpandElementTyped<$lit> {
233            fn from(value: $lit) -> Self {
234                let variable: Variable = value.into();
235
236                ExpandElement::Plain(variable).into()
237            }
238        }
239    };
240}
241
242from_const!(u8);
243from_const!(u16);
244from_const!(u32);
245from_const!(u64);
246from_const!(usize);
247from_const!(isize);
248from_const!(i64);
249from_const!(i8);
250from_const!(i16);
251from_const!(i32);
252from_const!(f64);
253from_const!(f16);
254from_const!(bf16);
255from_const!(flex32);
256from_const!(tf32);
257from_const!(f32);
258from_const!(e2m1);
259from_const!(e2m1x2);
260from_const!(e2m3);
261from_const!(e3m2);
262from_const!(e4m3);
263from_const!(e5m2);
264from_const!(ue8m0);
265from_const!(bool);
266
267macro_rules! tuple_cube_type {
268    ($($P:ident),*) => {
269        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
270            type ExpandType = ($($P::ExpandType,)*);
271        }
272    }
273}
274macro_rules! tuple_init {
275    ($($P:ident),*) => {
276        impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
277            #[allow(non_snake_case, unused, clippy::unused_unit)]
278            fn into_mut(self, scope: &mut Scope) -> Self {
279                let ($($P,)*) = self;
280                ($(
281                    $P.into_mut(scope),
282                )*)
283            }
284        }
285    }
286}
287macro_rules! tuple_debug {
288    ($($P:ident),*) => {
289        impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
290    }
291}
292macro_rules! tuple_runtime {
293    ($($P:ident),*) => {
294        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
295            #[allow(non_snake_case, unused, clippy::unused_unit)]
296            fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
297                let ($($P,)*) = self;
298                ($(
299                    $P.__expand_runtime_method(scope),
300                )*)
301            }
302        }
303    }
304}
305
306all_tuples!(tuple_cube_type, 0, 12, P);
307all_tuples!(tuple_debug, 0, 12, P);
308all_tuples!(tuple_init, 0, 12, P);
309all_tuples!(tuple_runtime, 0, 12, P);
310
311impl<P: CubePrimitive> CubeDebug for P {}
312
313pub trait ExpandElementIntoMut: CubeType {
314    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
315}
316
317impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
318    fn into_mut(self, scope: &mut Scope) -> Self {
319        <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
320    }
321}
322
323impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
324    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
325        scope.update_variable_name(*self.expand, name);
326    }
327}
328
329impl<T: CubeType> CubeDebug for &ExpandElementTyped<T> {
330    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
331        scope.update_variable_name(*self.expand, name);
332    }
333}
334
335impl<T: CubeType> CubeDebug for &mut ExpandElementTyped<T> {
336    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
337        scope.update_variable_name(*self.expand, name);
338    }
339}
340
341impl<T: CubeType> ExpandElementTyped<T> {
342    /// Comptime version of [`crate::frontend::Array::line_size`].
343    pub fn line_size(&self) -> LineSize {
344        self.expand.ty.line_size()
345    }
346
347    // Expanded version of vectorization factor.
348    pub fn __expand_line_size_method(self, _scope: &mut Scope) -> LineSize {
349        self.expand.ty.line_size()
350    }
351
352    pub fn into_variable(self) -> Variable {
353        self.expand.consume()
354    }
355}
356
357impl<T: CubeType> Clone for ExpandElementTyped<T> {
358    fn clone(&self) -> Self {
359        Self {
360            expand: self.expand.clone(),
361            _type: PhantomData,
362        }
363    }
364}
365
366impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
367    fn from(expand: ExpandElement) -> Self {
368        Self {
369            expand,
370            _type: PhantomData,
371        }
372    }
373}
374
375impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
376    fn from(value: ExpandElementTyped<T>) -> Self {
377        value.expand
378    }
379}
380
381impl<T: CubePrimitive> ExpandElementTyped<T> {
382    /// Create an [`ExpandElementTyped`] from a value that is normally a literal.
383    pub fn from_lit<L: Into<ConstantValue>>(scope: &Scope, lit: L) -> Self {
384        let variable: ConstantValue = lit.into();
385        let variable = T::as_type(scope).constant(variable);
386
387        ExpandElementTyped::new(ExpandElement::Plain(variable))
388    }
389
390    /// Get the [`ConstantValue`] from the variable.
391    pub fn constant(&self) -> Option<ConstantValue> {
392        match self.expand.kind {
393            VariableKind::Constant(val) => Some(val),
394            _ => None,
395        }
396    }
397
398    pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
399        let value = self.constant().unwrap();
400        T::from_const_value(value)
401    }
402}
403
404pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
405    scope: &mut Scope,
406    element: E,
407) -> ExpandElement {
408    let elem = element.into();
409
410    match elem.kind {
411        VariableKind::Constant { .. } => init_expand(scope, elem, false, Operation::Copy),
412        _ => elem,
413    }
414}
415
416pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
417    scope: &mut Scope,
418    element: E,
419) -> ExpandElement {
420    let elem = element.into();
421
422    let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
423
424    match elem.kind {
425        VariableKind::GlobalScalar { .. } => init(elem),
426        VariableKind::Constant { .. } => init(elem),
427        VariableKind::LocalMut { .. } => init(elem),
428        VariableKind::Versioned { .. } => init(elem),
429        VariableKind::LocalConst { .. } => init(elem),
430        VariableKind::Builtin(_) => init(elem),
431        VariableKind::Shared { .. }
432        | VariableKind::SharedArray { .. }
433        | VariableKind::GlobalInputArray { .. }
434        | VariableKind::GlobalOutputArray { .. }
435        | VariableKind::LocalArray { .. }
436        | VariableKind::ConstantArray { .. }
437        | VariableKind::Matrix { .. }
438        | VariableKind::BarrierToken { .. }
439        | VariableKind::Pipeline { .. }
440        | VariableKind::TensorMapOutput(_)
441        | VariableKind::TensorMapInput(_) => elem,
442    }
443}
444
445impl IntoMut for ExpandElement {
446    fn into_mut(self, scope: &mut Scope) -> Self {
447        into_mut_expand_element(scope, self)
448    }
449}
450
451impl<T: IntoMut> IntoMut for Option<T> {
452    fn into_mut(self, scope: &mut Scope) -> Self {
453        self.map(|o| IntoMut::into_mut(o, scope))
454    }
455}
456
457impl<T: CubeType> CubeType for Vec<T> {
458    type ExpandType = Vec<T::ExpandType>;
459}
460
461impl<T: CubeType> CubeType for &mut Vec<T> {
462    type ExpandType = Vec<T::ExpandType>;
463}
464
465impl<T: IntoMut> IntoMut for Vec<T> {
466    fn into_mut(self, scope: &mut Scope) -> Self {
467        self.into_iter().map(|e| e.into_mut(scope)).collect()
468    }
469}
470impl<T: CubeDebug> CubeDebug for Vec<T> {}
471
472/// Create a constant element of the correct type during expansion.
473pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
474    scope: &mut Scope,
475    val: C,
476) -> ExpandElementTyped<Out> {
477    let input: ConstantValue = val.into();
478    let var = Out::as_type(scope).constant(input);
479    ExpandElement::Plain(var).into()
480}
481
482impl LaunchArg for () {
483    type RuntimeArg<'a, R: Runtime> = ();
484    type CompilationArg = ();
485
486    fn compilation_arg<'a, R: Runtime>(
487        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
488    ) -> Self::CompilationArg {
489    }
490
491    fn expand(
492        _: &Self::CompilationArg,
493        _builder: &mut KernelBuilder,
494    ) -> <Self as CubeType>::ExpandType {
495    }
496}
497
498impl<R: Runtime> ArgSettings<R> for () {
499    fn register(&self, _launcher: &mut KernelLauncher<R>) {
500        // nothing to do
501    }
502}