Skip to main content

cubecl_ir/
features.rs

1use crate::{AddressType, SemanticType, StorageType, Type};
2use alloc::collections::{BTreeMap, BTreeSet};
3
4use enumset::EnumSetType;
5
6pub use enumset::EnumSet;
7
8/// Features supported by a runtime
9#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
10pub struct Features {
11    /// Plane features supported by this runtime.
12    pub plane: EnumSet<Plane>,
13    /// Clustered launches and intra-cluster operations like cluster shared memory
14    pub cube_cluster: bool,
15    /// Enables changing the type of containers during kernel execution.
16    pub memory_reinterpret: bool,
17    /// Enables explicit alignment. If false, alignment still compiles, but isn't actually applied.
18    pub alignment: bool,
19
20    /// Type support
21    pub types: Types,
22    /// Matrix multiplication features
23    pub matmul: MatmulFeatures,
24
25    /// Whether `copy_async` is supported
26    pub copy_async: bool,
27    /// Tensor Memory Accelerator supported features
28    pub tma: EnumSet<Tma>,
29    /// Whether vectors can be read from / stored to addresses not aligned
30    /// with the `vector_size`
31    pub unaligned_io: bool,
32}
33
34/// Type support for a device
35#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
36pub struct Types {
37    /// Valid address types
38    pub address: BTreeSet<AddressType>,
39    /// Types supported by this runtime, and which usages they support.
40    pub storage: BTreeMap<StorageType, EnumSet<TypeUsage>>,
41    /// Semantic constructs supported by this runtime.
42    pub semantic: BTreeSet<SemanticType>,
43    /// Supported vector types for atomic ops, only specific vectorizations for specific types are
44    /// supported here. Not all vector types are supported as scalars, i.e. Vulkan on Nvidia only
45    /// supports vectorized `f16`, not scalar. Only use the exact vectorizations registered here.
46    /// These may not be supported everywhere - in practice, f32 vectors are only supported in global
47    /// memory.
48    pub atomic: BTreeMap<Type, EnumSet<AtomicUsage>>,
49}
50
51/// Matrix multiplication-related features
52#[derive(Debug, Clone, PartialEq, Eq, Default, Hash)]
53pub struct MatmulFeatures {
54    /// The cmma feature enables cooperative matrix-multiply and accumulate operations.
55    pub cmma: BTreeSet<MmaConfig>,
56    /// The manual MMA feature enables cooperative matrix-multiply with manually managed data
57    /// movement
58    pub mma: BTreeSet<MmaConfig>,
59    /// Scaled MMA allows combining matrix multiplication with unscaling quantized values into a single
60    /// instruction. Scales must fit a specific layout and block size.
61    pub scaled_mma: BTreeSet<ScaledMmaConfig>,
62    /// Types supported for ldmatrix, if any
63    pub ldmatrix: BTreeSet<StorageType>,
64    /// Types supported by stmatrix, if any
65    pub stmatrix: BTreeSet<StorageType>,
66}
67
68/// Operations allowed for this type. CMMA is defined separately.
69#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
70pub enum TypeUsage {
71    /// Conversion to/from the type. All types should support this.
72    Conversion,
73    /// All math/logic instructions except dot product
74    Arithmetic,
75    /// Dot product, mainly for BF16 on Intel
76    DotProduct,
77    /// Whether this type can be stored in a buffer
78    Buffer,
79}
80
81impl TypeUsage {
82    pub fn all() -> EnumSet<Self> {
83        EnumSet::all()
84    }
85
86    pub fn no_store() -> EnumSet<Self> {
87        TypeUsage::Conversion | TypeUsage::Arithmetic
88    }
89
90    pub fn maybe_store(storable: bool) -> EnumSet<Self> {
91        if storable {
92            EnumSet::all()
93        } else {
94            Self::no_store()
95        }
96    }
97}
98
99/// Atomic operations allowed for this type.
100#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
101pub enum AtomicUsage {
102    /// Atomic loads and stores
103    LoadStore,
104    /// Atomic add/sub
105    Add,
106    /// Atomic min/max
107    MinMax,
108}
109
110impl AtomicUsage {
111    pub fn all() -> EnumSet<Self> {
112        EnumSet::all()
113    }
114}
115
116/// Supported plane features
117#[derive(Debug, Hash, PartialOrd, Ord, EnumSetType)]
118pub enum Plane {
119    /// Basic plane-wide operations
120    Ops,
121    /// Plane-wide sync
122    Sync,
123    /// Allows using plane operations with divergent control flow.
124    NonUniformControlFlow,
125}
126
127/// Shape and element types of a valid MMA configuration
128#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
129#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
130pub struct MmaConfig {
131    /// Element of the A matrix
132    pub a_type: StorageType,
133    /// Element of the B matrix
134    pub b_type: StorageType,
135    /// Element of the C/D matrices
136    pub cd_type: StorageType,
137    /// The size of the matrix on the `m` dimension
138    pub m: u32,
139    /// The size of the matrix on the `n` dimension
140    pub n: u32,
141    /// The size of the matrix on the `k` dimension
142    pub k: u32,
143}
144
145/// Shape and element types of a valid block-scaled MMA configuration
146#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
147#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
148pub struct ScaledMmaConfig {
149    /// Element of the A matrix
150    pub a_type: StorageType,
151    /// Element of the B matrix
152    pub b_type: StorageType,
153    /// Element of the C/D matrices
154    pub cd_type: StorageType,
155    /// Element of the blocks scales
156    pub scales_type: StorageType,
157    /// The size of the matrix on the `m` dimension
158    pub m: u32,
159    /// The size of the matrix on the `n` dimension
160    pub n: u32,
161    /// The size of the matrix on the `k` dimension
162    pub k: u32,
163    /// Number of scales per tile row/col.
164    /// A scale factor of 2 means `m x 2` scales for A and `2 x n` for B (in CUDA)
165    /// Scales blocks must be organized along the natural `vector_layout` of the operation
166    pub scales_factor: u32,
167}
168
169/// Atomic features that may be supported by a ``Runtime``.
170#[derive(Debug, PartialOrd, Ord, EnumSetType)]
171pub enum Tma {
172    /// Base feature set for tensor memory accelerator features. Includes tiling and im2col
173    Base,
174    /// im2colWide encoding for tensor map.
175    Im2colWide,
176    /// Different atomicities for 128-byte swizzle, i.e. 128-byte with 32-byte atomicity.
177    SwizzleAtomicity,
178}
179
180impl Features {
181    /// Get the usages for a type
182    pub fn type_usage(&self, ty: StorageType) -> EnumSet<TypeUsage> {
183        self.types
184            .storage
185            .get(&ty)
186            .cloned()
187            .unwrap_or_else(EnumSet::empty)
188    }
189
190    /// Get the usages for an atomic type
191    pub fn atomic_type_usage(&self, ty: Type) -> EnumSet<AtomicUsage> {
192        self.types
193            .atomic
194            .get(&ty)
195            .cloned()
196            .unwrap_or_else(EnumSet::empty)
197    }
198
199    /// Whether the type is supported in any way
200    pub fn supports_type(&self, ty: impl Into<Type>) -> bool {
201        match ty.into() {
202            Type::Scalar(storage_type) | Type::Vector(storage_type, _) => {
203                self.types.storage.contains_key(&storage_type)
204            }
205            Type::Semantic(semantic_type) => self.types.semantic.contains(&semantic_type),
206        }
207    }
208
209    /// Whether the address type is supported in any way
210    pub fn supports_address(&self, ty: impl Into<AddressType>) -> bool {
211        self.types.address.contains(&ty.into())
212    }
213}