1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
use crate::{AddressType, SemanticType, StorageType, Type};
use alloc::collections::{BTreeMap, BTreeSet};
use enumset::EnumSetType;
pub use enumset::EnumSet;
/// Features supported by a runtime
#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
pub struct Features {
/// Plane features supported by this runtime.
pub plane: EnumSet<Plane>,
/// Clustered launches and intra-cluster operations like cluster shared memory
pub cube_cluster: bool,
/// Enables changing the type of containers during kernel execution.
pub memory_reinterpret: bool,
/// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
pub alignment: bool,
/// Type support
pub types: Types,
/// Matrix multiplication features
pub matmul: MatmulFeatures,
/// Whether `copy_async` is supported
pub copy_async: bool,
/// Tensor Memory Accelerator supported features
pub tma: EnumSet<Tma>,
/// Whether vectors can be read from / stored to addresses not aligned
/// with the `vector_size`
pub unaligned_io: bool,
}
/// Type support for a device
#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
pub struct Types {
/// Valid address types
pub address: BTreeSet<AddressType>,
/// Types supported by this runtime, and which usages they support.
pub storage: BTreeMap<StorageType, EnumSet<TypeUsage>>,
/// Semantic constructs supported by this runtime.
pub semantic: BTreeSet<SemanticType>,
/// Supported vector types for atomic ops, only specific vectorizations for specific types are
/// supported here. Not all vector types are supported as scalars, i.e. Vulkan on Nvidia only
/// supports vectorized `f16`, not scalar. Only use the exact vectorizations registered here.
/// These may not be supported everywhere - in practice, f32 vectors are only supported in global
/// memory.
pub atomic: BTreeMap<Type, EnumSet<AtomicUsage>>,
}
/// Matrix multiplication-related features
#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
pub struct MatmulFeatures {
/// The cmma feature enables cooperative matrix-multiply and accumulate operations.
pub cmma: BTreeSet<MmaConfig>,
/// The manual MMA feature enables cooperative matrix-multiply with manually managed data
/// movement
pub mma: BTreeSet<MmaConfig>,
/// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
/// instruction. Scales must fit a specific layout and block size.
pub scaled_mma: BTreeSet<ScaledMmaConfig>,
/// Types supported for ldmatrix, if any
pub ldmatrix: BTreeSet<StorageType>,
/// Types supported by stmatrix, if any
pub stmatrix: BTreeSet<StorageType>,
}
/// Operations allowed for this type. CMMA is defined separately.
#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
pub enum TypeUsage {
/// Conversion to/from the type. All types should support this.
Conversion,
/// All math/logic instructions except dot product
Arithmetic,
/// Dot product, mainly for BF16 on Intel
DotProduct,
/// Whether this type can be stored in a buffer
Buffer,
}
impl TypeUsage {
pub fn all() -> EnumSet<Self> {
EnumSet::all()
}
pub fn no_store() -> EnumSet<Self> {
TypeUsage::Conversion | TypeUsage::Arithmetic
}
pub fn maybe_store(storable: bool) -> EnumSet<Self> {
if storable {
EnumSet::all()
} else {
Self::no_store()
}
}
}
/// Atomic operations allowed for this type.
#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
pub enum AtomicUsage {
/// Atomic loads and stores
LoadStore,
/// Atomic add/sub
Add,
/// Atomic min/max
MinMax,
}
impl AtomicUsage {
pub fn all() -> EnumSet<Self> {
EnumSet::all()
}
}
/// Supported plane features
#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
pub enum Plane {
/// Basic plane-wide operations
Ops,
/// Plane-wide sync
Sync,
/// Allows using plane operations with divergent control flow.
NonUniformControlFlow,
}
/// Shape and element types of a valid MMA configuration
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MmaConfig {
/// Element of the A matrix
pub a_type: StorageType,
/// Element of the B matrix
pub b_type: StorageType,
/// Element of the C/D matrices
pub cd_type: StorageType,
/// The size of the matrix on the `m` dimension
pub m: u32,
/// The size of the matrix on the `n` dimension
pub n: u32,
/// The size of the matrix on the `k` dimension
pub k: u32,
}
/// Shape and element types of a valid block-scaled MMA configuration
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ScaledMmaConfig {
/// Element of the A matrix
pub a_type: StorageType,
/// Element of the B matrix
pub b_type: StorageType,
/// Element of the C/D matrices
pub cd_type: StorageType,
/// Element of the blocks scales
pub scales_type: StorageType,
/// The size of the matrix on the `m` dimension
pub m: u32,
/// The size of the matrix on the `n` dimension
pub n: u32,
/// The size of the matrix on the `k` dimension
pub k: u32,
/// Number of scales per tile row/col.
/// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
/// Scales blocks must be organized along the natural `vector_layout` of the operation
pub scales_factor: u32,
}
/// Atomic features that may be supported by a ``Runtime``.
#[derive(Debug, PartialOrd, Ord, EnumSetType)]
pub enum Tma {
/// Base feature set for tensor memory accelerator features. Includes tiling and im2col
Base,
/// im2colWide encoding for tensor map.
Im2colWide,
/// Different atomicities for 128-byte swizzle, i.e. 128-byte with 32-byte atomicity.
SwizzleAtomicity,
}
impl Features {
/// Get the usages for a type
pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
self.types
.storage
.get(&ty)
.cloned()
.unwrap_or_else(EnumSet::empty)
}
/// Get the usages for an atomic type
pub fn atomic_type_usage(&self, ty: Type) -> EnumSet<AtomicUsage> {
self.types
.atomic
.get(&ty)
.cloned()
.unwrap_or_else(EnumSet::empty)
}
/// Whether the type is supported in any way
pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
match ty.into() {
Type::Scalar(storage_type) | Type::Vector(storage_type, _) => {
self.types.storage.contains_key(&storage_type)
}
Type::Semantic(semantic_type) => self.types.semantic.contains(&semantic_type),
}
}
/// Whether the address type is supported in any way
pub fn supports_address(&self, ty: impl Into<AddressType>) -> bool {
self.types.address.contains(&ty.into())
}
}