cubecl_wgpu/compiler/
base.rs

1use std::fmt::Display;
2
3use cubecl_core::{
4    Compiler, ExecutionMode, WgpuCompilationOptions,
5    prelude::{CompiledKernel, KernelDefinition},
6    server::ComputeServer,
7};
8#[cfg(feature = "msl")]
9use cubecl_cpp::shared::MslComputeKernel;
10use cubecl_runtime::compiler::CompilationError;
11use derive_more::derive::From;
12
13use crate::{WgpuServer, WgslCompiler};
14
15use super::wgsl;
16
17#[allow(clippy::large_enum_variant)]
18#[derive(Debug, Clone)]
19pub enum AutoCompiler {
20    Wgsl(WgslCompiler),
21    #[cfg(feature = "spirv")]
22    SpirV(cubecl_spirv::SpirvCompiler),
23    #[cfg(feature = "msl")]
24    Msl(cubecl_cpp::MslCompiler),
25}
26
27#[derive(From)]
28#[allow(clippy::large_enum_variant)]
29pub enum AutoRepresentation {
30    Wgsl(wgsl::ComputeShader),
31    #[cfg(feature = "spirv")]
32    SpirV(cubecl_spirv::SpirvKernel),
33    #[cfg(feature = "msl")]
34    Msl(MslComputeKernel),
35}
36
37#[cfg(feature = "spirv")]
38impl AutoRepresentation {
39    pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
40        match self {
41            AutoRepresentation::SpirV(repr) => Some(repr),
42            _ => None,
43        }
44    }
45}
46
47#[cfg(feature = "msl")]
48impl AutoRepresentation {
49    pub fn as_msl(&self) -> Option<&MslComputeKernel> {
50        match self {
51            AutoRepresentation::Msl(repr) => Some(repr),
52            _ => None,
53        }
54    }
55}
56
57impl Display for AutoRepresentation {
58    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
61            #[cfg(feature = "spirv")]
62            AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
63            #[cfg(feature = "msl")]
64            AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
65        }
66    }
67}
68
69impl Compiler for AutoCompiler {
70    type Representation = AutoRepresentation;
71
72    type CompilationOptions = WgpuCompilationOptions;
73
74    fn compile(
75        &mut self,
76        kernel: KernelDefinition,
77        compilation_options: &Self::CompilationOptions,
78        mode: ExecutionMode,
79    ) -> Result<Self::Representation, CompilationError> {
80        let kernel = match self {
81            AutoCompiler::Wgsl(wgsl_compiler) => {
82                Compiler::compile(wgsl_compiler, kernel, compilation_options, mode)?.into()
83            }
84            #[cfg(feature = "spirv")]
85            AutoCompiler::SpirV(spirv_compiler) => {
86                Compiler::compile(spirv_compiler, kernel, compilation_options, mode)?.into()
87            }
88            #[cfg(feature = "msl")]
89            AutoCompiler::Msl(msl_compiler) => {
90                // override compilation options with cpp compiler options for metal
91                use cubecl_cpp;
92                let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
93                Compiler::compile(msl_compiler, kernel, &compilation_options, mode)?.into()
94            }
95        };
96
97        Ok(kernel)
98    }
99
100    fn elem_size(&self, elem: cubecl_core::ir::ElemType) -> usize {
101        match self {
102            AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
103            #[cfg(feature = "spirv")]
104            AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
105            #[cfg(feature = "msl")]
106            AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
107        }
108    }
109
110    fn extension(&self) -> &'static str {
111        match self {
112            AutoCompiler::Wgsl(_) => "wgsl",
113            #[cfg(feature = "spirv")]
114            AutoCompiler::SpirV(_) => "spv",
115            #[cfg(feature = "msl")]
116            AutoCompiler::Msl(_) => "msl",
117        }
118    }
119}
120
121impl AutoCompiler {
122    pub fn compile(
123        &mut self,
124        server: &mut WgpuServer,
125        kernel: <WgpuServer as ComputeServer>::Kernel,
126        mode: ExecutionMode,
127    ) -> Result<CompiledKernel<Self>, CompilationError> {
128        match self {
129            AutoCompiler::Wgsl(_) => kernel.compile(self, &server.compilation_options, mode),
130            #[cfg(feature = "spirv")]
131            AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
132            #[cfg(feature = "msl")]
133            AutoCompiler::Msl(_) => kernel.compile(self, &server.compilation_options, mode),
134        }
135    }
136
137    pub fn lang_tag(&self) -> &'static str {
138        match self {
139            AutoCompiler::Wgsl(_) => "wgsl",
140            #[cfg(feature = "spirv")]
141            AutoCompiler::SpirV(_) => "spirv",
142            #[cfg(feature = "msl")]
143            AutoCompiler::Msl(_) => "msl",
144        }
145    }
146}