cubecl_core/frontend/element/
base.rs1use super::{CubePrimitive, Numeric};
2use crate::{
3 ir::{ConstantScalarValue, Operation, Scope, Variable, VariableKind},
4 prelude::{KernelBuilder, KernelLauncher, init_expand},
5};
6use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
7use cubecl_ir::ExpandElement;
8use cubecl_runtime::runtime::Runtime;
9use half::{bf16, f16};
10use std::marker::PhantomData;
11use variadics_please::all_tuples;
12
13#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
26pub trait CubeType {
27 type ExpandType: Clone + IntoMut + CubeDebug;
28
29 fn into_mut(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
31 expand.into_mut(scope)
32 }
33}
34
35pub trait IntoRuntime: CubeType + Sized {
37 fn runtime(self) -> Self {
38 self
39 }
40
41 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
42}
43
44pub trait IntoMut: Sized {
46 fn into_mut(self, scope: &mut Scope) -> Self;
47}
48
49pub trait CubeDebug: Sized {
50 #[allow(unused)]
53 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
54}
55
56pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
71impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
72
73pub trait CompilationArg:
75 Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
76{
77 fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
84 assert!(size_of::<Arg>() == size_of::<Self>());
87 let this = Box::new(self.clone());
88 unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
89 }
90}
91
92impl CompilationArg for () {}
93
94#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
103pub trait LaunchArg: CubeType + Send + Sync + 'static {
104 type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
106 type CompilationArg: CompilationArg;
108
109 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
110
111 fn expand(
113 arg: &Self::CompilationArg,
114 builder: &mut KernelBuilder,
115 ) -> <Self as CubeType>::ExpandType;
116
117 fn expand_output(
119 arg: &Self::CompilationArg,
120 builder: &mut KernelBuilder,
121 ) -> <Self as CubeType>::ExpandType {
122 Self::expand(arg, builder)
123 }
124}
125
126pub trait ArgSettings<R: Runtime>: Send + Sync {
128 fn register(&self, launcher: &mut KernelLauncher<R>);
130}
131
132macro_rules! launch_tuple {
133 ($(($T:ident, $t:ident)),*) => {
134 impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
135 type RuntimeArg<'a, R: Runtime> = ($($T::RuntimeArg<'a, R>),*);
136 type CompilationArg = ($($T::CompilationArg),*);
137
138 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
139 let ($($t),*) = runtime_arg;
140 ($($T::compilation_arg($t)),*)
141 }
142
143 fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
144 let ($($t),*) = arg;
145 ($($T::expand($t, builder)),*)
146 }
147
148 fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
149 let ($($t),*) = arg;
150 ($($T::expand_output($t, builder)),*)
151 }
152 }
153
154 impl<$($T: CompilationArg),*> CompilationArg for ($($T),*) {}
155
156 impl<R: Runtime, $($T: ArgSettings<R>),*> ArgSettings<R> for ($($T),*) {
157 fn register(&self, launcher: &mut KernelLauncher<R>) {
158 let ($($t),*) = self;
159 $($t.register(launcher);)*
160 }
161 }
162 };
163}
164
165all_tuples!(launch_tuple, 2, 12, T, t);
166
167#[derive(new)]
169pub struct ExpandElementTyped<T: CubeType> {
170 pub(crate) expand: ExpandElement,
171 pub(crate) _type: PhantomData<T>,
172}
173
174impl<T: CubeType> ExpandElementTyped<T> {
175 pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &ExpandElementTyped<E> {
179 unsafe { core::mem::transmute::<&ExpandElementTyped<T>, &ExpandElementTyped<E>>(self) }
180 }
181
182 pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut ExpandElementTyped<E> {
186 unsafe {
187 core::mem::transmute::<&mut ExpandElementTyped<T>, &mut ExpandElementTyped<E>>(self)
188 }
189 }
190}
191
192impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
193 fn from(value: &ExpandElementTyped<T>) -> Self {
194 value.clone()
195 }
196}
197
198impl<T: CubeType> From<ExpandElementTyped<T>> for Variable {
199 fn from(value: ExpandElementTyped<T>) -> Self {
200 value.expand.into()
201 }
202}
203
204impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
205 fn from(value: &mut ExpandElementTyped<T>) -> Self {
206 value.clone()
207 }
208}
209
210macro_rules! from_const {
211 ($lit:ty) => {
212 impl From<$lit> for ExpandElementTyped<$lit> {
213 fn from(value: $lit) -> Self {
214 let variable: Variable = value.into();
215
216 ExpandElement::Plain(variable).into()
217 }
218 }
219 };
220}
221
222from_const!(u8);
223from_const!(u16);
224from_const!(u32);
225from_const!(u64);
226from_const!(i64);
227from_const!(i8);
228from_const!(i16);
229from_const!(i32);
230from_const!(f64);
231from_const!(f16);
232from_const!(bf16);
233from_const!(flex32);
234from_const!(tf32);
235from_const!(f32);
236from_const!(e2m1);
237from_const!(e2m1x2);
238from_const!(e2m3);
239from_const!(e3m2);
240from_const!(e4m3);
241from_const!(e5m2);
242from_const!(ue8m0);
243from_const!(bool);
244
245macro_rules! tuple_cube_type {
246 ($($P:ident),*) => {
247 impl<$($P: CubeType),*> CubeType for ($($P,)*) {
248 type ExpandType = ($($P::ExpandType,)*);
249 }
250 }
251}
252macro_rules! tuple_init {
253 ($($P:ident),*) => {
254 impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
255 #[allow(non_snake_case, unused, clippy::unused_unit)]
256 fn into_mut(self, scope: &mut Scope) -> Self {
257 let ($($P,)*) = self;
258 ($(
259 $P.into_mut(scope),
260 )*)
261 }
262 }
263 }
264}
265macro_rules! tuple_debug {
266 ($($P:ident),*) => {
267 impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
268 }
269}
270macro_rules! tuple_runtime {
271 ($($P:ident),*) => {
272 impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
273 #[allow(non_snake_case, unused, clippy::unused_unit)]
274 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
275 let ($($P,)*) = self;
276 ($(
277 $P.__expand_runtime_method(scope),
278 )*)
279 }
280 }
281 }
282}
283
284all_tuples!(tuple_cube_type, 0, 12, P);
285all_tuples!(tuple_debug, 0, 12, P);
286all_tuples!(tuple_init, 0, 12, P);
287all_tuples!(tuple_runtime, 0, 12, P);
288
289impl<P: CubePrimitive> CubeDebug for P {}
290
291pub trait ExpandElementIntoMut: CubeType {
292 fn elem_into_mut(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
293}
294
295impl<T: ExpandElementIntoMut> IntoMut for ExpandElementTyped<T> {
296 fn into_mut(self, scope: &mut Scope) -> Self {
297 <T as ExpandElementIntoMut>::elem_into_mut(scope, self.into()).into()
298 }
299}
300
301impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
302 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
303 scope.update_variable_name(*self.expand, name);
304 }
305}
306
307impl<T: CubeType> ExpandElementTyped<T> {
308 pub fn line_size(&self) -> u32 {
310 self.expand.ty.line_size()
311 }
312
313 pub fn __expand_line_size_method(self, _scope: &mut Scope) -> u32 {
315 self.expand.ty.line_size()
316 }
317
318 pub fn into_variable(self) -> Variable {
319 self.expand.consume()
320 }
321}
322
323impl<T: CubeType> Clone for ExpandElementTyped<T> {
324 fn clone(&self) -> Self {
325 Self {
326 expand: self.expand.clone(),
327 _type: PhantomData,
328 }
329 }
330}
331
332impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
333 fn from(expand: ExpandElement) -> Self {
334 Self {
335 expand,
336 _type: PhantomData,
337 }
338 }
339}
340
341impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
342 fn from(value: ExpandElementTyped<T>) -> Self {
343 value.expand
344 }
345}
346
347impl<T: CubePrimitive> ExpandElementTyped<T> {
348 pub fn from_lit<L: Into<Variable>>(scope: &Scope, lit: L) -> Self {
350 let variable: Variable = lit.into();
351 let variable = T::as_type(scope).from_constant(variable);
352
353 ExpandElementTyped::new(ExpandElement::Plain(variable))
354 }
355
356 pub fn constant(&self) -> Option<ConstantScalarValue> {
358 match self.expand.kind {
359 VariableKind::ConstantScalar(val) => Some(val),
360 _ => None,
361 }
362 }
363
364 pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
365 let value = self.constant().unwrap();
366 T::from_const_value(value)
367 }
368}
369
370pub(crate) fn into_runtime_expand_element<E: Into<ExpandElement>>(
371 scope: &mut Scope,
372 element: E,
373) -> ExpandElement {
374 let elem = element.into();
375
376 match elem.kind {
377 VariableKind::ConstantScalar { .. } => init_expand(scope, elem, false, Operation::Copy),
378 _ => elem,
379 }
380}
381
382pub(crate) fn into_mut_expand_element<E: Into<ExpandElement>>(
383 scope: &mut Scope,
384 element: E,
385) -> ExpandElement {
386 let elem = element.into();
387
388 let mut init = |elem: ExpandElement| init_expand(scope, elem, true, Operation::Copy);
389
390 match elem.kind {
391 VariableKind::GlobalScalar { .. } => init(elem),
392 VariableKind::ConstantScalar { .. } => init(elem),
393 VariableKind::LocalMut { .. } => init(elem),
394 VariableKind::Versioned { .. } => init(elem),
395 VariableKind::LocalConst { .. } => init(elem),
396 VariableKind::Builtin(_) => init(elem),
397 VariableKind::Shared { .. }
398 | VariableKind::SharedArray { .. }
399 | VariableKind::GlobalInputArray { .. }
400 | VariableKind::GlobalOutputArray { .. }
401 | VariableKind::LocalArray { .. }
402 | VariableKind::ConstantArray { .. }
403 | VariableKind::Matrix { .. }
404 | VariableKind::BarrierToken { .. }
405 | VariableKind::Pipeline { .. }
406 | VariableKind::TensorMapOutput(_)
407 | VariableKind::TensorMapInput(_) => elem,
408 }
409}
410
411impl IntoMut for ExpandElement {
412 fn into_mut(self, scope: &mut Scope) -> Self {
413 into_mut_expand_element(scope, self)
414 }
415}
416
417impl<T: IntoMut> IntoMut for Option<T> {
418 fn into_mut(self, scope: &mut Scope) -> Self {
419 self.map(|o| IntoMut::into_mut(o, scope))
420 }
421}
422
423impl<T: CubeType> CubeType for Vec<T> {
424 type ExpandType = Vec<T::ExpandType>;
425}
426
427impl<T: CubeType> CubeType for &mut Vec<T> {
428 type ExpandType = Vec<T::ExpandType>;
429}
430
431impl<T: IntoMut> IntoMut for Vec<T> {
432 fn into_mut(self, scope: &mut Scope) -> Self {
433 self.into_iter().map(|e| e.into_mut(scope)).collect()
434 }
435}
436impl<T: CubeDebug> CubeDebug for Vec<T> {}
437
438pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
440 scope: &mut Scope,
441 val: C,
442) -> ExpandElementTyped<Out> {
443 let input: ExpandElementTyped<C> = val.into();
444 let const_val = input.expand.as_const().unwrap();
445 let var = Variable::constant(const_val.cast_to(Out::as_type(scope)));
446 ExpandElement::Plain(var).into()
447}
448
449impl LaunchArg for () {
450 type RuntimeArg<'a, R: Runtime> = ();
451 type CompilationArg = ();
452
453 fn compilation_arg<'a, R: Runtime>(
454 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
455 ) -> Self::CompilationArg {
456 }
457
458 fn expand(
459 _: &Self::CompilationArg,
460 _builder: &mut KernelBuilder,
461 ) -> <Self as CubeType>::ExpandType {
462 }
463}
464
465impl<R: Runtime> ArgSettings<R> for () {
466 fn register(&self, _launcher: &mut KernelLauncher<R>) {
467 }
469}