1use super::{CubePrimitive, Numeric};
2use crate::{
3 ir::{ConstantValue, Scope, Variable, VariableKind},
4 prelude::{DynamicSize, KernelBuilder, KernelLauncher, assign},
5 unexpanded,
6};
7use alloc::{boxed::Box, vec::Vec};
8use core::marker::PhantomData;
9use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
10use cubecl_ir::{ManagedVariable, VectorSize};
11use cubecl_runtime::runtime::Runtime;
12use half::{bf16, f16};
13use variadics_please::{all_tuples, all_tuples_enumerated};
14
15#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
28pub trait CubeType {
29 type ExpandType: Clone + IntoMut + CubeDebug;
30}
31
32pub trait CubeEnum: Sized {
33 type RuntimeValue: Clone + CubeDebug;
34
35 fn discriminant(&self) -> NativeExpand<i32>;
36
37 fn runtime_value(self) -> Self::RuntimeValue;
40
41 fn discriminant_of_value(&self, variant_name: &'static str) -> i32 {
42 Self::discriminant_of(variant_name)
43 }
44
45 fn discriminant_of(variant_name: &'static str) -> i32;
46}
47
48pub trait Assign {
49 fn expand_assign(&mut self, scope: &mut Scope, value: Self);
51 fn init_mut(&self, scope: &mut Scope) -> Self;
53}
54
55impl<T: CubePrimitive> Assign for T {
56 fn expand_assign(&mut self, _scope: &mut Scope, value: Self) {
57 *self = value;
58 }
59 fn init_mut(&self, _scope: &mut Scope) -> Self {
60 *self
61 }
62}
63
64impl<T: NativeAssign> Assign for NativeExpand<T> {
65 fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
66 assign::expand(scope, value, self.clone());
67 }
68 fn init_mut(&self, scope: &mut Scope) -> Self {
69 T::elem_init_mut(scope, self.expand.clone()).into()
70 }
71}
72
73impl<T: Assign> Assign for Option<T> {
74 fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
75 match (self, value) {
76 (Some(this), Some(other)) => this.expand_assign(scope, other),
77 (None, None) => {}
78 _ => panic!("Can't assign mismatched enum variants"),
79 }
80 }
81 fn init_mut(&self, scope: &mut Scope) -> Self {
82 self.as_ref().map(|value| value.init_mut(scope))
83 }
84}
85
86impl<T: Assign> Assign for Vec<T> {
87 fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
88 assert!(
89 self.len() == value.len(),
90 "Can't assign mismatched vector lengths"
91 );
92 for (this, other) in self.iter_mut().zip(value) {
93 this.expand_assign(scope, other);
94 }
95 }
96 fn init_mut(&self, scope: &mut Scope) -> Self {
97 self.iter().map(|it| it.init_mut(scope)).collect()
98 }
99}
100
101pub trait CloneExpand {
102 fn __expand_clone_method(&self, scope: &mut Scope) -> Self;
103}
104
105impl<C: Clone> CloneExpand for C {
106 fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
107 self.clone()
108 }
109}
110
111pub trait IntoRuntime: CubeType + Sized {
113 fn runtime(self) -> Self {
114 self
115 }
116
117 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
118}
119
120pub trait IntoComptime: Sized {
122 #[allow(clippy::wrong_self_convention)]
123 fn comptime(self) -> Self {
124 self
125 }
126}
127
128impl<T: Sized> IntoComptime for T {}
129
130pub trait IntoMut: Sized {
132 fn into_mut(self, scope: &mut Scope) -> Self;
134}
135
136pub fn into_mut_assign<T: Assign>(value: T, scope: &mut Scope) -> T {
137 let mut out = value.init_mut(scope);
138 out.expand_assign(scope, value);
139 out
140}
141
142pub trait CubeDebug: Sized {
143 #[allow(unused)]
146 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
147}
148
149pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
164impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
165
166pub trait CompilationArg:
168 Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
169{
170 fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
177 assert!(size_of::<Arg>() == size_of::<Self>());
180 let this = Box::new(self.clone());
181 unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
182 }
183}
184
185impl<T: Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static>
186 CompilationArg for T
187{
188}
189
190#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
199pub trait LaunchArg: CubeType + Send + Sync + 'static {
200 type RuntimeArg<R: Runtime>: Send + Sync;
202 type CompilationArg: CompilationArg;
204
205 fn register<R: Runtime>(
206 arg: Self::RuntimeArg<R>,
207 launcher: &mut KernelLauncher<R>,
208 ) -> Self::CompilationArg;
209
210 fn expand(
212 arg: &Self::CompilationArg,
213 builder: &mut KernelBuilder,
214 ) -> <Self as CubeType>::ExpandType;
215
216 fn expand_output(
218 arg: &Self::CompilationArg,
219 builder: &mut KernelBuilder,
220 ) -> <Self as CubeType>::ExpandType {
221 Self::expand(arg, builder)
222 }
223}
224
225macro_rules! launch_tuple {
226 ($(($T:ident, $t:ident)),*) => {
227 impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
228 type RuntimeArg<R: Runtime> = ($($T::RuntimeArg<R>),*);
229 type CompilationArg = ($($T::CompilationArg),*);
230
231 fn register<R: Runtime>(runtime_arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) -> Self::CompilationArg {
232 let ($($t),*) = runtime_arg;
233 ($($T::register($t, launcher)),*)
234 }
235
236 fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
237 let ($($t),*) = arg;
238 ($($T::expand($t, builder)),*)
239 }
240
241 fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
242 let ($($t),*) = arg;
243 ($($T::expand_output($t, builder)),*)
244 }
245 }
246 };
247}
248
249all_tuples!(launch_tuple, 2, 12, T, t);
250
251#[derive(new)]
253pub struct NativeExpand<T: CubeType> {
254 pub expand: ManagedVariable,
255 pub(crate) _type: PhantomData<T>,
256}
257
258impl<T: CubeType> NativeExpand<T> {
259 pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &NativeExpand<E> {
263 unsafe { core::mem::transmute::<&NativeExpand<T>, &NativeExpand<E>>(self) }
264 }
265
266 pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut NativeExpand<E> {
270 unsafe { core::mem::transmute::<&mut NativeExpand<T>, &mut NativeExpand<E>>(self) }
271 }
272}
273
274impl<T: CubeType> From<&NativeExpand<T>> for NativeExpand<T> {
275 fn from(value: &NativeExpand<T>) -> Self {
276 value.clone()
277 }
278}
279
280impl<T: CubeType> From<NativeExpand<T>> for Variable {
281 fn from(value: NativeExpand<T>) -> Self {
282 value.expand.into()
283 }
284}
285
286impl<T: CubeType> From<&mut NativeExpand<T>> for NativeExpand<T> {
287 fn from(value: &mut NativeExpand<T>) -> Self {
288 value.clone()
289 }
290}
291
292macro_rules! from_const {
293 ($lit:ty) => {
294 impl From<$lit> for NativeExpand<$lit> {
295 fn from(value: $lit) -> Self {
296 let variable: Variable = value.into();
297
298 ManagedVariable::Plain(variable).into()
299 }
300 }
301 };
302}
303
304from_const!(u8);
305from_const!(u16);
306from_const!(u32);
307from_const!(u64);
308from_const!(usize);
309from_const!(isize);
310from_const!(i64);
311from_const!(i8);
312from_const!(i16);
313from_const!(i32);
314from_const!(f64);
315from_const!(f16);
316from_const!(bf16);
317from_const!(flex32);
318from_const!(tf32);
319from_const!(f32);
320from_const!(e2m1);
321from_const!(e2m1x2);
322from_const!(e2m3);
323from_const!(e3m2);
324from_const!(e4m3);
325from_const!(e5m2);
326from_const!(ue8m0);
327from_const!(bool);
328
329macro_rules! tuple_cube_type {
330 ($($P:ident),*) => {
331 impl<$($P: CubeType),*> CubeType for ($($P,)*) {
332 type ExpandType = ($($P::ExpandType,)*);
333 }
334 }
335}
336macro_rules! tuple_init {
337 ($($P:ident),*) => {
338 impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
339 #[allow(non_snake_case, unused, clippy::unused_unit)]
340 fn into_mut(self, scope: &mut Scope) -> Self {
341 let ($($P,)*) = self;
342 ($(
343 $P.into_mut(scope),
344 )*)
345 }
346 }
347 }
348}
349macro_rules! tuple_debug {
350 ($($P:ident),*) => {
351 impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
352 }
353}
354macro_rules! tuple_runtime {
355 ($($P:ident),*) => {
356 impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
357 #[allow(non_snake_case, unused, clippy::unused_unit)]
358 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
359 let ($($P,)*) = self;
360 ($(
361 $P.__expand_runtime_method(scope),
362 )*)
363 }
364 }
365 }
366}
367macro_rules! tuple_assign {
368 ($(($n: tt, $P:ident)),*) => {
369 impl<$($P: Assign),*> Assign for ($($P,)*) {
370 #[allow(non_snake_case, unused, clippy::unused_unit)]
371 fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
372 let ($($P,)*) = self;
373 $(
374 $P.expand_assign(scope, value.$n);
375 )*
376 }
377 #[allow(non_snake_case, unused, clippy::unused_unit)]
378 fn init_mut(&self, scope: &mut Scope) -> Self {
379 let ($($P,)*) = self;
380 ($(
381 $P.init_mut(scope),
382 )*)
383 }
384 }
385 }
386}
387
388all_tuples!(tuple_cube_type, 0, 12, P);
389all_tuples!(tuple_debug, 0, 12, P);
390all_tuples!(tuple_init, 0, 12, P);
391all_tuples!(tuple_runtime, 0, 12, P);
392all_tuples_enumerated!(tuple_assign, 0, 12, P);
393
394impl<P: CubePrimitive> CubeDebug for P {}
395
396pub trait NativeAssign: CubeType {
398 fn elem_init_mut(scope: &mut Scope, elem: ManagedVariable) -> ManagedVariable {
399 init_mut_expand_element(scope, &elem)
400 }
401}
402
403impl<T: NativeAssign> IntoMut for NativeExpand<T> {
404 fn into_mut(self, scope: &mut Scope) -> Self {
405 into_mut_assign(self, scope)
406 }
407}
408
409impl<T: CubeType> CubeDebug for NativeExpand<T> {
410 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
411 scope.update_variable_name(*self.expand, name);
412 }
413}
414
415impl<T: CubeType> CubeDebug for &NativeExpand<T> {
416 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
417 scope.update_variable_name(*self.expand, name);
418 }
419}
420
421impl<T: CubeType> CubeDebug for &mut NativeExpand<T> {
422 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
423 scope.update_variable_name(*self.expand, name);
424 }
425}
426
427impl<T: CubeType> NativeExpand<T> {
428 pub fn vector_size(&self) -> VectorSize {
430 self.expand.ty.vector_size()
431 }
432
433 pub fn __expand_vector_size_method(self, _scope: &mut Scope) -> VectorSize {
435 self.expand.ty.vector_size()
436 }
437
438 pub fn into_variable(self) -> Variable {
439 self.expand.consume()
440 }
441}
442
443impl<T: CubeType> Clone for NativeExpand<T> {
444 fn clone(&self) -> Self {
445 Self {
446 expand: self.expand.clone(),
447 _type: PhantomData,
448 }
449 }
450}
451
452impl<T: CubeType> From<ManagedVariable> for NativeExpand<T> {
453 fn from(expand: ManagedVariable) -> Self {
454 Self {
455 expand,
456 _type: PhantomData,
457 }
458 }
459}
460
461impl<T: CubeType> From<NativeExpand<T>> for ManagedVariable {
462 fn from(value: NativeExpand<T>) -> Self {
463 value.expand
464 }
465}
466
467impl<T: CubePrimitive> NativeExpand<T> {
468 pub fn from_lit<L: Into<ConstantValue>>(scope: &Scope, lit: L) -> Self {
470 let variable: ConstantValue = lit.into();
471 let variable = T::as_type(scope).constant(variable);
472
473 NativeExpand::new(ManagedVariable::Plain(variable))
474 }
475
476 pub fn constant(&self) -> Option<ConstantValue> {
478 match self.expand.kind {
479 VariableKind::Constant(val) => Some(val),
480 _ => None,
481 }
482 }
483
484 pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
485 let value = self.constant().unwrap();
486 T::from_const_value(value)
487 }
488}
489
490pub(crate) fn init_mut_expand_element(
491 scope: &mut Scope,
492 element: &ManagedVariable,
493) -> ManagedVariable {
494 scope.create_local_mut(element.ty)
495}
496
497impl<T: IntoMut> IntoMut for Option<T> {
498 fn into_mut(self, scope: &mut Scope) -> Self {
499 self.map(|o| IntoMut::into_mut(o, scope))
500 }
501}
502
503impl<T: CubeType> CubeType for Vec<T> {
504 type ExpandType = Vec<T::ExpandType>;
505}
506
507impl<T: CubeType> CubeType for &mut Vec<T> {
508 type ExpandType = Vec<T::ExpandType>;
509}
510
511impl<T: IntoMut> IntoMut for Vec<T> {
512 fn into_mut(self, scope: &mut Scope) -> Self {
513 self.into_iter().map(|e| e.into_mut(scope)).collect()
514 }
515}
516impl<T: CubeDebug> CubeDebug for Vec<T> {}
517
518pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
520 scope: &mut Scope,
521 val: C,
522) -> NativeExpand<Out> {
523 let input: ConstantValue = val.into();
524 let var = Out::as_type(scope).constant(input);
525 ManagedVariable::Plain(var).into()
526}
527
528impl LaunchArg for () {
529 type RuntimeArg<R: Runtime> = ();
530 type CompilationArg = ();
531
532 fn register<R: Runtime>(_runtime_arg: Self::RuntimeArg<R>, _launcher: &mut KernelLauncher<R>) {
533 }
535
536 fn expand(
537 _: &Self::CompilationArg,
538 _builder: &mut KernelBuilder,
539 ) -> <Self as CubeType>::ExpandType {
540 }
541}
542
543pub trait DefaultExpand: CubeType {
544 fn __expand_default(scope: &mut Scope) -> Self::ExpandType;
545}
546
547impl<T: CubeType + Default + IntoRuntime> DefaultExpand for T {
548 fn __expand_default(scope: &mut Scope) -> T::ExpandType {
549 T::default().__expand_runtime_method(scope)
550 }
551}
552
553#[derive(Clone, Copy, Debug)]
554pub struct Const<const N: usize>;
555
556pub trait Size: core::fmt::Debug + Clone + Copy + Send + Sync + 'static {
557 fn __expand_value(scope: &Scope) -> usize;
558 fn value() -> usize {
559 unexpanded!()
560 }
561 fn try_value_const() -> Option<usize> {
562 None
563 }
564}
565
566impl<const VALUE: usize> Size for Const<VALUE> {
567 fn __expand_value(_scope: &Scope) -> usize {
568 VALUE
569 }
570 fn value() -> usize {
571 VALUE
572 }
573 fn try_value_const() -> Option<usize> {
574 Some(VALUE)
575 }
576}
577
578impl<Marker: 'static> Size for DynamicSize<Marker> {
579 fn __expand_value(scope: &Scope) -> usize {
580 scope.resolve_size::<Self>().expect("Size to be registered")
581 }
582 fn value() -> usize {
583 unexpanded!()
584 }
585}
586
587#[macro_export]
590macro_rules! define_scalar {
591 ($vis: vis $name: ident) => {
592 $crate::__private::paste! {
593 $vis struct [<__ $name>];
594 $vis type $name = $crate::prelude::DynamicScalar<[<__ $name>]>;
595 }
596 };
597}
598
599#[macro_export]
601macro_rules! define_size {
602 ($vis: vis $name: ident) => {
603 $crate::__private::paste! {
604 $vis struct [<__ $name>];
605 $vis type $name = $crate::prelude::DynamicSize<[<__ $name>]>;
606 }
607 };
608}