1use nom::{
2 branch::alt,
3 bytes::complete::{tag, take_until, take_until1},
4 character::streaming::multispace0,
5 combinator::{cut, map, opt},
6 error::ParseError,
7 number::complete::float,
8 sequence::{delimited, preceded, terminated},
9 IResult, Parser,
10};
11use wgpu::{BindGroupLayoutDescriptor, ShaderStages, StorageTextureAccess};
12
13use crate::{
14 parser::{
15 get_exports, parse_tokens, process, vec_to_owned, Definition, ExpansionError,
16 ExportedMoreThanOnce, NomError, Token,
17 },
18 utils::{Dispatcher, WorkgroupSize},
19};
20use pollster::FutureExt;
21use std::{
22 borrow::Cow,
23 collections::{BTreeMap, HashMap},
24 fmt::Write,
25 fs::DirEntry,
26 path::Path,
27 sync::Arc,
28};
29
30#[derive(Debug)]
31pub struct NonBoundPipeline {
32 pub label: Option<String>,
33 pub compute_pipeline: wgpu::ComputePipeline,
34 pub bind_group_layout: wgpu::BindGroupLayout,
35 pub dispatcher: Option<Dispatcher<'static>>,
36}
37
38#[derive(Debug, Clone)]
39pub struct ShaderSpecs<'def> {
40 pub workgroup_size: WorkgroupSize,
41 pub dispatcher: Option<Dispatcher<'static>>,
42 pub push_constants: Option<u32>,
43 pub shader_defs: HashMap<Cow<'def, str>, Definition<'def>>,
44 pub entry_point: Option<String>,
45
46 pub shader_label: Option<String>,
47 pub bindgroup_layout_label: Option<String>,
48 pub pipelinelayout_label: Option<String>,
49 pub pipeline_label: Option<String>,
50}
51
52impl<'def> ShaderSpecs<'def> {
53 pub fn new(workgroup_size: impl Into<WorkgroupSize>) -> Self {
54 let workgroup_size = workgroup_size.into();
55 let shader_defs = HashMap::from([
56 (
57 workgroup_size.x_name.clone(),
58 Definition::UInt(workgroup_size.x),
59 ),
60 (
61 workgroup_size.y_name.clone(),
62 Definition::UInt(workgroup_size.y),
63 ),
64 (
65 workgroup_size.z_name.clone(),
66 Definition::UInt(workgroup_size.z),
67 ),
68 ]);
69 Self {
70 workgroup_size: workgroup_size.into(),
71 dispatcher: None,
72 push_constants: None,
73 shader_defs,
74 entry_point: None,
75 shader_label: None,
76 bindgroup_layout_label: None,
77 pipelinelayout_label: None,
78 pipeline_label: None,
79 }
80 }
81
82 pub fn workgroupsize(mut self, val: WorkgroupSize) -> Self {
83 self.workgroup_size = val;
84 self
85 }
86
87 pub fn dispatcher(mut self, val: Dispatcher<'static>) -> Self {
88 self.dispatcher = Some(val);
89 self
90 }
91
92 pub fn direct_dispatcher(mut self, dims: &[u32; 3]) -> Self {
93 self.dispatcher = Some(Dispatcher::new_direct(dims, &self.workgroup_size));
94 self
95 }
96
97 pub fn push_constants(mut self, val: u32) -> Self {
98 self.push_constants = Some(val);
99 self
100 }
101
102 pub fn extend_defs(
103 mut self,
104 vals: impl IntoIterator<Item = (impl Into<Cow<'static, str>>, Definition<'def>)>,
105 ) -> Self {
106 let iter = vals.into_iter().map(|(key, val)| (key.into(), val));
107 self.shader_defs.extend(iter);
108 self
109 }
110
111 pub fn shader_label(mut self, val: &str) -> Self {
112 self.shader_label = Some(val.to_string());
113 self
114 }
115
116 pub fn bindgroup_layout_label(mut self, val: &str) -> Self {
117 self.bindgroup_layout_label = Some(val.to_string());
118 self
119 }
120
121 pub fn pipelinelayout_label(mut self, val: &str) -> Self {
122 self.pipelinelayout_label = Some(val.to_string());
123 self
124 }
125
126 pub fn pipeline_label(mut self, val: &str) -> Self {
127 self.pipeline_label = Some(val.to_string());
128 self
129 }
130
131 pub fn entry_point(mut self, val: &str) -> Self {
132 self.entry_point = Some(val.to_string());
133 self
134 }
135
136 pub fn labels(self, val: &str) -> Self {
137 self.shader_label(val)
138 .bindgroup_layout_label(val)
139 .pipelinelayout_label(val)
140 .pipeline_label(val)
141 }
142}
143
144#[derive(Debug, thiserror::Error)]
145enum ParseShaderErrorVariant<'a> {
146 #[error("{}", .0)]
147 MultipleExports(#[from] ExportedMoreThanOnce),
148 #[error("{}", .0)]
149 NomError(NomError<'a>),
150}
151
152#[derive(Debug, thiserror::Error)]
153#[error("Parsing the shader {} encountered an error: {}", .name, .variant)]
154pub struct ParseShaderError<'a> {
155 name: String,
156 variant: ParseShaderErrorVariant<'a>,
157}
158
159#[derive(thiserror::Error)]
160pub enum ShaderError {
161 #[error("Expansion error: {}", .0)]
162 ExpansionError(#[from] ExpansionError),
163
164 #[error("Wgpu validation error occured in this shader:\n-------------\n{}\n\n-------------\nWgpu Error: {}\n-------------", .shader, .error_string)]
165 ValidationError {
166 shader: String,
167 error_string: String,
168 },
169}
170
171impl std::fmt::Debug for ShaderError {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 let s = self.to_string();
174 f.write_str(&s)
175 }
176}
177
178#[derive(Clone)]
181pub struct ProcessedShader<'def> {
182 pub source: String,
183 pub specs: ShaderSpecs<'def>,
184}
185
186fn format_shader(shader: &str) -> String {
187 let mut s = "\n".to_string();
188
189 let n_lines = shader.lines().count() as f32;
190
191 let pad = (n_lines.log10() + 1.0).floor() as usize;
192
193 for (i, line) in shader.lines().enumerate() {
194 write!(&mut s, "{: >width$} {line}\n", i + 1, width = pad).unwrap();
195 }
196 s
197}
198
199impl<'def> std::fmt::Debug for ProcessedShader<'def> {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 let s = format_shader(&self.source);
202 f.write_str(&s)
203 }
204}
205
206impl<'def> ProcessedShader<'def> {
207 pub fn get_source(&self) -> &str {
208 &self.source
209 }
210
211 pub fn build(self, device: &wgpu::Device) -> Result<Arc<NonBoundPipeline>, ShaderError> {
212 let Self { source, specs } = self;
213
214 let mut bind_group_layout =
215 infer_layout(&source, device, specs.bindgroup_layout_label.as_deref());
216
217 let bind_group_layouts = bind_group_layout.iter().collect::<Vec<_>>();
218
219 device.push_error_scope(wgpu::ErrorFilter::Validation);
220 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
221 label: specs.shader_label.as_deref(),
222 source: wgpu::ShaderSource::Wgsl((&source).into()),
223 });
224 if let Some(err) = device.pop_error_scope().block_on() {
225 return Err(ShaderError::ValidationError {
226 error_string: err.to_string(),
227 shader: format_shader(&source),
228 });
229 }
230
231 device.push_error_scope(wgpu::ErrorFilter::Validation);
232 let pipelinelayout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
233 label: specs.pipelinelayout_label.as_deref(),
234 bind_group_layouts: &bind_group_layouts,
235 push_constant_ranges: &[wgpu::PushConstantRange {
236 stages: wgpu::ShaderStages::COMPUTE,
237 range: 0..specs.push_constants.unwrap_or(64),
238 }],
239 });
240 if let Some(err) = device.pop_error_scope().block_on() {
241 return Err(ShaderError::ValidationError {
242 error_string: err.to_string(),
243 shader: format_shader(&source),
244 });
245 }
246
247 device.push_error_scope(wgpu::ErrorFilter::Validation);
248 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
249 label: specs.pipeline_label.as_deref(),
250 layout: Some(&pipelinelayout),
251 module: &shader,
252 entry_point: specs.entry_point.as_deref(),
253 compilation_options: Default::default(),
254 cache: None,
255 });
256 if let Some(err) = device.pop_error_scope().block_on() {
257 return Err(ShaderError::ValidationError {
258 error_string: err.to_string(),
259 shader: format_shader(&source),
260 });
261 }
262
263 Ok(Arc::new(NonBoundPipeline {
264 label: specs.shader_label,
265 compute_pipeline,
266 bind_group_layout: bind_group_layout.swap_remove(0),
269 dispatcher: specs.dispatcher,
270 }))
271 }
272}
273
274#[derive(Debug, Clone)]
275pub struct ParsedShader<'a>(pub Vec<Token<'a>>);
276
277impl<'a> ParsedShader<'a> {
278 fn into_owned(self) -> ParsedShader<'static> {
279 ParsedShader(vec_to_owned(self.0))
280 }
281
282 fn get_exports(
283 &self,
284 exports: &mut HashMap<Cow<'a, str>, Vec<Token<'a>>>,
285 ) -> Result<(), ExportedMoreThanOnce> {
286 get_exports(&self.0, exports)
287 }
288}
289
290#[derive(Debug, Clone)]
291pub struct ShaderProcessor<'a> {
292 pub shaders: HashMap<Cow<'a, str>, ParsedShader<'a>>,
293 pub exports: HashMap<Cow<'a, str>, Vec<Token<'a>>>,
294}
295
296pub fn validate_wgsl_file(
297 file: std::result::Result<DirEntry, std::io::Error>,
298 full_path: bool,
299) -> Option<(String, String)> {
300 let file = file.ok()?;
301 if !file.metadata().ok()?.is_file() {
302 return None;
303 }
304 let path = file.path();
305 let Some(ext) = path
306 .extension()
307 .and_then(|ext| ext.to_str())
308 .map(|ext| ext.to_lowercase())
309 else {
310 return None;
311 };
312 if ext != "wgsl" {
313 return None;
314 }
315
316 let name = if !full_path {
317 path.file_stem().unwrap().to_os_string()
318 } else {
319 path.clone().into_os_string()
320 };
321 let name = name.into_string().ok()?;
322 let path = path.into_os_string().into_string().ok()?;
323
324 Some((name, path))
325}
326
327pub fn parse_shader<'a>(input: &'a str) -> Result<ParsedShader<'a>, NomError<'a>> {
328 nom_supreme::final_parser::final_parser(parse_tokens)(input).map(ParsedShader)
329}
330
331impl<'a> ShaderProcessor<'a> {
332 pub fn load_dir_dyn(path: impl AsRef<Path>) -> Result<Self, std::io::Error> {
333 let read_dir = std::fs::read_dir(path)?;
334
335 let mut exports = HashMap::new();
336
337 let shaders = read_dir
338 .filter_map(|file| {
339 validate_wgsl_file(file, false).and_then(|(name, path)| {
340 let source = std::fs::read_to_string(path).ok()?;
341 Some((name, source))
342 })
343 })
344 .map(|(name, source)| {
345 let parsed = match parse_shader(&source) {
346 Ok(val) => val,
347 Err(e) => {
348 println!("Failed parsing of shader {name}");
349 panic!("{}", e);
350 }
351 }
352 .into_owned();
353 match parsed.get_exports(&mut exports) {
357 Ok(()) => (),
358 Err(err) => Err(ParseShaderError {
359 name: name.clone(),
360 variant: ParseShaderErrorVariant::MultipleExports(err),
361 })
362 .unwrap(),
363 };
364 (name.into(), parsed)
365 })
366 .collect();
367 Ok(Self { shaders, exports })
368 }
369
370 pub fn from_shader_hashmap(
371 shaders: &'a HashMap<Cow<'a, str>, String>,
372 ) -> Result<ShaderProcessor<'a>, ParseShaderError<'a>> {
373 let mut exports = HashMap::new();
374 let shaders = shaders
375 .iter()
376 .map(|(name, source)| {
377 let parsed = parse_shader(&source).map_err(|err| ParseShaderError {
378 name: name.to_string(),
379 variant: ParseShaderErrorVariant::NomError(err),
380 })?;
381 match parsed.get_exports(&mut exports) {
382 Ok(()) => (),
383 Err(err) => {
384 return Err(ParseShaderError {
385 name: name.to_string(),
386 variant: ParseShaderErrorVariant::MultipleExports(err),
387 })
388 }
389 };
390 Ok((name.clone(), parsed))
391 })
392 .collect::<Result<HashMap<Cow<str>, ParsedShader>, ParseShaderError>>()?;
393
394 Ok(Self { shaders, exports })
395 }
396
397 pub fn from_parsed_shader_hashmap(
398 shaders: HashMap<Cow<'a, str>, ParsedShader<'a>>,
399 ) -> Result<ShaderProcessor<'a>, ParseShaderError<'a>> {
400 let mut exports = HashMap::new();
401 for (name, parsed) in shaders.iter() {
402 match parsed.get_exports(&mut exports) {
403 Ok(()) => (),
404 Err(e) => {
405 return Err(ParseShaderError {
406 name: name.to_string(),
407 variant: e.into(),
408 })
409 }
410 }
411 }
412
413 Ok(Self { shaders, exports })
414 }
415
416 pub fn process_by_name<'wg, 'def>(
417 &self,
418 name: &str,
419 specs: ShaderSpecs<'def>,
420 ) -> Result<ProcessedShader<'def>, crate::parser::ExpansionError> {
421 let definitions = &specs.shader_defs;
422 let lookup = |s: Cow<str>| definitions.get(&s as &str).cloned();
423 let exports = |s| self.exports.get(&s).cloned();
424 let source = process(self.shaders[name].0.clone(), lookup, exports)?;
425 Ok(ProcessedShader { source, specs })
426 }
427}
428
429fn attribute<'a, Error: ParseError<&'a str>>(
430 attr_name: &'static str,
431) -> impl Fn(&'a str) -> IResult<&'a str, u32, Error> {
432 move |inp| {
433 let (inp, _) = terminated(take_until("@"), tag("@"))(inp)?;
434
435 let (inp, _) = ws(tag(attr_name))(inp)?;
436
437 let (inp, group_idx) = delimited(ws(tag("(")), ws(float), ws(tag(")")))(inp)?;
438
439 Ok((inp, group_idx as u32))
440 }
441}
442
443fn ws<'a, F: Parser<&'a str, O, E>, O, E: ParseError<&'a str>>(
457 f: F,
458) -> impl FnMut(&'a str) -> IResult<&'a str, O, E> {
459 preceded(multispace0, f)
460}
461
462fn buffer_style(inp: &str) -> IResult<&str, wgpu::BindingType> {
463 let (inp, inner) = delimited(ws(tag("<")), ws(cut(take_until1(">"))), tag(">"))(inp)?;
464
465 let (inner, mut buffer_binding_type) = ws(alt((
466 map(tag("storage"), |_| wgpu::BufferBindingType::Storage {
467 read_only: true,
468 }),
469 map(tag("uniform"), |_| wgpu::BufferBindingType::Uniform),
470 )))(inner)?;
471
472 if let wgpu::BufferBindingType::Storage { read_only } = &mut buffer_binding_type {
473 opt(preceded(
474 ws(tag(",")),
475 ws(map(tag("read_write"), |t| {
476 *read_only = false;
477 t
478 })),
479 ))(inner)?;
480 }
481
482 let out = wgpu::BindingType::Buffer {
483 ty: buffer_binding_type,
484 has_dynamic_offset: false,
485 min_binding_size: None,
486 };
487 Ok((inp, out))
488}
489
490fn texture_style(inp: &str) -> IResult<&str, wgpu::BindingType> {
491 let (inp, _) = terminated(take_until(":"), tag(":"))(inp)?;
492
493 let (inp, ty) = ws(alt((
494 map(tag("sampler"), |_| {
495 wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering)
496 }),
497 map(tag("sampler_comparison"), |_| {
498 wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Comparison)
499 }),
500 map(tag("texture_depth_2d"), |_| wgpu::BindingType::Texture {
501 sample_type: wgpu::TextureSampleType::Depth,
502 view_dimension: wgpu::TextureViewDimension::D2,
503 multisampled: false,
504 }),
505 map(tag("texture_depth_2d_array"), |_| {
506 wgpu::BindingType::Texture {
507 sample_type: wgpu::TextureSampleType::Depth,
508 view_dimension: wgpu::TextureViewDimension::D2Array,
509 multisampled: false,
510 }
511 }),
512 map(tag("texture_depth_cube"), |_| wgpu::BindingType::Texture {
513 sample_type: wgpu::TextureSampleType::Depth,
514 view_dimension: wgpu::TextureViewDimension::Cube,
515 multisampled: false,
516 }),
517 map(tag("texture_depth_cube_array"), |_| {
518 wgpu::BindingType::Texture {
519 sample_type: wgpu::TextureSampleType::Depth,
520 view_dimension: wgpu::TextureViewDimension::CubeArray,
521 multisampled: false,
522 }
523 }),
524 map(tag("texture_depth_multisampled_2d"), |_| {
525 wgpu::BindingType::Texture {
526 sample_type: wgpu::TextureSampleType::Depth,
527 view_dimension: wgpu::TextureViewDimension::D2,
528 multisampled: true,
529 }
530 }),
531 parse_texture_type,
532 )))(inp)?;
533
534 Ok((inp, ty))
535}
536
537fn parse_texture_type(inp: &str) -> IResult<&str, wgpu::BindingType> {
538 let (inp, _) = tag("texture_")(inp)?;
539
540 let (inp, (storage, multisampled, view_dimension)) = alt((
541 map(tag("1d"), |_| {
542 (false, false, wgpu::TextureViewDimension::D1)
543 }),
544 map(tag("storage_1d"), |_| {
545 (true, false, wgpu::TextureViewDimension::D1)
546 }),
547 map(tag("2d"), |_| {
548 (false, false, wgpu::TextureViewDimension::D2)
549 }),
550 map(tag("storage_2d"), |_| {
551 (true, false, wgpu::TextureViewDimension::D2)
552 }),
553 map(tag("storage_2d_array"), |_| {
554 (true, false, wgpu::TextureViewDimension::D2Array)
555 }),
556 map(tag("multisampled_2d"), |_| {
557 (false, true, wgpu::TextureViewDimension::D2)
558 }),
559 map(tag("2d_array"), |_| {
560 (false, false, wgpu::TextureViewDimension::D2Array)
561 }),
562 map(tag("3d"), |_| {
563 (false, false, wgpu::TextureViewDimension::D3)
564 }),
565 map(tag("storage_3d"), |_| {
566 (true, false, wgpu::TextureViewDimension::D3)
567 }),
568 map(tag("cube"), |_| {
569 (false, false, wgpu::TextureViewDimension::Cube)
570 }),
571 map(tag("cube_array"), |_| {
572 (false, false, wgpu::TextureViewDimension::CubeArray)
573 }),
574 ))(inp)?;
575
576 let (inp, inner) = delimited(ws(tag("<")), ws(take_until(">")), tag(">"))(inp)?;
577
578 let ty = if storage {
579 let (inner, format) = alt((
580 map(tag("rgba8unorm"), |_| wgpu::TextureFormat::Rgba8Unorm),
581 map(tag("rgba8snorm"), |_| wgpu::TextureFormat::Rgba8Snorm),
582 map(tag("rgba8uint"), |_| wgpu::TextureFormat::Rgba8Uint),
583 map(tag("rgba8sint"), |_| wgpu::TextureFormat::Rgba8Sint),
584 map(tag("rgba16uint"), |_| wgpu::TextureFormat::Rgba16Uint),
585 map(tag("rgba16sint"), |_| wgpu::TextureFormat::Rgba16Sint),
586 map(tag("rgba16float"), |_| wgpu::TextureFormat::Rgba16Float),
587 map(tag("r32uint"), |_| wgpu::TextureFormat::R32Uint),
588 map(tag("r32sint"), |_| wgpu::TextureFormat::R32Sint),
589 map(tag("r32float"), |_| wgpu::TextureFormat::R32Float),
590 map(tag("rg32uint"), |_| wgpu::TextureFormat::Rg32Uint),
591 map(tag("rg32sint"), |_| wgpu::TextureFormat::Rg32Sint),
592 map(tag("rg32float"), |_| wgpu::TextureFormat::Rg32Float),
593 map(tag("rgba32uint"), |_| wgpu::TextureFormat::Rgba32Uint),
594 map(tag("rgba32sint"), |_| wgpu::TextureFormat::Rgba32Sint),
595 map(tag("rgba32float"), |_| wgpu::TextureFormat::Rgba32Float),
596 map(tag("bgra8unorm"), |_| wgpu::TextureFormat::Bgra8Unorm),
597 ))(inner)?;
598
599 let (_, access) = preceded(
600 ws(tag(",")),
601 alt((
602 ws(map(tag("read"), |_| StorageTextureAccess::ReadOnly)),
603 ws(map(tag("write"), |_| StorageTextureAccess::WriteOnly)),
604 ws(map(tag("read_write"), |_| StorageTextureAccess::ReadWrite)),
605 )),
606 )(inner)?;
607
608 wgpu::BindingType::StorageTexture {
609 access,
610 format,
611 view_dimension,
612 }
613 } else {
614 let (_, sample_type) = alt((
615 map(tag("f32"), |_| wgpu::TextureSampleType::Float {
616 filterable: true,
617 }),
618 map(tag("i32"), |_| wgpu::TextureSampleType::Sint),
619 map(tag("u32"), |_| wgpu::TextureSampleType::Uint),
620 ))(inner)?;
621
622 wgpu::BindingType::Texture {
623 sample_type,
624 view_dimension,
625 multisampled,
626 }
627 };
628
629 Ok((inp, ty))
630}
631
632pub fn parse_layout_entry(inp: &str) -> IResult<&str, (u32, wgpu::BindGroupLayoutEntry)> {
633 let (inp, group_idx) = attribute("group")(inp)?;
634
635 let (inp, binding_idx) = attribute("binding")(inp)?;
636
637 let (inp, _) = ws(tag("var"))(inp)?;
638
639 let (inp, ty) = alt((buffer_style, texture_style))(inp)?;
640
641 let out = wgpu::BindGroupLayoutEntry {
642 binding: binding_idx,
643 visibility: ShaderStages::COMPUTE,
645 ty,
646 count: None,
647 };
648
649 Ok((inp, (group_idx, out)))
650}
651
652pub fn infer_layout(
654 mut inp: &str,
655 device: &wgpu::Device,
656 label: Option<&str>,
657) -> Vec<wgpu::BindGroupLayout> {
658 let mut map = BTreeMap::new();
659 while let Ok((new_inp, (group_idx, layout))) = parse_layout_entry(inp) {
660 map.entry(group_idx).or_insert(Vec::new()).push(layout);
661 inp = new_inp;
662 }
663
664 map.into_iter()
665 .map(|(_group_idx, entries)| {
666 let desc = BindGroupLayoutDescriptor {
667 label,
668 entries: &entries,
669 };
670 let layout = device.create_bind_group_layout(&desc);
671 layout
672 })
673 .collect()
674}
675
676#[cfg(test)]
677pub mod tests {
678 use crate::utils::default_device;
679
680 use super::*;
681 use pollster::FutureExt;
682
683 #[test]
684 fn yup() {
685 let (device, _queue) = default_device().block_on().unwrap();
686 let data = "
687
688 // @group(0) @binding(1)
689var<storage, read> buffer: array<f32>;
690
691@compute @workgroup_size(16, 16, 1)
692fn main(){
693 return;
694}
695 ";
696
697 let _yup = infer_layout(data, &device, None);
698
699 dbg!("Success!");
700 }
701}