cubecl_wgpu/compiler/
base.rs1use std::fmt::Display;
2
3use cubecl_core::{
4 Compiler, ExecutionMode, WgpuCompilationOptions,
5 prelude::{CompiledKernel, KernelDefinition},
6 server::ComputeServer,
7};
8#[cfg(feature = "msl")]
9use cubecl_cpp::shared::MslComputeKernel;
10use cubecl_runtime::compiler::CompilationError;
11use derive_more::derive::From;
12
13use crate::{WgpuServer, WgslCompiler};
14
15use super::wgsl;
16
17#[allow(clippy::large_enum_variant)]
18#[derive(Debug, Clone)]
19pub enum AutoCompiler {
20 Wgsl(WgslCompiler),
21 #[cfg(feature = "spirv")]
22 SpirV(cubecl_spirv::SpirvCompiler),
23 #[cfg(feature = "msl")]
24 Msl(cubecl_cpp::MslCompiler),
25}
26
27#[derive(From)]
28#[allow(clippy::large_enum_variant)]
29pub enum AutoRepresentation {
30 Wgsl(wgsl::ComputeShader),
31 #[cfg(feature = "spirv")]
32 SpirV(cubecl_spirv::SpirvKernel),
33 #[cfg(feature = "msl")]
34 Msl(MslComputeKernel),
35}
36
37#[cfg(feature = "spirv")]
38impl AutoRepresentation {
39 pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
40 match self {
41 AutoRepresentation::SpirV(repr) => Some(repr),
42 _ => None,
43 }
44 }
45}
46
47#[cfg(feature = "msl")]
48impl AutoRepresentation {
49 pub fn as_msl(&self) -> Option<&MslComputeKernel> {
50 match self {
51 AutoRepresentation::Msl(repr) => Some(repr),
52 _ => None,
53 }
54 }
55}
56
57impl Display for AutoRepresentation {
58 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
61 #[cfg(feature = "spirv")]
62 AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
63 #[cfg(feature = "msl")]
64 AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
65 }
66 }
67}
68
69impl Compiler for AutoCompiler {
70 type Representation = AutoRepresentation;
71
72 type CompilationOptions = WgpuCompilationOptions;
73
74 fn compile(
75 &mut self,
76 kernel: KernelDefinition,
77 compilation_options: &Self::CompilationOptions,
78 mode: ExecutionMode,
79 ) -> Result<Self::Representation, CompilationError> {
80 let kernel = match self {
81 AutoCompiler::Wgsl(wgsl_compiler) => {
82 Compiler::compile(wgsl_compiler, kernel, compilation_options, mode)?.into()
83 }
84 #[cfg(feature = "spirv")]
85 AutoCompiler::SpirV(spirv_compiler) => {
86 Compiler::compile(spirv_compiler, kernel, compilation_options, mode)?.into()
87 }
88 #[cfg(feature = "msl")]
89 AutoCompiler::Msl(msl_compiler) => {
90 use cubecl_cpp;
92 let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
93 Compiler::compile(msl_compiler, kernel, &compilation_options, mode)?.into()
94 }
95 };
96
97 Ok(kernel)
98 }
99
100 fn elem_size(&self, elem: cubecl_core::ir::ElemType) -> usize {
101 match self {
102 AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
103 #[cfg(feature = "spirv")]
104 AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
105 #[cfg(feature = "msl")]
106 AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
107 }
108 }
109
110 fn extension(&self) -> &'static str {
111 match self {
112 AutoCompiler::Wgsl(_) => "wgsl",
113 #[cfg(feature = "spirv")]
114 AutoCompiler::SpirV(_) => "spv",
115 #[cfg(feature = "msl")]
116 AutoCompiler::Msl(_) => "msl",
117 }
118 }
119}
120
121impl AutoCompiler {
122 pub fn compile(
123 &mut self,
124 server: &mut WgpuServer,
125 kernel: <WgpuServer as ComputeServer>::Kernel,
126 mode: ExecutionMode,
127 ) -> Result<CompiledKernel<Self>, CompilationError> {
128 match self {
129 AutoCompiler::Wgsl(_) => kernel.compile(self, &server.compilation_options, mode),
130 #[cfg(feature = "spirv")]
131 AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
132 #[cfg(feature = "msl")]
133 AutoCompiler::Msl(_) => kernel.compile(self, &server.compilation_options, mode),
134 }
135 }
136
137 pub fn lang_tag(&self) -> &'static str {
138 match self {
139 AutoCompiler::Wgsl(_) => "wgsl",
140 #[cfg(feature = "spirv")]
141 AutoCompiler::SpirV(_) => "spirv",
142 #[cfg(feature = "msl")]
143 AutoCompiler::Msl(_) => "msl",
144 }
145 }
146}