cubecl_runtime/
features.rs

1use alloc::collections::{BTreeMap, BTreeSet};
2
3use cubecl_ir::{SemanticType, StorageType, Type};
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default)]
10pub struct Features {
11    /// Plane features supported by this runtime.
12    pub plane: EnumSet<Plane>,
13    /// Clustered launches and intra-cluster operations like cluster shared memory
14    pub cube_cluster: bool,
15    /// Enables to change the line size of containers during kernel execution.
16    pub dynamic_line_size: bool,
17
18    /// Types supported by this runtime, and which usages they support.
19    pub storage_types: BTreeMap<StorageType, EnumSet<TypeUsage>>,
20    /// Semantic constructs supported by this runtime.
21    pub semantic_types: BTreeSet<SemanticType>,
22
23    /// Tensor Memory Accelerator supported features
24    pub tma: EnumSet<Tma>,
25    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
26    pub cmma: BTreeSet<MmaConfig>,
27    /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
28    /// movement
29    pub mma: BTreeSet<MmaConfig>,
30    /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
31    /// instruction. Scales must fit a specific layout and block size.
32    pub scaled_mma: BTreeSet<ScaledMmaConfig>,
33}
34
35/// Operations allowed for this type. CMMA is defined separately.
36#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
37pub enum TypeUsage {
38    /// Conversion to/from the type. All types should support this.
39    Conversion,
40    /// All math/logic instructions except dot product
41    Arithmetic,
42    /// Dot product, mainly for BF16 on Intel
43    DotProduct,
44    /// Whether this type can be stored in a buffer
45    Buffer,
46    /// Atomic loads and stores
47    AtomicLoadStore,
48    /// Atomic add/sub
49    AtomicAdd,
50    /// Atomic min/max
51    AtomicMinMax,
52}
53
54/// Supported plane features
55#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
56pub enum Plane {
57    /// Basic plane-wide operations
58    Ops,
59    /// Plane-wide sync
60    Sync,
61}
62
63/// Shape and element types of a valid MMA configuration
64#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
65pub struct MmaConfig {
66    /// Element of the A matrix
67    pub a_type: StorageType,
68    /// Element of the B matrix
69    pub b_type: StorageType,
70    /// Element of the C/D matrices
71    pub cd_type: StorageType,
72    /// The size of the matrix on the `m` dimension
73    pub m: u32,
74    /// The size of the matrix on the `n` dimension
75    pub n: u32,
76    /// The size of the matrix on the `k` dimension
77    pub k: u32,
78}
79
80/// Shape and element types of a valid block-scaled MMA configuration
81#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
82pub struct ScaledMmaConfig {
83    /// Element of the A matrix
84    pub a_type: StorageType,
85    /// Element of the B matrix
86    pub b_type: StorageType,
87    /// Element of the C/D matrices
88    pub cd_type: StorageType,
89    /// Element of the blocks scales
90    pub scales_type: StorageType,
91    /// The size of the matrix on the `m` dimension
92    pub m: u32,
93    /// The size of the matrix on the `n` dimension
94    pub n: u32,
95    /// The size of the matrix on the `k` dimension
96    pub k: u32,
97    /// Number of scales per tile row/col.
98    /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
99    /// Scales blocks must be organized along the natural `line_layout` of the operation
100    pub scales_factor: u32,
101}
102
103/// Atomic features that may be supported by a [cube runtime](Runtime).
104#[derive(Debug, PartialOrd, Ord, EnumSetType)]
105pub enum Tma {
106    /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
107    Base,
108    /// im2colWide encoding for tensor map.
109    Im2colWide,
110}
111
112impl Features {
113    /// Get the usages for a type
114    pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
115        self.storage_types
116            .get(&ty)
117            .cloned()
118            .unwrap_or_else(EnumSet::empty)
119    }
120
121    /// Whether the type is supported in any way
122    pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
123        match ty.into() {
124            Type::Scalar(storage_type) | Type::Line(storage_type, _) => {
125                self.storage_types.contains_key(&storage_type)
126            }
127            Type::Semantic(semantic_type) => self.semantic_types.contains(&semantic_type),
128        }
129    }
130}
131
132impl TypeUsage {
133    /// All uses except atomics
134    pub fn all_scalar() -> EnumSet<TypeUsage> {
135        TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct | TypeUsage::Buffer
136    }
137
138    /// All atomic uses
139    pub fn all_atomic() -> EnumSet<TypeUsage> {
140        TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore | TypeUsage::AtomicMinMax
141    }
142}