Skip to main content

cubecl_core/frontend/element/
cube_elem.rs

1use core::fmt::Debug;
2
3use crate::{
4    self as cubecl, Assign, IntoRuntime,
5    prelude::{Const, CubeDebug, IntoMut, Size},
6};
7use cubecl_ir::{ConstantValue, ManagedVariable, StorageType, Type, features::TypeUsage};
8use cubecl_macros::{comptime_type, cube, intrinsic};
9use cubecl_runtime::{client::ComputeClient, runtime::Runtime};
10use enumset::EnumSet;
11
12use crate::frontend::CubeType;
13use crate::ir::Scope;
14
15use super::{NativeAssign, NativeExpand};
16
17/// Form of `CubeType` that encapsulates all primitive types:
18/// Numeric, `UInt`, Bool
19pub trait CubePrimitive:
20    CubeType<ExpandType = NativeExpand<Self>>
21    + NativeAssign
22    // + IntoRuntime
23    + core::cmp::PartialEq
24    + Send
25    + Sync
26    + 'static
27    + Clone
28    + Copy
29{
30    type Scalar: Scalar;
31    type Size: Size;
32    type WithScalar<S: Scalar>: CubePrimitive;
33
34    /// Return the element type to use on GPU.
35    fn as_type(_scope: &Scope) -> Type {
36        Self::as_type_native().expect("To be overridden if not native")
37    }
38
39    /// Native or static element type.
40    fn as_type_native() -> Option<Type> {
41        None
42    }
43
44    /// Native or static element type.
45    fn as_type_native_unchecked() -> Type {
46        Self::as_type_native().expect("To be a native type")
47    }
48
49    /// Only native element types have a size.
50    fn size() -> Option<usize> {
51        Self::as_type_native().map(|t| t.size())
52    }
53
54    /// Only native element types have a size.
55    fn size_bits() -> Option<usize> {
56        Self::as_type_native().map(|t| t.size_bits())
57    }
58
59    /// Only native element types have a size.
60    fn size_bits_unchecked() -> usize {
61        Self::as_type_native_unchecked().size_bits()
62    }
63
64    fn from_expand_elem(elem: ManagedVariable) -> Self::ExpandType {
65        NativeExpand::new(elem)
66    }
67
68    fn from_const_value(value: ConstantValue) -> Self;
69
70    fn into_lit_unchecked(self) -> Self {
71        self
72    }
73
74    fn supported_uses<R: Runtime>(
75        client: &ComputeClient<R>,
76    ) -> EnumSet<TypeUsage> {
77        let elem = Self::as_type_native_unchecked();
78        client.features().type_usage(elem.storage_type())
79    }
80
81    fn type_size() -> usize {
82        Self::as_type_native_unchecked().size()
83    }
84
85    fn type_size_bits() -> usize {
86        Self::as_type_native_unchecked().size_bits()
87    }
88
89    fn packing_factor() -> usize {
90        Self::as_type_native_unchecked().packing_factor()
91    }
92
93    fn vector_size() -> usize {
94        Self::as_type_native_unchecked().vector_size()
95    }
96
97    fn __expand_type_size(scope: &Scope) -> usize {
98        Self::as_type(scope).size()
99    }
100
101    fn __expand_type_size_bits(scope: &Scope) -> usize {
102        Self::as_type(scope).size_bits()
103    }
104
105    fn __expand_packing_factor(scope: &Scope) -> usize {
106        Self::as_type(scope).packing_factor()
107    }
108
109    fn __expand_vector_size(scope: &Scope) -> usize {
110        Self::as_type(scope).vector_size()
111    }
112}
113
114pub trait CubePrimitiveExpand {
115    type Scalar: Clone + IntoMut + CubeDebug + Assign;
116    type WithScalar<S: Scalar>: Clone + IntoMut + CubeDebug + Assign;
117}
118
119impl<T: CubePrimitive> CubePrimitiveExpand for NativeExpand<T> {
120    type Scalar = NativeExpand<T::Scalar>;
121    type WithScalar<S: Scalar> = NativeExpand<T::WithScalar<S>>;
122}
123
124/// Marker trait for scalar primitives. Should be implemented for all scalar `CubePrimitive`s, but
125/// **not** for `Vector` or non-standard primitives like `Barrier`. Alternatively, treat these as
126/// types that can be stored in a [`Vector`]
127pub trait Scalar:
128    CubePrimitive<Scalar = Self, Size = Const<1>> + Default + IntoRuntime + Debug
129{
130}
131
132#[cube]
133pub fn type_of<E: CubePrimitive>() -> comptime_type!(Type) {
134    intrinsic!(|scope| { E::as_type(scope) })
135}
136
137#[cube]
138pub fn storage_type_of<E: CubePrimitive>() -> comptime_type!(StorageType) {
139    intrinsic!(|scope| { E::as_type(scope).storage_type() })
140}