wgsl-parser 0.5.0

A zero-copy recursive-descent parser for WebGPU shading language
Documentation
use std::sync::Arc;

use gramatika::ParseStreamer;
use snapshot::{begin_snapshots, snapshot, snapshot_test};

use super::token;

begin_snapshots!();

use crate::{
	decl::Decl,
	modules::ImportedSymbol,
	utils::{WithComments, WithTokens},
	ParseStream,
};

#[snapshot_test]
fn var_decl() {
	snapshot!(Decl(WithTokens) "let foo = 1.0;");
	snapshot!(Decl(WithTokens) "var foo: f32 = 1.0;");
	snapshot!(Decl(WithTokens) "var foo: f32;");
	snapshot!(Decl(WithTokens) "var<uniform> uniforms: Uniforms;");
	snapshot!(Decl(WithTokens) "var<uniform> uniforms: common::Uniforms;");
	snapshot!(Decl(WithTokens) "const PI: f32 = 3.14159;");
}

#[snapshot_test]
fn override_decl() {
	snapshot!(Decl(WithTokens) "@id(0) override has_point_light: bool = true;");
	snapshot!(Decl(WithTokens) "@id(1200) override specular_param: f32 = 2.3;");
	snapshot!(Decl(WithTokens) "@id(1300) override gain: f32;");
	snapshot!(Decl(WithTokens) "override width: f32 = 0.0;");
	snapshot!(Decl(WithTokens) "override depth: f32;");
	snapshot!(Decl(WithTokens) "override height = 2 * depth;");
}

#[snapshot_test]
fn type_alias_decl() {
	snapshot! {
		Decl(WithTokens)
		"alias ViewProjectionMatrix = mat4x4<f32>;"
	};
}

