cubecl_core/frontend/element/
base.rs1use super::{flex32, tf32, CubePrimitive, Numeric};
2use crate::{
3 ir::{ConstantScalarValue, Operation, Variable, VariableKind},
4 prelude::{init_expand, CubeContext, KernelBuilder, KernelLauncher},
5 Runtime,
6};
7use alloc::rc::Rc;
8use half::{bf16, f16};
9use std::marker::PhantomData;
10
11#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
24pub trait CubeType {
25 type ExpandType: Clone + Init;
26
27 fn init(context: &mut CubeContext, expand: Self::ExpandType) -> Self::ExpandType {
29 expand.init(context)
30 }
31}
32
33pub trait IntoRuntime: CubeType + Sized {
37 fn runtime(self) -> Self {
39 self
40 }
41
42 fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType;
43}
44
45pub trait Init: Sized {
47 fn init(self, context: &mut CubeContext) -> Self;
52}
53
54pub trait CompilationArg:
56 serde::Serialize
57 + serde::de::DeserializeOwned
58 + Clone
59 + PartialEq
60 + Eq
61 + core::hash::Hash
62 + core::fmt::Debug
63 + Send
64 + Sync
65 + 'static
66{
67 fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
74 let val = serde_json::to_string(self).unwrap();
75
76 serde_json::from_str(&val)
77 .expect("Compilation argument should be the same even with different element types")
78 }
79}
80
81impl CompilationArg for () {}
82
83#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
90pub trait LaunchArgExpand: CubeType {
91 type CompilationArg: CompilationArg;
93
94 fn expand(
96 arg: &Self::CompilationArg,
97 builder: &mut KernelBuilder,
98 ) -> <Self as CubeType>::ExpandType;
99 fn expand_output(
101 arg: &Self::CompilationArg,
102 builder: &mut KernelBuilder,
103 ) -> <Self as CubeType>::ExpandType {
104 Self::expand(arg, builder)
105 }
106}
107
108pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static {
110 type RuntimeArg<'a, R: Runtime>: ArgSettings<R>;
112
113 fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg;
114}
115
116impl LaunchArg for () {
117 type RuntimeArg<'a, R: Runtime> = ();
118
119 fn compilation_arg<'a, R: Runtime>(
120 _runtime_arg: &'a Self::RuntimeArg<'a, R>,
121 ) -> Self::CompilationArg {
122 }
123}
124
125impl<R: Runtime> ArgSettings<R> for () {
126 fn register(&self, _launcher: &mut KernelLauncher<R>) {
127 }
129}
130
131impl LaunchArgExpand for () {
132 type CompilationArg = ();
133
134 fn expand(
135 _: &Self::CompilationArg,
136 _builder: &mut KernelBuilder,
137 ) -> <Self as CubeType>::ExpandType {
138 }
139}
140
141impl CubeType for () {
142 type ExpandType = ();
143}
144
145impl Init for () {
146 fn init(self, _context: &mut CubeContext) -> Self {
147 self
148 }
149}
150
151impl<T: Clone> CubeType for PhantomData<T> {
152 type ExpandType = ();
153}
154
155impl<T: Clone> IntoRuntime for PhantomData<T> {
156 fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {}
157}
158
159pub trait ArgSettings<R: Runtime>: Send + Sync {
161 fn register(&self, launcher: &mut KernelLauncher<R>);
163}
164
165#[derive(Clone, Debug)]
167pub enum ExpandElement {
168 Managed(Rc<Variable>),
170 Plain(Variable),
172}
173
174#[derive(new)]
176pub struct ExpandElementTyped<T: CubeType> {
177 pub(crate) expand: ExpandElement,
178 pub(crate) _type: PhantomData<T>,
179}
180
181macro_rules! from_const {
182 ($lit:ty) => {
183 impl From<$lit> for ExpandElementTyped<$lit> {
184 fn from(value: $lit) -> Self {
185 let variable: Variable = value.into();
186
187 ExpandElement::Plain(variable).into()
188 }
189 }
190 };
191}
192
193from_const!(u8);
194from_const!(u16);
195from_const!(u32);
196from_const!(u64);
197from_const!(i64);
198from_const!(i8);
199from_const!(i16);
200from_const!(i32);
201from_const!(f64);
202from_const!(f16);
203from_const!(bf16);
204from_const!(flex32);
205from_const!(tf32);
206from_const!(f32);
207from_const!(bool);
208
209macro_rules! tuple_cube_type {
210 ($($P:ident),*) => {
211 impl<$($P: CubeType),*> CubeType for ($($P,)*) {
212 type ExpandType = ($($P::ExpandType,)*);
213 }
214 }
215}
216macro_rules! tuple_init {
217 ($($P:ident),*) => {
218 impl<$($P: Init),*> Init for ($($P,)*) {
219 #[allow(non_snake_case)]
220 fn init(self, context: &mut CubeContext) -> Self {
221 let ($($P,)*) = self;
222 ($(
223 $P.init(context),
224 )*)
225 }
226 }
227 }
228}
229macro_rules! tuple_runtime {
230 ($($P:ident),*) => {
231 impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
232 #[allow(non_snake_case)]
233 fn __expand_runtime_method(self, context: &mut CubeContext) -> Self::ExpandType {
234 let ($($P,)*) = self;
235 ($(
236 $P.__expand_runtime_method(context),
237 )*)
238 }
239 }
240 }
241}
242
243tuple_cube_type!(P1);
244tuple_cube_type!(P1, P2);
245tuple_cube_type!(P1, P2, P3);
246tuple_cube_type!(P1, P2, P3, P4);
247tuple_cube_type!(P1, P2, P3, P4, P5);
248tuple_cube_type!(P1, P2, P3, P4, P5, P6);
249
250tuple_init!(P1);
251tuple_init!(P1, P2);
252tuple_init!(P1, P2, P3);
253tuple_init!(P1, P2, P3, P4);
254tuple_init!(P1, P2, P3, P4, P5);
255tuple_init!(P1, P2, P3, P4, P5, P6);
256
257tuple_runtime!(P1);
258tuple_runtime!(P1, P2);
259tuple_runtime!(P1, P2, P3);
260tuple_runtime!(P1, P2, P3, P4);
261tuple_runtime!(P1, P2, P3, P4, P5);
262tuple_runtime!(P1, P2, P3, P4, P5, P6);
263
264pub trait ExpandElementBaseInit: CubeType {
265 fn init_elem(context: &mut CubeContext, elem: ExpandElement) -> ExpandElement;
266}
267
268impl<T: ExpandElementBaseInit> Init for ExpandElementTyped<T> {
269 fn init(self, context: &mut CubeContext) -> Self {
270 <T as ExpandElementBaseInit>::init_elem(context, self.into()).into()
271 }
272}
273
274impl<T: CubeType> ExpandElementTyped<T> {
275 pub fn __expand_vectorization_factor_method(self, _context: &mut CubeContext) -> u32 {
277 self.expand
278 .item
279 .vectorization
280 .map(|it| it.get())
281 .unwrap_or(1) as u32
282 }
283}
284
285impl<T: CubeType> Clone for ExpandElementTyped<T> {
286 fn clone(&self) -> Self {
287 Self {
288 expand: self.expand.clone(),
289 _type: PhantomData,
290 }
291 }
292}
293
294impl<T: CubeType> From<ExpandElement> for ExpandElementTyped<T> {
295 fn from(expand: ExpandElement) -> Self {
296 Self {
297 expand,
298 _type: PhantomData,
299 }
300 }
301}
302
303impl<T: CubeType> From<ExpandElementTyped<T>> for ExpandElement {
304 fn from(value: ExpandElementTyped<T>) -> Self {
305 value.expand
306 }
307}
308
309impl<T: CubePrimitive> ExpandElementTyped<T> {
310 pub fn from_lit<L: Into<Variable>>(context: &CubeContext, lit: L) -> Self {
312 let variable: Variable = lit.into();
313 let variable = T::as_elem(context).from_constant(variable);
314
315 ExpandElementTyped::new(ExpandElement::Plain(variable))
316 }
317
318 pub fn constant(&self) -> Option<ConstantScalarValue> {
320 match self.expand.kind {
321 VariableKind::ConstantScalar(val) => Some(val),
322 _ => None,
323 }
324 }
325}
326
327impl ExpandElement {
328 pub fn can_mut(&self) -> bool {
330 match self {
331 ExpandElement::Managed(var) => {
332 if let VariableKind::LocalMut { .. } = var.as_ref().kind {
333 Rc::strong_count(var) <= 2
334 } else {
335 false
336 }
337 }
338 ExpandElement::Plain(_) => false,
339 }
340 }
341
342 pub fn consume(self) -> Variable {
344 *self
345 }
346}
347
348impl core::ops::Deref for ExpandElement {
349 type Target = Variable;
350
351 fn deref(&self) -> &Self::Target {
352 match self {
353 ExpandElement::Managed(var) => var.as_ref(),
354 ExpandElement::Plain(var) => var,
355 }
356 }
357}
358
359impl From<ExpandElement> for Variable {
360 fn from(value: ExpandElement) -> Self {
361 match value {
362 ExpandElement::Managed(var) => *var,
363 ExpandElement::Plain(var) => var,
364 }
365 }
366}
367
368pub(crate) fn init_expand_element<E: Into<ExpandElement>>(
369 context: &mut CubeContext,
370 element: E,
371) -> ExpandElement {
372 let elem = element.into();
373
374 if elem.can_mut() {
375 return elem;
377 }
378
379 let mut init = |elem: ExpandElement| init_expand(context, elem, Operation::Copy);
380
381 match elem.kind {
382 VariableKind::GlobalScalar { .. } => init(elem),
383 VariableKind::ConstantScalar { .. } => init(elem),
384 VariableKind::LocalMut { .. } => init(elem),
385 VariableKind::Versioned { .. } => init(elem),
386 VariableKind::LocalConst { .. } => init(elem),
387 VariableKind::Builtin(_) => init(elem),
390 VariableKind::SharedMemory { .. }
392 | VariableKind::GlobalInputArray { .. }
393 | VariableKind::GlobalOutputArray { .. }
394 | VariableKind::LocalArray { .. }
395 | VariableKind::ConstantArray { .. }
396 | VariableKind::Slice { .. }
397 | VariableKind::Matrix { .. } => elem,
398 }
399}
400
401impl Init for ExpandElement {
402 fn init(self, context: &mut CubeContext) -> Self {
403 init_expand_element(context, self)
404 }
405}
406
407impl<T: Init> Init for Option<T> {
408 fn init(self, context: &mut CubeContext) -> Self {
409 self.map(|o| Init::init(o, context))
410 }
411}
412
413impl<T: CubeType> CubeType for Vec<T> {
414 type ExpandType = Vec<T::ExpandType>;
415}
416
417impl<T: CubeType> CubeType for &mut Vec<T> {
418 type ExpandType = Vec<T::ExpandType>;
419}
420
421impl<T: Init> Init for Vec<T> {
422 fn init(self, context: &mut CubeContext) -> Self {
423 self.into_iter().map(|e| e.init(context)).collect()
424 }
425}
426
427pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
429 context: &mut CubeContext,
430 val: C,
431) -> ExpandElementTyped<Out> {
432 let input: ExpandElementTyped<C> = val.into();
433 <Out as super::Cast>::__expand_cast_from(context, input)
434}