cubecl_wgpu/compiler/
base.rs1use std::fmt::Display;
2
3use cubecl_core::{
4 Compiler, ExecutionMode, WgpuCompilationOptions,
5 ir::StorageType,
6 prelude::{CompiledKernel, KernelDefinition},
7 server::ComputeServer,
8};
9#[cfg(feature = "msl")]
10use cubecl_cpp::shared::MslComputeKernel;
11use cubecl_runtime::compiler::CompilationError;
12use derive_more::derive::From;
13
14use crate::{WgpuServer, WgslCompiler};
15
16use super::wgsl;
17
18#[allow(clippy::large_enum_variant)]
19#[derive(Debug, Clone)]
20pub enum AutoCompiler {
21 Wgsl(WgslCompiler),
22 #[cfg(feature = "spirv")]
23 SpirV(cubecl_spirv::SpirvCompiler),
24 #[cfg(feature = "msl")]
25 Msl(cubecl_cpp::MslCompiler),
26}
27
28#[derive(From)]
29#[allow(clippy::large_enum_variant)]
30pub enum AutoRepresentation {
31 Wgsl(wgsl::ComputeShader),
32 #[cfg(feature = "spirv")]
33 SpirV(cubecl_spirv::SpirvKernel),
34 #[cfg(feature = "msl")]
35 Msl(MslComputeKernel),
36}
37
38#[cfg(feature = "spirv")]
39impl AutoRepresentation {
40 pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
41 match self {
42 AutoRepresentation::SpirV(repr) => Some(repr),
43 _ => None,
44 }
45 }
46}
47
48#[cfg(feature = "msl")]
49impl AutoRepresentation {
50 pub fn as_msl(&self) -> Option<&MslComputeKernel> {
51 match self {
52 AutoRepresentation::Msl(repr) => Some(repr),
53 _ => None,
54 }
55 }
56}
57
58impl Display for AutoRepresentation {
59 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
62 #[cfg(feature = "spirv")]
63 AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
64 #[cfg(feature = "msl")]
65 AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
66 }
67 }
68}
69
70impl Compiler for AutoCompiler {
71 type Representation = AutoRepresentation;
72
73 type CompilationOptions = WgpuCompilationOptions;
74
75 fn compile(
76 &mut self,
77 kernel: KernelDefinition,
78 compilation_options: &Self::CompilationOptions,
79 mode: ExecutionMode,
80 addr_type: StorageType,
81 ) -> Result<Self::Representation, CompilationError> {
82 let kernel = match self {
83 AutoCompiler::Wgsl(wgsl_compiler) => {
84 Compiler::compile(wgsl_compiler, kernel, compilation_options, mode, addr_type)?
85 .into()
86 }
87 #[cfg(feature = "spirv")]
88 AutoCompiler::SpirV(spirv_compiler) => {
89 Compiler::compile(spirv_compiler, kernel, compilation_options, mode, addr_type)?
90 .into()
91 }
92 #[cfg(feature = "msl")]
93 AutoCompiler::Msl(msl_compiler) => {
94 use cubecl_cpp;
96 let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
97 Compiler::compile(msl_compiler, kernel, &compilation_options, mode, addr_type)?
98 .into()
99 }
100 };
101
102 Ok(kernel)
103 }
104
105 fn elem_size(&self, elem: cubecl_core::ir::ElemType) -> usize {
106 match self {
107 AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
108 #[cfg(feature = "spirv")]
109 AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
110 #[cfg(feature = "msl")]
111 AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
112 }
113 }
114
115 fn extension(&self) -> &'static str {
116 match self {
117 AutoCompiler::Wgsl(_) => "wgsl",
118 #[cfg(feature = "spirv")]
119 AutoCompiler::SpirV(_) => "spv",
120 #[cfg(feature = "msl")]
121 AutoCompiler::Msl(_) => "msl",
122 }
123 }
124}
125
126impl AutoCompiler {
127 pub fn compile(
128 &mut self,
129 server: &mut WgpuServer,
130 kernel: <WgpuServer as ComputeServer>::Kernel,
131 mode: ExecutionMode,
132 ) -> Result<CompiledKernel<Self>, CompilationError> {
133 match self {
134 AutoCompiler::Wgsl(_) => kernel.compile(
135 self,
136 &server.compilation_options,
137 mode,
138 kernel.address_type(),
139 ),
140 #[cfg(feature = "spirv")]
141 AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
142 #[cfg(feature = "msl")]
143 AutoCompiler::Msl(_) => kernel.compile(
144 self,
145 &server.compilation_options,
146 mode,
147 kernel.address_type(),
148 ),
149 }
150 }
151
152 pub fn lang_tag(&self) -> &'static str {
153 match self {
154 AutoCompiler::Wgsl(_) => "wgsl",
155 #[cfg(feature = "spirv")]
156 AutoCompiler::SpirV(_) => "spirv",
157 #[cfg(feature = "msl")]
158 AutoCompiler::Msl(_) => "msl",
159 }
160 }
161}