Skip to main content

cubecl_wgpu/compiler/
base.rs

1use std::fmt::Display;
2
3use cubecl_core::{
4    Compiler, ExecutionMode, WgpuCompilationOptions,
5    ir::StorageType,
6    prelude::{CompiledKernel, KernelDefinition},
7    server::ComputeServer,
8};
9#[cfg(feature = "msl")]
10use cubecl_cpp::shared::MslComputeKernel;
11use cubecl_runtime::compiler::CompilationError;
12use derive_more::derive::From;
13
14use crate::{WgpuServer, WgslCompiler};
15
16use super::wgsl;
17
18#[allow(clippy::large_enum_variant)]
19#[derive(Debug, Clone)]
20pub enum AutoCompiler {
21    Wgsl(WgslCompiler),
22    #[cfg(feature = "spirv")]
23    SpirV(cubecl_spirv::SpirvCompiler),
24    #[cfg(feature = "msl")]
25    Msl(cubecl_cpp::MslCompiler),
26}
27
28#[derive(From)]
29#[allow(clippy::large_enum_variant)]
30pub enum AutoRepresentation {
31    Wgsl(wgsl::ComputeShader),
32    #[cfg(feature = "spirv")]
33    SpirV(cubecl_spirv::SpirvKernel),
34    #[cfg(feature = "msl")]
35    Msl(MslComputeKernel),
36}
37
38#[derive(From, Clone, Copy)]
39#[allow(clippy::large_enum_variant)]
40pub enum AutoRepresentationRef<'a> {
41    Wgsl(&'a wgsl::ComputeShader),
42    #[cfg(feature = "spirv")]
43    SpirV(&'a cubecl_spirv::SpirvKernel),
44    #[cfg(feature = "msl")]
45    Msl(&'a MslComputeKernel),
46}
47
48#[cfg(feature = "spirv")]
49impl AutoRepresentation {
50    pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
51        match self {
52            AutoRepresentation::SpirV(repr) => Some(repr),
53            _ => None,
54        }
55    }
56}
57
58#[cfg(feature = "msl")]
59impl AutoRepresentation {
60    pub fn as_msl(&self) -> Option<&MslComputeKernel> {
61        match self {
62            AutoRepresentation::Msl(repr) => Some(repr),
63            _ => None,
64        }
65    }
66}
67
68impl AutoRepresentation {
69    pub fn as_ref(&self) -> AutoRepresentationRef<'_> {
70        match self {
71            AutoRepresentation::Wgsl(compute_shader) => AutoRepresentationRef::Wgsl(compute_shader),
72            #[cfg(feature = "spirv")]
73            AutoRepresentation::SpirV(spirv_kernel) => AutoRepresentationRef::SpirV(spirv_kernel),
74            #[cfg(feature = "msl")]
75            AutoRepresentation::Msl(compute_shader) => AutoRepresentationRef::Msl(compute_shader),
76        }
77    }
78}
79
80impl Display for AutoRepresentation {
81    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
82        match self {
83            AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
84            #[cfg(feature = "spirv")]
85            AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
86            #[cfg(feature = "msl")]
87            AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
88        }
89    }
90}
91
92impl Compiler for AutoCompiler {
93    type Representation = AutoRepresentation;
94
95    type CompilationOptions = WgpuCompilationOptions;
96
97    fn compile(
98        &mut self,
99        kernel: KernelDefinition,
100        compilation_options: &Self::CompilationOptions,
101        mode: ExecutionMode,
102        addr_type: StorageType,
103    ) -> Result<Self::Representation, CompilationError> {
104        let kernel = match self {
105            AutoCompiler::Wgsl(wgsl_compiler) => {
106                Compiler::compile(wgsl_compiler, kernel, compilation_options, mode, addr_type)?
107                    .into()
108            }
109            #[cfg(feature = "spirv")]
110            AutoCompiler::SpirV(spirv_compiler) => {
111                Compiler::compile(spirv_compiler, kernel, compilation_options, mode, addr_type)?
112                    .into()
113            }
114            #[cfg(feature = "msl")]
115            AutoCompiler::Msl(msl_compiler) => {
116                // override compilation options with cpp compiler options for metal
117                use cubecl_cpp;
118                let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
119                Compiler::compile(msl_compiler, kernel, &compilation_options, mode, addr_type)?
120                    .into()
121            }
122        };
123
124        Ok(kernel)
125    }
126
127    fn elem_size(&self, elem: cubecl_core::ir::ElemType) -> usize {
128        match self {
129            AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
130            #[cfg(feature = "spirv")]
131            AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
132            #[cfg(feature = "msl")]
133            AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
134        }
135    }
136
137    fn extension(&self) -> &'static str {
138        match self {
139            AutoCompiler::Wgsl(_) => "wgsl",
140            #[cfg(feature = "spirv")]
141            AutoCompiler::SpirV(_) => "spv",
142            #[cfg(feature = "msl")]
143            AutoCompiler::Msl(_) => "msl",
144        }
145    }
146}
147
148impl AutoCompiler {
149    pub fn compile(
150        &mut self,
151        server: &mut WgpuServer,
152        kernel: <WgpuServer as ComputeServer>::Kernel,
153        mode: ExecutionMode,
154    ) -> Result<CompiledKernel<Self>, CompilationError> {
155        match self {
156            AutoCompiler::Wgsl(_) => kernel.compile(
157                self,
158                &server.compilation_options,
159                mode,
160                kernel.address_type(),
161            ),
162            #[cfg(feature = "spirv")]
163            AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
164            #[cfg(feature = "msl")]
165            AutoCompiler::Msl(_) => kernel.compile(
166                self,
167                &server.compilation_options,
168                mode,
169                kernel.address_type(),
170            ),
171        }
172    }
173
174    pub fn lang_tag(&self) -> &'static str {
175        match self {
176            AutoCompiler::Wgsl(_) => "wgsl",
177            #[cfg(feature = "spirv")]
178            AutoCompiler::SpirV(_) => "spirv",
179            #[cfg(feature = "msl")]
180            AutoCompiler::Msl(_) => "msl",
181        }
182    }
183}