#[snapshot_test]
fn struct_decl() {
	// With semicolon
	snapshot!(Decl(WithTokens) r#"
		struct Data {
			a: i32,
			b: vec2<f32>,
		};
	"#);

	// Without semicolon
	snapshot!(Decl(WithTokens) r#"
		struct VertexOutput {
			@builtin(position) clip_position: vec4<f32>,
			@location(0) tex_coords: vec2<f32>,
			@location(1) world_normal: vec3<f32>,
			@location(2) world_position: vec3<f32>,
		}
	"#);

	snapshot!(Decl(WithTokens) r#"
		struct View {
			view_proj: mat4x4<f32>,
			inverse_view_proj: mat4x4<f32>,
			view: mat4x4<f32>,
			inverse_view: mat4x4<f32>,
			projection: mat4x4<f32>,
			inverse_projection: mat4x4<f32>,
			world_position: vec3<f32>,
			viewport: vec4<f32>,
			frustum: array<vec4<f32>, 6>,
		};
	"#);
}

#[snapshot_test]
fn function_decl() {
	snapshot!(Decl(WithTokens) r#"
		fn main() -> @location(0) vec4<f32> {
			return vec4<f32>(0.4, 0.4, 0.8, 1.0);
		}
	"#);

	snapshot!(Decl(WithTokens) r#"
		fn main(
			@builtin(position) coord_in: vec4<f32>,
		) -> @location(0) vec4<f32> {
			return vec4<f32>(coord_in.x, coord_in.y, 0.0, 1.0);
		}
	"#);

	snapshot!(Decl(WithTokens) r#"
		fn mul(a: f32, b: f32) -> f32 {
			return a * b;
		}
	"#);
}

#[snapshot_test]
fn attribute_insertion() {
	snapshot!(Decl(WithComments, WithTokens) r#"
		@vertex
		fn vert(model: VertexInput) -> VertexOutput {
			// ...
		}
	"#);

	snapshot!(Decl(WithTokens) r#"
		@group(0) @binding(0)
		var<uniform> uniforms: common::Uniforms;
	"#);

	snapshot!(Decl(WithTokens) r#"
		struct<export> Uniforms {
			view_pos: vec4<f32>,
			view_proj: mat4x4<f32>,
		};
	"#);
}

#[snapshot_test]
fn import_path_decl() {
	// Good
	snapshot!(Decl(WithTokens) "#define_import_path foo::bar::baz");

	// Error (`as` binding)
	snapshot!(Decl(WithTokens) "#define_import_path foo::bar as baz");

	// Error (block)
	snapshot!(Decl(WithTokens) r"#define_import_path foo::bar::{baz}");
}

#[test]
fn import_path_decl_names() {
	let decl = match ParseStream::from("#define_import_path foo::bar::baz").parse::<Decl>() {
		Ok(Decl::ImportPath(decl)) => decl,
		Ok(other) => panic!("Expected Decl::ImportPath(...), received: {other:#?}"),
		Err(other) => panic!("{other}"),
	};

	let name = decl.name();
	let expected_name = token!(Module "baz" (1:31..1:34));
	assert_eq!(
		name,
		&expected_name,
		"\n{}",
		super::diff(&format!("{name:#?}"), &format!("{expected_name:#?}"))
	);

	let qualified_name = decl.qualified_name();
	let expected_qualified_name = token!(Path "foo::bar::baz" (1:21..1:34));
	assert_eq!(
		&qualified_name,
		&expected_qualified_name,
		"\n{}",
		super::diff(
			&format!("{qualified_name:#?}"),
			&format!("{expected_qualified_name:#?}")
		),
	);

	let qualified_path = decl.qualified_path();
	let expected_path = Arc::new([
		token!(Module "foo" (1:21..1:24)),
		token!(Module "bar" (1:26..1:29)),
		token!(Module "baz" (1:31..1:34)),
	]);

	assert_eq!(&*qualified_path, &*expected_path);
}

const COMPLEX_IMPORT: &str = r#"
#import lorem::{
	ipsum::{Dolor, SitAmet},
	foo::bar::baz as foo_bar_baz,
	foobar,
	foobar::bazbar,
}
"#;

#[snapshot_test]
fn import_decl() {
	snapshot!(Decl(WithTokens) "#import foo::bar::baz");
	snapshot!(Decl(WithTokens) "#import foo::bar::baz as foo_bar_baz");
	snapshot!(Decl(WithTokens) r"#import lorem::ipsum::{Dolor, SitAmet}");

	snapshot!(Decl(WithTokens) r#"
		#import lorem::{
			ipsum::{Dolor, SitAmet},
			foo::bar::baz as foo_bar_baz,
			foobar,
			foobar::bazbar,
		}
	"#);
}

#[test]
fn imported_symbols() {
	let decl = match ParseStream::from(COMPLEX_IMPORT).parse::<Decl>() {
		Ok(Decl::Import(decl)) => decl,
		Ok(other) => panic!("Expected Decl::Import(...), received: {other:#?}"),
		Err(other) => panic!("{other}"),
	};

	let expected = vec![
		ImportedSymbol {
			qualified_path: Arc::new([
				token!(Module "lorem" (2:9..2:14)),
				token!(Module "ipsum" (3:2..3:7)),
				token!(Module "Dolor" (3:10..3:15)),
			]),
			local_binding: token!(Module "Dolor" (3:10..3:15)),
		},
		ImportedSymbol {
			qualified_path: Arc::new([
				token!(Module "lorem" (2:9..2:14)),
				token!(Module "ipsum" (3:2..3:7)),
				token!(Module "SitAmet" (3:17..3:24)),
			]),
			local_binding: token!(Module "SitAmet" (3:17..3:24)),
		},
		ImportedSymbol {
			qualified_path: Arc::new([
				token!(Module "lorem" (2:9..2:14)),
				token!(Module "foo" (4:2..4:5)),
				token!(Module "bar" (4:7..4:10)),
				token!(Module "baz" (4:12..4:15)),
			]),
			local_binding: token!(Module "foo_bar_baz" (4:19..4:30)),
		},
		ImportedSymbol {
			qualified_path: Arc::new([
				token!(Module "lorem" (2:9..2:14)),
				token!(Module "foobar" (5:2..5:8)),
			]),
			local_binding: token!(Module "foobar" (5:2..5:8)),
		},
		ImportedSymbol {
			qualified_path: Arc::new([
				token!(Module "lorem" (2:9..2:14)),
				token!(Module "foobar" (6:2..6:8)),
				token!(Module "bazbar" (6:10..6:16)),
			]),
			local_binding: token!(Module "bazbar" (6:10..6:16)),
		},
	];

	let imported_symbols = decl.build_imported_symbols();
	assert_eq!(
		&imported_symbols,
		&expected,
		"\n{}",
		super::diff(&format!("{imported_symbols:#?}"), &format!("{expected:#?}"),),
	);
}