cubecl_wgpu/compiler/
base.rs

1use std::fmt::Display;
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::{
5    Compiler, WgpuCompilationOptions,
6    prelude::{CompiledKernel, KernelDefinition},
7    server::ComputeServer,
8};
9#[cfg(feature = "msl")]
10use cubecl_cpp::shared::MslComputeKernel;
11use cubecl_runtime::compiler::CompilationError;
12use derive_more::derive::From;
13
14use crate::{WgpuServer, WgslCompiler};
15
16use super::wgsl;
17
18#[allow(clippy::large_enum_variant)]
19#[derive(Debug, Clone)]
20pub enum AutoCompiler {
21    Wgsl(WgslCompiler),
22    #[cfg(feature = "spirv")]
23    SpirV(cubecl_spirv::SpirvCompiler),
24    #[cfg(feature = "msl")]
25    Msl(cubecl_cpp::MslCompiler),
26}
27
28#[derive(From)]
29#[allow(clippy::large_enum_variant)]
30pub enum AutoRepresentation {
31    Wgsl(wgsl::ComputeShader),
32    #[cfg(feature = "spirv")]
33    SpirV(cubecl_spirv::SpirvKernel),
34    #[cfg(feature = "msl")]
35    Msl(MslComputeKernel),
36}
37
38#[cfg(feature = "spirv")]
39impl AutoRepresentation {
40    pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
41        match self {
42            AutoRepresentation::SpirV(repr) => Some(repr),
43            _ => None,
44        }
45    }
46}
47
48#[cfg(feature = "msl")]
49impl AutoRepresentation {
50    pub fn as_msl(&self) -> Option<&MslComputeKernel> {
51        match self {
52            AutoRepresentation::Msl(repr) => Some(repr),
53            _ => None,
54        }
55    }
56}
57
58impl Display for AutoRepresentation {
59    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
60        match self {
61            AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
62            #[cfg(feature = "spirv")]
63            AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
64            #[cfg(feature = "msl")]
65            AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
66        }
67    }
68}
69
70impl Compiler for AutoCompiler {
71    type Representation = AutoRepresentation;
72
73    type CompilationOptions = WgpuCompilationOptions;
74
75    fn compile(
76        &mut self,
77        kernel: KernelDefinition,
78        compilation_options: &Self::CompilationOptions,
79        mode: ExecutionMode,
80    ) -> Result<Self::Representation, CompilationError> {
81        let kernel = match self {
82            AutoCompiler::Wgsl(wgsl_compiler) => {
83                Compiler::compile(wgsl_compiler, kernel, compilation_options, mode)?.into()
84            }
85            #[cfg(feature = "spirv")]
86            AutoCompiler::SpirV(spirv_compiler) => {
87                Compiler::compile(spirv_compiler, kernel, compilation_options, mode)?.into()
88            }
89            #[cfg(feature = "msl")]
90            AutoCompiler::Msl(msl_compiler) => {
91                // override compilation options with cpp compiler options for metal
92                use cubecl_cpp;
93                let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
94                Compiler::compile(msl_compiler, kernel, &compilation_options, mode)?.into()
95            }
96        };
97
98        Ok(kernel)
99    }
100
101    fn elem_size(&self, elem: cubecl_core::ir::ElemType) -> usize {
102        match self {
103            AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
104            #[cfg(feature = "spirv")]
105            AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
106            #[cfg(feature = "msl")]
107            AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
108        }
109    }
110
111    fn extension(&self) -> &'static str {
112        match self {
113            AutoCompiler::Wgsl(_) => "wgsl",
114            #[cfg(feature = "spirv")]
115            AutoCompiler::SpirV(_) => "spv",
116            #[cfg(feature = "msl")]
117            AutoCompiler::Msl(_) => "msl",
118        }
119    }
120}
121
122impl AutoCompiler {
123    pub fn compile(
124        &mut self,
125        server: &mut WgpuServer,
126        kernel: <WgpuServer as ComputeServer>::Kernel,
127        mode: ExecutionMode,
128    ) -> Result<CompiledKernel<Self>, CompilationError> {
129        match self {
130            AutoCompiler::Wgsl(_) => kernel.compile(self, &server.compilation_options, mode),
131            #[cfg(feature = "spirv")]
132            AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
133            #[cfg(feature = "msl")]
134            AutoCompiler::Msl(_) => kernel.compile(self, &server.compilation_options, mode),
135        }
136    }
137
138    pub fn lang_tag(&self) -> &'static str {
139        match self {
140            AutoCompiler::Wgsl(_) => "wgsl",
141            #[cfg(feature = "spirv")]
142            AutoCompiler::SpirV(_) => "spirv",
143            #[cfg(feature = "msl")]
144            AutoCompiler::Msl(_) => "msl",
145        }
146    }
147}