cubecl_wgpu/compiler/
base.rs

1use std::fmt::Display;
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::{
5    Compiler, WgpuCompilationOptions,
6    prelude::{CompiledKernel, KernelDefinition},
7    server::ComputeServer,
8};
9#[cfg(feature = "msl")]
10use cubecl_cpp::shared::MslComputeKernel;
11use derive_more::derive::From;
12
13use crate::{WgpuServer, WgslCompiler};
14
15use super::wgsl;
16
17#[allow(clippy::large_enum_variant)]
18#[derive(Debug, Clone)]
19pub enum AutoCompiler {
20    Wgsl(WgslCompiler),
21    #[cfg(feature = "spirv")]
22    SpirV(cubecl_spirv::SpirvCompiler),
23    #[cfg(feature = "msl")]
24    Msl(cubecl_cpp::MslCompiler),
25}
26
27#[derive(From)]
28#[allow(clippy::large_enum_variant)]
29pub enum AutoRepresentation {
30    Wgsl(wgsl::ComputeShader),
31    #[cfg(feature = "spirv")]
32    SpirV(cubecl_spirv::SpirvKernel),
33    #[cfg(feature = "msl")]
34    Msl(MslComputeKernel),
35}
36
37#[cfg(feature = "spirv")]
38impl AutoRepresentation {
39    pub fn as_spirv(&self) -> Option<&cubecl_spirv::SpirvKernel> {
40        match self {
41            AutoRepresentation::SpirV(repr) => Some(repr),
42            _ => None,
43        }
44    }
45}
46
47#[cfg(feature = "msl")]
48impl AutoRepresentation {
49    pub fn as_msl(&self) -> Option<&MslComputeKernel> {
50        match self {
51            AutoRepresentation::Msl(repr) => Some(repr),
52            _ => None,
53        }
54    }
55}
56
57impl Display for AutoRepresentation {
58    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            AutoRepresentation::Wgsl(compute_shader) => compute_shader.fmt(f),
61            #[cfg(feature = "spirv")]
62            AutoRepresentation::SpirV(spirv_kernel) => spirv_kernel.fmt(f),
63            #[cfg(feature = "msl")]
64            AutoRepresentation::Msl(compute_shader) => compute_shader.fmt(f),
65        }
66    }
67}
68
69impl Compiler for AutoCompiler {
70    type Representation = AutoRepresentation;
71
72    type CompilationOptions = WgpuCompilationOptions;
73
74    fn compile(
75        &mut self,
76        kernel: KernelDefinition,
77        compilation_options: &Self::CompilationOptions,
78        mode: ExecutionMode,
79    ) -> Self::Representation {
80        match self {
81            AutoCompiler::Wgsl(wgsl_compiler) => {
82                Compiler::compile(wgsl_compiler, kernel, compilation_options, mode).into()
83            }
84            #[cfg(feature = "spirv")]
85            AutoCompiler::SpirV(spirv_compiler) => {
86                Compiler::compile(spirv_compiler, kernel, compilation_options, mode).into()
87            }
88            #[cfg(feature = "msl")]
89            AutoCompiler::Msl(msl_compiler) => {
90                // override compilation options with cpp compiler options for metal
91                use cubecl_cpp;
92                let compilation_options = cubecl_cpp::shared::CompilationOptions::default();
93                Compiler::compile(msl_compiler, kernel, &compilation_options, mode).into()
94            }
95        }
96    }
97
98    fn elem_size(&self, elem: cubecl_core::ir::Elem) -> usize {
99        match self {
100            AutoCompiler::Wgsl(wgsl_compiler) => wgsl_compiler.elem_size(elem),
101            #[cfg(feature = "spirv")]
102            AutoCompiler::SpirV(spirv_compiler) => spirv_compiler.elem_size(elem),
103            #[cfg(feature = "msl")]
104            AutoCompiler::Msl(msl_compiler) => msl_compiler.elem_size(elem),
105        }
106    }
107
108    fn extension(&self) -> &'static str {
109        match self {
110            AutoCompiler::Wgsl(_) => "wgsl",
111            #[cfg(feature = "spirv")]
112            AutoCompiler::SpirV(_) => "spv",
113            #[cfg(feature = "msl")]
114            AutoCompiler::Msl(_) => "msl",
115        }
116    }
117}
118
119impl AutoCompiler {
120    pub fn compile(
121        &mut self,
122        server: &mut WgpuServer,
123        kernel: <WgpuServer as ComputeServer>::Kernel,
124        mode: ExecutionMode,
125    ) -> CompiledKernel<Self> {
126        match self {
127            AutoCompiler::Wgsl(_) => kernel.compile(self, &server.compilation_options, mode),
128            #[cfg(feature = "spirv")]
129            AutoCompiler::SpirV(_) => crate::vulkan::compile(self, server, kernel, mode),
130            #[cfg(feature = "msl")]
131            AutoCompiler::Msl(_) => kernel.compile(self, &server.compilation_options, mode),
132        }
133    }
134
135    pub fn lang_tag(&self) -> &'static str {
136        match self {
137            AutoCompiler::Wgsl(_) => "wgsl",
138            #[cfg(feature = "spirv")]
139            AutoCompiler::SpirV(_) => "spirv",
140            #[cfg(feature = "msl")]
141            AutoCompiler::Msl(_) => "msl",
142        }
143    }
144}