cubecl_wgpu/compiler/
base.rs1use std::fmt::Display;
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::{
5 Compiler, WgpuCompilationOptions,
6 prelude::{CompiledKernel, KernelDefinition},
7 server::ComputeServer,
8};
9#[cfg(feature = "msl")]
10use cubecl_cpp::shared::MslComputeKernel;
11use cubecl_runtime::compiler::CompilationError;
12use derive_more::derive::From;
13
14use crate::{WgpuServer, WgslCompiler};
15
16use super::wgsl;
17
18#[allow(clippy::large_enum_variant)]
19#[derive(Debug, Clone)]
20pub enum AutoCompiler {
21 Wgsl(WgslCompiler),
22 #[cfg(feature = "spirv")]
23 SpirV(cubecl_spirv::SpirvCompiler),
24 #[cfg(feature = "msl")]
25 Msl(cubecl_cpp::MslCompiler),
26}
27
28#[derive(From)]
29#[allow(clippy::large_enum_variant)]
30pub enum AutoRepresentation {
31 Wgsl(wgsl::ComputeShader),
32 #[cfg(feature = "spirv")]
33 SpirV(cubecl_spirv::SpirvKernel),
34 #[cfg(feature = "msl")]
35 Msl(MslComputeKernel),
36}
37
38#[cfg(feature = "spirv")]
39impl AutoRepresentation {
40 pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
41 match self {
42 AutoRepresentation::SpirV(repr) => Some(repr),
43 _ => None,
44 }
45 }
46}
47
48#[cfg(feature = "msl")]
49impl AutoRepresentation {
50 pub fn as_msl(&self) -> Option<&MslComputeKernel> {
51 match self {
52 AutoRepresentation::Msl(repr) => Some(repr),
53 _ => None,
54 }
55 }
56}
57
58impl Display for AutoRepresentation {
59 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
62 #[cfg(feature = "spirv")]
63 AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
64 #[cfg(feature = "msl")]
65 AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
66 }
67 }
68}
69
70impl Compiler for AutoCompiler {
71 type Representation = AutoRepresentation;
72
73 type CompilationOptions = WgpuCompilationOptions;
74
75 fn compile(
76 &mut self,
77 kernel: KernelDefinition,
78 compilation_options: &Self::CompilationOptions,
79 mode: ExecutionMode,
80 ) -> Result<Self::Representation, CompilationError> {
81 let kernel = match self {
82 AutoCompiler::Wgsl(wgsl_compiler) => {
83 Compiler::compile(wgsl_compiler, kernel, compilation_options, mode)?.into()
84 }
85 #[cfg(feature = "spirv")]
86 AutoCompiler::SpirV(spirv_compiler) => {
87 Compiler::compile(spirv_compiler, kernel, compilation_options, mode)?.into()
88 }
89 #[cfg(feature = "msl")]
90 AutoCompiler::Msl(msl_compiler) => {
91 use cubecl_cpp;
93 let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
94 Compiler::compile(msl_compiler, kernel, &compilation_options, mode)?.into()
95 }
96 };
97
98 Ok(kernel)
99 }
100
101 fn elem_size(&self, elem: cubecl_core::ir::ElemType) -> usize {
102 match self {
103 AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
104 #[cfg(feature = "spirv")]
105 AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
106 #[cfg(feature = "msl")]
107 AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
108 }
109 }
110
111 fn extension(&self) -> &'static str {
112 match self {
113 AutoCompiler::Wgsl(_) => "wgsl",
114 #[cfg(feature = "spirv")]
115 AutoCompiler::SpirV(_) => "spv",
116 #[cfg(feature = "msl")]
117 AutoCompiler::Msl(_) => "msl",
118 }
119 }
120}
121
122impl AutoCompiler {
123 pub fn compile(
124 &mut self,
125 server: &mut WgpuServer,
126 kernel: <WgpuServer as ComputeServer>::Kernel,
127 mode: ExecutionMode,
128 ) -> Result<CompiledKernel<Self>, CompilationError> {
129 match self {
130 AutoCompiler::Wgsl(_) => kernel.compile(self, &server.compilation_options, mode),
131 #[cfg(feature = "spirv")]
132 AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
133 #[cfg(feature = "msl")]
134 AutoCompiler::Msl(_) => kernel.compile(self, &server.compilation_options, mode),
135 }
136 }
137
138 pub fn lang_tag(&self) -> &'static str {
139 match self {
140 AutoCompiler::Wgsl(_) => "wgsl",
141 #[cfg(feature = "spirv")]
142 AutoCompiler::SpirV(_) => "spirv",
143 #[cfg(feature = "msl")]
144 AutoCompiler::Msl(_) => "msl",
145 }
146 }
147}