cubecl_runtime/features.rs
1use alloc::collections::{BTreeMap, BTreeSet};
2
3use cubecl_ir::{SemanticType, StorageType, Type};
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default)]
10pub struct Features {
11 /// Plane features supported by this runtime.
12 pub plane: EnumSet<Plane>,
13 /// Clustered launches and intra-cluster operations like cluster shared memory
14 pub cube_cluster: bool,
15 /// Enables to change the line size of containers during kernel execution.
16 pub dynamic_line_size: bool,
17 /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18 pub alignment: bool,
19
20 /// Types supported by this runtime, and which usages they support.
21 pub storage_types: BTreeMap<StorageType, EnumSet<TypeUsage>>,
22 /// Semantic constructs supported by this runtime.
23 pub semantic_types: BTreeSet<SemanticType>,
24
25 /// Tensor Memory Accelerator supported features
26 pub tma: EnumSet<Tma>,
27 /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
28 pub cmma: BTreeSet<MmaConfig>,
29 /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
30 /// movement
31 pub mma: BTreeSet<MmaConfig>,
32 /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
33 /// instruction. Scales must fit a specific layout and block size.
34 pub scaled_mma: BTreeSet<ScaledMmaConfig>,
35 /// Types supported for ldmatrix, if any
36 pub ldmatrix: BTreeSet<StorageType>,
37 /// Types supported by stmatrix, if any
38 pub stmatrix: BTreeSet<StorageType>,
39}
40
41/// Operations allowed for this type. CMMA is defined separately.
42#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
43pub enum TypeUsage {
44 /// Conversion to/from the type. All types should support this.
45 Conversion,
46 /// All math/logic instructions except dot product
47 Arithmetic,
48 /// Dot product, mainly for BF16 on Intel
49 DotProduct,
50 /// Whether this type can be stored in a buffer
51 Buffer,
52 /// Atomic loads and stores
53 AtomicLoadStore,
54 /// Atomic add/sub
55 AtomicAdd,
56 /// Atomic min/max
57 AtomicMinMax,
58}
59
60/// Supported plane features
61#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
62pub enum Plane {
63 /// Basic plane-wide operations
64 Ops,
65 /// Plane-wide sync
66 Sync,
67}
68
69/// Shape and element types of a valid MMA configuration
70#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
71pub struct MmaConfig {
72 /// Element of the A matrix
73 pub a_type: StorageType,
74 /// Element of the B matrix
75 pub b_type: StorageType,
76 /// Element of the C/D matrices
77 pub cd_type: StorageType,
78 /// The size of the matrix on the `m` dimension
79 pub m: u32,
80 /// The size of the matrix on the `n` dimension
81 pub n: u32,
82 /// The size of the matrix on the `k` dimension
83 pub k: u32,
84}
85
86/// Shape and element types of a valid block-scaled MMA configuration
87#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
88pub struct ScaledMmaConfig {
89 /// Element of the A matrix
90 pub a_type: StorageType,
91 /// Element of the B matrix
92 pub b_type: StorageType,
93 /// Element of the C/D matrices
94 pub cd_type: StorageType,
95 /// Element of the blocks scales
96 pub scales_type: StorageType,
97 /// The size of the matrix on the `m` dimension
98 pub m: u32,
99 /// The size of the matrix on the `n` dimension
100 pub n: u32,
101 /// The size of the matrix on the `k` dimension
102 pub k: u32,
103 /// Number of scales per tile row/col.
104 /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
105 /// Scales blocks must be organized along the natural `line_layout` of the operation
106 pub scales_factor: u32,
107}
108
109/// Atomic features that may be supported by a [cube runtime](Runtime).
110#[derive(Debug, PartialOrd, Ord, EnumSetType)]
111pub enum Tma {
112 /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
113 Base,
114 /// im2colWide encoding for tensor map.
115 Im2colWide,
116}
117
118impl Features {
119 /// Get the usages for a type
120 pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
121 self.storage_types
122 .get(&ty)
123 .cloned()
124 .unwrap_or_else(EnumSet::empty)
125 }
126
127 /// Whether the type is supported in any way
128 pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
129 match ty.into() {
130 Type::Scalar(storage_type) | Type::Line(storage_type, _) => {
131 self.storage_types.contains_key(&storage_type)
132 }
133 Type::Semantic(semantic_type) => self.semantic_types.contains(&semantic_type),
134 }
135 }
136}
137
138impl TypeUsage {
139 /// All uses except atomics
140 pub fn all_scalar() -> EnumSet<TypeUsage> {
141 TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct | TypeUsage::Buffer
142 }
143
144 /// All atomic uses
145 pub fn all_atomic() -> EnumSet<TypeUsage> {
146 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore | TypeUsage::AtomicMinMax
147 }
148}