cubecl_core/frontend/element/
base.rs1use super::{CubePrimitive, Numeric};
2use crate::{
3 Runtime,
4 ir::{ConstantScalarValue, Operation, Scope, Variable, VariableKind},
5 prelude::{KernelBuilder, KernelLauncher, init_expand},
6};
7use cubecl_common::{flex32, tf32};
8use cubecl_ir::ExpandElement;
9use half::{bf16, f16};
10use std::{
11 any::{Any, TypeId},
12 marker::PhantomData,
13};
14use variadics_please::all_tuples;
15
16#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
29pub trait CubeType {
30 type ExpandType: Clone + Init + CubeDebug;
31
32 fn init(scope: &mut Scope, expand: Self::ExpandType) -> Self::ExpandType {
34 expand.init(scope)
35 }
36}
37
38pub trait IntoRuntime: CubeType + Sized {
40 fn runtime(self) -> Self {
41 self
42 }
43
44 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
45}
46
47pub trait Init: Sized {
49 fn init(self, scope: &mut Scope) -> Self;
54}
55
56pub trait CubeDebug: Sized {
57 #[allow(unused)]
60 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
61}
62
63pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
78impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
79
80pub trait CubeLaunch: CubeType + LaunchArg + LaunchArgExpand {}
82impl<T: CubeType + LaunchArg + LaunchArgExpand> CubeLaunch for T {}
83
84pub trait CompilationArg:
86 serde::Serialize
87 + serde::de::DeserializeOwned
88 + Clone
89 + PartialEq
90 + Eq
91 + core::hash::Hash
92 + core::fmt::Debug
93 + Send
94 + Sync
95 + 'static
96{
97 fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
104 if TypeId::of::<Arg>() == TypeId::of::<Self>() {
105 let tmp: Box<dyn Any> = Box::new(self.clone());
106 *tmp.downcast().unwrap()
107 } else {
108 let val = serde_json::to_string(self).unwrap();
109 serde_json::from_str(&val)
110 .expect("Compilation argument should be the same even with different element types")
111 }
112 }
113}
114
115impl CompilationArg for () {}
116
117#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
126pub trait LaunchArgExpand: CubeType {
127 type CompilationArg: CompilationArg;
129
130 fn expand(
132 arg: &Self::CompilationArg,
133 builder: &mut KernelBuilder,
134 ) -> <Self as CubeType>::ExpandType;
135
136 fn expand_output(
138 arg: &Self::CompilationArg,
139 builder: &mut KernelBuilder,
140 ) -> <Self as CubeType>::ExpandType {
141 Self::expand(arg, builder)
142 }
143}
144
145pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
147 type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
149
150 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
151}
152
153pub trait ArgSettings<R: Runtime>: Send + Sync {
155 fn register(&self, launcher: &mut KernelLauncher<R>);
157}
158
159#[derive(new)]
161pub struct ExpandElementTyped<T: CubeType> {
162 pub(crate) expand: ExpandElement,
163 pub(crate) _type: PhantomData<T>,
164}
165
166impl<T: CubeType> From<&ExpandElementTyped<T>> for ExpandElementTyped<T> {
167 fn from(value: &ExpandElementTyped<T>) -> Self {
168 value.clone()
169 }
170}
171
172impl<T: CubeType> From<&mut ExpandElementTyped<T>> for ExpandElementTyped<T> {
173 fn from(value: &mut ExpandElementTyped<T>) -> Self {
174 value.clone()
175 }
176}
177
178macro_rules! from_const {
179 ($lit:ty) => {
180 impl From<$lit> for ExpandElementTyped<$lit> {
181 fn from(value: $lit) -> Self {
182 let variable: Variable = value.into();
183
184 ExpandElement::Plain(variable).into()
185 }
186 }
187 };
188}
189
190from_const!(u8);
191from_const!(u16);
192from_const!(u32);
193from_const!(u64);
194from_const!(i64);
195from_const!(i8);
196from_const!(i16);
197from_const!(i32);
198from_const!(f64);
199from_const!(f16);
200from_const!(bf16);
201from_const!(flex32);
202from_const!(tf32);
203from_const!(f32);
204from_const!(bool);
205
206macro_rules! tuple_cube_type {
207 ($($P:ident),*) => {
208 impl<$($P: CubeType),*> CubeType for ($($P,)*) {
209 type ExpandType = ($($P::ExpandType,)*);
210 }
211 }
212}
213macro_rules! tuple_init {
214 ($($P:ident),*) => {
215 impl<$($P: Init),*> Init for ($($P,)*) {
216 #[allow(non_snake_case, unused, clippy::unused_unit)]
217 fn init(self, scope: &mut Scope) -> Self {
218 let ($($P,)*) = self;
219 ($(
220 $P.init(scope),
221 )*)
222 }
223 }
224 }
225}
226macro_rules! tuple_debug {
227 ($($P:ident),*) => {
228 impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
229 }
230}
231macro_rules! tuple_runtime {
232 ($($P:ident),*) => {
233 impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
234 #[allow(non_snake_case, unused, clippy::unused_unit)]
235 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
236 let ($($P,)*) = self;
237 ($(
238 $P.__expand_runtime_method(scope),
239 )*)
240 }
241 }
242 }
243}
244
245all_tuples!(tuple_cube_type, 0, 12, P);
246all_tuples!(tuple_debug, 0, 12, P);
247all_tuples!(tuple_init, 0, 12, P);
248all_tuples!(tuple_runtime, 0, 12, P);
249
250impl<P: CubePrimitive> CubeDebug for P {}
251
252pub trait ExpandElementBaseInit: CubeType {
253 fn init_elem(scope: &mut Scope, elem: ExpandElement) -> ExpandElement;
254}
255
256impl<T: ExpandElementBaseInit> Init for ExpandElementTyped<T> {
257 fn init(self, scope: &mut Scope) -> Self {
258 <T as ExpandElementBaseInit>::init_elem(scope, self.into()).into()
259 }
260}
261
262impl<T: CubeType> CubeDebug for ExpandElementTyped<T> {
263 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
264 scope.update_variable_name(*self.expand, name);
265 }
266}
267
268impl<T: CubeType> ExpandElementTyped<T> {
269 pub fn __expand_vectorization_factor_method(self, _scope: &mut Scope) -> u32 {
271 self.expand
272 .item
273 .vectorization
274 .map(|it| it.get())
275 .unwrap_or(1) as u32
276 }
277
278 pub fn into_variable(self) -> Variable {
279 self.expand.consume()
280 }
281}
282
283impl<T: CubeType> Clone for ExpandElementTyped<T> {
284 fn clone(&self) -> Self {
285 Self {
286 expand: self.expand.clone(),
287 _type: PhantomData,
288 }
289 }
290}
291
292impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
293 fn from(expand: ExpandElement) -> Self {
294 Self {
295 expand,
296 _type: PhantomData,
297 }
298 }
299}
300
301impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
302 fn from(value: ExpandElementTyped<T>) -> Self {
303 value.expand
304 }
305}
306
307impl<T: CubePrimitive> ExpandElementTyped<T> {
308 pub fn from_lit<L: Into<Variable>>(scope: &Scope, lit: L) -> Self {
310 let variable: Variable = lit.into();
311 let variable = T::as_elem(scope).from_constant(variable);
312
313 ExpandElementTyped::new(ExpandElement::Plain(variable))
314 }
315
316 pub fn constant(&self) -> Option<ConstantScalarValue> {
318 match self.expand.kind {
319 VariableKind::ConstantScalar(val) => Some(val),
320 _ => None,
321 }
322 }
323}
324
325pub(crate) fn init_expand_element<E: Into<ExpandElement>>(
326 scope: &mut Scope,
327 element: E,
328) -> ExpandElement {
329 let elem = element.into();
330
331 if elem.can_mut() {
332 return elem;
334 }
335
336 let mut init = |elem: ExpandElement| init_expand(scope, elem, Operation::Copy);
337
338 match elem.kind {
339 VariableKind::GlobalScalar { .. } => init(elem),
340 VariableKind::ConstantScalar { .. } => init(elem),
341 VariableKind::LocalMut { .. } => init(elem),
342 VariableKind::Versioned { .. } => init(elem),
343 VariableKind::LocalConst { .. } => init(elem),
344 VariableKind::Builtin(_) => init(elem),
345 VariableKind::SharedMemory { .. }
346 | VariableKind::GlobalInputArray { .. }
347 | VariableKind::GlobalOutputArray { .. }
348 | VariableKind::LocalArray { .. }
349 | VariableKind::ConstantArray { .. }
350 | VariableKind::Slice { .. }
351 | VariableKind::Matrix { .. }
352 | VariableKind::Barrier { .. }
353 | VariableKind::Pipeline { .. }
354 | VariableKind::TensorMap(_) => elem,
355 }
356}
357
358impl Init for ExpandElement {
359 fn init(self, scope: &mut Scope) -> Self {
360 init_expand_element(scope, self)
361 }
362}
363
364impl<T: Init> Init for Option<T> {
365 fn init(self, scope: &mut Scope) -> Self {
366 self.map(|o| Init::init(o, scope))
367 }
368}
369
370impl<T: CubeType> CubeType for Vec<T> {
371 type ExpandType = Vec<T::ExpandType>;
372}
373
374impl<T: CubeType> CubeType for &mut Vec<T> {
375 type ExpandType = Vec<T::ExpandType>;
376}
377
378impl<T: Init> Init for Vec<T> {
379 fn init(self, scope: &mut Scope) -> Self {
380 self.into_iter().map(|e| e.init(scope)).collect()
381 }
382}
383impl<T: CubeDebug> CubeDebug for Vec<T> {}
384
385pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
387 scope: &mut Scope,
388 val: C,
389) -> ExpandElementTyped<Out> {
390 let input: ExpandElementTyped<C> = val.into();
391 let const_val = input.expand.as_const().unwrap();
392 let var = Variable::constant(const_val.cast_to(Out::as_elem(scope)));
393 ExpandElement::Plain(var).into()
394}
395
396impl LaunchArg for () {
397 type RuntimeArg<'a, R: Runtime> = ();
398
399 fn compilation_arg<'a, R: Runtime>(
400 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
401 ) -> Self::CompilationArg {
402 }
403}
404
405impl<R: Runtime> ArgSettings<R> for () {
406 fn register(&self, _launcher: &mut KernelLauncher<R>) {
407 }
409}
410
411impl LaunchArgExpand for () {
412 type CompilationArg = ();
413
414 fn expand(
415 _: &Self::CompilationArg,
416 _builder: &mut KernelBuilder,
417 ) -> <Self as CubeType>::ExpandType {
418 }
419}