cubecl_runtime/features.rs
1use alloc::collections::{BTreeMap, BTreeSet};
2
3use cubecl_ir::{SemanticType, StorageType, Type};
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default)]
10pub struct Features {
11 /// Plane features supported by this runtime.
12 pub plane: EnumSet<Plane>,
13 /// Clustered launches and intra-cluster operations like cluster shared memory
14 pub cube_cluster: bool,
15 /// Enables to change the line size of containers during kernel execution.
16 pub dynamic_line_size: bool,
17
18 /// Types supported by this runtime, and which usages they support.
19 pub storage_types: BTreeMap<StorageType, EnumSet<TypeUsage>>,
20 /// Semantic constructs supported by this runtime.
21 pub semantic_types: BTreeSet<SemanticType>,
22
23 /// Tensor Memory Accelerator supported features
24 pub tma: EnumSet<Tma>,
25 /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
26 pub cmma: BTreeSet<MmaConfig>,
27 /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
28 /// movement
29 pub mma: BTreeSet<MmaConfig>,
30 /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
31 /// instruction. Scales must fit a specific layout and block size.
32 pub scaled_mma: BTreeSet<ScaledMmaConfig>,
33}
34
35/// Operations allowed for this type. CMMA is defined separately.
36#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
37pub enum TypeUsage {
38 /// Conversion to/from the type. All types should support this.
39 Conversion,
40 /// All math/logic instructions except dot product
41 Arithmetic,
42 /// Dot product, mainly for BF16 on Intel
43 DotProduct,
44 /// Whether this type can be stored in a buffer
45 Buffer,
46 /// Atomic loads and stores
47 AtomicLoadStore,
48 /// Atomic add/sub
49 AtomicAdd,
50 /// Atomic min/max
51 AtomicMinMax,
52}
53
54/// Supported plane features
55#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
56pub enum Plane {
57 /// Basic plane-wide operations
58 Ops,
59 /// Plane-wide sync
60 Sync,
61}
62
63/// Shape and element types of a valid MMA configuration
64#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub struct MmaConfig {
66 /// Element of the A matrix
67 pub a_type: StorageType,
68 /// Element of the B matrix
69 pub b_type: StorageType,
70 /// Element of the C/D matrices
71 pub cd_type: StorageType,
72 /// The size of the matrix on the `m` dimension
73 pub m: u32,
74 /// The size of the matrix on the `n` dimension
75 pub n: u32,
76 /// The size of the matrix on the `k` dimension
77 pub k: u32,
78}
79
80/// Shape and element types of a valid block-scaled MMA configuration
81#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
82pub struct ScaledMmaConfig {
83 /// Element of the A matrix
84 pub a_type: StorageType,
85 /// Element of the B matrix
86 pub b_type: StorageType,
87 /// Element of the C/D matrices
88 pub cd_type: StorageType,
89 /// Element of the blocks scales
90 pub scales_type: StorageType,
91 /// The size of the matrix on the `m` dimension
92 pub m: u32,
93 /// The size of the matrix on the `n` dimension
94 pub n: u32,
95 /// The size of the matrix on the `k` dimension
96 pub k: u32,
97 /// Number of scales per tile row/col.
98 /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
99 /// Scales blocks must be organized along the natural `line_layout` of the operation
100 pub scales_factor: u32,
101}
102
103/// Atomic features that may be supported by a [cube runtime](Runtime).
104#[derive(Debug, PartialOrd, Ord, EnumSetType)]
105pub enum Tma {
106 /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
107 Base,
108 /// im2colWide encoding for tensor map.
109 Im2colWide,
110}
111
112impl Features {
113 /// Get the usages for a type
114 pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
115 self.storage_types
116 .get(&ty)
117 .cloned()
118 .unwrap_or_else(EnumSet::empty)
119 }
120
121 /// Whether the type is supported in any way
122 pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
123 match ty.into() {
124 Type::Scalar(storage_type) | Type::Line(storage_type, _) => {
125 self.storage_types.contains_key(&storage_type)
126 }
127 Type::Semantic(semantic_type) => self.semantic_types.contains(&semantic_type),
128 }
129 }
130}
131
132impl TypeUsage {
133 /// All uses except atomics
134 pub fn all_scalar() -> EnumSet<TypeUsage> {
135 TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct | TypeUsage::Buffer
136 }
137
138 /// All atomic uses
139 pub fn all_atomic() -> EnumSet<TypeUsage> {
140 TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore | TypeUsage::AtomicMinMax
141 }
142}