cubecl_ir/
features.rs

1use crate::{AddressType, SemanticType, StorageType, Type};
2use alloc::collections::{BTreeMap, BTreeSet};
3
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default)]
10pub struct Features {
11    /// Plane features supported by this runtime.
12    pub plane: EnumSet<Plane>,
13    /// Clustered launches and intra-cluster operations like cluster shared memory
14    pub cube_cluster: bool,
15    /// Enables to change the line size of containers during kernel execution.
16    pub dynamic_line_size: bool,
17    /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18    pub alignment: bool,
19    /// Valid address types
20    pub address_types: BTreeSet<AddressType>,
21
22    /// Types supported by this runtime, and which usages they support.
23    pub storage_types: BTreeMap<StorageType, EnumSet<TypeUsage>>,
24    /// Semantic constructs supported by this runtime.
25    pub semantic_types: BTreeSet<SemanticType>,
26
27    /// Whether `copy_async` is supported
28    pub copy_async: bool,
29    /// Tensor Memory Accelerator supported features
30    pub tma: EnumSet<Tma>,
31    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
32    pub cmma: BTreeSet<MmaConfig>,
33    /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
34    /// movement
35    pub mma: BTreeSet<MmaConfig>,
36    /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
37    /// instruction. Scales must fit a specific layout and block size.
38    pub scaled_mma: BTreeSet<ScaledMmaConfig>,
39    /// Types supported for ldmatrix, if any
40    pub ldmatrix: BTreeSet<StorageType>,
41    /// Types supported by stmatrix, if any
42    pub stmatrix: BTreeSet<StorageType>,
43    /// Whether Lines can be read from / stored to addresses not aligned
44    /// with the line_size
45    pub unaligned_io: bool,
46}
47
48/// Operations allowed for this type. CMMA is defined separately.
49#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
50pub enum TypeUsage {
51    /// Conversion to/from the type. All types should support this.
52    Conversion,
53    /// All math/logic instructions except dot product
54    Arithmetic,
55    /// Dot product, mainly for BF16 on Intel
56    DotProduct,
57    /// Whether this type can be stored in a buffer
58    Buffer,
59    /// Atomic loads and stores
60    AtomicLoadStore,
61    /// Atomic add/sub
62    AtomicAdd,
63    /// Atomic min/max
64    AtomicMinMax,
65}
66
67/// Supported plane features
68#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
69pub enum Plane {
70    /// Basic plane-wide operations
71    Ops,
72    /// Plane-wide sync
73    Sync,
74    /// Allows using plane operations with divergent control flow.
75    NonUniformControlFlow,
76}
77
78/// Shape and element types of a valid MMA configuration
79#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
80#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
81pub struct MmaConfig {
82    /// Element of the A matrix
83    pub a_type: StorageType,
84    /// Element of the B matrix
85    pub b_type: StorageType,
86    /// Element of the C/D matrices
87    pub cd_type: StorageType,
88    /// The size of the matrix on the `m` dimension
89    pub m: u32,
90    /// The size of the matrix on the `n` dimension
91    pub n: u32,
92    /// The size of the matrix on the `k` dimension
93    pub k: u32,
94}
95
96/// Shape and element types of a valid block-scaled MMA configuration
97#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
98#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
99pub struct ScaledMmaConfig {
100    /// Element of the A matrix
101    pub a_type: StorageType,
102    /// Element of the B matrix
103    pub b_type: StorageType,
104    /// Element of the C/D matrices
105    pub cd_type: StorageType,
106    /// Element of the blocks scales
107    pub scales_type: StorageType,
108    /// The size of the matrix on the `m` dimension
109    pub m: u32,
110    /// The size of the matrix on the `n` dimension
111    pub n: u32,
112    /// The size of the matrix on the `k` dimension
113    pub k: u32,
114    /// Number of scales per tile row/col.
115    /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
116    /// Scales blocks must be organized along the natural `line_layout` of the operation
117    pub scales_factor: u32,
118}
119
120/// Atomic features that may be supported by a [cube runtime](Runtime).
121#[derive(Debug, PartialOrd, Ord, EnumSetType)]
122pub enum Tma {
123    /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
124    Base,
125    /// im2colWide encoding for tensor map.
126    Im2colWide,
127    /// Different atomicities for 128-byte swizzle, i.e. 128-byte with 32-byte atomicity.
128    SwizzleAtomicity,
129}
130
131impl Features {
132    /// Get the usages for a type
133    pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
134        self.storage_types
135            .get(&ty)
136            .cloned()
137            .unwrap_or_else(EnumSet::empty)
138    }
139
140    /// Whether the type is supported in any way
141    pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
142        match ty.into() {
143            Type::Scalar(storage_type) | Type::Line(storage_type, _) => {
144                self.storage_types.contains_key(&storage_type)
145            }
146            Type::Semantic(semantic_type) => self.semantic_types.contains(&semantic_type),
147        }
148    }
149
150    /// Whether the address type is supported in any way
151    pub fn supports_address(&self, ty: impl Into<AddressType>) -> bool {
152        self.address_types.contains(&ty.into())
153    }
154}
155
156impl TypeUsage {
157    /// All uses except atomics
158    pub fn all_scalar() -> EnumSet<TypeUsage> {
159        TypeUsage::Conversion | TypeUsage::Arithmetic | TypeUsage::DotProduct | TypeUsage::Buffer
160    }
161
162    /// All atomic uses
163    pub fn all_atomic() -> EnumSet<TypeUsage> {
164        TypeUsage::AtomicAdd | TypeUsage::AtomicLoadStore | TypeUsage::AtomicMinMax
165    }
166}