1extern crate proc_macro;
102
103mod backends;
104
105#[cfg(not(any(feature = "shaderc", feature = "naga")))]
106compile_error!("no compiler backend enabled; please specify at least one of \
107 the following input source features: `glsl`, `hlsl`, `wgsl`");
108
109use proc_macro::TokenStream;
110use syn::parse::{Parse, ParseStream, Result as ParseResult, Error as ParseError};
111use syn::{parse_macro_input, Ident, LitStr, Token, Expr};
112
113#[derive(Clone, Copy)]
114enum InputSourceLanguage {
115 Unknown,
116 Glsl,
117 Hlsl,
118 Wgsl,
119}
120#[derive(Clone, Copy)]
121enum TargetSpirvVersion {
122 Spirv1_0,
123 #[allow(dead_code)]
124 Spirv1_1,
125 #[allow(dead_code)]
126 Spirv1_2,
127 Spirv1_3,
128 #[allow(dead_code)]
129 Spirv1_4,
130 Spirv1_5,
131}
132#[derive(Clone, Copy)]
133enum TargetEnvironmentType {
134 Vulkan,
135 OpenGL,
136 WebGpu,
137}
138#[derive(Clone, Copy)]
139enum OptimizationLevel {
140 MinSize,
141 MaxPerformance,
142 None,
143}
144#[derive(Clone, Copy)]
145enum ShaderKind {
146 Unknown,
147
148 Vertex,
149 TesselationControl,
150 TesselationEvaluation,
151 Geometry,
152 Fragment,
153 Compute,
154 Mesh,
156 Task,
157 RayGeneration,
159 Intersection,
160 AnyHit,
161 ClosestHit,
162 Miss,
163 Callable,
164}
165
166struct ShaderCompilationConfig {
167 path: Option<String>,
168 lang: InputSourceLanguage,
169 incl_dirs: Vec<String>,
170 defs: Vec<(String, Option<String>)>,
171 spv_ver: TargetSpirvVersion,
172 env_ty: TargetEnvironmentType,
173 entry: String,
174 optim_lv: OptimizationLevel,
175 debug: bool,
176 kind: ShaderKind,
177 auto_bind: bool,
178 #[cfg(feature = "naga")]
180 y_flip: bool,
181}
182impl Default for ShaderCompilationConfig {
183 fn default() -> Self {
184 ShaderCompilationConfig {
185 path: None,
186 lang: InputSourceLanguage::Unknown,
187 incl_dirs: Vec::new(),
188 defs: Vec::new(),
189 spv_ver: TargetSpirvVersion::Spirv1_0,
190 env_ty: TargetEnvironmentType::Vulkan,
191 entry: "main".to_owned(),
192 optim_lv: OptimizationLevel::None,
193 debug: true,
194 kind: ShaderKind::Unknown,
195 auto_bind: false,
196
197 #[cfg(feature = "naga")]
198 y_flip: true,
199 }
200 }
201}
202
203struct JitSpirv(TokenStream);
204
205#[inline]
206fn parse_str(input: &mut ParseStream) -> ParseResult<String> {
207 input.parse::<LitStr>()
208 .map(|x| x.value())
209}
210#[inline]
211fn parse_ident(input: &mut ParseStream) -> ParseResult<String> {
212 input.parse::<Ident>()
213 .map(|x| x.to_string())
214}
215
216fn parse_compile_cfg(
217 input: &mut ParseStream
218) -> ParseResult<ShaderCompilationConfig> {
219 let mut cfg = ShaderCompilationConfig::default();
220 while !input.is_empty() {
221 use syn::Error;
222 input.parse::<Token![,]>()?;
224 let k = if let Ok(k) = input.parse::<Ident>() { k } else { break };
225 match &k.to_string() as &str {
226 "path" => {
227 input.parse::<Token![,]>()?;
228 cfg.path = Some(parse_str(input)?);
229 },
230
231 "glsl" => cfg.lang = InputSourceLanguage::Glsl,
232 "hlsl" => {
233 cfg.lang = InputSourceLanguage::Hlsl;
234 cfg.optim_lv = OptimizationLevel::MaxPerformance;
237 },
238 "wgsl" => cfg.lang = InputSourceLanguage::Wgsl,
239
240 "vert" => cfg.kind = ShaderKind::Vertex,
241 "tesc" => cfg.kind = ShaderKind::TesselationControl,
242 "tese" => cfg.kind = ShaderKind::TesselationEvaluation,
243 "geom" => cfg.kind = ShaderKind::Geometry,
244 "frag" => cfg.kind = ShaderKind::Fragment,
245 "comp" => cfg.kind = ShaderKind::Compute,
246 "mesh" => cfg.kind = ShaderKind::Mesh,
247 "task" => cfg.kind = ShaderKind::Task,
248 "rgen" => cfg.kind = ShaderKind::RayGeneration,
249 "rint" => cfg.kind = ShaderKind::Intersection,
250 "rahit" => cfg.kind = ShaderKind::AnyHit,
251 "rchit" => cfg.kind = ShaderKind::ClosestHit,
252 "rmiss" => cfg.kind = ShaderKind::Miss,
253 "rcall" => cfg.kind = ShaderKind::Callable,
254
255 "I" => {
256 cfg.incl_dirs.push(parse_str(input)?)
257 },
258 "D" => {
259 let k = parse_ident(input)?;
260 let v = if input.parse::<Token![=]>().is_ok() {
261 Some(parse_str(input)?)
262 } else { None };
263 cfg.defs.push((k, v));
264 },
265
266 "entry" => {
267 if input.parse::<Token![=]>().is_ok() {
268 cfg.entry = parse_str(input)?.to_owned();
269 }
270 }
271
272 "min_size" => cfg.optim_lv = OptimizationLevel::MinSize,
273 "max_perf" => cfg.optim_lv = OptimizationLevel::MaxPerformance,
274
275 "no_debug" => cfg.debug = false,
276
277 "vulkan" | "vulkan1_0" => {
278 cfg.env_ty = TargetEnvironmentType::Vulkan;
279 cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
280 },
281 "vulkan1_1" => {
282 cfg.env_ty = TargetEnvironmentType::Vulkan;
283 cfg.spv_ver = TargetSpirvVersion::Spirv1_3;
284 },
285 "vulkan1_2" => {
286 cfg.env_ty = TargetEnvironmentType::Vulkan;
287 cfg.spv_ver = TargetSpirvVersion::Spirv1_5;
288 },
289 "opengl" | "opengl4_5" => {
290 cfg.env_ty = TargetEnvironmentType::OpenGL;
291 cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
292 },
293 "webgpu" => {
294 cfg.env_ty = TargetEnvironmentType::WebGpu;
295 cfg.spv_ver = TargetSpirvVersion::Spirv1_0;
296 }
297
298 "auto_bind" => cfg.auto_bind = true,
299
300 #[cfg(feature = "naga")]
301 "no_y_flip" => cfg.y_flip = false,
302
303 _ => return Err(Error::new(k.span(), "unsupported compilation parameter")),
304 }
305 }
306 Ok(cfg)
307}
308
309fn generate_compile_code(
310 src: &Expr,
311 cfg: &ShaderCompilationConfig,
312) -> Result<proc_macro::TokenStream, String> {
313 use quote::quote;
314 let mut is_valid = false;
315 let mut out = quote!(Err(String::default()));
317 if let Ok(generated_code) = backends::naga::generate_compile_code(src, cfg) {
318 out.extend(quote!(.or_else(#generated_code)));
319 is_valid = true;
320 }
321 if let Ok(generated_code) = backends::shaderc::generate_compile_code(src, cfg) {
322 out.extend(quote!(.or_else(#generated_code)));
323 is_valid = true;
324 }
325 if !is_valid {
326 return Err("cannot find a proper shader compiler backend".to_owned());
327 }
328 Ok(out.into())
329}
330
331impl Parse for JitSpirv {
332 fn parse(mut input: ParseStream) -> ParseResult<Self> {
333 let src = input.parse::<Expr>()?;
334
335 let cfg = parse_compile_cfg(&mut input)?;
336 let tokens = generate_compile_code(&src, &cfg)
337 .map_err(|e| ParseError::new(input.span(), e))?;
338 Ok(JitSpirv(tokens))
339 }
340}
341
342#[proc_macro]
345pub fn jit_spirv(tokens: TokenStream) -> TokenStream {
346 let JitSpirv(tokens) = parse_macro_input!(tokens as JitSpirv);
347 tokens
348}