cubecl_runtime/features.rs
1use alloc::collections::{BTreeMap, BTreeSet};
2
3use cubecl_ir::{SemanticType, StorageType, Type};
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default)]
10pub struct Features {
11 /// Plane features supported by this runtime.
12 pub plane: EnumSet<Plane>,
13 /// Clustered launches and intra-cluster operations like cluster shared memory
14 pub cube_cluster: bool,
15 /// Enables to change the line size of containers during kernel execution.
16 pub dynamic_line_size: bool,
17 /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18 pub alignment: bool,
19
20 /// Types supported by this runtime, and which usages they support.
21 pub storage_types: BTreeMap<StorageType, EnumSet<TypeUsage>>,
22 /// Semantic constructs supported by this runtime.
23 pub semantic_types: BTreeSet<SemanticType>,
24
25 /// Whether `copy_async` is supported
26 pub copy_async: bool,
27 /// Tensor Memory Accelerator supported features
28 pub tma: EnumSet<Tma>,
29 /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
30 pub cmma: BTreeSet<MmaConfig>,
31 /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
32 /// movement
33 pub mma: BTreeSet<MmaConfig>,
34 /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
35 /// instruction. Scales must fit a specific layout and block size.
36 pub scaled_mma: BTreeSet<ScaledMmaConfig>,
37 /// Types supported for ldmatrix, if any
38 pub ldmatrix: BTreeSet<StorageType>,
39 /// Types supported by stmatrix, if any
40 pub stmatrix: BTreeSet<StorageType>,
41}
42
43/// Operations allowed for this type. CMMA is defined separately.
44#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
45pub enum TypeUsage {
46 /// Conversion to/from the type. All types should support this.
47 Conversion,
48 /// All math/logic instructions except dot product
49 Arithmetic,
50 /// Dot product, mainly for BF16 on Intel
51 DotProduct,
52 /// Whether this type can be stored in a buffer
53 Buffer,
54 /// Atomic loads and stores
55 AtomicLoadStore,
56 /// Atomic add/sub
57 AtomicAdd,
58 /// Atomic min/max
59 AtomicMinMax,
60}
61
62/// Supported plane features
63#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
64pub enum Plane {
65 /// Basic plane-wide operations
66 Ops,
67 /// Plane-wide sync
68 Sync,
69}
70
71/// Shape and element types of a valid MMA configuration
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
73pub struct MmaConfig {
74 /// Element of the A matrix
75 pub a_type: StorageType,
76 /// Element of the B matrix
77 pub b_type: StorageType,
78 /// Element of the C/D matrices
79 pub cd_type: StorageType,
80 /// The size of the matrix on the `m` dimension
81 pub m: u32,
82 /// The size of the matrix on the `n` dimension
83 pub n: u32,
84 /// The size of the matrix on the `k` dimension
85 pub k: u32,
86}
87
88/// Shape and element types of a valid block-scaled MMA configuration
89#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
90pub struct ScaledMmaConfig {
91 /// Element of the A matrix
92 pub a_type: StorageType,
93 /// Element of the B matrix
94 pub b_type: StorageType,
95 /// Element of the C/D matrices
96 pub cd_type: StorageType,
97 /// Element of the blocks scales
98 pub scales_type: StorageType,
99 /// The size of the matrix on the `m` dimension
100 pub m: u32,
101 /// The size of the matrix on the `n` dimension
102 pub n: u32,
103 /// The size of the matrix on the `k` dimension
104 pub k: u32,
105 /// Number of scales per tile row/col.
106 /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
107 /// Scales blocks must be organized along the natural `line_layout` of the operation
108 pub scales_factor: u32,
109}
110
111/// Atomic features that may be supported by a [cube runtime](Runtime).
112#[derive(Debug, PartialOrd, Ord, EnumSetType)]
113pub enum Tma {
114 /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
115 Base,
116 /// im2colWide encoding for tensor map.
117 Im2colWide,
118 /// Different atomicities for 128-byte swizzle, i.e. 128-byte with 32-byte atomicity.
119 SwizzleAtomicity,
120}
121
122impl Features {
123 /// Get the usages for a type
124 pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
125 self.storage_types
126 .get(&ty)
127 .cloned()
128 .unwrap_or_else(EnumSet::empty)
129 }
130
131 /// Whether the type is supported in any way
132 pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
133 match ty.into() {
134 Type::Scalar(storage_type) | Type::Line(storage_type, _) => {
135 self.storage_types.contains_key(&storage_type)
136 }
137 Type::Semantic(semantic_type) => self.semantic_types.contains(&semantic_type),
138 }
139 }
140}
141
142impl TypeUsage {
143 /// All uses except atomics
144 pub fn all_scalar() -> EnumSet<TypeUsage> {
145 TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct | TypeUsage::Buffer
146 }
147
148 /// All atomic uses
149 pub fn all_atomic() -> EnumSet<TypeUsage> {
150 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore | TypeUsage::AtomicMinMax
151 }
152}