1#[derive(Debug, Default)]
7#[allow(dead_code)]
8pub struct WGSLTypeChecker;
9impl WGSLTypeChecker {
10 #[allow(dead_code)]
12 pub fn new() -> Self {
13 WGSLTypeChecker
14 }
15 #[allow(dead_code)]
17 pub fn check_binop(
18 &self,
19 _op: &str,
20 lhs: &WGSLType,
21 rhs: &WGSLType,
22 ) -> Result<WGSLType, WGSLTypeError> {
23 if lhs == rhs {
24 Ok(lhs.clone())
25 } else {
26 Err(WGSLTypeError::TypeMismatch {
27 expected: lhs.clone(),
28 found: rhs.clone(),
29 })
30 }
31 }
32 #[allow(dead_code)]
34 pub fn check_atomic(&self, ty: &WGSLType) -> Result<(), WGSLTypeError> {
35 match ty {
36 WGSLType::I32 | WGSLType::U32 => Ok(()),
37 _ => Err(WGSLTypeError::InvalidOperandType {
38 op: "atomic".to_string(),
39 ty: ty.clone(),
40 }),
41 }
42 }
43 #[allow(dead_code)]
45 pub fn check_host_shareable(&self, ty: &WGSLType) -> Result<(), WGSLTypeError> {
46 match ty {
47 WGSLType::Bool
48 | WGSLType::Sampler
49 | WGSLType::SamplerComparison
50 | WGSLType::TextureDepth2D
51 | WGSLType::Texture2D
52 | WGSLType::Texture2DArray
53 | WGSLType::TextureCube
54 | WGSLType::Texture3D
55 | WGSLType::TextureMultisampled2D => {
56 Err(WGSLTypeError::NonShareableBinding { ty: ty.clone() })
57 }
58 _ => Ok(()),
59 }
60 }
61 #[allow(dead_code)]
63 pub fn element_type(ty: &WGSLType) -> &WGSLType {
64 match ty {
65 WGSLType::Vec2f
66 | WGSLType::Vec3f
67 | WGSLType::Vec4f
68 | WGSLType::Mat2x2f
69 | WGSLType::Mat3x3f
70 | WGSLType::Mat4x4f
71 | WGSLType::Mat2x4f
72 | WGSLType::Mat4x2f => &WGSLType::F32,
73 WGSLType::Vec2i | WGSLType::Vec3i | WGSLType::Vec4i => &WGSLType::I32,
74 WGSLType::Vec2u | WGSLType::Vec3u | WGSLType::Vec4u => &WGSLType::U32,
75 WGSLType::Vec2b => &WGSLType::Bool,
76 _ => ty,
77 }
78 }
79 #[allow(dead_code)]
81 pub fn vector_width(ty: &WGSLType) -> u32 {
82 match ty {
83 WGSLType::Vec2f | WGSLType::Vec2i | WGSLType::Vec2u | WGSLType::Vec2b => 2,
84 WGSLType::Vec3f | WGSLType::Vec3i | WGSLType::Vec3u => 3,
85 WGSLType::Vec4f | WGSLType::Vec4i | WGSLType::Vec4u => 4,
86 _ => 1,
87 }
88 }
89}
90#[derive(Debug, Clone)]
92pub struct WGSLConstant {
93 pub name: String,
95 pub ty: Option<WGSLType>,
97 pub value: String,
99}
100impl WGSLConstant {
101 pub fn typed(name: impl Into<String>, ty: WGSLType, value: impl Into<String>) -> Self {
103 WGSLConstant {
104 name: name.into(),
105 ty: Some(ty),
106 value: value.into(),
107 }
108 }
109 pub fn inferred(name: impl Into<String>, value: impl Into<String>) -> Self {
111 WGSLConstant {
112 name: name.into(),
113 ty: None,
114 value: value.into(),
115 }
116 }
117 pub fn emit(&self) -> String {
119 match &self.ty {
120 Some(ty) => format!("const {}: {} = {};", self.name, ty, self.value),
121 None => format!("const {} = {};", self.name, self.value),
122 }
123 }
124}
125#[derive(Debug, Clone)]
127pub struct WGSLStruct {
128 pub name: String,
130 pub fields: Vec<WGSLStructField>,
132}
133impl WGSLStruct {
134 pub fn new(name: impl Into<String>) -> Self {
136 WGSLStruct {
137 name: name.into(),
138 fields: Vec::new(),
139 }
140 }
141 pub fn add_field(&mut self, field: WGSLStructField) {
143 self.fields.push(field);
144 }
145 pub fn emit(&self) -> String {
147 let mut out = format!("struct {} {{\n", self.name);
148 for f in &self.fields {
149 out.push_str(&f.emit());
150 out.push('\n');
151 }
152 out.push('}');
153 out
154 }
155}
156#[derive(Debug, Clone)]
158#[allow(dead_code)]
159pub enum WGSLStatement {
160 Let {
162 name: String,
163 ty: Option<WGSLType>,
164 init: String,
165 },
166 Var {
168 name: String,
169 ty: Option<WGSLType>,
170 init: Option<String>,
171 },
172 Assign { lhs: String, rhs: String },
174 CompoundAssign {
176 lhs: String,
177 op: String,
178 rhs: String,
179 },
180 If {
182 cond: String,
183 then_stmts: Vec<WGSLStatement>,
184 else_stmts: Vec<WGSLStatement>,
185 },
186 For {
188 init: Option<Box<WGSLStatement>>,
189 cond: Option<String>,
190 update: Option<Box<WGSLStatement>>,
191 body: Vec<WGSLStatement>,
192 },
193 While {
195 cond: String,
196 body: Vec<WGSLStatement>,
197 },
198 Loop {
200 body: Vec<WGSLStatement>,
201 continuing: Vec<WGSLStatement>,
202 },
203 Switch {
205 expr: String,
206 cases: Vec<(String, Vec<WGSLStatement>)>,
207 default: Vec<WGSLStatement>,
208 },
209 Return(Option<String>),
211 Break,
213 Continue,
215 Discard,
217 Raw(String),
219 Call { func: String, args: Vec<String> },
221}
222impl WGSLStatement {
223 #[allow(dead_code)]
225 pub fn emit(&self, indent: usize) -> String {
226 let pad = " ".repeat(indent);
227 match self {
228 WGSLStatement::Let { name, ty, init } => {
229 let ty_str = ty.as_ref().map(|t| format!(": {}", t)).unwrap_or_default();
230 format!("{}let {}{} = {};", pad, name, ty_str, init)
231 }
232 WGSLStatement::Var { name, ty, init } => {
233 let ty_str = ty.as_ref().map(|t| format!(": {}", t)).unwrap_or_default();
234 let init_str = init
235 .as_ref()
236 .map(|i| format!(" = {}", i))
237 .unwrap_or_default();
238 format!("{}var {}{}{};", pad, name, ty_str, init_str)
239 }
240 WGSLStatement::Assign { lhs, rhs } => format!("{}{} = {};", pad, lhs, rhs),
241 WGSLStatement::CompoundAssign { lhs, op, rhs } => {
242 format!("{}{} {}= {};", pad, lhs, op, rhs)
243 }
244 WGSLStatement::If {
245 cond,
246 then_stmts,
247 else_stmts,
248 } => {
249 let mut out = format!("{}if ({}) {{\n", pad, cond);
250 for s in then_stmts {
251 out.push_str(&s.emit(indent + 1));
252 out.push('\n');
253 }
254 if else_stmts.is_empty() {
255 out.push_str(&format!("{}}}", pad));
256 } else {
257 out.push_str(&format!("{}}} else {{\n", pad));
258 for s in else_stmts {
259 out.push_str(&s.emit(indent + 1));
260 out.push('\n');
261 }
262 out.push_str(&format!("{}}}", pad));
263 }
264 out
265 }
266 WGSLStatement::For {
267 init,
268 cond,
269 update,
270 body,
271 } => {
272 let init_str = init
273 .as_ref()
274 .map(|s| s.emit(0).trim_end_matches(';').to_string())
275 .unwrap_or_default();
276 let cond_str = cond.as_deref().unwrap_or("");
277 let update_str = update
278 .as_ref()
279 .map(|s| s.emit(0).trim_end_matches(';').to_string())
280 .unwrap_or_default();
281 let mut out = format!(
282 "{}for ({}; {}; {}) {{\n",
283 pad, init_str, cond_str, update_str
284 );
285 for s in body {
286 out.push_str(&s.emit(indent + 1));
287 out.push('\n');
288 }
289 out.push_str(&format!("{}}}", pad));
290 out
291 }
292 WGSLStatement::While { cond, body } => {
293 let mut out = format!("{}while ({}) {{\n", pad, cond);
294 for s in body {
295 out.push_str(&s.emit(indent + 1));
296 out.push('\n');
297 }
298 out.push_str(&format!("{}}}", pad));
299 out
300 }
301 WGSLStatement::Loop { body, continuing } => {
302 let mut out = format!("{}loop {{\n", pad);
303 for s in body {
304 out.push_str(&s.emit(indent + 1));
305 out.push('\n');
306 }
307 if !continuing.is_empty() {
308 out.push_str(&format!("{} continuing {{\n", pad));
309 for s in continuing {
310 out.push_str(&s.emit(indent + 2));
311 out.push('\n');
312 }
313 out.push_str(&format!("{} }}\n", pad));
314 }
315 out.push_str(&format!("{}}}", pad));
316 out
317 }
318 WGSLStatement::Switch {
319 expr,
320 cases,
321 default,
322 } => {
323 let mut out = format!("{}switch ({}) {{\n", pad, expr);
324 for (val, stmts) in cases {
325 out.push_str(&format!("{} case {}: {{\n", pad, val));
326 for s in stmts {
327 out.push_str(&s.emit(indent + 2));
328 out.push('\n');
329 }
330 out.push_str(&format!("{} }}\n", pad));
331 }
332 out.push_str(&format!("{} default: {{\n", pad));
333 for s in default {
334 out.push_str(&s.emit(indent + 2));
335 out.push('\n');
336 }
337 out.push_str(&format!("{} }}\n", pad));
338 out.push_str(&format!("{}}}", pad));
339 out
340 }
341 WGSLStatement::Return(Some(expr)) => format!("{}return {};", pad, expr),
342 WGSLStatement::Return(None) => format!("{}return;", pad),
343 WGSLStatement::Break => format!("{}break;", pad),
344 WGSLStatement::Continue => format!("{}continue;", pad),
345 WGSLStatement::Discard => format!("{}discard;", pad),
346 WGSLStatement::Raw(s) => format!("{}{};", pad, s),
347 WGSLStatement::Call { func, args } => {
348 format!("{}{}({});", pad, func, args.join(", "))
349 }
350 }
351 }
352}
353#[derive(Debug, Clone)]
355#[allow(dead_code)]
356pub struct WGSLComputeKernelParams {
357 pub name: String,
359 pub wg_x: u32,
361 pub wg_y: u32,
363 pub wg_z: u32,
365 pub use_local_id: bool,
367 pub use_workgroup_id: bool,
369 pub use_num_workgroups: bool,
371}
372#[derive(Debug, Clone)]
374pub struct WGSLParam {
375 pub name: String,
377 pub ty: WGSLType,
379 pub builtin: Option<String>,
381 pub location: Option<u32>,
383}
384impl WGSLParam {
385 pub fn new(name: impl Into<String>, ty: WGSLType) -> Self {
387 WGSLParam {
388 name: name.into(),
389 ty,
390 builtin: None,
391 location: None,
392 }
393 }
394 pub fn with_builtin(name: impl Into<String>, ty: WGSLType, builtin: impl Into<String>) -> Self {
396 WGSLParam {
397 name: name.into(),
398 ty,
399 builtin: Some(builtin.into()),
400 location: None,
401 }
402 }
403 pub fn with_location(name: impl Into<String>, ty: WGSLType, loc: u32) -> Self {
405 WGSLParam {
406 name: name.into(),
407 ty,
408 builtin: None,
409 location: Some(loc),
410 }
411 }
412 pub fn emit(&self) -> String {
414 let mut attrs = String::new();
415 if let Some(b) = &self.builtin {
416 attrs.push_str(&format!("@builtin({}) ", b));
417 }
418 if let Some(loc) = self.location {
419 attrs.push_str(&format!("@location({}) ", loc));
420 }
421 format!("{}{}: {}", attrs, self.name, self.ty)
422 }
423}
424#[derive(Debug, Clone, PartialEq, Eq)]
426pub enum WGSLEntryPoint {
427 None,
429 Vertex,
431 Fragment,
433 Compute { x: u32, y: u32, z: u32 },
435}
436impl WGSLEntryPoint {
437 pub fn attribute(&self) -> String {
439 match self {
440 WGSLEntryPoint::None => String::new(),
441 WGSLEntryPoint::Vertex => "@vertex\n".into(),
442 WGSLEntryPoint::Fragment => "@fragment\n".into(),
443 WGSLEntryPoint::Compute { x, y, z } => {
444 format!("@compute @workgroup_size({}, {}, {})\n", x, y, z)
445 }
446 }
447 }
448}
449#[derive(Debug, Clone, PartialEq, Eq, Hash)]
451pub enum WGSLType {
452 Bool,
454 I32,
456 U32,
458 F32,
460 F16,
462 Vec2f,
464 Vec3f,
466 Vec4f,
468 Vec2i,
470 Vec3i,
472 Vec4i,
474 Vec2u,
476 Vec3u,
478 Vec4u,
480 Vec2b,
482 Mat2x2f,
484 Mat3x3f,
486 Mat4x4f,
488 Mat2x4f,
490 Mat4x2f,
492 Texture2D,
494 Texture2DArray,
496 TextureCube,
498 Texture3D,
500 TextureDepth2D,
502 TextureStorage2D { format: String, access: String },
504 TextureMultisampled2D,
506 Sampler,
508 SamplerComparison,
510 Struct(String),
512 Array(Box<WGSLType>, u32),
514 RuntimeArray(Box<WGSLType>),
516 Ptr {
518 address_space: WGSLAddressSpace,
519 inner: Box<WGSLType>,
520 },
521 AtomicU32,
523 AtomicI32,
525}
526impl WGSLType {
527 pub fn keyword(&self) -> String {
529 match self {
530 WGSLType::Bool => "bool".into(),
531 WGSLType::I32 => "i32".into(),
532 WGSLType::U32 => "u32".into(),
533 WGSLType::F32 => "f32".into(),
534 WGSLType::F16 => "f16".into(),
535 WGSLType::Vec2f => "vec2<f32>".into(),
536 WGSLType::Vec3f => "vec3<f32>".into(),
537 WGSLType::Vec4f => "vec4<f32>".into(),
538 WGSLType::Vec2i => "vec2<i32>".into(),
539 WGSLType::Vec3i => "vec3<i32>".into(),
540 WGSLType::Vec4i => "vec4<i32>".into(),
541 WGSLType::Vec2u => "vec2<u32>".into(),
542 WGSLType::Vec3u => "vec3<u32>".into(),
543 WGSLType::Vec4u => "vec4<u32>".into(),
544 WGSLType::Vec2b => "vec2<bool>".into(),
545 WGSLType::Mat2x2f => "mat2x2<f32>".into(),
546 WGSLType::Mat3x3f => "mat3x3<f32>".into(),
547 WGSLType::Mat4x4f => "mat4x4<f32>".into(),
548 WGSLType::Mat2x4f => "mat2x4<f32>".into(),
549 WGSLType::Mat4x2f => "mat4x2<f32>".into(),
550 WGSLType::Texture2D => "texture_2d<f32>".into(),
551 WGSLType::Texture2DArray => "texture_2d_array<f32>".into(),
552 WGSLType::TextureCube => "texture_cube<f32>".into(),
553 WGSLType::Texture3D => "texture_3d<f32>".into(),
554 WGSLType::TextureDepth2D => "texture_depth_2d".into(),
555 WGSLType::TextureStorage2D { format, access } => {
556 format!("texture_storage_2d<{}, {}>", format, access)
557 }
558 WGSLType::TextureMultisampled2D => "texture_multisampled_2d<f32>".into(),
559 WGSLType::Sampler => "sampler".into(),
560 WGSLType::SamplerComparison => "sampler_comparison".into(),
561 WGSLType::Struct(name) => name.clone(),
562 WGSLType::Array(elem, n) => format!("array<{}, {}>", elem.keyword(), n),
563 WGSLType::RuntimeArray(elem) => format!("array<{}>", elem.keyword()),
564 WGSLType::Ptr {
565 address_space,
566 inner,
567 } => {
568 format!("ptr<{}, {}>", address_space, inner.keyword())
569 }
570 WGSLType::AtomicU32 => "atomic<u32>".into(),
571 WGSLType::AtomicI32 => "atomic<i32>".into(),
572 }
573 }
574 pub fn is_opaque(&self) -> bool {
576 matches!(
577 self,
578 WGSLType::Texture2D
579 | WGSLType::Texture2DArray
580 | WGSLType::TextureCube
581 | WGSLType::Texture3D
582 | WGSLType::TextureDepth2D
583 | WGSLType::TextureStorage2D { .. }
584 | WGSLType::TextureMultisampled2D
585 | WGSLType::Sampler
586 | WGSLType::SamplerComparison
587 )
588 }
589 pub fn is_float_like(&self) -> bool {
591 matches!(
592 self,
593 WGSLType::F32 | WGSLType::F16 | WGSLType::Vec2f | WGSLType::Vec3f | WGSLType::Vec4f
594 )
595 }
596}
597#[derive(Debug, Clone)]
599pub struct WGSLBinding {
600 pub group: u32,
602 pub binding: u32,
604 pub name: String,
606 pub ty: WGSLType,
608 pub access: Option<WGSLAccess>,
610}
611impl WGSLBinding {
612 pub fn new(group: u32, binding: u32, name: impl Into<String>, ty: WGSLType) -> Self {
614 WGSLBinding {
615 group,
616 binding,
617 name: name.into(),
618 ty,
619 access: None,
620 }
621 }
622 pub fn storage(
624 group: u32,
625 binding: u32,
626 name: impl Into<String>,
627 ty: WGSLType,
628 access: WGSLAccess,
629 ) -> Self {
630 WGSLBinding {
631 group,
632 binding,
633 name: name.into(),
634 ty,
635 access: Some(access),
636 }
637 }
638 pub fn emit(&self) -> String {
640 let access_str = match &self.access {
641 Some(a) => format!("<{}>", a),
642 None => String::new(),
643 };
644 format!(
645 "@group({}) @binding({}) var{} {}: {};",
646 self.group, self.binding, access_str, self.name, self.ty
647 )
648 }
649}
650#[derive(Debug, Default)]
652#[allow(dead_code)]
653pub struct WGSLComputeKernel {
654 pub params: WGSLComputeKernelParams,
656 pub bindings: Vec<WGSLBinding>,
658 pub shared_vars: Vec<WGSLGlobal>,
660 pub body: Vec<WGSLStatement>,
662}
663impl WGSLComputeKernel {
664 #[allow(dead_code)]
666 pub fn new(name: impl Into<String>, wg_x: u32, wg_y: u32, wg_z: u32) -> Self {
667 WGSLComputeKernel {
668 params: WGSLComputeKernelParams {
669 name: name.into(),
670 wg_x,
671 wg_y,
672 wg_z,
673 ..Default::default()
674 },
675 ..Default::default()
676 }
677 }
678 #[allow(dead_code)]
680 pub fn push(&mut self, stmt: WGSLStatement) {
681 self.body.push(stmt);
682 }
683 #[allow(dead_code)]
685 pub fn add_shared_array(&mut self, name: impl Into<String>, elem_ty: WGSLType, size: u32) {
686 self.shared_vars.push(WGSLGlobal::workgroup(
687 name,
688 WGSLType::Array(Box::new(elem_ty), size),
689 ));
690 }
691 #[allow(dead_code)]
693 pub fn emit_function(&self) -> WGSLFunction {
694 let p = &self.params;
695 let mut func = WGSLFunction::compute(&p.name, p.wg_x, p.wg_y, p.wg_z);
696 func.add_param(WGSLParam {
697 name: "global_id".to_string(),
698 ty: WGSLType::Vec3u,
699 builtin: Some("global_invocation_id".to_string()),
700 location: None,
701 });
702 if p.use_local_id {
703 func.add_param(WGSLParam {
704 name: "local_id".to_string(),
705 ty: WGSLType::Vec3u,
706 builtin: Some("local_invocation_id".to_string()),
707 location: None,
708 });
709 }
710 if p.use_workgroup_id {
711 func.add_param(WGSLParam {
712 name: "wg_id".to_string(),
713 ty: WGSLType::Vec3u,
714 builtin: Some("workgroup_id".to_string()),
715 location: None,
716 });
717 }
718 if p.use_num_workgroups {
719 func.add_param(WGSLParam {
720 name: "num_wgs".to_string(),
721 ty: WGSLType::Vec3u,
722 builtin: Some("num_workgroups".to_string()),
723 location: None,
724 });
725 }
726 for stmt in &self.body {
727 func.add_statement(stmt.emit(0));
728 }
729 func
730 }
731 #[allow(dead_code)]
733 pub fn emit_shader(&self) -> String {
734 let mut shader = WGSLShader::new();
735 for b in &self.bindings {
736 shader.add_binding(b.clone());
737 }
738 for g in &self.shared_vars {
739 shader.add_global(g.clone());
740 }
741 shader.add_function(self.emit_function());
742 WGSLBackend::new().emit_shader(&shader)
743 }
744}
745#[allow(dead_code)]
747pub struct WGSLSnippets;
748impl WGSLSnippets {
749 #[allow(dead_code)]
751 pub fn linear_map(val: &str, in_lo: f32, in_hi: f32, out_lo: f32, out_hi: f32) -> String {
752 format!(
753 "mix({out_lo}, {out_hi}, ({val} - {in_lo}) / ({in_hi} - {in_lo}))",
754 val = val,
755 in_lo = in_lo,
756 in_hi = in_hi,
757 out_lo = out_lo,
758 out_hi = out_hi,
759 )
760 }
761 #[allow(dead_code)]
763 pub fn rotate2d(v: &str, angle: &str) -> String {
764 format!(
765 "vec2f(cos({a}) * {v}.x - sin({a}) * {v}.y, sin({a}) * {v}.x + cos({a}) * {v}.y)",
766 v = v,
767 a = angle
768 )
769 }
770 #[allow(dead_code)]
772 pub fn linear_to_srgb(c: &str) -> String {
773 format!(
774 "select({c} * 12.92, pow({c}, vec4f(1.0 / 2.4)) * 1.055 - vec4f(0.055), {c} <= vec4f(0.0031308))",
775 c = c
776 )
777 }
778 #[allow(dead_code)]
780 pub fn srgb_to_linear(c: &str) -> String {
781 format!(
782 "select({c} / 12.92, pow(({c} + vec4f(0.055)) / vec4f(1.055), vec4f(2.4)), {c} <= vec4f(0.04045))",
783 c = c
784 )
785 }
786 #[allow(dead_code)]
788 pub fn blinn_phong(normal: &str, halfway: &str, shininess: &str) -> String {
789 format!(
790 "pow(max(dot({n}, {h}), 0.0), {s})",
791 n = normal,
792 h = halfway,
793 s = shininess
794 )
795 }
796 #[allow(dead_code)]
798 pub fn wang_hash(seed: &str) -> String {
799 format!(
800 "(({s} ^ 61u) ^ ({s} >> 16u)) * 9u ^ (({s} ^ 61u) ^ ({s} >> 16u)) >> 4u ^ (({s} ^ 61u) ^ ({s} >> 16u)) * 0x27d4eb2du",
801 s = seed
802 )
803 }
804 #[allow(dead_code)]
806 pub fn pcg_next(state: &str) -> String {
807 format!(
808 "let _pcg_tmp = {s} * 747796405u + 2891336453u; let _pcg_word = ((_pcg_tmp >> ((_pcg_tmp >> 28u) + 4u)) ^ _pcg_tmp) * 277803737u; (_pcg_word >> 22u) ^ _pcg_word",
809 s = state
810 )
811 }
812 #[allow(dead_code)]
814 pub fn rgb_to_hsv_fn() -> String {
815 r"fn rgb_to_hsv(c: vec3f) -> vec3f {
816 let k = vec4f(0.0, -1.0 / 3.0, 2.0 / 3.0, -1.0);
817 let p = mix(vec4f(c.bg, k.wz), vec4f(c.gb, k.xy), step(c.b, c.g));
818 let q = mix(vec4f(p.xyw, c.r), vec4f(c.r, p.yzx), step(p.x, c.r));
819 let d = q.x - min(q.w, q.y);
820 let e = 1.0e-10;
821 return vec3f(abs(q.z + (q.w - q.y) / (6.0 * d + e)), d / (q.x + e), q.x);
822}"
823 .to_string()
824 }
825 #[allow(dead_code)]
827 pub fn hsv_to_rgb_fn() -> String {
828 r"fn hsv_to_rgb(c: vec3f) -> vec3f {
829 let k = vec4f(1.0, 2.0 / 3.0, 1.0 / 3.0, 3.0);
830 let p = abs(fract(c.xxx + k.xyz) * 6.0 - k.www);
831 return c.z * mix(k.xxx, clamp(p - k.xxx, vec3f(0.0), vec3f(1.0)), c.y);
832}"
833 .to_string()
834 }
835 #[allow(dead_code)]
837 pub fn gaussian_weight(i: i32, sigma: f32) -> f32 {
838 let x = i as f32;
839 let denom = (2.0 * std::f32::consts::PI * sigma * sigma).sqrt();
840 (-(x * x) / (2.0 * sigma * sigma)).exp() / denom
841 }
842 #[allow(dead_code)]
844 pub fn gaussian_blur_fn(radius: i32, sigma: f32, horizontal: bool) -> String {
845 let dir = if horizontal {
846 "vec2f(1.0, 0.0)"
847 } else {
848 "vec2f(0.0, 1.0)"
849 };
850 let weights: Vec<f32> = (-radius..=radius)
851 .map(|i| Self::gaussian_weight(i, sigma))
852 .collect();
853 let total: f32 = weights.iter().sum();
854 let norm_weights: Vec<f32> = weights.iter().map(|w| w / total).collect();
855 let mut body = format!(
856 "fn gaussian_blur_{}(tex: texture_2d<f32>, samp: sampler, uv: vec2f, texel_size: vec2f) -> vec4f {{\n",
857 if horizontal { "h" } else { "v" }
858 );
859 body.push_str(" var result = vec4f(0.0);\n");
860 for (idx, i) in (-radius..=radius).enumerate() {
861 body.push_str(&format!(
862 " result += textureSample(tex, samp, uv + {} * {} * texel_size) * {}f;\n",
863 dir, i, norm_weights[idx]
864 ));
865 }
866 body.push_str(" return result;\n}");
867 body
868 }
869}
870#[derive(Debug, Clone, PartialEq, Eq)]
872#[allow(dead_code)]
873pub enum WGSLStageVisibility {
874 Vertex,
875 Fragment,
876 Compute,
877 VertexFragment,
878 All,
879}
880#[derive(Debug, Clone)]
882#[allow(dead_code)]
883pub struct WGSLBindingEntry {
884 pub binding: u32,
886 pub resource_type: WGSLResourceType,
888 pub visibility: WGSLStageVisibility,
890}
891#[derive(Debug, Clone)]
893pub struct WGSLStructField {
894 pub name: String,
896 pub ty: WGSLType,
898 pub builtin: Option<String>,
900 pub location: Option<u32>,
902 pub interpolate: Option<String>,
904}
905impl WGSLStructField {
906 pub fn new(name: impl Into<String>, ty: WGSLType) -> Self {
908 WGSLStructField {
909 name: name.into(),
910 ty,
911 builtin: None,
912 location: None,
913 interpolate: None,
914 }
915 }
916 pub fn builtin(name: impl Into<String>, ty: WGSLType, builtin: impl Into<String>) -> Self {
918 WGSLStructField {
919 name: name.into(),
920 ty,
921 builtin: Some(builtin.into()),
922 location: None,
923 interpolate: None,
924 }
925 }
926 pub fn location(name: impl Into<String>, ty: WGSLType, loc: u32) -> Self {
928 WGSLStructField {
929 name: name.into(),
930 ty,
931 builtin: None,
932 location: Some(loc),
933 interpolate: None,
934 }
935 }
936 pub fn emit(&self) -> String {
938 let mut attrs = String::new();
939 if let Some(b) = &self.builtin {
940 attrs.push_str(&format!("@builtin({}) ", b));
941 }
942 if let Some(loc) = self.location {
943 attrs.push_str(&format!("@location({}) ", loc));
944 }
945 if let Some(interp) = &self.interpolate {
946 attrs.push_str(&format!("@interpolate({}) ", interp));
947 }
948 format!(" {}{}: {},", attrs, self.name, self.ty)
949 }
950}
951#[derive(Debug, Clone, Default)]
953#[allow(dead_code)]
954pub struct WGSLCodeMetrics {
955 pub num_functions: usize,
957 pub num_entry_points: usize,
959 pub num_structs: usize,
961 pub num_bindings: usize,
963 pub num_globals: usize,
965 pub num_constants: usize,
967 pub num_overrides: usize,
969 pub total_statements: usize,
971 pub num_enables: usize,
973}
974impl WGSLCodeMetrics {
975 #[allow(dead_code)]
977 pub fn compute(shader: &WGSLShader) -> Self {
978 let num_entry_points = shader
979 .functions
980 .iter()
981 .filter(|f| !matches!(f.entry_point, WGSLEntryPoint::None))
982 .count();
983 let total_statements = shader.functions.iter().map(|f| f.body.len()).sum();
984 WGSLCodeMetrics {
985 num_functions: shader.functions.len(),
986 num_entry_points,
987 num_structs: shader.structs.len(),
988 num_bindings: shader.bindings.len(),
989 num_globals: shader.globals.len(),
990 num_constants: shader.constants.len(),
991 num_overrides: shader.overrides.len(),
992 total_statements,
993 num_enables: shader.enables.len(),
994 }
995 }
996 #[allow(dead_code)]
998 pub fn summary(&self) -> String {
999 format!(
1000 "functions={} entry_points={} structs={} bindings={} globals={} constants={} overrides={} statements={} enables={}",
1001 self.num_functions, self.num_entry_points, self.num_structs, self
1002 .num_bindings, self.num_globals, self.num_constants, self.num_overrides, self
1003 .total_statements, self.num_enables,
1004 )
1005 }
1006}
1007#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1009pub enum WGSLAccess {
1010 Read,
1011 Write,
1012 ReadWrite,
1013}
1014#[derive(Debug, Clone)]
1016pub struct WGSLFunction {
1017 pub name: String,
1019 pub entry_point: WGSLEntryPoint,
1021 pub params: Vec<WGSLParam>,
1023 pub return_type: Option<WGSLType>,
1025 pub return_attrib: WGSLReturnAttrib,
1027 pub body: Vec<String>,
1029}
1030impl WGSLFunction {
1031 pub fn helper(name: impl Into<String>) -> Self {
1033 WGSLFunction {
1034 name: name.into(),
1035 entry_point: WGSLEntryPoint::None,
1036 params: Vec::new(),
1037 return_type: None,
1038 return_attrib: WGSLReturnAttrib::None,
1039 body: Vec::new(),
1040 }
1041 }
1042 pub fn vertex(name: impl Into<String>) -> Self {
1044 WGSLFunction {
1045 name: name.into(),
1046 entry_point: WGSLEntryPoint::Vertex,
1047 params: Vec::new(),
1048 return_type: None,
1049 return_attrib: WGSLReturnAttrib::None,
1050 body: Vec::new(),
1051 }
1052 }
1053 pub fn fragment(name: impl Into<String>) -> Self {
1055 WGSLFunction {
1056 name: name.into(),
1057 entry_point: WGSLEntryPoint::Fragment,
1058 params: Vec::new(),
1059 return_type: None,
1060 return_attrib: WGSLReturnAttrib::None,
1061 body: Vec::new(),
1062 }
1063 }
1064 pub fn compute(name: impl Into<String>, x: u32, y: u32, z: u32) -> Self {
1066 WGSLFunction {
1067 name: name.into(),
1068 entry_point: WGSLEntryPoint::Compute { x, y, z },
1069 params: Vec::new(),
1070 return_type: None,
1071 return_attrib: WGSLReturnAttrib::None,
1072 body: Vec::new(),
1073 }
1074 }
1075 pub fn add_param(&mut self, param: WGSLParam) {
1077 self.params.push(param);
1078 }
1079 pub fn set_return_type(&mut self, ty: WGSLType) {
1081 self.return_type = Some(ty);
1082 }
1083 pub fn set_return_type_with_attrib(&mut self, ty: WGSLType, attrib: WGSLReturnAttrib) {
1085 self.return_type = Some(ty);
1086 self.return_attrib = attrib;
1087 }
1088 pub fn add_statement(&mut self, stmt: impl Into<String>) {
1090 self.body.push(stmt.into());
1091 }
1092 pub fn emit(&self) -> String {
1094 let mut out = self.entry_point.attribute();
1095 let params: Vec<String> = self.params.iter().map(|p| p.emit()).collect();
1096 let ret = match &self.return_type {
1097 Some(ty) => format!(" -> {}{}", self.return_attrib.prefix(), ty),
1098 None => String::new(),
1099 };
1100 out.push_str(&format!(
1101 "fn {}({}){} {{\n",
1102 self.name,
1103 params.join(", "),
1104 ret
1105 ));
1106 for stmt in &self.body {
1107 out.push_str(&format!(" {};\n", stmt));
1108 }
1109 out.push('}');
1110 out
1111 }
1112}
1113#[derive(Debug, Clone)]
1115#[allow(dead_code)]
1116pub struct WGSLRenderPipeline {
1117 pub name: String,
1119 pub vertex_input: String,
1121 pub varying: String,
1123 pub vs_body: Vec<String>,
1125 pub fs_body: Vec<String>,
1127 pub bindings: Vec<WGSLBinding>,
1129 pub structs: Vec<WGSLStruct>,
1131}
1132impl WGSLRenderPipeline {
1133 #[allow(dead_code)]
1135 pub fn new(name: impl Into<String>) -> Self {
1136 let name_str = name.into();
1137 WGSLRenderPipeline {
1138 vertex_input: format!("{}Input", name_str),
1139 varying: format!("{}Varying", name_str),
1140 name: name_str,
1141 vs_body: Vec::new(),
1142 fs_body: Vec::new(),
1143 bindings: Vec::new(),
1144 structs: Vec::new(),
1145 }
1146 }
1147 #[allow(dead_code)]
1149 pub fn emit(&self) -> String {
1150 let mut shader = WGSLShader::new();
1151 for s in &self.structs {
1152 shader.add_struct(s.clone());
1153 }
1154 for b in &self.bindings {
1155 shader.add_binding(b.clone());
1156 }
1157 let vs_name = format!("{}_vs", self.name);
1158 let mut vs = WGSLFunction::vertex(&vs_name);
1159 vs.add_param(WGSLParam::new(
1160 "input",
1161 WGSLType::Struct(self.vertex_input.clone()),
1162 ));
1163 vs.set_return_type(WGSLType::Struct(self.varying.clone()));
1164 for stmt in &self.vs_body {
1165 vs.add_statement(stmt.clone());
1166 }
1167 shader.add_function(vs);
1168 let fs_name = format!("{}_fs", self.name);
1169 let mut fs = WGSLFunction::fragment(&fs_name);
1170 fs.add_param(WGSLParam::new(
1171 "varying",
1172 WGSLType::Struct(self.varying.clone()),
1173 ));
1174 fs.set_return_type_with_attrib(WGSLType::Vec4f, WGSLReturnAttrib::Location(0));
1175 for stmt in &self.fs_body {
1176 fs.add_statement(stmt.clone());
1177 }
1178 shader.add_function(fs);
1179 WGSLBackend::new().emit_shader(&shader)
1180 }
1181}
1182#[derive(Debug, Clone, PartialEq, Eq)]
1184#[allow(dead_code)]
1185pub enum WGSLResourceType {
1186 UniformBuffer,
1188 StorageBufferReadOnly,
1190 StorageBufferReadWrite,
1192 SampledTexture,
1194 StorageTexture { format: String },
1196 Sampler,
1198 ComparisonSampler,
1200}
1201#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1203pub enum WGSLAddressSpace {
1204 Function,
1206 Private,
1208 Workgroup,
1210 Uniform,
1212 Storage,
1214 Handle,
1216}
1217#[derive(Debug, Clone, PartialEq, Eq)]
1219#[allow(dead_code)]
1220pub enum WGSLTypeError {
1221 TypeMismatch { expected: WGSLType, found: WGSLType },
1223 InvalidOperandType { op: String, ty: WGSLType },
1225 SwizzleOutOfRange { component: char, ty: WGSLType },
1227 NonShareableBinding { ty: WGSLType },
1229}
1230pub struct WGSLBackend;
1232impl WGSLBackend {
1233 pub fn new() -> Self {
1235 WGSLBackend
1236 }
1237 pub fn emit_shader(&self, shader: &WGSLShader) -> String {
1239 let mut out = String::new();
1240 for e in &shader.enables {
1241 out.push_str(&format!("enable {};\n", e));
1242 }
1243 if !shader.enables.is_empty() {
1244 out.push('\n');
1245 }
1246 for c in &shader.constants {
1247 out.push_str(&c.emit());
1248 out.push('\n');
1249 }
1250 if !shader.constants.is_empty() {
1251 out.push('\n');
1252 }
1253 for o in &shader.overrides {
1254 out.push_str(&o.emit());
1255 out.push('\n');
1256 }
1257 if !shader.overrides.is_empty() {
1258 out.push('\n');
1259 }
1260 for s in &shader.structs {
1261 out.push_str(&s.emit());
1262 out.push_str("\n\n");
1263 }
1264 for b in &shader.bindings {
1265 out.push_str(&b.emit());
1266 out.push('\n');
1267 }
1268 if !shader.bindings.is_empty() {
1269 out.push('\n');
1270 }
1271 for g in &shader.globals {
1272 out.push_str(&g.emit());
1273 out.push('\n');
1274 }
1275 if !shader.globals.is_empty() {
1276 out.push('\n');
1277 }
1278 for f in &shader.functions {
1279 out.push_str(&f.emit());
1280 out.push_str("\n\n");
1281 }
1282 out
1283 }
1284 pub fn triangle_shader_template(&self) -> String {
1286 let mut shader = WGSLShader::new();
1287 let mut vo = WGSLStruct::new("VertexOutput");
1288 vo.add_field(WGSLStructField::builtin(
1289 "position",
1290 WGSLType::Vec4f,
1291 "position",
1292 ));
1293 vo.add_field(WGSLStructField::location("color", WGSLType::Vec4f, 0));
1294 shader.add_struct(vo);
1295 let mut vert = WGSLFunction::vertex("vs_main");
1296 vert.add_param(WGSLParam::with_builtin(
1297 "vertex_index",
1298 WGSLType::U32,
1299 "vertex_index",
1300 ));
1301 vert.set_return_type(WGSLType::Struct("VertexOutput".into()));
1302 vert.add_statement(
1303 "var positions = array<vec2<f32>, 3>(vec2(0.0, 0.5), vec2(-0.5, -0.5), vec2(0.5, -0.5))",
1304 );
1305 vert.add_statement(
1306 "var colors = array<vec4<f32>, 3>(vec4(1.0, 0.0, 0.0, 1.0), vec4(0.0, 1.0, 0.0, 1.0), vec4(0.0, 0.0, 1.0, 1.0))",
1307 );
1308 vert.add_statement("var out: VertexOutput");
1309 vert.add_statement("out.position = vec4<f32>(positions[vertex_index], 0.0, 1.0)");
1310 vert.add_statement("out.color = colors[vertex_index]");
1311 vert.add_statement("return out");
1312 shader.add_function(vert);
1313 let mut frag = WGSLFunction::fragment("fs_main");
1314 frag.add_param(WGSLParam::with_location(
1315 "in",
1316 WGSLType::Struct("VertexOutput".into()),
1317 0,
1318 ));
1319 frag.set_return_type_with_attrib(WGSLType::Vec4f, WGSLReturnAttrib::Location(0));
1320 frag.add_statement("return in.color");
1321 shader.add_function(frag);
1322 self.emit_shader(&shader)
1323 }
1324 pub fn compute_shader_template(&self) -> String {
1326 let mut shader = WGSLShader::new();
1327 shader.add_binding(WGSLBinding::storage(
1328 0,
1329 0,
1330 "input_data",
1331 WGSLType::RuntimeArray(Box::new(WGSLType::F32)),
1332 WGSLAccess::Read,
1333 ));
1334 shader.add_binding(WGSLBinding::storage(
1335 0,
1336 1,
1337 "output_data",
1338 WGSLType::RuntimeArray(Box::new(WGSLType::F32)),
1339 WGSLAccess::ReadWrite,
1340 ));
1341 let mut comp = WGSLFunction::compute("main", 64, 1, 1);
1342 comp.add_param(WGSLParam::with_builtin(
1343 "global_id",
1344 WGSLType::Vec3u,
1345 "global_invocation_id",
1346 ));
1347 comp.add_statement("let idx = global_id.x");
1348 comp.add_statement("output_data[idx] = input_data[idx] * 2.0");
1349 shader.add_function(comp);
1350 self.emit_shader(&shader)
1351 }
1352 pub fn texture_sample_template(&self) -> String {
1354 let mut shader = WGSLShader::new();
1355 let mut vo = WGSLStruct::new("VertexOutput");
1356 vo.add_field(WGSLStructField::builtin(
1357 "position",
1358 WGSLType::Vec4f,
1359 "position",
1360 ));
1361 vo.add_field(WGSLStructField::location("uv", WGSLType::Vec2f, 0));
1362 shader.add_struct(vo);
1363 shader.add_binding(WGSLBinding::new(0, 0, "t_diffuse", WGSLType::Texture2D));
1364 shader.add_binding(WGSLBinding::new(0, 1, "s_diffuse", WGSLType::Sampler));
1365 let mut ub = WGSLStruct::new("Uniforms");
1366 ub.add_field(WGSLStructField::new("transform", WGSLType::Mat4x4f));
1367 shader.add_struct(ub);
1368 shader.add_binding(WGSLBinding::new(
1369 1,
1370 0,
1371 "uniforms",
1372 WGSLType::Struct("Uniforms".into()),
1373 ));
1374 let mut vert = WGSLFunction::vertex("vs_main");
1375 vert.add_param(WGSLParam::with_location("position", WGSLType::Vec4f, 0));
1376 vert.add_param(WGSLParam::with_location("uv", WGSLType::Vec2f, 1));
1377 vert.set_return_type(WGSLType::Struct("VertexOutput".into()));
1378 vert.add_statement("var out: VertexOutput");
1379 vert.add_statement("out.position = uniforms.transform * position");
1380 vert.add_statement("out.uv = uv");
1381 vert.add_statement("return out");
1382 shader.add_function(vert);
1383 let mut frag = WGSLFunction::fragment("fs_main");
1384 frag.add_param(WGSLParam::with_location(
1385 "in",
1386 WGSLType::Struct("VertexOutput".into()),
1387 0,
1388 ));
1389 frag.set_return_type_with_attrib(WGSLType::Vec4f, WGSLReturnAttrib::Location(0));
1390 frag.add_statement("return textureSample(t_diffuse, s_diffuse, in.uv)");
1391 shader.add_function(frag);
1392 self.emit_shader(&shader)
1393 }
1394 pub fn reduction_shader_template(&self, workgroup_size: u32) -> String {
1396 let mut shader = WGSLShader::new();
1397 let mut ws_override = WGSLOverride::new("WORKGROUP_SIZE", WGSLType::U32);
1398 ws_override.default_value = Some(workgroup_size.to_string());
1399 shader.add_override(ws_override);
1400 shader.add_binding(WGSLBinding::storage(
1401 0,
1402 0,
1403 "data",
1404 WGSLType::RuntimeArray(Box::new(WGSLType::F32)),
1405 WGSLAccess::Read,
1406 ));
1407 shader.add_binding(WGSLBinding::storage(
1408 0,
1409 1,
1410 "result",
1411 WGSLType::F32,
1412 WGSLAccess::ReadWrite,
1413 ));
1414 shader.add_global(WGSLGlobal::workgroup(
1415 "shared_data",
1416 WGSLType::Array(Box::new(WGSLType::F32), workgroup_size),
1417 ));
1418 let mut comp = WGSLFunction::compute("reduce", workgroup_size, 1, 1);
1419 comp.add_param(WGSLParam::with_builtin(
1420 "global_id",
1421 WGSLType::Vec3u,
1422 "global_invocation_id",
1423 ));
1424 comp.add_param(WGSLParam::with_builtin(
1425 "local_id",
1426 WGSLType::Vec3u,
1427 "local_invocation_id",
1428 ));
1429 comp.add_statement("let gid = global_id.x");
1430 comp.add_statement("let lid = local_id.x");
1431 comp.add_statement("shared_data[lid] = data[gid]");
1432 comp.add_statement("workgroupBarrier()");
1433 comp.add_statement("var stride = WORKGROUP_SIZE / 2u");
1434 comp.add_statement(
1435 "loop { if stride == 0u { break; } if lid < stride { shared_data[lid] += shared_data[lid + stride]; } workgroupBarrier(); stride /= 2u; }",
1436 );
1437 comp.add_statement("if lid == 0u { result = shared_data[0u]; }");
1438 shader.add_function(comp);
1439 self.emit_shader(&shader)
1440 }
1441}
1442#[derive(Debug, Clone)]
1444#[allow(dead_code)]
1445pub struct WGSLBindingGroupLayout {
1446 pub group: u32,
1448 pub entries: Vec<WGSLBindingEntry>,
1450}
1451impl WGSLBindingGroupLayout {
1452 #[allow(dead_code)]
1454 pub fn new(group: u32) -> Self {
1455 WGSLBindingGroupLayout {
1456 group,
1457 entries: Vec::new(),
1458 }
1459 }
1460 #[allow(dead_code)]
1462 pub fn add_entry(
1463 &mut self,
1464 binding: u32,
1465 resource_type: WGSLResourceType,
1466 visibility: WGSLStageVisibility,
1467 ) {
1468 self.entries.push(WGSLBindingEntry {
1469 binding,
1470 resource_type,
1471 visibility,
1472 });
1473 }
1474 #[allow(dead_code)]
1476 pub fn emit_comment(&self) -> String {
1477 let mut out = format!("// BindGroup {} layout:\n", self.group);
1478 for e in &self.entries {
1479 out.push_str(&format!(
1480 "// binding={} type={:?} visibility={}\n",
1481 e.binding, e.resource_type, e.visibility
1482 ));
1483 }
1484 out
1485 }
1486}
1487#[derive(Debug, Clone)]
1489pub enum WGSLExpr {
1490 Literal(String),
1492 Var(String),
1494 BinOp {
1496 op: String,
1497 lhs: Box<WGSLExpr>,
1498 rhs: Box<WGSLExpr>,
1499 },
1500 UnaryOp { op: String, operand: Box<WGSLExpr> },
1502 Call { func: String, args: Vec<WGSLExpr> },
1504 Field { base: Box<WGSLExpr>, field: String },
1506 Index {
1508 base: Box<WGSLExpr>,
1509 index: Box<WGSLExpr>,
1510 },
1511}
1512impl WGSLExpr {
1513 pub fn emit(&self) -> String {
1515 match self {
1516 WGSLExpr::Literal(s) => s.clone(),
1517 WGSLExpr::Var(name) => name.clone(),
1518 WGSLExpr::BinOp { op, lhs, rhs } => {
1519 format!("({} {} {})", lhs.emit(), op, rhs.emit())
1520 }
1521 WGSLExpr::UnaryOp { op, operand } => format!("({}{})", op, operand.emit()),
1522 WGSLExpr::Call { func, args } => {
1523 let arg_strs: Vec<String> = args.iter().map(|a| a.emit()).collect();
1524 format!("{}({})", func, arg_strs.join(", "))
1525 }
1526 WGSLExpr::Field { base, field } => format!("{}.{}", base.emit(), field),
1527 WGSLExpr::Index { base, index } => {
1528 format!("{}[{}]", base.emit(), index.emit())
1529 }
1530 }
1531 }
1532 pub fn binop(op: impl Into<String>, lhs: WGSLExpr, rhs: WGSLExpr) -> Self {
1534 WGSLExpr::BinOp {
1535 op: op.into(),
1536 lhs: Box::new(lhs),
1537 rhs: Box::new(rhs),
1538 }
1539 }
1540 pub fn call(func: impl Into<String>, args: Vec<WGSLExpr>) -> Self {
1542 WGSLExpr::Call {
1543 func: func.into(),
1544 args,
1545 }
1546 }
1547 pub fn var(name: impl Into<String>) -> Self {
1549 WGSLExpr::Var(name.into())
1550 }
1551 pub fn f32_lit(v: f32) -> Self {
1553 WGSLExpr::Literal(format!("{:.6}", v))
1554 }
1555 pub fn u32_lit(v: u32) -> Self {
1557 WGSLExpr::Literal(format!("{}u", v))
1558 }
1559}
1560#[derive(Debug, Default)]
1562#[allow(dead_code)]
1563pub struct WGSLValidator {
1564 pub errors: Vec<String>,
1566 pub warnings: Vec<String>,
1568}
1569impl WGSLValidator {
1570 #[allow(dead_code)]
1572 pub fn new() -> Self {
1573 WGSLValidator::default()
1574 }
1575 #[allow(dead_code)]
1577 pub fn validate(&mut self, shader: &WGSLShader) -> bool {
1578 self.errors.clear();
1579 self.warnings.clear();
1580 let mut fn_names = std::collections::HashSet::new();
1581 for f in &shader.functions {
1582 if !fn_names.insert(f.name.clone()) {
1583 self.errors
1584 .push(format!("duplicate function name: '{}'", f.name));
1585 }
1586 }
1587 let mut struct_names = std::collections::HashSet::new();
1588 for s in &shader.structs {
1589 if !struct_names.insert(s.name.clone()) {
1590 self.errors
1591 .push(format!("duplicate struct name: '{}'", s.name));
1592 }
1593 }
1594 let mut binding_slots = std::collections::HashSet::new();
1595 for b in &shader.bindings {
1596 let key = (b.group, b.binding);
1597 if !binding_slots.insert(key) {
1598 self.errors.push(format!(
1599 "duplicate binding @group({}) @binding({})",
1600 b.group, b.binding
1601 ));
1602 }
1603 }
1604 if shader.functions.is_empty() {
1605 self.warnings.push("shader has no functions".to_string());
1606 }
1607 for f in &shader.functions {
1608 match f.entry_point {
1609 WGSLEntryPoint::Compute { .. } => {
1610 if f.params.is_empty() {
1611 self.warnings.push(format!(
1612 "compute entry '{}' has no parameters (no global_invocation_id?)",
1613 f.name
1614 ));
1615 }
1616 }
1617 _ => {}
1618 }
1619 }
1620 self.errors.is_empty()
1621 }
1622 #[allow(dead_code)]
1624 pub fn is_valid(&self) -> bool {
1625 self.errors.is_empty()
1626 }
1627}
1628#[derive(Debug, Clone)]
1630pub enum WGSLReturnAttrib {
1631 Builtin(String),
1633 Location(u32),
1635 None,
1637}
1638impl WGSLReturnAttrib {
1639 pub fn prefix(&self) -> String {
1641 match self {
1642 WGSLReturnAttrib::Builtin(b) => format!("@builtin({}) ", b),
1643 WGSLReturnAttrib::Location(n) => format!("@location({}) ", n),
1644 WGSLReturnAttrib::None => String::new(),
1645 }
1646 }
1647}
1648#[allow(dead_code)]
1650pub struct WGSLPrimitiveHelper;
1651impl WGSLPrimitiveHelper {
1652 #[allow(dead_code)]
1654 pub fn vec2f(x: f32, y: f32) -> String {
1655 format!("vec2f({}, {})", x, y)
1656 }
1657 #[allow(dead_code)]
1659 pub fn vec3f(x: f32, y: f32, z: f32) -> String {
1660 format!("vec3f({}, {}, {})", x, y, z)
1661 }
1662 #[allow(dead_code)]
1664 pub fn vec4f(x: f32, y: f32, z: f32, w: f32) -> String {
1665 format!("vec4f({}, {}, {}, {})", x, y, z, w)
1666 }
1667 #[allow(dead_code)]
1669 pub fn vec2u(x: u32, y: u32) -> String {
1670 format!("vec2u({}u, {}u)", x, y)
1671 }
1672 #[allow(dead_code)]
1674 pub fn vec3u(x: u32, y: u32, z: u32) -> String {
1675 format!("vec3u({}u, {}u, {}u)", x, y, z)
1676 }
1677 #[allow(dead_code)]
1679 pub fn mat4x4_identity() -> String {
1680 "mat4x4f(1.0,0.0,0.0,0.0, 0.0,1.0,0.0,0.0, 0.0,0.0,1.0,0.0, 0.0,0.0,0.0,1.0)".to_string()
1681 }
1682 #[allow(dead_code)]
1684 pub fn perspective_matrix(fov_y_rad: f32, aspect: f32, near: f32, far: f32) -> String {
1685 let f = 1.0 / (fov_y_rad / 2.0).tan();
1686 let nf = 1.0 / (near - far);
1687 format!(
1688 "mat4x4f({f},0.0,0.0,0.0, 0.0,{f_a},0.0,0.0, 0.0,0.0,{nf_a},{b}, 0.0,0.0,-1.0,0.0)",
1689 f = f,
1690 f_a = f / aspect,
1691 nf_a = (near + far) * nf,
1692 b = 2.0 * far * near * nf,
1693 )
1694 }
1695 #[allow(dead_code)]
1697 pub fn ortho_matrix(
1698 left: f32,
1699 right: f32,
1700 bottom: f32,
1701 top: f32,
1702 near: f32,
1703 far: f32,
1704 ) -> String {
1705 let rl = right - left;
1706 let tb = top - bottom;
1707 let fn_ = far - near;
1708 format!(
1709 "mat4x4f({a},0.0,0.0,0.0, 0.0,{b},0.0,0.0, 0.0,0.0,{c},0.0, {tx},{ty},{tz},1.0)",
1710 a = 2.0 / rl,
1711 b = 2.0 / tb,
1712 c = -2.0 / fn_,
1713 tx = -(right + left) / rl,
1714 ty = -(top + bottom) / tb,
1715 tz = -(far + near) / fn_,
1716 )
1717 }
1718 #[allow(dead_code)]
1720 pub fn swizzle(base: &str, components: &str) -> String {
1721 format!("{}.{}", base, components)
1722 }
1723 #[allow(dead_code)]
1725 pub fn select(false_val: &str, true_val: &str, cond: &str) -> String {
1726 format!("select({}, {}, {})", false_val, true_val, cond)
1727 }
1728 #[allow(dead_code)]
1730 pub fn atomic_add(ptr: &str, val: &str) -> String {
1731 format!("atomicAdd({}, {})", ptr, val)
1732 }
1733 #[allow(dead_code)]
1735 pub fn barrier() -> &'static str {
1736 "workgroupBarrier()"
1737 }
1738}
1739#[derive(Debug, Default)]
1741#[allow(dead_code)]
1742pub struct WGSLShaderBuilder {
1743 pub(super) shader: WGSLShader,
1744 pub(super) next_group: u32,
1745 pub(super) next_binding: u32,
1746}
1747impl WGSLShaderBuilder {
1748 #[allow(dead_code)]
1750 pub fn new() -> Self {
1751 WGSLShaderBuilder::default()
1752 }
1753 #[allow(dead_code)]
1755 pub fn enable(mut self, ext: impl Into<String>) -> Self {
1756 self.shader.add_enable(ext);
1757 self
1758 }
1759 #[allow(dead_code)]
1761 pub fn constant(
1762 mut self,
1763 name: impl Into<String>,
1764 ty: WGSLType,
1765 value: impl Into<String>,
1766 ) -> Self {
1767 self.shader
1768 .add_constant(WGSLConstant::typed(name, ty, value));
1769 self
1770 }
1771 #[allow(dead_code)]
1773 pub fn struct_def(mut self, s: WGSLStruct) -> Self {
1774 self.shader.add_struct(s);
1775 self
1776 }
1777 #[allow(dead_code)]
1779 pub fn uniform(mut self, name: impl Into<String>, ty: WGSLType) -> Self {
1780 let b = WGSLBinding::new(self.next_group, self.next_binding, name, ty);
1781 self.next_binding += 1;
1782 self.shader.add_binding(b);
1783 self
1784 }
1785 #[allow(dead_code)]
1787 pub fn next_group(mut self) -> Self {
1788 self.next_group += 1;
1789 self.next_binding = 0;
1790 self
1791 }
1792 #[allow(dead_code)]
1794 pub fn storage_read(mut self, name: impl Into<String>, elem_ty: WGSLType) -> Self {
1795 let b = WGSLBinding::storage(
1796 self.next_group,
1797 self.next_binding,
1798 name,
1799 WGSLType::RuntimeArray(Box::new(elem_ty)),
1800 WGSLAccess::Read,
1801 );
1802 self.next_binding += 1;
1803 self.shader.add_binding(b);
1804 self
1805 }
1806 #[allow(dead_code)]
1808 pub fn storage_rw(mut self, name: impl Into<String>, elem_ty: WGSLType) -> Self {
1809 let b = WGSLBinding::storage(
1810 self.next_group,
1811 self.next_binding,
1812 name,
1813 WGSLType::RuntimeArray(Box::new(elem_ty)),
1814 WGSLAccess::ReadWrite,
1815 );
1816 self.next_binding += 1;
1817 self.shader.add_binding(b);
1818 self
1819 }
1820 #[allow(dead_code)]
1822 pub fn texture2d(mut self, name: impl Into<String>) -> Self {
1823 let b = WGSLBinding::new(
1824 self.next_group,
1825 self.next_binding,
1826 name,
1827 WGSLType::Texture2D,
1828 );
1829 self.next_binding += 1;
1830 self.shader.add_binding(b);
1831 self
1832 }
1833 #[allow(dead_code)]
1835 pub fn sampler(mut self, name: impl Into<String>) -> Self {
1836 let b = WGSLBinding::new(self.next_group, self.next_binding, name, WGSLType::Sampler);
1837 self.next_binding += 1;
1838 self.shader.add_binding(b);
1839 self
1840 }
1841 #[allow(dead_code)]
1843 pub fn workgroup_var(mut self, name: impl Into<String>, ty: WGSLType) -> Self {
1844 self.shader.add_global(WGSLGlobal::workgroup(name, ty));
1845 self
1846 }
1847 #[allow(dead_code)]
1849 pub fn private_var(mut self, name: impl Into<String>, ty: WGSLType) -> Self {
1850 self.shader.add_global(WGSLGlobal::private(name, ty));
1851 self
1852 }
1853 #[allow(dead_code)]
1855 pub fn helper(mut self, f: WGSLFunction) -> Self {
1856 self.shader.add_function(f);
1857 self
1858 }
1859 #[allow(dead_code)]
1861 pub fn build(self) -> WGSLShader {
1862 self.shader
1863 }
1864}
1865#[derive(Debug, Clone)]
1867pub struct WGSLGlobal {
1868 pub name: String,
1870 pub ty: WGSLType,
1872 pub address_space: WGSLAddressSpace,
1874 pub access: Option<WGSLAccess>,
1876 pub initializer: Option<String>,
1878}
1879impl WGSLGlobal {
1880 pub fn private(name: impl Into<String>, ty: WGSLType) -> Self {
1882 WGSLGlobal {
1883 name: name.into(),
1884 ty,
1885 address_space: WGSLAddressSpace::Private,
1886 access: None,
1887 initializer: None,
1888 }
1889 }
1890 pub fn workgroup(name: impl Into<String>, ty: WGSLType) -> Self {
1892 WGSLGlobal {
1893 name: name.into(),
1894 ty,
1895 address_space: WGSLAddressSpace::Workgroup,
1896 access: None,
1897 initializer: None,
1898 }
1899 }
1900 pub fn emit(&self) -> String {
1902 let access_str = match self.access {
1903 Some(a) => format!(", {}", a),
1904 None => String::new(),
1905 };
1906 let init = match &self.initializer {
1907 Some(v) => format!(" = {}", v),
1908 None => String::new(),
1909 };
1910 format!(
1911 "var<{}{}> {}: {}{};",
1912 self.address_space, access_str, self.name, self.ty, init
1913 )
1914 }
1915}
1916#[derive(Debug, Clone)]
1918pub struct WGSLOverride {
1919 pub name: String,
1921 pub ty: WGSLType,
1923 pub id: Option<u32>,
1925 pub default_value: Option<String>,
1927}
1928impl WGSLOverride {
1929 pub fn new(name: impl Into<String>, ty: WGSLType) -> Self {
1931 WGSLOverride {
1932 name: name.into(),
1933 ty,
1934 id: None,
1935 default_value: None,
1936 }
1937 }
1938 pub fn emit(&self) -> String {
1940 let id_attr = match self.id {
1941 Some(n) => format!("@id({}) ", n),
1942 None => String::new(),
1943 };
1944 let init = match &self.default_value {
1945 Some(v) => format!(" = {}", v),
1946 None => String::new(),
1947 };
1948 format!("{}override {}: {}{};", id_attr, self.name, self.ty, init)
1949 }
1950}
1951#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1953#[allow(dead_code)]
1954pub enum WGSLBuiltinFunction {
1955 Abs,
1956 Acos,
1957 Acosh,
1958 Asin,
1959 Asinh,
1960 Atan,
1961 Atanh,
1962 Atan2,
1963 Ceil,
1964 Clamp,
1965 Cos,
1966 Cosh,
1967 Cross,
1968 Degrees,
1969 Distance,
1970 Dot,
1971 Exp,
1972 Exp2,
1973 FaceForward,
1974 Floor,
1975 Fma,
1976 Fract,
1977 Frexp,
1978 InverseSqrt,
1979 Ldexp,
1980 Length,
1981 Log,
1982 Log2,
1983 Max,
1984 Min,
1985 Mix,
1986 Modf,
1987 Normalize,
1988 Pow,
1989 Quantize,
1990 Radians,
1991 Reflect,
1992 Refract,
1993 Round,
1994 Saturate,
1995 Sign,
1996 Sin,
1997 Sinh,
1998 Smoothstep,
1999 Sqrt,
2000 Step,
2001 Tan,
2002 Tanh,
2003 Transpose,
2004 Trunc,
2005 CountLeadingZeros,
2006 CountOneBits,
2007 CountTrailingZeros,
2008 ExtractBits,
2009 FirstLeadingBit,
2010 FirstTrailingBit,
2011 InsertBits,
2012 ReverseBits,
2013 TextureDimensions,
2014 TextureGather,
2015 TextureGatherCompare,
2016 TextureLoad,
2017 TextureNumLayers,
2018 TextureNumLevels,
2019 TextureNumSamples,
2020 TextureSample,
2021 TextureSampleBias,
2022 TextureSampleCompare,
2023 TextureSampleCompareLevel,
2024 TextureSampleGrad,
2025 TextureSampleLevel,
2026 TextureStore,
2027 Dpdx,
2028 Dpdxcoarse,
2029 Dpdxfine,
2030 Dpdy,
2031 Dpdycoarse,
2032 Dpdyfine,
2033 Fwidth,
2034 FwidthCoarse,
2035 FwidthFine,
2036 AtomicLoad,
2037 AtomicStore,
2038 AtomicAdd,
2039 AtomicSub,
2040 AtomicMax,
2041 AtomicMin,
2042 AtomicAnd,
2043 AtomicOr,
2044 AtomicXor,
2045 AtomicExchange,
2046 AtomicCompareExchangeWeak,
2047 WorkgroupBarrier,
2048 StorageBarrier,
2049 TextureBarrier,
2050 Pack2x16float,
2051 Pack2x16snorm,
2052 Pack2x16unorm,
2053 Pack4x8snorm,
2054 Pack4x8unorm,
2055 Unpack2x16float,
2056 Unpack2x16snorm,
2057 Unpack2x16unorm,
2058 Unpack4x8snorm,
2059 Unpack4x8unorm,
2060}
2061impl WGSLBuiltinFunction {
2062 #[allow(dead_code)]
2064 pub fn name(&self) -> &'static str {
2065 match self {
2066 WGSLBuiltinFunction::Abs => "abs",
2067 WGSLBuiltinFunction::Acos => "acos",
2068 WGSLBuiltinFunction::Acosh => "acosh",
2069 WGSLBuiltinFunction::Asin => "asin",
2070 WGSLBuiltinFunction::Asinh => "asinh",
2071 WGSLBuiltinFunction::Atan => "atan",
2072 WGSLBuiltinFunction::Atanh => "atanh",
2073 WGSLBuiltinFunction::Atan2 => "atan2",
2074 WGSLBuiltinFunction::Ceil => "ceil",
2075 WGSLBuiltinFunction::Clamp => "clamp",
2076 WGSLBuiltinFunction::Cos => "cos",
2077 WGSLBuiltinFunction::Cosh => "cosh",
2078 WGSLBuiltinFunction::Cross => "cross",
2079 WGSLBuiltinFunction::Degrees => "degrees",
2080 WGSLBuiltinFunction::Distance => "distance",
2081 WGSLBuiltinFunction::Dot => "dot",
2082 WGSLBuiltinFunction::Exp => "exp",
2083 WGSLBuiltinFunction::Exp2 => "exp2",
2084 WGSLBuiltinFunction::FaceForward => "faceForward",
2085 WGSLBuiltinFunction::Floor => "floor",
2086 WGSLBuiltinFunction::Fma => "fma",
2087 WGSLBuiltinFunction::Fract => "fract",
2088 WGSLBuiltinFunction::Frexp => "frexp",
2089 WGSLBuiltinFunction::InverseSqrt => "inverseSqrt",
2090 WGSLBuiltinFunction::Ldexp => "ldexp",
2091 WGSLBuiltinFunction::Length => "length",
2092 WGSLBuiltinFunction::Log => "log",
2093 WGSLBuiltinFunction::Log2 => "log2",
2094 WGSLBuiltinFunction::Max => "max",
2095 WGSLBuiltinFunction::Min => "min",
2096 WGSLBuiltinFunction::Mix => "mix",
2097 WGSLBuiltinFunction::Modf => "modf",
2098 WGSLBuiltinFunction::Normalize => "normalize",
2099 WGSLBuiltinFunction::Pow => "pow",
2100 WGSLBuiltinFunction::Quantize => "quantizeToF16",
2101 WGSLBuiltinFunction::Radians => "radians",
2102 WGSLBuiltinFunction::Reflect => "reflect",
2103 WGSLBuiltinFunction::Refract => "refract",
2104 WGSLBuiltinFunction::Round => "round",
2105 WGSLBuiltinFunction::Saturate => "saturate",
2106 WGSLBuiltinFunction::Sign => "sign",
2107 WGSLBuiltinFunction::Sin => "sin",
2108 WGSLBuiltinFunction::Sinh => "sinh",
2109 WGSLBuiltinFunction::Smoothstep => "smoothstep",
2110 WGSLBuiltinFunction::Sqrt => "sqrt",
2111 WGSLBuiltinFunction::Step => "step",
2112 WGSLBuiltinFunction::Tan => "tan",
2113 WGSLBuiltinFunction::Tanh => "tanh",
2114 WGSLBuiltinFunction::Transpose => "transpose",
2115 WGSLBuiltinFunction::Trunc => "trunc",
2116 WGSLBuiltinFunction::CountLeadingZeros => "countLeadingZeros",
2117 WGSLBuiltinFunction::CountOneBits => "countOneBits",
2118 WGSLBuiltinFunction::CountTrailingZeros => "countTrailingZeros",
2119 WGSLBuiltinFunction::ExtractBits => "extractBits",
2120 WGSLBuiltinFunction::FirstLeadingBit => "firstLeadingBit",
2121 WGSLBuiltinFunction::FirstTrailingBit => "firstTrailingBit",
2122 WGSLBuiltinFunction::InsertBits => "insertBits",
2123 WGSLBuiltinFunction::ReverseBits => "reverseBits",
2124 WGSLBuiltinFunction::TextureDimensions => "textureDimensions",
2125 WGSLBuiltinFunction::TextureGather => "textureGather",
2126 WGSLBuiltinFunction::TextureGatherCompare => "textureGatherCompare",
2127 WGSLBuiltinFunction::TextureLoad => "textureLoad",
2128 WGSLBuiltinFunction::TextureNumLayers => "textureNumLayers",
2129 WGSLBuiltinFunction::TextureNumLevels => "textureNumLevels",
2130 WGSLBuiltinFunction::TextureNumSamples => "textureNumSamples",
2131 WGSLBuiltinFunction::TextureSample => "textureSample",
2132 WGSLBuiltinFunction::TextureSampleBias => "textureSampleBias",
2133 WGSLBuiltinFunction::TextureSampleCompare => "textureSampleCompare",
2134 WGSLBuiltinFunction::TextureSampleCompareLevel => "textureSampleCompareLevel",
2135 WGSLBuiltinFunction::TextureSampleGrad => "textureSampleGrad",
2136 WGSLBuiltinFunction::TextureSampleLevel => "textureSampleLevel",
2137 WGSLBuiltinFunction::TextureStore => "textureStore",
2138 WGSLBuiltinFunction::Dpdx => "dpdx",
2139 WGSLBuiltinFunction::Dpdxcoarse => "dpdxCoarse",
2140 WGSLBuiltinFunction::Dpdxfine => "dpdxFine",
2141 WGSLBuiltinFunction::Dpdy => "dpdy",
2142 WGSLBuiltinFunction::Dpdycoarse => "dpdyCoarse",
2143 WGSLBuiltinFunction::Dpdyfine => "dpdyFine",
2144 WGSLBuiltinFunction::Fwidth => "fwidth",
2145 WGSLBuiltinFunction::FwidthCoarse => "fwidthCoarse",
2146 WGSLBuiltinFunction::FwidthFine => "fwidthFine",
2147 WGSLBuiltinFunction::AtomicLoad => "atomicLoad",
2148 WGSLBuiltinFunction::AtomicStore => "atomicStore",
2149 WGSLBuiltinFunction::AtomicAdd => "atomicAdd",
2150 WGSLBuiltinFunction::AtomicSub => "atomicSub",
2151 WGSLBuiltinFunction::AtomicMax => "atomicMax",
2152 WGSLBuiltinFunction::AtomicMin => "atomicMin",
2153 WGSLBuiltinFunction::AtomicAnd => "atomicAnd",
2154 WGSLBuiltinFunction::AtomicOr => "atomicOr",
2155 WGSLBuiltinFunction::AtomicXor => "atomicXor",
2156 WGSLBuiltinFunction::AtomicExchange => "atomicExchange",
2157 WGSLBuiltinFunction::AtomicCompareExchangeWeak => "atomicCompareExchangeWeak",
2158 WGSLBuiltinFunction::WorkgroupBarrier => "workgroupBarrier",
2159 WGSLBuiltinFunction::StorageBarrier => "storageBarrier",
2160 WGSLBuiltinFunction::TextureBarrier => "textureBarrier",
2161 WGSLBuiltinFunction::Pack2x16float => "pack2x16float",
2162 WGSLBuiltinFunction::Pack2x16snorm => "pack2x16snorm",
2163 WGSLBuiltinFunction::Pack2x16unorm => "pack2x16unorm",
2164 WGSLBuiltinFunction::Pack4x8snorm => "pack4x8snorm",
2165 WGSLBuiltinFunction::Pack4x8unorm => "pack4x8unorm",
2166 WGSLBuiltinFunction::Unpack2x16float => "unpack2x16float",
2167 WGSLBuiltinFunction::Unpack2x16snorm => "unpack2x16snorm",
2168 WGSLBuiltinFunction::Unpack2x16unorm => "unpack2x16unorm",
2169 WGSLBuiltinFunction::Unpack4x8snorm => "unpack4x8snorm",
2170 WGSLBuiltinFunction::Unpack4x8unorm => "unpack4x8unorm",
2171 }
2172 }
2173 #[allow(dead_code)]
2175 pub fn call(&self, args: &[&str]) -> String {
2176 format!("{}({})", self.name(), args.join(", "))
2177 }
2178}
2179#[derive(Debug, Clone)]
2181pub struct WGSLShader {
2182 pub enables: Vec<String>,
2184 pub constants: Vec<WGSLConstant>,
2186 pub overrides: Vec<WGSLOverride>,
2188 pub structs: Vec<WGSLStruct>,
2190 pub bindings: Vec<WGSLBinding>,
2192 pub globals: Vec<WGSLGlobal>,
2194 pub functions: Vec<WGSLFunction>,
2196}
2197impl WGSLShader {
2198 pub fn new() -> Self {
2200 WGSLShader {
2201 enables: Vec::new(),
2202 constants: Vec::new(),
2203 overrides: Vec::new(),
2204 structs: Vec::new(),
2205 bindings: Vec::new(),
2206 globals: Vec::new(),
2207 functions: Vec::new(),
2208 }
2209 }
2210 pub fn add_enable(&mut self, ext: impl Into<String>) {
2212 self.enables.push(ext.into());
2213 }
2214 pub fn add_constant(&mut self, c: WGSLConstant) {
2216 self.constants.push(c);
2217 }
2218 pub fn add_override(&mut self, o: WGSLOverride) {
2220 self.overrides.push(o);
2221 }
2222 pub fn add_struct(&mut self, s: WGSLStruct) {
2224 self.structs.push(s);
2225 }
2226 pub fn add_binding(&mut self, b: WGSLBinding) {
2228 self.bindings.push(b);
2229 }
2230 pub fn add_global(&mut self, g: WGSLGlobal) {
2232 self.globals.push(g);
2233 }
2234 pub fn add_function(&mut self, f: WGSLFunction) {
2236 self.functions.push(f);
2237 }
2238}