Skip to main content

cubecl_ir/
features.rs

1use crate::{AddressType, SemanticType, StorageType, Type};
2use alloc::collections::{BTreeMap, BTreeSet};
3
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
10pub struct Features {
11    /// Plane features supported by this runtime.
12    pub plane: EnumSet<Plane>,
13    /// Clustered launches and intra-cluster operations like cluster shared memory
14    pub cube_cluster: bool,
15    /// Enables changing the type of containers during kernel execution.
16    pub memory_reinterpret: bool,
17    /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18    pub alignment: bool,
19
20    /// Type support
21    pub types: Types,
22    /// Matrix multiplication features
23    pub matmul: MatmulFeatures,
24
25    /// Whether `copy_async` is supported
26    pub copy_async: bool,
27    /// Tensor Memory Accelerator supported features
28    pub tma: EnumSet<Tma>,
29    /// Whether vectors can be read from / stored to addresses not aligned
30    /// with the `vector_size`
31    pub unaligned_io: bool,
32}
33
34/// Type support for a device
35#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
36pub struct Types {
37    /// Valid address types
38    pub address: BTreeSet<AddressType>,
39    /// Types supported by this runtime, and which usages they support.
40    pub storage: BTreeMap<StorageType, EnumSet<TypeUsage>>,
41    /// Complex-specific capability families supported by this runtime.
42    pub complex: BTreeMap<StorageType, EnumSet<ComplexUsage>>,
43    /// Semantic constructs supported by this runtime.
44    pub semantic: BTreeSet<SemanticType>,
45    /// Supported vector types for atomic ops, only specific vectorizations for specific types are
46    /// supported here. Not all vector types are supported as scalars, i.e. Vulkan on Nvidia only
47    /// supports vectorized `f16`, not scalar. Only use the exact vectorizations registered here.
48    /// These may not be supported everywhere - in practice, f32 vectors are only supported in global
49    /// memory.
50    pub atomic: BTreeMap<Type, EnumSet<AtomicUsage>>,
51}
52
53/// Matrix multiplication-related features
54#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
55pub struct MatmulFeatures {
56    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
57    pub cmma: BTreeSet<MmaConfig>,
58    /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
59    /// movement
60    pub mma: BTreeSet<MmaConfig>,
61    /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
62    /// instruction. Scales must fit a specific layout and block size.
63    pub scaled_mma: BTreeSet<ScaledMmaConfig>,
64    /// Types supported for ldmatrix, if any
65    pub ldmatrix: BTreeSet<StorageType>,
66    /// Types supported by stmatrix, if any
67    pub stmatrix: BTreeSet<StorageType>,
68}
69
70/// Operations allowed for this type. CMMA is defined separately.
71#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
72pub enum TypeUsage {
73    /// Conversion to/from the type. All types should support this.
74    Conversion,
75    /// All math/logic instructions except dot product
76    Arithmetic,
77    /// Dot product, mainly for BF16 on Intel
78    DotProduct,
79    /// Whether this type can be stored in a buffer
80    Buffer,
81}
82
83/// Complex capability families allowed for a complex storage type.
84#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
85pub enum ComplexUsage {
86    /// Core ML-centric complex functionality: arithmetic, negation, conjugation, real/imag.
87    Core,
88    /// Equality and inequality comparisons.
89    Compare,
90    /// Higher-level math functions such as exp/log/sin/cos/sqrt/tanh/powf and abs.
91    Math,
92}
93
94impl TypeUsage {
95    pub fn all() -> EnumSet<Self> {
96        EnumSet::all()
97    }
98
99    pub fn no_store() -> EnumSet<Self> {
100        TypeUsage::Conversion | TypeUsage::Arithmetic
101    }
102
103    pub fn maybe_store(storable: bool) -> EnumSet<Self> {
104        if storable {
105            EnumSet::all()
106        } else {
107            Self::no_store()
108        }
109    }
110}
111
112/// Atomic operations allowed for this type.
113#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
114pub enum AtomicUsage {
115    /// Atomic loads and stores
116    LoadStore,
117    /// Atomic add/sub
118    Add,
119    /// Atomic min/max
120    MinMax,
121}
122
123impl AtomicUsage {
124    pub fn all() -> EnumSet<Self> {
125        EnumSet::all()
126    }
127}
128
129/// Supported plane features
130#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
131pub enum Plane {
132    /// Basic plane-wide operations
133    Ops,
134    /// Plane-wide sync
135    Sync,
136    /// Allows using plane operations with divergent control flow.
137    NonUniformControlFlow,
138}
139
140/// Shape and element types of a valid MMA configuration
141#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
142#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
143pub struct MmaConfig {
144    /// Element of the A matrix
145    pub a_type: StorageType,
146    /// Element of the B matrix
147    pub b_type: StorageType,
148    /// Element of the C/D matrices
149    pub cd_type: StorageType,
150    /// The size of the matrix on the `m` dimension
151    pub m: u32,
152    /// The size of the matrix on the `n` dimension
153    pub n: u32,
154    /// The size of the matrix on the `k` dimension
155    pub k: u32,
156}
157
158/// Shape and element types of a valid block-scaled MMA configuration
159#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
160#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
161pub struct ScaledMmaConfig {
162    /// Element of the A matrix
163    pub a_type: StorageType,
164    /// Element of the B matrix
165    pub b_type: StorageType,
166    /// Element of the C/D matrices
167    pub cd_type: StorageType,
168    /// Element of the blocks scales
169    pub scales_type: StorageType,
170    /// The size of the matrix on the `m` dimension
171    pub m: u32,
172    /// The size of the matrix on the `n` dimension
173    pub n: u32,
174    /// The size of the matrix on the `k` dimension
175    pub k: u32,
176    /// Number of scales per tile row/col.
177    /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
178    /// Scales blocks must be organized along the natural `vector_layout` of the operation
179    pub scales_factor: u32,
180}
181
182/// Atomic features that may be supported by a ``Runtime``.
183#[derive(Debug, PartialOrd, Ord, EnumSetType)]
184pub enum Tma {
185    /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
186    Base,
187    /// im2colWide encoding for tensor map.
188    Im2colWide,
189    /// Different atomicities for 128-byte swizzle, i.e. 128-byte with 32-byte atomicity.
190    SwizzleAtomicity,
191}
192
193impl Features {
194    /// Get the usages for a type
195    pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
196        self.types
197            .storage
198            .get(&ty)
199            .cloned()
200            .unwrap_or_else(EnumSet::empty)
201    }
202
203    /// Get the complex capability families for a type.
204    pub fn complex_usage(&self, ty: StorageType) -> EnumSet<ComplexUsage> {
205        self.types
206            .complex
207            .get(&ty)
208            .cloned()
209            .unwrap_or_else(EnumSet::empty)
210    }
211
212    /// Get the usages for an atomic type
213    pub fn atomic_type_usage(&self, ty: Type) -> EnumSet<AtomicUsage> {
214        self.types
215            .atomic
216            .get(&ty)
217            .cloned()
218            .unwrap_or_else(EnumSet::empty)
219    }
220
221    /// Whether the type is supported in any way
222    pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
223        match ty.into() {
224            Type::Scalar(storage_type) | Type::Vector(storage_type, _) => {
225                self.types.storage.contains_key(&storage_type)
226            }
227            Type::Semantic(semantic_type) => self.types.semantic.contains(&semantic_type),
228        }
229    }
230
231    /// Whether the address type is supported in any way
232    pub fn supports_address(&self, ty: impl Into<AddressType>) -> bool {
233        self.types.address.contains(&ty.into())
234    }
235
236    /// Whether a complex storage type supports the requested capability family.
237    pub fn supports_complex_usage(&self, ty: impl Into<StorageType>, usage: ComplexUsage) -> bool {
238        self.complex_usage(ty.into()).contains(usage)
239    }
240}