cubecl_wgpu/compiler/
base.rs1use std::fmt::Display;
2
3use cubecl_core::{
4 Compiler, ExecutionMode, WgpuCompilationOptions,
5 ir::StorageType,
6 prelude::{CompiledKernel, KernelDefinition},
7 server::ComputeServer,
8};
9#[cfg(feature = "msl")]
10use cubecl_cpp::shared::MslComputeKernel;
11use cubecl_runtime::compiler::CompilationError;
12use derive_more::derive::From;
13
14use crate::{WgpuServer, WgslCompiler};
15
16use super::wgsl;
17
18#[allow(clippy::large_enum_variant)]
19#[derive(Debug, Clone)]
20pub enum AutoCompiler {
21 Wgsl(WgslCompiler),
22 #[cfg(feature = "spirv")]
23 SpirV(cubecl_spirv::SpirvCompiler),
24 #[cfg(feature = "msl")]
25 Msl(cubecl_cpp::MslCompiler),
26}
27
28#[derive(From)]
29#[allow(clippy::large_enum_variant)]
30pub enum AutoRepresentation {
31 Wgsl(wgsl::ComputeShader),
32 #[cfg(feature = "spirv")]
33 SpirV(cubecl_spirv::SpirvKernel),
34 #[cfg(feature = "msl")]
35 Msl(MslComputeKernel),
36}
37
38#[derive(From, Clone, Copy)]
39#[allow(clippy::large_enum_variant)]
40pub enum AutoRepresentationRef<'a> {
41 Wgsl(&'a wgsl::ComputeShader),
42 #[cfg(feature = "spirv")]
43 SpirV(&'a cubecl_spirv::SpirvKernel),
44 #[cfg(feature = "msl")]
45 Msl(&'a MslComputeKernel),
46}
47
48#[cfg(feature = "spirv")]
49impl AutoRepresentation {
50 pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
51 match self {
52 AutoRepresentation::SpirV(repr) => Some(repr),
53 _ => None,
54 }
55 }
56}
57
58#[cfg(feature = "msl")]
59impl AutoRepresentation {
60 pub fn as_msl(&self) -> Option<&MslComputeKernel> {
61 match self {
62 AutoRepresentation::Msl(repr) => Some(repr),
63 _ => None,
64 }
65 }
66}
67
68impl AutoRepresentation {
69 pub fn as_ref(&self) -> AutoRepresentationRef<'_> {
70 match self {
71 AutoRepresentation::Wgsl(compute_shader) => AutoRepresentationRef::Wgsl(compute_shader),
72 #[cfg(feature = "spirv")]
73 AutoRepresentation::SpirV(spirv_kernel) => AutoRepresentationRef::SpirV(spirv_kernel),
74 #[cfg(feature = "msl")]
75 AutoRepresentation::Msl(compute_shader) => AutoRepresentationRef::Msl(compute_shader),
76 }
77 }
78}
79
80impl Display for AutoRepresentation {
81 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
82 match self {
83 AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
84 #[cfg(feature = "spirv")]
85 AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
86 #[cfg(feature = "msl")]
87 AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
88 }
89 }
90}
91
92impl Compiler for AutoCompiler {
93 type Representation = AutoRepresentation;
94
95 type CompilationOptions = WgpuCompilationOptions;
96
97 fn compile(
98 &mut self,
99 kernel: KernelDefinition,
100 compilation_options: &Self::CompilationOptions,
101 mode: ExecutionMode,
102 addr_type: StorageType,
103 ) -> Result<Self::Representation, CompilationError> {
104 let kernel = match self {
105 AutoCompiler::Wgsl(wgsl_compiler) => {
106 Compiler::compile(wgsl_compiler, kernel, compilation_options, mode, addr_type)?
107 .into()
108 }
109 #[cfg(feature = "spirv")]
110 AutoCompiler::SpirV(spirv_compiler) => {
111 Compiler::compile(spirv_compiler, kernel, compilation_options, mode, addr_type)?
112 .into()
113 }
114 #[cfg(feature = "msl")]
115 AutoCompiler::Msl(msl_compiler) => {
116 use cubecl_cpp;
118 let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
119 Compiler::compile(msl_compiler, kernel, &compilation_options, mode, addr_type)?
120 .into()
121 }
122 };
123
124 Ok(kernel)
125 }
126
127 fn elem_size(&self, elem: cubecl_core::ir::ElemType) -> usize {
128 match self {
129 AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
130 #[cfg(feature = "spirv")]
131 AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
132 #[cfg(feature = "msl")]
133 AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
134 }
135 }
136
137 fn extension(&self) -> &'static str {
138 match self {
139 AutoCompiler::Wgsl(_) => "wgsl",
140 #[cfg(feature = "spirv")]
141 AutoCompiler::SpirV(_) => "spv",
142 #[cfg(feature = "msl")]
143 AutoCompiler::Msl(_) => "msl",
144 }
145 }
146}
147
148impl AutoCompiler {
149 pub fn compile(
150 &mut self,
151 server: &mut WgpuServer,
152 kernel: <WgpuServer as ComputeServer>::Kernel,
153 mode: ExecutionMode,
154 ) -> Result<CompiledKernel<Self>, CompilationError> {
155 match self {
156 AutoCompiler::Wgsl(_) => kernel.compile(
157 self,
158 &server.compilation_options,
159 mode,
160 kernel.address_type(),
161 ),
162 #[cfg(feature = "spirv")]
163 AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
164 #[cfg(feature = "msl")]
165 AutoCompiler::Msl(_) => kernel.compile(
166 self,
167 &server.compilation_options,
168 mode,
169 kernel.address_type(),
170 ),
171 }
172 }
173
174 pub fn lang_tag(&self) -> &'static str {
175 match self {
176 AutoCompiler::Wgsl(_) => "wgsl",
177 #[cfg(feature = "spirv")]
178 AutoCompiler::SpirV(_) => "spirv",
179 #[cfg(feature = "msl")]
180 AutoCompiler::Msl(_) => "msl",
181 }
182 }
183}