cubecl_ir/features.rs
1use crate::{AddressType, SemanticType, StorageType, Type};
2use alloc::collections::{BTreeMap, BTreeSet};
3
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
10pub struct Features {
11 /// Plane features supported by this runtime.
12 pub plane: EnumSet<Plane>,
13 /// Clustered launches and intra-cluster operations like cluster shared memory
14 pub cube_cluster: bool,
15 /// Enables changing the type of containers during kernel execution.
16 pub memory_reinterpret: bool,
17 /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18 pub alignment: bool,
19
20 /// Type support
21 pub types: Types,
22 /// Matrix multiplication features
23 pub matmul: MatmulFeatures,
24
25 /// Whether `copy_async` is supported
26 pub copy_async: bool,
27 /// Tensor Memory Accelerator supported features
28 pub tma: EnumSet<Tma>,
29 /// Whether vectors can be read from / stored to addresses not aligned
30 /// with the `vector_size`
31 pub unaligned_io: bool,
32}
33
34/// Type support for a device
35#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
36pub struct Types {
37 /// Valid address types
38 pub address: BTreeSet<AddressType>,
39 /// Types supported by this runtime, and which usages they support.
40 pub storage: BTreeMap<StorageType, EnumSet<TypeUsage>>,
41 /// Semantic constructs supported by this runtime.
42 pub semantic: BTreeSet<SemanticType>,
43 /// Supported vector types for atomic ops, only specific vectorizations for specific types are
44 /// supported here. Not all vector types are supported as scalars, i.e. Vulkan on Nvidia only
45 /// supports vectorized `f16`, not scalar. Only use the exact vectorizations registered here.
46 /// These may not be supported everywhere - in practice, f32 vectors are only supported in global
47 /// memory.
48 pub atomic: BTreeMap<Type, EnumSet<AtomicUsage>>,
49}
50
51/// Matrix multiplication-related features
52#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
53pub struct MatmulFeatures {
54 /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
55 pub cmma: BTreeSet<MmaConfig>,
56 /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
57 /// movement
58 pub mma: BTreeSet<MmaConfig>,
59 /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
60 /// instruction. Scales must fit a specific layout and block size.
61 pub scaled_mma: BTreeSet<ScaledMmaConfig>,
62 /// Types supported for ldmatrix, if any
63 pub ldmatrix: BTreeSet<StorageType>,
64 /// Types supported by stmatrix, if any
65 pub stmatrix: BTreeSet<StorageType>,
66}
67
68/// Operations allowed for this type. CMMA is defined separately.
69#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
70pub enum TypeUsage {
71 /// Conversion to/from the type. All types should support this.
72 Conversion,
73 /// All math/logic instructions except dot product
74 Arithmetic,
75 /// Dot product, mainly for BF16 on Intel
76 DotProduct,
77 /// Whether this type can be stored in a buffer
78 Buffer,
79}
80
81impl TypeUsage {
82 pub fn all() -> EnumSet<Self> {
83 EnumSet::all()
84 }
85
86 pub fn no_store() -> EnumSet<Self> {
87 TypeUsage::Conversion | TypeUsage::Arithmetic
88 }
89
90 pub fn maybe_store(storable: bool) -> EnumSet<Self> {
91 if storable {
92 EnumSet::all()
93 } else {
94 Self::no_store()
95 }
96 }
97}
98
99/// Atomic operations allowed for this type.
100#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
101pub enum AtomicUsage {
102 /// Atomic loads and stores
103 LoadStore,
104 /// Atomic add/sub
105 Add,
106 /// Atomic min/max
107 MinMax,
108}
109
110impl AtomicUsage {
111 pub fn all() -> EnumSet<Self> {
112 EnumSet::all()
113 }
114}
115
116/// Supported plane features
117#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
118pub enum Plane {
119 /// Basic plane-wide operations
120 Ops,
121 /// Plane-wide sync
122 Sync,
123 /// Allows using plane operations with divergent control flow.
124 NonUniformControlFlow,
125}
126
127/// Shape and element types of a valid MMA configuration
128#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
129#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
130pub struct MmaConfig {
131 /// Element of the A matrix
132 pub a_type: StorageType,
133 /// Element of the B matrix
134 pub b_type: StorageType,
135 /// Element of the C/D matrices
136 pub cd_type: StorageType,
137 /// The size of the matrix on the `m` dimension
138 pub m: u32,
139 /// The size of the matrix on the `n` dimension
140 pub n: u32,
141 /// The size of the matrix on the `k` dimension
142 pub k: u32,
143}
144
145/// Shape and element types of a valid block-scaled MMA configuration
146#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
147#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
148pub struct ScaledMmaConfig {
149 /// Element of the A matrix
150 pub a_type: StorageType,
151 /// Element of the B matrix
152 pub b_type: StorageType,
153 /// Element of the C/D matrices
154 pub cd_type: StorageType,
155 /// Element of the blocks scales
156 pub scales_type: StorageType,
157 /// The size of the matrix on the `m` dimension
158 pub m: u32,
159 /// The size of the matrix on the `n` dimension
160 pub n: u32,
161 /// The size of the matrix on the `k` dimension
162 pub k: u32,
163 /// Number of scales per tile row/col.
164 /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
165 /// Scales blocks must be organized along the natural `vector_layout` of the operation
166 pub scales_factor: u32,
167}
168
169/// Atomic features that may be supported by a ``Runtime``.
170#[derive(Debug, PartialOrd, Ord, EnumSetType)]
171pub enum Tma {
172 /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
173 Base,
174 /// im2colWide encoding for tensor map.
175 Im2colWide,
176 /// Different atomicities for 128-byte swizzle, i.e. 128-byte with 32-byte atomicity.
177 SwizzleAtomicity,
178}
179
180impl Features {
181 /// Get the usages for a type
182 pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
183 self.types
184 .storage
185 .get(&ty)
186 .cloned()
187 .unwrap_or_else(EnumSet::empty)
188 }
189
190 /// Get the usages for an atomic type
191 pub fn atomic_type_usage(&self, ty: Type) -> EnumSet<AtomicUsage> {
192 self.types
193 .atomic
194 .get(&ty)
195 .cloned()
196 .unwrap_or_else(EnumSet::empty)
197 }
198
199 /// Whether the type is supported in any way
200 pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
201 match ty.into() {
202 Type::Scalar(storage_type) | Type::Vector(storage_type, _) => {
203 self.types.storage.contains_key(&storage_type)
204 }
205 Type::Semantic(semantic_type) => self.types.semantic.contains(&semantic_type),
206 }
207 }
208
209 /// Whether the address type is supported in any way
210 pub fn supports_address(&self, ty: impl Into<AddressType>) -> bool {
211 self.types.address.contains(&ty.into())
212 }
213}