cubecl_core/frontend/element/
base.rs

1use super::{CubePrimitive, Numeric};
2use crate::{
3    Runtime,
4    ir::{ConstantScalarValue, Operation, Scope, Variable, VariableKind},
5    prelude::{KernelBuilder, KernelLauncher, init_expand},
6};
7use cubecl_common::{flex32, tf32};
8use cubecl_ir::ExpandElement;
9use half::{bf16, f16};
10use std::{
11    any::{Any, TypeId},
12    marker::PhantomData,
13};
14use variadics_please::all_tuples;
15
16/// Types used in a cube function must implement this trait
17///
18/// Variables whose values will be known at runtime must
19/// have ExpandElement as associated type
20/// Variables whose values will be known at compile time
21/// must have the primitive type as associated type
22///
23/// Note: Cube functions should be written using CubeTypes,
24/// so that the code generated uses the associated ExpandType.
25/// This allows Cube code to not necessitate cloning, which is cumbersome
26/// in algorithmic code. The necessary cloning will automatically appear in
27/// the generated code.
28#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
29pub trait CubeType {
30    type ExpandType: Clone + Init + CubeDebug;
31
32    /// Wrapper around the init method, necessary to type inference.
33    fn init(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
34        expand.init(scope)
35    }
36}
37
38/// Trait useful to convert a comptime value into runtime value.
39pub trait IntoRuntime: CubeType + Sized {
40    fn runtime(self) -> Self {
41        self
42    }
43
44    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
45}
46
47/// Trait to be implemented by [cube types](CubeType) implementations.
48pub trait Init: Sized {
49    /// Initialize a type within a [scope](Scope).
50    ///
51    /// You can return the same value when the variable is a non-mutable data structure or
52    /// if the type can not be deeply cloned/copied.
53    fn init(self, scope: &mut Scope) -> Self;
54}
55
56pub trait CubeDebug: Sized {
57    /// Set the debug name of this type's expansion. Should do nothing for types that don't appear
58    /// at runtime
59    #[allow(unused)]
60    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
61}
62
63/// A type that can be used as a kernel comptime argument.
64/// Note that a type doesn't need to implement `CubeComptime` to be used as
65/// a comptime argument. However, this facilitate the declaration of generic cube types.
66///
67/// # Example
68///
69/// ```ignore
70/// #[derive(CubeType)]
71/// pub struct Example<A: CubeType, B: CubeComptime> {
72///     a: A,
73///     #[cube(comptime)]
74///     b: B
75/// }
76/// ```
77pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
78impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
79
80/// A [CubeType] that can be used as a kernel argument such as [Array] or [Tensor].
81pub trait CubeLaunch: CubeType + LaunchArg + LaunchArgExpand {}
82impl<T: CubeType + LaunchArg + LaunchArgExpand> CubeLaunch for T {}
83
84/// Argument used during the compilation of kernels.
85pub trait CompilationArg:
86    serde::Serialize
87    + serde::de::DeserializeOwned
88    + Clone
89    + PartialEq
90    + Eq
91    + core::hash::Hash
92    + core::fmt::Debug
93    + Send
94    + Sync
95    + 'static
96{
97    /// Compilation args should be the same even with different element types. However, it isn't
98    /// possible to enforce it with the type system. So, we make the compilation args serializable
99    /// and dynamically cast them.
100    ///
101    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
102    /// since this is only done once when compiling a kernel for the first time.
103    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
104        if TypeId::of::<Arg>() == TypeId::of::<Self>() {
105            let tmp: Box<dyn Any> = Box::new(self.clone());
106            *tmp.downcast().unwrap()
107        } else {
108            let val = serde_json::to_string(self).unwrap();
109            serde_json::from_str(&val)
110                .expect("Compilation argument should be the same even with different element types")
111        }
112    }
113}
114
115impl CompilationArg for () {}
116
117/// Defines how a [launch argument](LaunchArg) can be expanded.
118///
119/// TODO Verify the accuracy of the next comment.
120///
121/// Normally this type should be implemented two times for an argument.
122/// Once for the reference and the other for the mutable reference. Often time, the reference
123/// should expand the argument as an input while the mutable reference should expand the argument
124/// as an output.
125#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
126pub trait LaunchArgExpand: CubeType {
127    /// Compilation argument.
128    type CompilationArg: CompilationArg;
129
130    /// Register an input variable during compilation that fill the [KernelBuilder].
131    fn expand(
132        arg: &Self::CompilationArg,
133        builder: &mut KernelBuilder,
134    ) -> <Self as CubeType>::ExpandType;
135
136    /// Register an output variable during compilation that fill the [KernelBuilder].
137    fn expand_output(
138        arg: &Self::CompilationArg,
139        builder: &mut KernelBuilder,
140    ) -> <Self as CubeType>::ExpandType {
141        Self::expand(arg, builder)
142    }
143}
144
145/// Defines a type that can be used as argument to a kernel.
146pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
147    /// The runtime argument for the kernel.
148    type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
149
150    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
151}
152
153/// Defines the argument settings used to launch a kernel.
154pub trait ArgSettings<R: Runtime>: Send + Sync {
155    /// Register the information of an argument to the [KernelLauncher].
156    fn register(&self, launcher: &mut KernelLauncher<R>);
157}
158
159/// Expand type associated with a type.
160#[derive(new)]
161pub struct ExpandElementTyped<T: CubeType> {
162    pub(crate) expand: ExpandElement,
163    pub(crate) _type: PhantomData<T>,
164}
165
166impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
167    fn from(value: &ExpandElementTyped<T>) -> Self {
168        value.clone()
169    }
170}
171
172impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
173    fn from(value: &mut ExpandElementTyped<T>) -> Self {
174        value.clone()
175    }
176}
177
178macro_rules! from_const {
179    ($lit:ty) => {
180        impl From<$lit> for ExpandElementTyped<$lit> {
181            fn from(value: $lit) -> Self {
182                let variable: Variable = value.into();
183
184                ExpandElement::Plain(variable).into()
185            }
186        }
187    };
188}
189
190from_const!(u8);
191from_const!(u16);
192from_const!(u32);
193from_const!(u64);
194from_const!(i64);
195from_const!(i8);
196from_const!(i16);
197from_const!(i32);
198from_const!(f64);
199from_const!(f16);
200from_const!(bf16);
201from_const!(flex32);
202from_const!(tf32);
203from_const!(f32);
204from_const!(bool);
205
206macro_rules! tuple_cube_type {
207    ($($P:ident),*) => {
208        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
209            type ExpandType = ($($P::ExpandType,)*);
210        }
211    }
212}
213macro_rules! tuple_init {
214    ($($P:ident),*) => {
215        impl<$($P: Init),*> Init for ($($P,)*) {
216            #[allow(non_snake_case, unused, clippy::unused_unit)]
217            fn init(self, scope: &mut Scope) -> Self {
218                let ($($P,)*) = self;
219                ($(
220                    $P.init(scope),
221                )*)
222            }
223        }
224    }
225}
226macro_rules! tuple_debug {
227    ($($P:ident),*) => {
228        impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
229    }
230}
231macro_rules! tuple_runtime {
232    ($($P:ident),*) => {
233        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
234            #[allow(non_snake_case, unused, clippy::unused_unit)]
235            fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
236                let ($($P,)*) = self;
237                ($(
238                    $P.__expand_runtime_method(scope),
239                )*)
240            }
241        }
242    }
243}
244
245all_tuples!(tuple_cube_type, 0, 12, P);
246all_tuples!(tuple_debug, 0, 12, P);
247all_tuples!(tuple_init, 0, 12, P);
248all_tuples!(tuple_runtime, 0, 12, P);
249
250impl<P: CubePrimitive> CubeDebug for P {}
251
252pub trait ExpandElementBaseInit: CubeType {
253    fn init_elem(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
254}
255
256impl<T: ExpandElementBaseInit> Init for ExpandElementTyped<T> {
257    fn init(self, scope: &mut Scope) -> Self {
258        <T as ExpandElementBaseInit>::init_elem(scope, self.into()).into()
259    }
260}
261
262impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
263    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
264        scope.update_variable_name(*self.expand, name);
265    }
266}
267
268impl<T: CubeType> ExpandElementTyped<T> {
269    // Expanded version of vectorization factor.
270    pub fn __expand_vectorization_factor_method(self, _scope: &mut Scope) -> u32 {
271        self.expand
272            .item
273            .vectorization
274            .map(|it| it.get())
275            .unwrap_or(1) as u32
276    }
277
278    pub fn into_variable(self) -> Variable {
279        self.expand.consume()
280    }
281}
282
283impl<T: CubeType> Clone for ExpandElementTyped<T> {
284    fn clone(&self) -> Self {
285        Self {
286            expand: self.expand.clone(),
287            _type: PhantomData,
288        }
289    }
290}
291
292impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
293    fn from(expand: ExpandElement) -> Self {
294        Self {
295            expand,
296            _type: PhantomData,
297        }
298    }
299}
300
301impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
302    fn from(value: ExpandElementTyped<T>) -> Self {
303        value.expand
304    }
305}
306
307impl<T: CubePrimitive> ExpandElementTyped<T> {
308    /// Create an [ExpandElementTyped] from a value that is normally a literal.
309    pub fn from_lit<L: Into<Variable>>(scope: &Scope, lit: L) -> Self {
310        let variable: Variable = lit.into();
311        let variable = T::as_elem(scope).from_constant(variable);
312
313        ExpandElementTyped::new(ExpandElement::Plain(variable))
314    }
315
316    /// Get the [ConstantScalarValue] from the variable.
317    pub fn constant(&self) -> Option<ConstantScalarValue> {
318        match self.expand.kind {
319            VariableKind::ConstantScalar(val) => Some(val),
320            _ => None,
321        }
322    }
323}
324
325pub(crate) fn init_expand_element<E: Into<ExpandElement>>(
326    scope: &mut Scope,
327    element: E,
328) -> ExpandElement {
329    let elem = element.into();
330
331    if elem.can_mut() {
332        // Can reuse inplace :)
333        return elem;
334    }
335
336    let mut init = |elem: ExpandElement| init_expand(scope, elem, Operation::Copy);
337
338    match elem.kind {
339        VariableKind::GlobalScalar { .. } => init(elem),
340        VariableKind::ConstantScalar { .. } => init(elem),
341        VariableKind::LocalMut { .. } => init(elem),
342        VariableKind::Versioned { .. } => init(elem),
343        VariableKind::LocalConst { .. } => init(elem),
344        VariableKind::Builtin(_) => init(elem),
345        VariableKind::SharedMemory { .. }
346        | VariableKind::GlobalInputArray { .. }
347        | VariableKind::GlobalOutputArray { .. }
348        | VariableKind::LocalArray { .. }
349        | VariableKind::ConstantArray { .. }
350        | VariableKind::Slice { .. }
351        | VariableKind::Matrix { .. }
352        | VariableKind::Barrier { .. }
353        | VariableKind::Pipeline { .. }
354        | VariableKind::TensorMap(_) => elem,
355    }
356}
357
358impl Init for ExpandElement {
359    fn init(self, scope: &mut Scope) -> Self {
360        init_expand_element(scope, self)
361    }
362}
363
364impl<T: Init> Init for Option<T> {
365    fn init(self, scope: &mut Scope) -> Self {
366        self.map(|o| Init::init(o, scope))
367    }
368}
369
370impl<T: CubeType> CubeType for Vec<T> {
371    type ExpandType = Vec<T::ExpandType>;
372}
373
374impl<T: CubeType> CubeType for &mut Vec<T> {
375    type ExpandType = Vec<T::ExpandType>;
376}
377
378impl<T: Init> Init for Vec<T> {
379    fn init(self, scope: &mut Scope) -> Self {
380        self.into_iter().map(|e| e.init(scope)).collect()
381    }
382}
383impl<T: CubeDebug> CubeDebug for Vec<T> {}
384
385/// Create a constant element of the correct type during expansion.
386pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
387    scope: &mut Scope,
388    val: C,
389) -> ExpandElementTyped<Out> {
390    let input: ExpandElementTyped<C> = val.into();
391    let const_val = input.expand.as_const().unwrap();
392    let var = Variable::constant(const_val.cast_to(Out::as_elem(scope)));
393    ExpandElement::Plain(var).into()
394}
395
396impl LaunchArg for () {
397    type RuntimeArg<'a, R: Runtime> = ();
398
399    fn compilation_arg<'a, R: Runtime>(
400        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
401    ) -> Self::CompilationArg {
402    }
403}
404
405impl<R: Runtime> ArgSettings<R> for () {
406    fn register(&self, _launcher: &mut KernelLauncher<R>) {
407        // nothing to do
408    }
409}
410
411impl LaunchArgExpand for () {
412    type CompilationArg = ();
413
414    fn expand(
415        _: &Self::CompilationArg,
416        _builder: &mut KernelBuilder,
417    ) -> <Self as CubeType>::ExpandType {
418    }
419}