cubecl_core/frontend/element/
base.rs

1use super::{flex32, tf32, CubePrimitive, Numeric};
2use crate::{
3    ir::{ConstantScalarValue, Operation, Variable, VariableKind},
4    prelude::{init_expand, CubeContext, KernelBuilder, KernelLauncher},
5    Runtime,
6};
7use alloc::rc::Rc;
8use half::{bf16, f16};
9use std::marker::PhantomData;
10
11/// Types used in a cube function must implement this trait
12///
13/// Variables whose values will be known at runtime must
14/// have ExpandElement as associated type
15/// Variables whose values will be known at compile time
16/// must have the primitive type as associated type
17///
18/// Note: Cube functions should be written using CubeTypes,
19/// so that the code generated uses the associated ExpandType.
20/// This allows Cube code to not necessitate cloning, which is cumbersome
21/// in algorithmic code. The necessary cloning will automatically appear in
22/// the generated code.
23#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
24pub trait CubeType {
25    type ExpandType: Clone + Init;
26
27    /// Wrapper around the init method, necessary to type inference.
28    fn init(context: &mut CubeContext, expand: Self::ExpandType) -> Self::ExpandType {
29        expand.init(context)
30    }
31}
32
33/// Trait useful for cube types that are also used with comptime.
34///
35/// This is used to set a variable as mutable. (Need to be fixed or at least renamed.)
36pub trait IntoRuntime: CubeType + Sized {
37    /// Make sure a type is actually expanded into its runtime [expand type](CubeType::ExpandType).
38    fn runtime(self) -> Self {
39        self
40    }
41
42    fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType;
43}
44
45/// Trait to be implemented by [cube types](CubeType) implementations.
46pub trait Init: Sized {
47    /// Initialize a type within a [context](CubeContext).
48    ///
49    /// You can return the same value when the variable is a non-mutable data structure or
50    /// if the type can not be deeply cloned/copied.
51    fn init(self, context: &mut CubeContext) -> Self;
52}
53
54/// Argument used during the compilation of kernels.
55pub trait CompilationArg:
56    serde::Serialize
57    + serde::de::DeserializeOwned
58    + Clone
59    + PartialEq
60    + Eq
61    + core::hash::Hash
62    + core::fmt::Debug
63    + Send
64    + Sync
65    + 'static
66{
67    /// Compilation args should be the same even with different element types. However, it isn't
68    /// possible to enforce it with the type system. So, we make the compilation args serializable
69    /// and dynamically cast them.
70    ///
71    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
72    /// since this is only done once when compiling a kernel for the first time.
73    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
74        let val = serde_json::to_string(self).unwrap();
75
76        serde_json::from_str(&val)
77            .expect("Compilation argument should be the same even with different element types")
78    }
79}
80
81impl CompilationArg for () {}
82
83/// Defines how a [launch argument](LaunchArg) can be expanded.
84///
85/// Normally this type should be implemented two times for an argument.
86/// Once for the reference and the other for the mutable reference. Often time, the reference
87/// should expand the argument as an input while the mutable reference should expand the argument
88/// as an output.
89#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
90pub trait LaunchArgExpand: CubeType {
91    /// Compilation argument.
92    type CompilationArg: CompilationArg;
93
94    /// Register an input variable during compilation that fill the [KernelBuilder].
95    fn expand(
96        arg: &Self::CompilationArg,
97        builder: &mut KernelBuilder,
98    ) -> <Self as CubeType>::ExpandType;
99    /// Register an output variable during compilation that fill the [KernelBuilder].
100    fn expand_output(
101        arg: &Self::CompilationArg,
102        builder: &mut KernelBuilder,
103    ) -> <Self as CubeType>::ExpandType {
104        Self::expand(arg, builder)
105    }
106}
107
108/// Defines a type that can be used as argument to a kernel.
109pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
110    /// The runtime argument for the kernel.
111    type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
112
113    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
114}
115
116impl LaunchArg for () {
117    type RuntimeArg<'a, R: Runtime> = ();
118
119    fn compilation_arg<'a, R: Runtime>(
120        _runtime_arg: &'a Self::RuntimeArg<'a, R>,
121    ) -> Self::CompilationArg {
122    }
123}
124
125impl<R: Runtime> ArgSettings<R> for () {
126    fn register(&self, _launcher: &mut KernelLauncher<R>) {
127        // nothing to do
128    }
129}
130
131impl LaunchArgExpand for () {
132    type CompilationArg = ();
133
134    fn expand(
135        _: &Self::CompilationArg,
136        _builder: &mut KernelBuilder,
137    ) -> <Self as CubeType>::ExpandType {
138    }
139}
140
141impl CubeType for () {
142    type ExpandType = ();
143}
144
145impl Init for () {
146    fn init(self, _context: &mut CubeContext) -> Self {
147        self
148    }
149}
150
151impl<T: Clone> CubeType for PhantomData<T> {
152    type ExpandType = ();
153}
154
155impl<T: Clone> IntoRuntime for PhantomData<T> {
156    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {}
157}
158
159/// Defines the argument settings used to launch a kernel.
160pub trait ArgSettings<R: Runtime>: Send + Sync {
161    /// Register the information to the [KernelLauncher].
162    fn register(&self, launcher: &mut KernelLauncher<R>);
163}
164
165/// Reference to a JIT variable
166#[derive(Clone, Debug)]
167pub enum ExpandElement {
168    /// Variable kept in the variable pool.
169    Managed(Rc<Variable>),
170    /// Variable not kept in the variable pool.
171    Plain(Variable),
172}
173
174/// Expand type associated with a type.
175#[derive(new)]
176pub struct ExpandElementTyped<T: CubeType> {
177    pub(crate) expand: ExpandElement,
178    pub(crate) _type: PhantomData<T>,
179}
180
181macro_rules! from_const {
182    ($lit:ty) => {
183        impl From<$lit> for ExpandElementTyped<$lit> {
184            fn from(value: $lit) -> Self {
185                let variable: Variable = value.into();
186
187                ExpandElement::Plain(variable).into()
188            }
189        }
190    };
191}
192
193from_const!(u8);
194from_const!(u16);
195from_const!(u32);
196from_const!(u64);
197from_const!(i64);
198from_const!(i8);
199from_const!(i16);
200from_const!(i32);
201from_const!(f64);
202from_const!(f16);
203from_const!(bf16);
204from_const!(flex32);
205from_const!(tf32);
206from_const!(f32);
207from_const!(bool);
208
209macro_rules! tuple_cube_type {
210    ($($P:ident),*) => {
211        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
212            type ExpandType = ($($P::ExpandType,)*);
213        }
214    }
215}
216macro_rules! tuple_init {
217    ($($P:ident),*) => {
218        impl<$($P: Init),*> Init for ($($P,)*) {
219            #[allow(non_snake_case)]
220            fn init(self, context: &mut CubeContext) -> Self {
221                let ($($P,)*) = self;
222                ($(
223                    $P.init(context),
224                )*)
225            }
226        }
227    }
228}
229macro_rules! tuple_runtime {
230    ($($P:ident),*) => {
231        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
232            #[allow(non_snake_case)]
233            fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType {
234                let ($($P,)*) = self;
235                ($(
236                    $P.__expand_runtime_method(context),
237                )*)
238            }
239        }
240    }
241}
242
243tuple_cube_type!(P1);
244tuple_cube_type!(P1, P2);
245tuple_cube_type!(P1, P2, P3);
246tuple_cube_type!(P1, P2, P3, P4);
247tuple_cube_type!(P1, P2, P3, P4, P5);
248tuple_cube_type!(P1, P2, P3, P4, P5, P6);
249
250tuple_init!(P1);
251tuple_init!(P1, P2);
252tuple_init!(P1, P2, P3);
253tuple_init!(P1, P2, P3, P4);
254tuple_init!(P1, P2, P3, P4, P5);
255tuple_init!(P1, P2, P3, P4, P5, P6);
256
257tuple_runtime!(P1);
258tuple_runtime!(P1, P2);
259tuple_runtime!(P1, P2, P3);
260tuple_runtime!(P1, P2, P3, P4);
261tuple_runtime!(P1, P2, P3, P4, P5);
262tuple_runtime!(P1, P2, P3, P4, P5, P6);
263
264pub trait ExpandElementBaseInit: CubeType {
265    fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement;
266}
267
268impl<T: ExpandElementBaseInit> Init for ExpandElementTyped<T> {
269    fn init(self, context: &mut CubeContext) -> Self {
270        <T as ExpandElementBaseInit>::init_elem(context, self.into()).into()
271    }
272}
273
274impl<T: CubeType> ExpandElementTyped<T> {
275    // Expanded version of vectorization factor.
276    pub fn __expand_vectorization_factor_method(self, _context: &mut CubeContext) -> u32 {
277        self.expand
278            .item
279            .vectorization
280            .map(|it| it.get())
281            .unwrap_or(1) as u32
282    }
283}
284
285impl<T: CubeType> Clone for ExpandElementTyped<T> {
286    fn clone(&self) -> Self {
287        Self {
288            expand: self.expand.clone(),
289            _type: PhantomData,
290        }
291    }
292}
293
294impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
295    fn from(expand: ExpandElement) -> Self {
296        Self {
297            expand,
298            _type: PhantomData,
299        }
300    }
301}
302
303impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
304    fn from(value: ExpandElementTyped<T>) -> Self {
305        value.expand
306    }
307}
308
309impl<T: CubePrimitive> ExpandElementTyped<T> {
310    /// Create an [ExpandElementTyped] from a value that is normally a literal.
311    pub fn from_lit<L: Into<Variable>>(context: &CubeContext, lit: L) -> Self {
312        let variable: Variable = lit.into();
313        let variable = T::as_elem(context).from_constant(variable);
314
315        ExpandElementTyped::new(ExpandElement::Plain(variable))
316    }
317
318    /// Get the [ConstantScalarValue] from the variable.
319    pub fn constant(&self) -> Option<ConstantScalarValue> {
320        match self.expand.kind {
321            VariableKind::ConstantScalar(val) => Some(val),
322            _ => None,
323        }
324    }
325}
326
327impl ExpandElement {
328    /// If the element can be mutated inplace, potentially reusing the register.
329    pub fn can_mut(&self) -> bool {
330        match self {
331            ExpandElement::Managed(var) => {
332                if let VariableKind::LocalMut { .. } = var.as_ref().kind {
333                    Rc::strong_count(var) <= 2
334                } else {
335                    false
336                }
337            }
338            ExpandElement::Plain(_) => false,
339        }
340    }
341
342    /// Explicitly consume the element, freeing it for reuse if no other copies exist.
343    pub fn consume(self) -> Variable {
344        *self
345    }
346}
347
348impl core::ops::Deref for ExpandElement {
349    type Target = Variable;
350
351    fn deref(&self) -> &Self::Target {
352        match self {
353            ExpandElement::Managed(var) => var.as_ref(),
354            ExpandElement::Plain(var) => var,
355        }
356    }
357}
358
359impl From<ExpandElement> for Variable {
360    fn from(value: ExpandElement) -> Self {
361        match value {
362            ExpandElement::Managed(var) => *var,
363            ExpandElement::Plain(var) => var,
364        }
365    }
366}
367
368pub(crate) fn init_expand_element<E: Into<ExpandElement>>(
369    context: &mut CubeContext,
370    element: E,
371) -> ExpandElement {
372    let elem = element.into();
373
374    if elem.can_mut() {
375        // Can reuse inplace :)
376        return elem;
377    }
378
379    let mut init = |elem: ExpandElement| init_expand(context, elem, Operation::Copy);
380
381    match elem.kind {
382        VariableKind::GlobalScalar { .. } => init(elem),
383        VariableKind::ConstantScalar { .. } => init(elem),
384        VariableKind::LocalMut { .. } => init(elem),
385        VariableKind::Versioned { .. } => init(elem),
386        VariableKind::LocalConst { .. } => init(elem),
387        // Constant should be initialized since the new variable can be mutated afterward.
388        // And it is assumed those values are cloned.
389        VariableKind::Builtin(_) => init(elem),
390        // Array types can't be copied, so we should simply return the same variable.
391        VariableKind::SharedMemory { .. }
392        | VariableKind::GlobalInputArray { .. }
393        | VariableKind::GlobalOutputArray { .. }
394        | VariableKind::LocalArray { .. }
395        | VariableKind::ConstantArray { .. }
396        | VariableKind::Slice { .. }
397        | VariableKind::Matrix { .. } => elem,
398    }
399}
400
401impl Init for ExpandElement {
402    fn init(self, context: &mut CubeContext) -> Self {
403        init_expand_element(context, self)
404    }
405}
406
407impl<T: Init> Init for Option<T> {
408    fn init(self, context: &mut CubeContext) -> Self {
409        self.map(|o| Init::init(o, context))
410    }
411}
412
413impl<T: CubeType> CubeType for Vec<T> {
414    type ExpandType = Vec<T::ExpandType>;
415}
416
417impl<T: CubeType> CubeType for &mut Vec<T> {
418    type ExpandType = Vec<T::ExpandType>;
419}
420
421impl<T: Init> Init for Vec<T> {
422    fn init(self, context: &mut CubeContext) -> Self {
423        self.into_iter().map(|e| e.init(context)).collect()
424    }
425}
426
427/// Create a constant element of the correct type during expansion.
428pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
429    context: &mut CubeContext,
430    val: C,
431) -> ExpandElementTyped<Out> {
432    let input: ExpandElementTyped<C> = val.into();
433    <Out as super::Cast>::__expand_cast_from(context, input)
434}