cubecl_core/frontend/element/
base.rs1use super::{CubePrimitive, Numeric};
2use crate::{
3 Runtime,
4 ir::{ConstantScalarValue, Operation, Scope, Variable, VariableKind},
5 prelude::{KernelBuilder, KernelLauncher, init_expand},
6};
7use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
8use cubecl_ir::ExpandElement;
9use half::{bf16, f16};
10use std::marker::PhantomData;
11use variadics_please::all_tuples;
12
13#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
26pub trait CubeType {
27 type ExpandType: Clone + IntoMut + CubeDebug;
28
29 fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
31 expand.into_mut(scope)
32 }
33}
34
35pub trait IntoRuntime: CubeType + Sized {
37 fn runtime(self) -> Self {
38 self
39 }
40
41 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
42}
43
44pub trait IntoMut: Sized {
46 fn into_mut(self, scope: &mut Scope) -> Self;
47}
48
49pub trait CubeDebug: Sized {
50 #[allow(unused)]
53 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
54}
55
56pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
71impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
72
73pub trait CompilationArg:
75 Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
76{
77 fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
84 assert!(size_of::<Arg>() == size_of::<Self>());
87 let this = Box::new(self.clone());
88 unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
89 }
90}
91
92impl CompilationArg for () {}
93
94#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
103pub trait LaunchArg: CubeType + Send + Sync + 'static {
104 type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
106 type CompilationArg: CompilationArg;
108
109 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
110
111 fn expand(
113 arg: &Self::CompilationArg,
114 builder: &mut KernelBuilder,
115 ) -> <Self as CubeType>::ExpandType;
116
117 fn expand_output(
119 arg: &Self::CompilationArg,
120 builder: &mut KernelBuilder,
121 ) -> <Self as CubeType>::ExpandType {
122 Self::expand(arg, builder)
123 }
124}
125
126pub trait ArgSettings<R: Runtime>: Send + Sync {
128 fn register(&self, launcher: &mut KernelLauncher<R>);
130}
131
132macro_rules! launch_tuple {
133 ($(($T:ident, $t:ident)),*) => {
134 impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
135 type RuntimeArg<'a, R: Runtime> = ($($T::RuntimeArg<'a, R>),*);
136 type CompilationArg = ($($T::CompilationArg),*);
137
138 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
139 let ($($t),*) = runtime_arg;
140 ($($T::compilation_arg($t)),*)
141 }
142
143 fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
144 let ($($t),*) = arg;
145 ($($T::expand($t, builder)),*)
146 }
147
148 fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
149 let ($($t),*) = arg;
150 ($($T::expand_output($t, builder)),*)
151 }
152 }
153
154 impl<$($T: CompilationArg),*> CompilationArg for ($($T),*) {}
155
156 impl<R: Runtime, $($T: ArgSettings<R>),*> ArgSettings<R> for ($($T),*) {
157 fn register(&self, launcher: &mut KernelLauncher<R>) {
158 let ($($t),*) = self;
159 $($t.register(launcher);)*
160 }
161 }
162 };
163}
164
165all_tuples!(launch_tuple, 2, 12, T, t);
166
167#[derive(new)]
169pub struct ExpandElementTyped<T: CubeType> {
170 pub(crate) expand: ExpandElement,
171 pub(crate) _type: PhantomData<T>,
172}
173
174impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
175 fn from(value: &ExpandElementTyped<T>) -> Self {
176 value.clone()
177 }
178}
179
180impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
181 fn from(value: ExpandElementTyped<T>) -> Self {
182 value.expand.into()
183 }
184}
185
186impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
187 fn from(value: &mut ExpandElementTyped<T>) -> Self {
188 value.clone()
189 }
190}
191
192macro_rules! from_const {
193 ($lit:ty) => {
194 impl From<$lit> for ExpandElementTyped<$lit> {
195 fn from(value: $lit) -> Self {
196 let variable: Variable = value.into();
197
198 ExpandElement::Plain(variable).into()
199 }
200 }
201 };
202}
203
204from_const!(u8);
205from_const!(u16);
206from_const!(u32);
207from_const!(u64);
208from_const!(i64);
209from_const!(i8);
210from_const!(i16);
211from_const!(i32);
212from_const!(f64);
213from_const!(f16);
214from_const!(bf16);
215from_const!(flex32);
216from_const!(tf32);
217from_const!(f32);
218from_const!(e2m1);
219from_const!(e2m1x2);
220from_const!(e2m3);
221from_const!(e3m2);
222from_const!(e4m3);
223from_const!(e5m2);
224from_const!(ue8m0);
225from_const!(bool);
226
227macro_rules! tuple_cube_type {
228 ($($P:ident),*) => {
229 impl<$($P: CubeType),*> CubeType for ($($P,)*) {
230 type ExpandType = ($($P::ExpandType,)*);
231 }
232 }
233}
234macro_rules! tuple_init {
235 ($($P:ident),*) => {
236 impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
237 #[allow(non_snake_case, unused, clippy::unused_unit)]
238 fn into_mut(self, scope: &mut Scope) -> Self {
239 let ($($P,)*) = self;
240 ($(
241 $P.into_mut(scope),
242 )*)
243 }
244 }
245 }
246}
247macro_rules! tuple_debug {
248 ($($P:ident),*) => {
249 impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
250 }
251}
252macro_rules! tuple_runtime {
253 ($($P:ident),*) => {
254 impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
255 #[allow(non_snake_case, unused, clippy::unused_unit)]
256 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
257 let ($($P,)*) = self;
258 ($(
259 $P.__expand_runtime_method(scope),
260 )*)
261 }
262 }
263 }
264}
265
266all_tuples!(tuple_cube_type, 0, 12, P);
267all_tuples!(tuple_debug, 0, 12, P);
268all_tuples!(tuple_init, 0, 12, P);
269all_tuples!(tuple_runtime, 0, 12, P);
270
271impl<P: CubePrimitive> CubeDebug for P {}
272
273pub trait ExpandElementIntoMut: CubeType {
274 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
275}
276
277impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
278 fn into_mut(self, scope: &mut Scope) -> Self {
279 <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
280 }
281}
282
283impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
284 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
285 scope.update_variable_name(*self.expand, name);
286 }
287}
288
289impl<T: CubeType> ExpandElementTyped<T> {
290 pub fn line_size(&self) -> u32 {
292 self.expand.ty.line_size()
293 }
294
295 pub fn __expand_line_size_method(self, _scope: &mut Scope) -> u32 {
297 self.expand.ty.line_size()
298 }
299
300 pub fn into_variable(self) -> Variable {
301 self.expand.consume()
302 }
303}
304
305impl<T: CubeType> Clone for ExpandElementTyped<T> {
306 fn clone(&self) -> Self {
307 Self {
308 expand: self.expand.clone(),
309 _type: PhantomData,
310 }
311 }
312}
313
314impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
315 fn from(expand: ExpandElement) -> Self {
316 Self {
317 expand,
318 _type: PhantomData,
319 }
320 }
321}
322
323impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
324 fn from(value: ExpandElementTyped<T>) -> Self {
325 value.expand
326 }
327}
328
329impl<T: CubePrimitive> ExpandElementTyped<T> {
330 pub fn from_lit<L: Into<Variable>>(scope: &Scope, lit: L) -> Self {
332 let variable: Variable = lit.into();
333 let variable = T::as_type(scope).from_constant(variable);
334
335 ExpandElementTyped::new(ExpandElement::Plain(variable))
336 }
337
338 pub fn constant(&self) -> Option<ConstantScalarValue> {
340 match self.expand.kind {
341 VariableKind::ConstantScalar(val) => Some(val),
342 _ => None,
343 }
344 }
345
346 pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
347 let value = self.constant().unwrap();
348 T::from_const_value(value)
349 }
350}
351
352pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
353 scope: &mut Scope,
354 element: E,
355) -> ExpandElement {
356 let elem = element.into();
357
358 match elem.kind {
359 VariableKind::ConstantScalar { .. } => init_expand(scope, elem, false, Operation::Copy),
360 _ => elem,
361 }
362}
363
364pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
365 scope: &mut Scope,
366 element: E,
367) -> ExpandElement {
368 let elem = element.into();
369
370 let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
371
372 match elem.kind {
373 VariableKind::GlobalScalar { .. } => init(elem),
374 VariableKind::ConstantScalar { .. } => init(elem),
375 VariableKind::LocalMut { .. } => init(elem),
376 VariableKind::Versioned { .. } => init(elem),
377 VariableKind::LocalConst { .. } => init(elem),
378 VariableKind::Builtin(_) => init(elem),
379 VariableKind::SharedMemory { .. }
380 | VariableKind::GlobalInputArray { .. }
381 | VariableKind::GlobalOutputArray { .. }
382 | VariableKind::LocalArray { .. }
383 | VariableKind::ConstantArray { .. }
384 | VariableKind::Matrix { .. }
385 | VariableKind::Barrier { .. }
386 | VariableKind::Pipeline { .. }
387 | VariableKind::TensorMapOutput(_)
388 | VariableKind::TensorMapInput(_) => elem,
389 }
390}
391
392impl IntoMut for ExpandElement {
393 fn into_mut(self, scope: &mut Scope) -> Self {
394 into_mut_expand_element(scope, self)
395 }
396}
397
398impl<T: IntoMut> IntoMut for Option<T> {
399 fn into_mut(self, scope: &mut Scope) -> Self {
400 self.map(|o| IntoMut::into_mut(o, scope))
401 }
402}
403
404impl<T: CubeType> CubeType for Vec<T> {
405 type ExpandType = Vec<T::ExpandType>;
406}
407
408impl<T: CubeType> CubeType for &mut Vec<T> {
409 type ExpandType = Vec<T::ExpandType>;
410}
411
412impl<T: IntoMut> IntoMut for Vec<T> {
413 fn into_mut(self, scope: &mut Scope) -> Self {
414 self.into_iter().map(|e| e.into_mut(scope)).collect()
415 }
416}
417impl<T: CubeDebug> CubeDebug for Vec<T> {}
418
419pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
421 scope: &mut Scope,
422 val: C,
423) -> ExpandElementTyped<Out> {
424 let input: ExpandElementTyped<C> = val.into();
425 let const_val = input.expand.as_const().unwrap();
426 let var = Variable::constant(const_val.cast_to(Out::as_type(scope)));
427 ExpandElement::Plain(var).into()
428}
429
430impl LaunchArg for () {
431 type RuntimeArg<'a, R: Runtime> = ();
432 type CompilationArg = ();
433
434 fn compilation_arg<'a, R: Runtime>(
435 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
436 ) -> Self::CompilationArg {
437 }
438
439 fn expand(
440 _: &Self::CompilationArg,
441 _builder: &mut KernelBuilder,
442 ) -> <Self as CubeType>::ExpandType {
443 }
444}
445
446impl<R: Runtime> ArgSettings<R> for () {
447 fn register(&self, _launcher: &mut KernelLauncher<R>) {
448 }
450}