cubecl_cpp/cuda/mma/
cuda_compiler.rs

1use super::{WMMA_MINIMUM_VERSION, WMMA_NAMESPACE};
2use crate::{
3    cuda::{
4        CudaDialect,
5        arch::CudaArchitecture,
6        mma::{
7            compile_manual_mma, compile_scaled_mma, supported_mma_combinations,
8            supported_scaled_mma_combinations,
9        },
10    },
11    shared::{
12        Architecture, DialectWmmaCompiler, Flags, Fragment, FragmentIdent, FragmentLayout,
13        ManualMma, SupportedMmaCombinations, SupportedScaledMmaCombinations, Variable,
14        WmmaInstruction, wmma_api_base,
15    },
16};
17use cubecl_core::ir::{self as gpu, features::MmaConfig};
18use itertools::Itertools;
19
20#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
21pub struct CudaWmmaCompiler {}
22
23impl DialectWmmaCompiler<CudaDialect<Self>> for CudaWmmaCompiler {
24    fn compile_wmma_includes(
25        f: &mut std::fmt::Formatter<'_>,
26        _flags: &Flags<CudaDialect<Self>>,
27    ) -> std::fmt::Result {
28        f.write_str("#include <mma.h>\n")
29    }
30
31    fn compile_wmma_fragment_declaration(
32        f: &mut std::fmt::Formatter<'_>,
33        var: &crate::shared::Variable<CudaDialect<Self>>,
34    ) -> std::fmt::Result {
35        wmma_api_base::compile_fragment_declaration(f, var)
36    }
37
38    fn compile_wwma_fragment_ident(
39        f: &mut std::fmt::Formatter<'_>,
40        ident: &FragmentIdent<CudaDialect<Self>>,
41    ) -> std::fmt::Result {
42        wmma_api_base::compile_fragment_ident(f, WMMA_NAMESPACE, ident)
43    }
44
45    fn compile_wmma_fragment_layout(
46        f: &mut std::fmt::Formatter<'_>,
47        layout: &FragmentLayout<CudaDialect<Self>>,
48    ) -> std::fmt::Result {
49        wmma_api_base::compile_fragment_layout(f, WMMA_NAMESPACE, layout)
50    }
51
52    fn compile_wmma_fragment(
53        f: &mut std::fmt::Formatter<'_>,
54        fragment: &Fragment<CudaDialect<Self>>,
55    ) -> std::fmt::Result {
56        wmma_api_base::compile_fragment(f, WMMA_NAMESPACE, fragment)
57    }
58
59    fn compile_wmma_instruction(
60        f: &mut std::fmt::Formatter<'_>,
61        instruction: &WmmaInstruction<CudaDialect<Self>>,
62    ) -> std::fmt::Result {
63        wmma_api_base::compile_instruction(f, WMMA_NAMESPACE, instruction)
64    }
65
66    fn compile_manual_mma(
67        f: &mut std::fmt::Formatter<'_>,
68        mma: ManualMma<CudaDialect<Self>>,
69    ) -> std::fmt::Result {
70        compile_manual_mma(f, mma)
71    }
72
73    fn compile_scaled_mma(
74        f: &mut std::fmt::Formatter<'_>,
75        mma: ManualMma<CudaDialect<Self>>,
76        scales_a: Variable<CudaDialect<Self>>,
77        scales_b: Variable<CudaDialect<Self>>,
78        scales_factor: u32,
79    ) -> std::fmt::Result {
80        compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
81    }
82
83    fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
84        let mut result: SupportedMmaCombinations = vec![];
85        if arch.get_version() >= WMMA_MINIMUM_VERSION {
86            let tdims = vec![(16, 16, 16), (32, 8, 16), (8, 32, 16)];
87            // Types fully supported.
88            let types = vec![
89                (
90                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
91                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
92                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
93                ),
94                (
95                    gpu::ElemType::Float(gpu::FloatKind::F16),
96                    gpu::ElemType::Float(gpu::FloatKind::F16),
97                    gpu::ElemType::Float(gpu::FloatKind::F32),
98                ),
99                (
100                    gpu::ElemType::Float(gpu::FloatKind::BF16),
101                    gpu::ElemType::Float(gpu::FloatKind::BF16),
102                    gpu::ElemType::Float(gpu::FloatKind::F32),
103                ),
104                (
105                    gpu::ElemType::Int(gpu::IntKind::I8),
106                    gpu::ElemType::Int(gpu::IntKind::I8),
107                    gpu::ElemType::Int(gpu::IntKind::I32),
108                ),
109                (
110                    gpu::ElemType::UInt(gpu::UIntKind::U8),
111                    gpu::ElemType::UInt(gpu::UIntKind::U8),
112                    gpu::ElemType::Int(gpu::IntKind::I32),
113                ),
114            ];
115            let combinations: SupportedMmaCombinations = types
116                .into_iter()
117                .cartesian_product(tdims)
118                .map(|((a, b, c), (m, n, k))| MmaConfig {
119                    a_type: a.into(),
120                    b_type: b.into(),
121                    cd_type: c.into(),
122                    m,
123                    n,
124                    k,
125                })
126                .collect();
127            result.extend(combinations);
128            if arch.get_version() >= 80 {
129                result.push(MmaConfig {
130                    a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
131                    b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
132                    cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
133                    m: 16,
134                    n: 16,
135                    k: 8,
136                });
137            }
138        }
139        result
140    }
141
142    fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
143        supported_mma_combinations(arch)
144    }
145
146    fn supported_scaled_mma_combinations(
147        arch: &CudaArchitecture,
148    ) -> SupportedScaledMmaCombinations {
149        supported_scaled_mma_combinations(arch)
150    }
151}