1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use alloc::collections::{BTreeMap, BTreeSet};
use cubecl_ir::{SemanticType, StorageType, Type};
use enumset::EnumSetType;
pub use enumset::EnumSet;
/// Features supported by a runtime
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Features {
/// Plane features supported by this runtime.
pub plane: EnumSet<Plane>,
/// Clustered launches and intra-cluster operations like cluster shared memory
pub cube_cluster: bool,
/// Enables to change the line size of containers during kernel execution.
pub dynamic_line_size: bool,
/// Types supported by this runtime, and which usages they support.
pub storage_types: BTreeMap<StorageType, EnumSet<TypeUsage>>,
/// Semantic constructs supported by this runtime.
pub semantic_types: BTreeSet<SemanticType>,
/// Tensor Memory Accelerator supported features
pub tma: EnumSet<Tma>,
/// The cmma feature enables cooperative matrix-multiply and accumulate operations.
pub cmma: BTreeSet<MmaConfig>,
/// The manual MMA feature enables cooperative matrix-multiply with manually managed data
/// movement
pub mma: BTreeSet<MmaConfig>,
/// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
/// instruction. Scales must fit a specific layout and block size.
pub scaled_mma: BTreeSet<ScaledMmaConfig>,
}
/// Operations allowed for this type. CMMA is defined separately.
#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
pub enum TypeUsage {
/// Conversion to/from the type. All types should support this.
Conversion,
/// All math/logic instructions except dot product
Arithmetic,
/// Dot product, mainly for BF16 on Intel
DotProduct,
/// Whether this type can be stored in a buffer
Buffer,
/// Atomic loads and stores
AtomicLoadStore,
/// Atomic add/sub
AtomicAdd,
/// Atomic min/max
AtomicMinMax,
}
/// Supported plane features
#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
pub enum Plane {
/// Basic plane-wide operations
Ops,
/// Plane-wide sync
Sync,
}
/// Shape and element types of a valid MMA configuration
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct MmaConfig {
/// Element of the A matrix
pub a_type: StorageType,
/// Element of the B matrix
pub b_type: StorageType,
/// Element of the C/D matrices
pub cd_type: StorageType,
/// The size of the matrix on the `m` dimension
pub m: u32,
/// The size of the matrix on the `n` dimension
pub n: u32,
/// The size of the matrix on the `k` dimension
pub k: u32,
}
/// Shape and element types of a valid block-scaled MMA configuration
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct ScaledMmaConfig {
/// Element of the A matrix
pub a_type: StorageType,
/// Element of the B matrix
pub b_type: StorageType,
/// Element of the C/D matrices
pub cd_type: StorageType,
/// Element of the blocks scales
pub scales_type: StorageType,
/// The size of the matrix on the `m` dimension
pub m: u32,
/// The size of the matrix on the `n` dimension
pub n: u32,
/// The size of the matrix on the `k` dimension
pub k: u32,
/// Number of scales per tile row/col.
/// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
/// Scales blocks must be organized along the natural `line_layout` of the operation
pub scales_factor: u32,
}
/// Atomic features that may be supported by a [cube runtime](Runtime).
#[derive(Debug, PartialOrd, Ord, EnumSetType)]
pub enum Tma {
/// Base feature set for tensor memory accelerator features. Includes tiling and im2col
Base,
/// im2colWide encoding for tensor map.
Im2colWide,
}
impl Features {
/// Get the usages for a type
pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
self.storage_types
.get(&ty)
.cloned()
.unwrap_or_else(EnumSet::empty)
}
/// Whether the type is supported in any way
pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
match ty.into() {
Type::Scalar(storage_type) | Type::Line(storage_type, _) => {
self.storage_types.contains_key(&storage_type)
}
Type::Semantic(semantic_type) => self.semantic_types.contains(&semantic_type),
}
}
}
impl TypeUsage {
/// All uses except atomics
pub fn all_scalar() -> EnumSet<TypeUsage> {
TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct | TypeUsage::Buffer
}
/// All atomic uses
pub fn all_atomic() -> EnumSet<TypeUsage> {
TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore | TypeUsage::AtomicMinMax
}
}