cubecl_wgpu/compiler/
base.rs

1use std::fmt::Display;
2
3use cubecl_core::{
4    Compiler, ExecutionMode, WgpuCompilationOptions,
5    ir::StorageType,
6    prelude::{CompiledKernel, KernelDefinition},
7    server::ComputeServer,
8};
9#[cfg(feature = "msl")]
10use cubecl_cpp::shared::MslComputeKernel;
11use cubecl_runtime::compiler::CompilationError;
12use derive_more::derive::From;
13
14use crate::{WgpuServer, WgslCompiler};
15
16use super::wgsl;
17
18#[allow(clippy::large_enum_variant)]
19#[derive(Debug, Clone)]
20pub enum AutoCompiler {
21    Wgsl(WgslCompiler),
22    #[cfg(feature = "spirv")]
23    SpirV(cubecl_spirv::SpirvCompiler),
24    #[cfg(feature = "msl")]
25    Msl(cubecl_cpp::MslCompiler),
26}
27
28#[derive(From)]
29#[allow(clippy::large_enum_variant)]
30pub enum AutoRepresentation {
31    Wgsl(wgsl::ComputeShader),
32    #[cfg(feature = "spirv")]
33    SpirV(cubecl_spirv::SpirvKernel),
34    #[cfg(feature = "msl")]
35    Msl(MslComputeKernel),
36}
37
38#[cfg(feature = "spirv")]
39impl AutoRepresentation {
40    pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
41        match self {
42            AutoRepresentation::SpirV(repr) => Some(repr),
43            _ => None,
44        }
45    }
46}
47
48#[cfg(feature = "msl")]
49impl AutoRepresentation {
50    pub fn as_msl(&self) -> Option<&MslComputeKernel> {
51        match self {
52            AutoRepresentation::Msl(repr) => Some(repr),
53            _ => None,
54        }
55    }
56}
57
58impl Display for AutoRepresentation {
59    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
60        match self {
61            AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
62            #[cfg(feature = "spirv")]
63            AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
64            #[cfg(feature = "msl")]
65            AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
66        }
67    }
68}
69
70impl Compiler for AutoCompiler {
71    type Representation = AutoRepresentation;
72
73    type CompilationOptions = WgpuCompilationOptions;
74
75    fn compile(
76        &mut self,
77        kernel: KernelDefinition,
78        compilation_options: &Self::CompilationOptions,
79        mode: ExecutionMode,
80        addr_type: StorageType,
81    ) -> Result<Self::Representation, CompilationError> {
82        let kernel = match self {
83            AutoCompiler::Wgsl(wgsl_compiler) => {
84                Compiler::compile(wgsl_compiler, kernel, compilation_options, mode, addr_type)?
85                    .into()
86            }
87            #[cfg(feature = "spirv")]
88            AutoCompiler::SpirV(spirv_compiler) => {
89                Compiler::compile(spirv_compiler, kernel, compilation_options, mode, addr_type)?
90                    .into()
91            }
92            #[cfg(feature = "msl")]
93            AutoCompiler::Msl(msl_compiler) => {
94                // override compilation options with cpp compiler options for metal
95                use cubecl_cpp;
96                let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
97                Compiler::compile(msl_compiler, kernel, &compilation_options, mode, addr_type)?
98                    .into()
99            }
100        };
101
102        Ok(kernel)
103    }
104
105    fn elem_size(&self, elem: cubecl_core::ir::ElemType) -> usize {
106        match self {
107            AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
108            #[cfg(feature = "spirv")]
109            AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
110            #[cfg(feature = "msl")]
111            AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
112        }
113    }
114
115    fn extension(&self) -> &'static str {
116        match self {
117            AutoCompiler::Wgsl(_) => "wgsl",
118            #[cfg(feature = "spirv")]
119            AutoCompiler::SpirV(_) => "spv",
120            #[cfg(feature = "msl")]
121            AutoCompiler::Msl(_) => "msl",
122        }
123    }
124}
125
126impl AutoCompiler {
127    pub fn compile(
128        &mut self,
129        server: &mut WgpuServer,
130        kernel: <WgpuServer as ComputeServer>::Kernel,
131        mode: ExecutionMode,
132    ) -> Result<CompiledKernel<Self>, CompilationError> {
133        match self {
134            AutoCompiler::Wgsl(_) => kernel.compile(
135                self,
136                &server.compilation_options,
137                mode,
138                kernel.address_type(),
139            ),
140            #[cfg(feature = "spirv")]
141            AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
142            #[cfg(feature = "msl")]
143            AutoCompiler::Msl(_) => kernel.compile(
144                self,
145                &server.compilation_options,
146                mode,
147                kernel.address_type(),
148            ),
149        }
150    }
151
152    pub fn lang_tag(&self) -> &'static str {
153        match self {
154            AutoCompiler::Wgsl(_) => "wgsl",
155            #[cfg(feature = "spirv")]
156            AutoCompiler::SpirV(_) => "spirv",
157            #[cfg(feature = "msl")]
158            AutoCompiler::Msl(_) => "msl",
159        }
160    }
161}