cubecl_runtime/
features.rs

1use alloc::collections::{BTreeMap, BTreeSet};
2
3use cubecl_ir::{SemanticType, StorageType, Type};
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default)]
10pub struct Features {
11    /// Plane features supported by this runtime.
12    pub plane: EnumSet<Plane>,
13    /// Clustered launches and intra-cluster operations like cluster shared memory
14    pub cube_cluster: bool,
15    /// Enables to change the line size of containers during kernel execution.
16    pub dynamic_line_size: bool,
17    /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18    pub alignment: bool,
19
20    /// Types supported by this runtime, and which usages they support.
21    pub storage_types: BTreeMap<StorageType, EnumSet<TypeUsage>>,
22    /// Semantic constructs supported by this runtime.
23    pub semantic_types: BTreeSet<SemanticType>,
24
25    /// Tensor Memory Accelerator supported features
26    pub tma: EnumSet<Tma>,
27    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
28    pub cmma: BTreeSet<MmaConfig>,
29    /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
30    /// movement
31    pub mma: BTreeSet<MmaConfig>,
32    /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
33    /// instruction. Scales must fit a specific layout and block size.
34    pub scaled_mma: BTreeSet<ScaledMmaConfig>,
35    /// Types supported for ldmatrix, if any
36    pub ldmatrix: BTreeSet<StorageType>,
37    /// Types supported by stmatrix, if any
38    pub stmatrix: BTreeSet<StorageType>,
39}
40
41/// Operations allowed for this type. CMMA is defined separately.
42#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
43pub enum TypeUsage {
44    /// Conversion to/from the type. All types should support this.
45    Conversion,
46    /// All math/logic instructions except dot product
47    Arithmetic,
48    /// Dot product, mainly for BF16 on Intel
49    DotProduct,
50    /// Whether this type can be stored in a buffer
51    Buffer,
52    /// Atomic loads and stores
53    AtomicLoadStore,
54    /// Atomic add/sub
55    AtomicAdd,
56    /// Atomic min/max
57    AtomicMinMax,
58}
59
60/// Supported plane features
61#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
62pub enum Plane {
63    /// Basic plane-wide operations
64    Ops,
65    /// Plane-wide sync
66    Sync,
67}
68
69/// Shape and element types of a valid MMA configuration
70#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
71pub struct MmaConfig {
72    /// Element of the A matrix
73    pub a_type: StorageType,
74    /// Element of the B matrix
75    pub b_type: StorageType,
76    /// Element of the C/D matrices
77    pub cd_type: StorageType,
78    /// The size of the matrix on the `m` dimension
79    pub m: u32,
80    /// The size of the matrix on the `n` dimension
81    pub n: u32,
82    /// The size of the matrix on the `k` dimension
83    pub k: u32,
84}
85
86/// Shape and element types of a valid block-scaled MMA configuration
87#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
88pub struct ScaledMmaConfig {
89    /// Element of the A matrix
90    pub a_type: StorageType,
91    /// Element of the B matrix
92    pub b_type: StorageType,
93    /// Element of the C/D matrices
94    pub cd_type: StorageType,
95    /// Element of the blocks scales
96    pub scales_type: StorageType,
97    /// The size of the matrix on the `m` dimension
98    pub m: u32,
99    /// The size of the matrix on the `n` dimension
100    pub n: u32,
101    /// The size of the matrix on the `k` dimension
102    pub k: u32,
103    /// Number of scales per tile row/col.
104    /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
105    /// Scales blocks must be organized along the natural `line_layout` of the operation
106    pub scales_factor: u32,
107}
108
109/// Atomic features that may be supported by a [cube runtime](Runtime).
110#[derive(Debug, PartialOrd, Ord, EnumSetType)]
111pub enum Tma {
112    /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
113    Base,
114    /// im2colWide encoding for tensor map.
115    Im2colWide,
116}
117
118impl Features {
119    /// Get the usages for a type
120    pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
121        self.storage_types
122            .get(&ty)
123            .cloned()
124            .unwrap_or_else(EnumSet::empty)
125    }
126
127    /// Whether the type is supported in any way
128    pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
129        match ty.into() {
130            Type::Scalar(storage_type) | Type::Line(storage_type, _) => {
131                self.storage_types.contains_key(&storage_type)
132            }
133            Type::Semantic(semantic_type) => self.semantic_types.contains(&semantic_type),
134        }
135    }
136}
137
138impl TypeUsage {
139    /// All uses except atomics
140    pub fn all_scalar() -> EnumSet<TypeUsage> {
141        TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct | TypeUsage::Buffer
142    }
143
144    /// All atomic uses
145    pub fn all_atomic() -> EnumSet<TypeUsage> {
146        TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore | TypeUsage::AtomicMinMax
147    }
148}