cubecl_wgpu/compiler/
base.rs1use std::fmt::Display;
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::{
5 Compiler, WgpuCompilationOptions,
6 prelude::{CompiledKernel, KernelDefinition},
7 server::ComputeServer,
8};
9#[cfg(feature = "msl")]
10use cubecl_cpp::shared::MslComputeKernel;
11use derive_more::derive::From;
12
13use crate::{WgpuServer, WgslCompiler};
14
15use super::wgsl;
16
17#[allow(clippy::large_enum_variant)]
18#[derive(Debug, Clone)]
19pub enum AutoCompiler {
20 Wgsl(WgslCompiler),
21 #[cfg(feature = "spirv")]
22 SpirV(cubecl_spirv::SpirvCompiler),
23 #[cfg(feature = "msl")]
24 Msl(cubecl_cpp::MslCompiler),
25}
26
27#[derive(From)]
28#[allow(clippy::large_enum_variant)]
29pub enum AutoRepresentation {
30 Wgsl(wgsl::ComputeShader),
31 #[cfg(feature = "spirv")]
32 SpirV(cubecl_spirv::SpirvKernel),
33 #[cfg(feature = "msl")]
34 Msl(MslComputeKernel),
35}
36
37#[cfg(feature = "spirv")]
38impl AutoRepresentation {
39 pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
40 match self {
41 AutoRepresentation::SpirV(repr) => Some(repr),
42 _ => None,
43 }
44 }
45}
46
47#[cfg(feature = "msl")]
48impl AutoRepresentation {
49 pub fn as_msl(&self) -> Option<&MslComputeKernel> {
50 match self {
51 AutoRepresentation::Msl(repr) => Some(repr),
52 _ => None,
53 }
54 }
55}
56
57impl Display for AutoRepresentation {
58 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
61 #[cfg(feature = "spirv")]
62 AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
63 #[cfg(feature = "msl")]
64 AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
65 }
66 }
67}
68
69impl Compiler for AutoCompiler {
70 type Representation = AutoRepresentation;
71
72 type CompilationOptions = WgpuCompilationOptions;
73
74 fn compile(
75 &mut self,
76 kernel: KernelDefinition,
77 compilation_options: &Self::CompilationOptions,
78 mode: ExecutionMode,
79 ) -> Self::Representation {
80 match self {
81 AutoCompiler::Wgsl(wgsl_compiler) => {
82 Compiler::compile(wgsl_compiler, kernel, compilation_options, mode).into()
83 }
84 #[cfg(feature = "spirv")]
85 AutoCompiler::SpirV(spirv_compiler) => {
86 Compiler::compile(spirv_compiler, kernel, compilation_options, mode).into()
87 }
88 #[cfg(feature = "msl")]
89 AutoCompiler::Msl(msl_compiler) => {
90 use cubecl_cpp;
92 let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
93 Compiler::compile(msl_compiler, kernel, &compilation_options, mode).into()
94 }
95 }
96 }
97
98 fn elem_size(&self, elem: cubecl_core::ir::Elem) -> usize {
99 match self {
100 AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
101 #[cfg(feature = "spirv")]
102 AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
103 #[cfg(feature = "msl")]
104 AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
105 }
106 }
107
108 fn extension(&self) -> &'static str {
109 match self {
110 AutoCompiler::Wgsl(_) => "wgsl",
111 #[cfg(feature = "spirv")]
112 AutoCompiler::SpirV(_) => "spv",
113 #[cfg(feature = "msl")]
114 AutoCompiler::Msl(_) => "msl",
115 }
116 }
117}
118
119impl AutoCompiler {
120 pub fn compile(
121 &mut self,
122 server: &mut WgpuServer,
123 kernel: <WgpuServer as ComputeServer>::Kernel,
124 mode: ExecutionMode,
125 ) -> CompiledKernel<Self> {
126 match self {
127 AutoCompiler::Wgsl(_) => kernel.compile(self, &server.compilation_options, mode),
128 #[cfg(feature = "spirv")]
129 AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
130 #[cfg(feature = "msl")]
131 AutoCompiler::Msl(_) => kernel.compile(self, &server.compilation_options, mode),
132 }
133 }
134
135 pub fn lang_tag(&self) -> &'static str {
136 match self {
137 AutoCompiler::Wgsl(_) => "wgsl",
138 #[cfg(feature = "spirv")]
139 AutoCompiler::SpirV(_) => "spirv",
140 #[cfg(feature = "msl")]
141 AutoCompiler::Msl(_) => "msl",
142 }
143 }
144}