cubecl_core/frontend/element/
base.rs

1use super::{CubePrimitive, Numeric};
2use crate::{
3    Runtime,
4    ir::{ConstantScalarValue, Operation, Scope, Variable, VariableKind},
5    prelude::{KernelBuilder, KernelLauncher, init_expand},
6};
7use cubecl_common::{e2m1, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
8use cubecl_ir::ExpandElement;
9use half::{bf16, f16};
10use std::{
11    any::{Any, TypeId},
12    marker::PhantomData,
13};
14use variadics_please::all_tuples;
15
16/// Types used in a cube function must implement this trait
17///
18/// Variables whose values will be known at runtime must
19/// have ExpandElement as associated type
20/// Variables whose values will be known at compile time
21/// must have the primitive type as associated type
22///
23/// Note: Cube functions should be written using CubeTypes,
24/// so that the code generated uses the associated ExpandType.
25/// This allows Cube code to not necessitate cloning, which is cumbersome
26/// in algorithmic code. The necessary cloning will automatically appear in
27/// the generated code.
28#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
29pub trait CubeType {
30    type ExpandType: Clone + IntoMut + CubeDebug;
31
32    /// Wrapper around the init method, necessary to type inference.
33    fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
34        expand.into_mut(scope)
35    }
36}
37
38/// Trait useful to convert a comptime value into runtime value.
39pub trait IntoRuntime: CubeType + Sized {
40    fn runtime(self) -> Self {
41        self
42    }
43
44    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
45}
46
47/// Convert an expand type to a version with mutable registers when necessary.
48pub trait IntoMut: Sized {
49    fn into_mut(self, scope: &mut Scope) -> Self;
50}
51
52pub trait CubeDebug: Sized {
53    /// Set the debug name of this type's expansion. Should do nothing for types that don't appear
54    /// at runtime
55    #[allow(unused)]
56    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
57}
58
59/// A type that can be used as a kernel comptime argument.
60/// Note that a type doesn't need to implement `CubeComptime` to be used as
61/// a comptime argument. However, this facilitate the declaration of generic cube types.
62///
63/// # Example
64///
65/// ```ignore
66/// #[derive(CubeType)]
67/// pub struct Example<A: CubeType, B: CubeComptime> {
68///     a: A,
69///     #[cube(comptime)]
70///     b: B
71/// }
72/// ```
73pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
74impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
75
76/// A [CubeType] that can be used as a kernel argument such as [Array] or [Tensor].
77pub trait CubeLaunch: CubeType + LaunchArg + LaunchArgExpand {}
78impl<T: CubeType + LaunchArg + LaunchArgExpand> CubeLaunch for T {}
79
80/// Argument used during the compilation of kernels.
81pub trait CompilationArg:
82    serde::Serialize
83    + serde::de::DeserializeOwned
84    + Clone
85    + PartialEq
86    + Eq
87    + core::hash::Hash
88    + core::fmt::Debug
89    + Send
90    + Sync
91    + 'static
92{
93    /// Compilation args should be the same even with different element types. However, it isn't
94    /// possible to enforce it with the type system. So, we make the compilation args serializable
95    /// and dynamically cast them.
96    ///
97    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
98    /// since this is only done once when compiling a kernel for the first time.
99    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
100        if TypeId::of::<Arg>() == TypeId::of::<Self>() {
101            let tmp: Box<dyn Any> = Box::new(self.clone());
102            *tmp.downcast().unwrap()
103        } else {
104            let val = serde_json::to_string(self).unwrap();
105            serde_json::from_str(&val)
106                .expect("Compilation argument should be the same even with different element types")
107        }
108    }
109}
110
111impl CompilationArg for () {}
112
113/// Defines how a [launch argument](LaunchArg) can be expanded.
114///
115/// TODO Verify the accuracy of the next comment.
116///
117/// Normally this type should be implemented two times for an argument.
118/// Once for the reference and the other for the mutable reference. Often time, the reference
119/// should expand the argument as an input while the mutable reference should expand the argument
120/// as an output.
121#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
122pub trait LaunchArgExpand: CubeType {
123    /// Compilation argument.
124    type CompilationArg: CompilationArg;
125
126    /// Register an input variable during compilation that fill the [KernelBuilder].
127    fn expand(
128        arg: &Self::CompilationArg,
129        builder: &mut KernelBuilder,
130    ) -> <Self as CubeType>::ExpandType;
131
132    /// Register an output variable during compilation that fill the [KernelBuilder].
133    fn expand_output(
134        arg: &Self::CompilationArg,
135        builder: &mut KernelBuilder,
136    ) -> <Self as CubeType>::ExpandType {
137        Self::expand(arg, builder)
138    }
139}
140
141/// Defines a type that can be used as argument to a kernel.
142pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
143    /// The runtime argument for the kernel.
144    type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
145
146    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
147}
148
149/// Defines the argument settings used to launch a kernel.
150pub trait ArgSettings<R: Runtime>: Send + Sync {
151    /// Register the information of an argument to the [KernelLauncher].
152    fn register(&self, launcher: &mut KernelLauncher<R>);
153}
154
155/// Expand type associated with a type.
156#[derive(new)]
157pub struct ExpandElementTyped<T: CubeType> {
158    pub(crate) expand: ExpandElement,
159    pub(crate) _type: PhantomData<T>,
160}
161
162impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
163    fn from(value: &ExpandElementTyped<T>) -> Self {
164        value.clone()
165    }
166}
167
168impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
169    fn from(value: ExpandElementTyped<T>) -> Self {
170        value.expand.into()
171    }
172}
173
174impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
175    fn from(value: &mut ExpandElementTyped<T>) -> Self {
176        value.clone()
177    }
178}
179
180macro_rules! from_const {
181    ($lit:ty) => {
182        impl From<$lit> for ExpandElementTyped<$lit> {
183            fn from(value: $lit) -> Self {
184                let variable: Variable = value.into();
185
186                ExpandElement::Plain(variable).into()
187            }
188        }
189    };
190}
191
192from_const!(u8);
193from_const!(u16);
194from_const!(u32);
195from_const!(u64);
196from_const!(i64);
197from_const!(i8);
198from_const!(i16);
199from_const!(i32);
200from_const!(f64);
201from_const!(f16);
202from_const!(bf16);
203from_const!(flex32);
204from_const!(tf32);
205from_const!(f32);
206from_const!(e2m1);
207from_const!(e2m3);
208from_const!(e3m2);
209from_const!(e4m3);
210from_const!(e5m2);
211from_const!(ue8m0);
212from_const!(bool);
213
214macro_rules! tuple_cube_type {
215    ($($P:ident),*) => {
216        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
217            type ExpandType = ($($P::ExpandType,)*);
218        }
219    }
220}
221macro_rules! tuple_init {
222    ($($P:ident),*) => {
223        impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
224            #[allow(non_snake_case, unused, clippy::unused_unit)]
225            fn into_mut(self, scope: &mut Scope) -> Self {
226                let ($($P,)*) = self;
227                ($(
228                    $P.into_mut(scope),
229                )*)
230            }
231        }
232    }
233}
234macro_rules! tuple_debug {
235    ($($P:ident),*) => {
236        impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
237    }
238}
239macro_rules! tuple_runtime {
240    ($($P:ident),*) => {
241        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
242            #[allow(non_snake_case, unused, clippy::unused_unit)]
243            fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
244                let ($($P,)*) = self;
245                ($(
246                    $P.__expand_runtime_method(scope),
247                )*)
248            }
249        }
250    }
251}
252
253all_tuples!(tuple_cube_type, 0, 12, P);
254all_tuples!(tuple_debug, 0, 12, P);
255all_tuples!(tuple_init, 0, 12, P);
256all_tuples!(tuple_runtime, 0, 12, P);
257
258impl<P: CubePrimitive> CubeDebug for P {}
259
260pub trait ExpandElementIntoMut: CubeType {
261    fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
262}
263
264impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
265    fn into_mut(self, scope: &mut Scope) -> Self {
266        <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
267    }
268}
269
270impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
271    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
272        scope.update_variable_name(*self.expand, name);
273    }
274}
275
276impl<T: CubeType> ExpandElementTyped<T> {
277    // Expanded version of vectorization factor.
278    pub fn __expand_vectorization_factor_method(self, _scope: &mut Scope) -> u32 {
279        self.expand
280            .item
281            .vectorization
282            .map(|it| it.get())
283            .unwrap_or(1) as u32
284    }
285
286    pub fn into_variable(self) -> Variable {
287        self.expand.consume()
288    }
289}
290
291impl<T: CubeType> Clone for ExpandElementTyped<T> {
292    fn clone(&self) -> Self {
293        Self {
294            expand: self.expand.clone(),
295            _type: PhantomData,
296        }
297    }
298}
299
300impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
301    fn from(expand: ExpandElement) -> Self {
302        Self {
303            expand,
304            _type: PhantomData,
305        }
306    }
307}
308
309impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
310    fn from(value: ExpandElementTyped<T>) -> Self {
311        value.expand
312    }
313}
314
315impl<T: CubePrimitive> ExpandElementTyped<T> {
316    /// Create an [ExpandElementTyped] from a value that is normally a literal.
317    pub fn from_lit<L: Into<Variable>>(scope: &Scope, lit: L) -> Self {
318        let variable: Variable = lit.into();
319        let variable = T::as_elem(scope).from_constant(variable);
320
321        ExpandElementTyped::new(ExpandElement::Plain(variable))
322    }
323
324    /// Get the [ConstantScalarValue] from the variable.
325    pub fn constant(&self) -> Option<ConstantScalarValue> {
326        match self.expand.kind {
327            VariableKind::ConstantScalar(val) => Some(val),
328            _ => None,
329        }
330    }
331}
332
333pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
334    scope: &mut Scope,
335    element: E,
336) -> ExpandElement {
337    let elem = element.into();
338
339    match elem.kind {
340        VariableKind::ConstantScalar { .. } => init_expand(scope, elem, false, Operation::Copy),
341        _ => elem,
342    }
343}
344
345pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
346    scope: &mut Scope,
347    element: E,
348) -> ExpandElement {
349    let elem = element.into();
350
351    let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
352
353    match elem.kind {
354        VariableKind::GlobalScalar { .. } => init(elem),
355        VariableKind::ConstantScalar { .. } => init(elem),
356        VariableKind::LocalMut { .. } => init(elem),
357        VariableKind::Versioned { .. } => init(elem),
358        VariableKind::LocalConst { .. } => init(elem),
359        VariableKind::Builtin(_) => init(elem),
360        VariableKind::SharedMemory { .. }
361        | VariableKind::GlobalInputArray { .. }
362        | VariableKind::GlobalOutputArray { .. }
363        | VariableKind::LocalArray { .. }
364        | VariableKind::ConstantArray { .. }
365        | VariableKind::Matrix { .. }
366        | VariableKind::Barrier { .. }
367        | VariableKind::Pipeline { .. }
368        | VariableKind::TensorMap(_) => elem,
369    }
370}
371
372impl IntoMut for ExpandElement {
373    fn into_mut(self, scope: &mut Scope) -> Self {
374        into_mut_expand_element(scope, self)
375    }
376}
377
378impl<T: IntoMut> IntoMut for Option<T> {
379    fn into_mut(self, scope: &mut Scope) -> Self {
380        self.map(|o| IntoMut::into_mut(o, scope))
381    }
382}
383
384impl<T: CubeType> CubeType for Vec<T> {
385    type ExpandType = Vec<T::ExpandType>;
386}
387
388impl<T: CubeType> CubeType for &mut Vec<T> {
389    type ExpandType = Vec<T::ExpandType>;
390}
391
392impl<T: IntoMut> IntoMut for Vec<T> {
393    fn into_mut(self, scope: &mut Scope) -> Self {
394        self.into_iter().map(|e| e.into_mut(scope)).collect()
395    }
396}
397impl<T: CubeDebug> CubeDebug for Vec<T> {}
398
399/// Create a constant element of the correct type during expansion.
400pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
401    scope: &mut Scope,
402    val: C,
403) -> ExpandElementTyped<Out> {
404    let input: ExpandElementTyped<C> = val.into();
405    let const_val = input.expand.as_const().unwrap();
406    let var = Variable::constant(const_val.cast_to(Out::as_elem(scope)));
407    ExpandElement::Plain(var).into()
408}
409
410impl LaunchArg for () {
411    type RuntimeArg<'a, R: Runtime> = ();
412
413    fn compilation_arg<'a, R: Runtime>(
414        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
415    ) -> Self::CompilationArg {
416    }
417}
418
419impl<R: Runtime> ArgSettings<R> for () {
420    fn register(&self, _launcher: &mut KernelLauncher<R>) {
421        // nothing to do
422    }
423}
424
425impl LaunchArgExpand for () {
426    type CompilationArg = ();
427
428    fn expand(
429        _: &Self::CompilationArg,
430        _builder: &mut KernelBuilder,
431    ) -> <Self as CubeType>::ExpandType {
432    }
433}