cubecl_runtime/
features.rs

1use alloc::collections::{BTreeMap, BTreeSet};
2
3use cubecl_ir::{SemanticType, StorageType, Type};
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default)]
10pub struct Features {
11    /// Plane features supported by this runtime.
12    pub plane: EnumSet<Plane>,
13    /// Clustered launches and intra-cluster operations like cluster shared memory
14    pub cube_cluster: bool,
15    /// Enables to change the line size of containers during kernel execution.
16    pub dynamic_line_size: bool,
17    /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18    pub alignment: bool,
19
20    /// Types supported by this runtime, and which usages they support.
21    pub storage_types: BTreeMap<StorageType, EnumSet<TypeUsage>>,
22    /// Semantic constructs supported by this runtime.
23    pub semantic_types: BTreeSet<SemanticType>,
24
25    /// Whether `copy_async` is supported
26    pub copy_async: bool,
27    /// Tensor Memory Accelerator supported features
28    pub tma: EnumSet<Tma>,
29    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
30    pub cmma: BTreeSet<MmaConfig>,
31    /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
32    /// movement
33    pub mma: BTreeSet<MmaConfig>,
34    /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
35    /// instruction. Scales must fit a specific layout and block size.
36    pub scaled_mma: BTreeSet<ScaledMmaConfig>,
37    /// Types supported for ldmatrix, if any
38    pub ldmatrix: BTreeSet<StorageType>,
39    /// Types supported by stmatrix, if any
40    pub stmatrix: BTreeSet<StorageType>,
41    /// Whether Lines can be read from / stored to addresses not aligned
42    /// with the line_size
43    pub unaligned_io: bool,
44}
45
46/// Operations allowed for this type. CMMA is defined separately.
47#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
48pub enum TypeUsage {
49    /// Conversion to/from the type. All types should support this.
50    Conversion,
51    /// All math/logic instructions except dot product
52    Arithmetic,
53    /// Dot product, mainly for BF16 on Intel
54    DotProduct,
55    /// Whether this type can be stored in a buffer
56    Buffer,
57    /// Atomic loads and stores
58    AtomicLoadStore,
59    /// Atomic add/sub
60    AtomicAdd,
61    /// Atomic min/max
62    AtomicMinMax,
63}
64
65/// Supported plane features
66#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
67pub enum Plane {
68    /// Basic plane-wide operations
69    Ops,
70    /// Plane-wide sync
71    Sync,
72    /// Allows using plane operations with divergent control flow.
73    NonUniformControlFlow,
74}
75
76/// Shape and element types of a valid MMA configuration
77#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
78pub struct MmaConfig {
79    /// Element of the A matrix
80    pub a_type: StorageType,
81    /// Element of the B matrix
82    pub b_type: StorageType,
83    /// Element of the C/D matrices
84    pub cd_type: StorageType,
85    /// The size of the matrix on the `m` dimension
86    pub m: u32,
87    /// The size of the matrix on the `n` dimension
88    pub n: u32,
89    /// The size of the matrix on the `k` dimension
90    pub k: u32,
91}
92
93/// Shape and element types of a valid block-scaled MMA configuration
94#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
95pub struct ScaledMmaConfig {
96    /// Element of the A matrix
97    pub a_type: StorageType,
98    /// Element of the B matrix
99    pub b_type: StorageType,
100    /// Element of the C/D matrices
101    pub cd_type: StorageType,
102    /// Element of the blocks scales
103    pub scales_type: StorageType,
104    /// The size of the matrix on the `m` dimension
105    pub m: u32,
106    /// The size of the matrix on the `n` dimension
107    pub n: u32,
108    /// The size of the matrix on the `k` dimension
109    pub k: u32,
110    /// Number of scales per tile row/col.
111    /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
112    /// Scales blocks must be organized along the natural `line_layout` of the operation
113    pub scales_factor: u32,
114}
115
116/// Atomic features that may be supported by a [cube runtime](Runtime).
117#[derive(Debug, PartialOrd, Ord, EnumSetType)]
118pub enum Tma {
119    /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
120    Base,
121    /// im2colWide encoding for tensor map.
122    Im2colWide,
123    /// Different atomicities for 128-byte swizzle, i.e. 128-byte with 32-byte atomicity.
124    SwizzleAtomicity,
125}
126
127impl Features {
128    /// Get the usages for a type
129    pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
130        self.storage_types
131            .get(&ty)
132            .cloned()
133            .unwrap_or_else(EnumSet::empty)
134    }
135
136    /// Whether the type is supported in any way
137    pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
138        match ty.into() {
139            Type::Scalar(storage_type) | Type::Line(storage_type, _) => {
140                self.storage_types.contains_key(&storage_type)
141            }
142            Type::Semantic(semantic_type) => self.semantic_types.contains(&semantic_type),
143        }
144    }
145}
146
147impl TypeUsage {
148    /// All uses except atomics
149    pub fn all_scalar() -> EnumSet<TypeUsage> {
150        TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct | TypeUsage::Buffer
151    }
152
153    /// All atomic uses
154    pub fn all_atomic() -> EnumSet<TypeUsage> {
155        TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore | TypeUsage::AtomicMinMax
156    }
157}