cubecl_cpp/cuda/mma/
cuda_compiler.rs

1use crate::{
2    cuda::{
3        CudaDialect,
4        arch::CudaArchitecture,
5        mma::{
6            compile_manual_mma, compile_scaled_mma, supported_mma_combinations,
7            supported_scaled_mma_combinations,
8        },
9    },
10    shared::{
11        Architecture, DialectWmmaCompiler, Flags, Fragment, FragmentIdent, FragmentLayout,
12        ManualMma, SupportedMmaCombinations, SupportedScaledMmaCombinations, Variable,
13        WmmaInstruction, wmma_api_base,
14    },
15};
16use cubecl_core::ir::{self as gpu};
17use cubecl_runtime::MmaConfig;
18use itertools::Itertools;
19
20use super::{WMMA_MINIMUM_VERSION, WMMA_NAMESPACE};
21
22#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
23pub struct CudaWmmaCompiler {}
24
25impl DialectWmmaCompiler<CudaDialect<Self>> for CudaWmmaCompiler {
26    fn compile_wmma_includes(f: &mut std::fmt::Formatter<'_>, _flags: &Flags) -> std::fmt::Result {
27        f.write_str("#include <mma.h>\n")
28    }
29
30    fn compile_wmma_fragment_declaration(
31        f: &mut std::fmt::Formatter<'_>,
32        var: &crate::shared::Variable<CudaDialect<Self>>,
33    ) -> std::fmt::Result {
34        wmma_api_base::compile_fragment_declaration(f, var)
35    }
36
37    fn compile_wwma_fragment_ident(
38        f: &mut std::fmt::Formatter<'_>,
39        ident: &FragmentIdent<CudaDialect<Self>>,
40    ) -> std::fmt::Result {
41        wmma_api_base::compile_fragment_ident(f, WMMA_NAMESPACE, ident)
42    }
43
44    fn compile_wmma_fragment_layout(
45        f: &mut std::fmt::Formatter<'_>,
46        layout: &FragmentLayout<CudaDialect<Self>>,
47    ) -> std::fmt::Result {
48        wmma_api_base::compile_fragment_layout(f, WMMA_NAMESPACE, layout)
49    }
50
51    fn compile_wmma_fragment(
52        f: &mut std::fmt::Formatter<'_>,
53        fragment: &Fragment<CudaDialect<Self>>,
54    ) -> std::fmt::Result {
55        wmma_api_base::compile_fragment(f, WMMA_NAMESPACE, fragment)
56    }
57
58    fn compile_wmma_instruction(
59        f: &mut std::fmt::Formatter<'_>,
60        instruction: &WmmaInstruction<CudaDialect<Self>>,
61    ) -> std::fmt::Result {
62        wmma_api_base::compile_instruction(f, WMMA_NAMESPACE, instruction)
63    }
64
65    fn compile_manual_mma(
66        f: &mut std::fmt::Formatter<'_>,
67        mma: ManualMma<CudaDialect<Self>>,
68    ) -> std::fmt::Result {
69        compile_manual_mma(f, mma)
70    }
71
72    fn compile_scaled_mma(
73        f: &mut std::fmt::Formatter<'_>,
74        mma: ManualMma<CudaDialect<Self>>,
75        scales_a: Variable<CudaDialect<Self>>,
76        scales_b: Variable<CudaDialect<Self>>,
77        scales_factor: u32,
78    ) -> std::fmt::Result {
79        compile_scaled_mma(f, mma, scales_a, scales_b, scales_factor)
80    }
81
82    fn supported_wmma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
83        let mut result: SupportedMmaCombinations = vec![];
84        if arch.get_version() >= WMMA_MINIMUM_VERSION {
85            let tdims = vec![(16, 16, 16), (32, 8, 16), (8, 32, 16)];
86            // Types fully supported.
87            let types = vec![
88                (
89                    gpu::ElemType::Float(gpu::FloatKind::F16), // m
90                    gpu::ElemType::Float(gpu::FloatKind::F16), // n
91                    gpu::ElemType::Float(gpu::FloatKind::F16), // k
92                ),
93                (
94                    gpu::ElemType::Float(gpu::FloatKind::F16),
95                    gpu::ElemType::Float(gpu::FloatKind::F16),
96                    gpu::ElemType::Float(gpu::FloatKind::F32),
97                ),
98                (
99                    gpu::ElemType::Float(gpu::FloatKind::BF16),
100                    gpu::ElemType::Float(gpu::FloatKind::BF16),
101                    gpu::ElemType::Float(gpu::FloatKind::F32),
102                ),
103                (
104                    gpu::ElemType::Int(gpu::IntKind::I8),
105                    gpu::ElemType::Int(gpu::IntKind::I8),
106                    gpu::ElemType::Int(gpu::IntKind::I32),
107                ),
108                (
109                    gpu::ElemType::UInt(gpu::UIntKind::U8),
110                    gpu::ElemType::UInt(gpu::UIntKind::U8),
111                    gpu::ElemType::Int(gpu::IntKind::I32),
112                ),
113            ];
114            let combinations: SupportedMmaCombinations = types
115                .into_iter()
116                .cartesian_product(tdims)
117                .map(|((a, b, c), (m, n, k))| MmaConfig {
118                    a_type: a.into(),
119                    b_type: b.into(),
120                    cd_type: c.into(),
121                    m,
122                    n,
123                    k,
124                })
125                .collect();
126            result.extend(combinations);
127            if arch.get_version() >= 80 {
128                result.push(MmaConfig {
129                    a_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
130                    b_type: gpu::ElemType::Float(gpu::FloatKind::TF32).into(),
131                    cd_type: gpu::ElemType::Float(gpu::FloatKind::F32).into(),
132                    m: 16,
133                    n: 16,
134                    k: 8,
135                });
136            }
137        }
138        result
139    }
140
141    fn supported_mma_combinations(arch: &CudaArchitecture) -> SupportedMmaCombinations {
142        supported_mma_combinations(arch)
143    }
144
145    fn supported_scaled_mma_combinations(
146        arch: &CudaArchitecture,
147    ) -> SupportedScaledMmaCombinations {
148        supported_scaled_mma_combinations(arch)
149    }
150}