wgsl_bindgen 0.22.2

Type safe Rust bindings workflow for wgsl shaders in wgpu
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
mod bindings;
mod types;

use std::collections::HashMap;
use std::path::PathBuf;

pub use bindings::*;
use derive_builder::Builder;
use derive_more::IsVariant;
use enumflags2::{bitflags, BitFlags};
pub use naga::valid::Capabilities as WgslShaderIrCapabilities;
use proc_macro2::TokenStream;
use regex::Regex;
pub use types::*;

use crate::{
  FastIndexMap, WGSLBindgen, WgslBindgenError, WgslType, WgslTypeSerializeStrategy,
};

/// An enum representing the source type that will be generated for the output.
#[bitflags(default = EmbedSource)]
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, IsVariant)]
pub enum WgslShaderSourceType {
  /// Preparse the shader modules and embed the final shader string in the output.
  /// This option skips the naga_oil dependency in the output, and but doesn't allow shader defines.
  EmbedSource,

  /// Use Composer with embedded strings for each shader module,
  /// This option allows shader defines and but doesn't allow hot-reloading.
  EmbedWithNagaOilComposer,

  /// Use Composer with relative paths and user-provided file loading
  /// This option allows shader defines and custom IO without requiring nightly Rust.
  ComposerWithRelativePath,
}

/// A struct representing a directory to scan for additional source files.
///
/// This struct is used to represent a directory to scan for additional source files
/// when generating Rust bindings for WGSL shaders. The `module_import_root` field
/// is used to specify the root prefix or namespace that should be applied to all
/// shaders given as the entrypoints, and the `directory` field is used to specify
/// the directory to scan for additional source files.
#[derive(Debug, Clone, Default)]
pub struct AdditionalScanDirectory {
  pub module_import_root: Option<String>,
  pub directory: String,
}

impl From<(Option<&str>, &str)> for AdditionalScanDirectory {
  fn from((module_import_root, directory): (Option<&str>, &str)) -> Self {
    Self {
      module_import_root: module_import_root.map(ToString::to_string),
      directory: directory.to_string(),
    }
  }
}

/// A trait for building `WgslType` to `TokenStream` map.
///
/// This map is used to convert built-in WGSL types into their corresponding
/// representations in the generated Rust code. The specific format used for
/// matrix and vector types can vary, and the generated types for the same WGSL
/// type may differ in size or alignment.
///
/// Implementations of this trait provide a `build` function that takes a
/// `WgslTypeSerializeStrategy` and returns an `WgslTypeMap`.
pub trait WgslTypeMapBuild {
  /// Builds the `WgslTypeMap` based on the given serialization strategy.
  fn build(&self, strategy: WgslTypeSerializeStrategy) -> WgslTypeMap;
}

impl WgslTypeMapBuild for WgslTypeMap {
  fn build(&self, _: WgslTypeSerializeStrategy) -> WgslTypeMap {
    self.clone()
  }
}

/// This struct is used to create a custom mapping from the wgsl side to rust side,
/// skipping generation of the struct and using the custom one instead.
/// This also means skipping checks for alignment and size when using bytemuck
/// for the struct.
/// This is useful for core primitive types you would want to model in Rust side
#[derive(Clone, Debug)]
pub struct OverrideStruct {
  /// fully qualified struct name of the struct in wgsl, eg: `lib::fp64::Fp64`
  pub from: String,
  /// fully qualified struct name in your crate, eg: `crate::fp64::Fp64`
  pub to: TokenStream,
  /// the alignment of the struct in bytes, this is used to ensure that the struct is aligned correctly
  pub alignment: usize,
}

impl From<(&str, TokenStream, usize)> for OverrideStruct {
  fn from((from, to, alignment): (&str, TokenStream, usize)) -> Self {
    OverrideStruct {
      from: from.to_owned(),
      to,
      alignment,
    }
  }
}

/// Struct  for overriding the field type of specific structs.
#[derive(Clone, Debug)]
pub struct OverrideStructFieldType {
  pub struct_regex: Regex,
  pub field_regex: Regex,
  pub override_type: TokenStream,
}
impl From<(Regex, Regex, TokenStream)> for OverrideStructFieldType {
  fn from(
    (struct_regex, field_regex, override_type): (Regex, Regex, TokenStream),
  ) -> Self {
    Self {
      struct_regex,
      field_regex,
      override_type,
    }
  }
}
impl From<(&str, &str, TokenStream)> for OverrideStructFieldType {
  fn from((struct_regex, field_regex, override_type): (&str, &str, TokenStream)) -> Self {
    Self {
      struct_regex: Regex::new(struct_regex).expect("Failed to create struct regex"),
      field_regex: Regex::new(field_regex).expect("Failed to create field regex"),
      override_type,
    }
  }
}

/// Struct for overriding alignment of specific structs.
#[derive(Clone, Debug)]
pub struct OverrideStructAlignment {
  pub struct_regex: Regex,
  pub alignment: u16,
}
impl From<(Regex, u16)> for OverrideStructAlignment {
  fn from((struct_regex, alignment): (Regex, u16)) -> Self {
    Self {
      struct_regex,
      alignment,
    }
  }
}
impl From<(&str, u16)> for OverrideStructAlignment {
  fn from((struct_regex, alignment): (&str, u16)) -> Self {
    Self {
      struct_regex: Regex::new(struct_regex).expect("Failed to create struct regex"),
      alignment,
    }
  }
}

/// Struct for overriding binding module path of bindgroup entry
#[derive(Clone, Debug)]
pub struct OverrideBindGroupEntryModulePath {
  pub bind_group_entry_regex: Regex,
  pub target_path: String,
}
impl From<(Regex, &str)> for OverrideBindGroupEntryModulePath {
  fn from((bind_group_entry_regex, target_path): (Regex, &str)) -> Self {
    Self {
      bind_group_entry_regex,
      target_path: target_path.to_string(),
    }
  }
}
impl From<(&str, &str)> for OverrideBindGroupEntryModulePath {
  fn from((bind_group_entry_regex, target_path): (&str, &str)) -> Self {
    Self {
      bind_group_entry_regex: Regex::new(bind_group_entry_regex)
        .expect("Failed to create bind group entry regex"),
      target_path: target_path.to_string(),
    }
  }
}

/// Struct for overriding texture filterability for specific bindings
#[derive(Clone, Debug)]
pub struct OverrideTextureFilterability {
  /// Regex to match binding path (e.g., "shared_data::.*texture.*")
  pub binding_regex: Regex,
  /// Whether the texture should be filterable
  pub filterable: bool,
}
impl From<(Regex, bool)> for OverrideTextureFilterability {
  fn from((binding_regex, filterable): (Regex, bool)) -> Self {
    Self {
      binding_regex,
      filterable,
    }
  }
}
impl From<(&str, bool)> for OverrideTextureFilterability {
  fn from((binding_regex, filterable): (&str, bool)) -> Self {
    Self {
      binding_regex: Regex::new(binding_regex).expect("Failed to create binding regex"),
      filterable,
    }
  }
}

/// Enum for sampler binding types
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SamplerType {
  Filtering,
  NonFiltering,
  Comparison,
}

/// Struct for overriding sampler types for specific bindings
#[derive(Clone, Debug)]
pub struct OverrideSamplerType {
  /// Regex to match binding path (e.g., ".*shadow_sampler.*")
  pub binding_regex: Regex,
  /// The sampler type to use
  pub sampler_type: SamplerType,
}
impl From<(Regex, SamplerType)> for OverrideSamplerType {
  fn from((binding_regex, sampler_type): (Regex, SamplerType)) -> Self {
    Self {
      binding_regex,
      sampler_type,
    }
  }
}
impl From<(&str, SamplerType)> for OverrideSamplerType {
  fn from((binding_regex, sampler_type): (&str, SamplerType)) -> Self {
    Self {
      binding_regex: Regex::new(binding_regex).expect("Failed to create binding regex"),
      sampler_type,
    }
  }
}

/// An enum representing the visibility of the type generated in the output
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub enum WgslTypeVisibility {
  /// All exported types set to `pub` visiblity
  #[default]
  Public,

  /// All exported types set to `pub(crate)` visiblity
  RestrictedCrate,

  /// All exported types set to `pub(super)` visiblity
  RestrictedSuper,
}

#[derive(Debug, Default, Builder)]
#[builder(
  setter(into),
  field(private),
  build_fn(private, name = "fallible_build")
)]
pub struct WgslBindgenOption {
  /// A vector of entry points to be added. Each entry point is represented as a `String`.
  #[builder(setter(each(name = "add_entry_point", into)))]
  pub entry_points: Vec<String>,

  /// The root prefix/namespace if any applied to all shaders given as the entrypoints.
  #[builder(default, setter(strip_option, into))]
  pub module_import_root: Option<String>,

  /// The root shader workspace directory where all the imports will tested for resolution.
  #[builder(setter(into))]
  pub workspace_root: PathBuf,

  /// A boolean flag indicating whether to emit a rerun-if-changed directive to Cargo. Defaults to `true`.
  #[builder(default = "true")]
  pub emit_rerun_if_change: bool,

  /// A boolean flag indicating whether to skip header comments. Enabling headers allows to not rerun if contents did not change.
  #[builder(default = "false")]
  pub skip_header_comments: bool,

  /// A boolean flag indicating whether to skip the hash check. This will avoid reruns of bindings generation if
  /// entry shaders including their imports has not changed. Defaults to `false`.
  #[builder(default = "false")]
  pub skip_hash_check: bool,

  /// Derive [encase::ShaderType](https://docs.rs/encase/latest/encase/trait.ShaderType.html#)
  /// for user defined WGSL structs when `WgslTypeSerializeStrategy::Encase`.
  /// else derive bytemuck
  #[builder(default)]
  pub serialization_strategy: WgslTypeSerializeStrategy,

  /// Derive [serde::Serialize](https://docs.rs/serde/1.0.159/serde/trait.Serialize.html)
  /// and [serde::Deserialize](https://docs.rs/serde/1.0.159/serde/trait.Deserialize.html)
  /// for user defined WGSL structs when `true`.
  #[builder(default = "false")]
  pub derive_serde: bool,

  /// The shader source type generated bitflags. Defaults to `WgslShaderSourceType::EmbedSource`.
  #[builder(default)]
  pub shader_source_type: BitFlags<WgslShaderSourceType>,

  /// The output file path for the generated Rust bindings. Defaults to `None`.
  #[builder(default, setter(strip_option, into))]
  pub output: Option<PathBuf>,

  /// The additional set of directories to scan for source files.
  #[builder(default, setter(into, each(name = "additional_scan_dir", into)))]
  pub additional_scan_dirs: Vec<AdditionalScanDirectory>,

  /// The [wgpu::naga::valid::Capabilities](https://docs.rs/wgpu/latest/wgpu/naga/valid/struct.Capabilities.html) to support. Defaults to `None`.
  #[builder(default, setter(strip_option))]
  pub ir_capabilities: Option<WgslShaderIrCapabilities>,

  /// Whether to generate short constructor similar to enums constructors instead of `new`, if number of parameters are below the specified threshold
  /// Defaults to `None`
  #[builder(default, setter(strip_option, into))]
  pub short_constructor: Option<i32>,

  /// Which visiblity to use for the exported types.
  #[builder(default)]
  pub type_visibility: WgslTypeVisibility,

  /// A mapping operation for WGSL built-in types. This is used to map WGSL built-in types to their corresponding representations.
  #[builder(setter(custom))]
  pub type_map: WgslTypeMap,

  /// A vector of custom struct mappings to be added, which will override the struct to be generated.
  /// This is merged with the default struct mappings.
  #[builder(default, setter(each(name = "add_override_struct_mapping", into)))]
  pub override_struct: Vec<OverrideStruct>,

  /// A vector of `OverrideStructFieldType` to override the generated types for struct fields in matching structs.
  #[builder(default, setter(into))]
  pub override_struct_field_type: Vec<OverrideStructFieldType>,

  /// A vector of regular expressions and alignments that override the generated alignment for matching structs.
  /// This can be used in scenarios where a specific minimum alignment is required for a uniform buffer.
  /// Refer to the [WebGPU specs](https://www.w3.org/TR/webgpu/#dom-supported-limits-minuniformbufferoffsetalignment) for more information.
  #[builder(default, setter(into))]
  pub override_struct_alignment: Vec<OverrideStructAlignment>,

  /// A vector of regular expressions and target module path that that override the module path for bind group entries.
  /// This can be used to customize where bind group entries are generated in the output code.
  #[builder(default, setter(into))]
  pub override_bind_group_entry_module_path: Vec<OverrideBindGroupEntryModulePath>,

  /// The regular expression of the padding fields used in the shader struct types.
  /// These fields will be omitted in the *Init structs generated, and will automatically be assigned the default values.
  #[builder(default, setter(each(name = "add_custom_padding_field_regexp", into)))]
  pub custom_padding_field_regexps: Vec<Regex>,

  /// Whether to always have the init struct generated in the out. This is only applicable when using bytemuck mode.
  #[builder(default = "false")]
  pub always_generate_init_struct: bool,

  /// This field can be used to provide a custom generator for extra bindings that are not covered by the default generator.
  #[builder(default, setter(custom))]
  pub extra_binding_generator: Option<BindingGenerator>,

  /// This field is used to provide the default generator for WGPU bindings. The generator is represented as a `BindingGenerator`.
  #[builder(default, setter(custom))]
  pub wgpu_binding_generator: BindingGenerator,

  /// A vector of texture filterability overrides for specific bindings.
  /// Allows specifying which textures should not be filterable.
  #[builder(default, setter(into))]
  pub override_texture_filterability: Vec<OverrideTextureFilterability>,

  /// A vector of sampler type overrides for specific bindings.
  /// Allows specifying the sampler binding type (Filtering, NonFiltering, Comparison).
  #[builder(default, setter(into))]
  pub override_sampler_type: Vec<OverrideSamplerType>,

  /// Shader definitions to be passed to naga-oil for conditional compilation.
  /// These are preprocessor definitions that can be used in WGSL shaders with #ifdef, #ifndef, etc.
  #[builder(default, setter(into))]
  pub shader_defs: Vec<(String, naga_oil::compose::ShaderDefValue)>,
}

impl WgslBindgenOptionBuilder {
  pub fn build(&mut self) -> Result<WGSLBindgen, WgslBindgenError> {
    self.merge_struct_type_overrides();

    let options = self.fallible_build()?;
    WGSLBindgen::new(options)
  }

  pub fn type_map(&mut self, map_build: impl WgslTypeMapBuild) -> &mut Self {
    let serialization_strategy = self
      .serialization_strategy
      .expect("Serialization strategy must be set before `wgs_type_map`");

    let map = map_build.build(serialization_strategy);

    match self.type_map.as_mut() {
      Some(m) => m.extend(map),
      None => self.type_map = Some(map),
    }

    self
  }

  /// Add a shader definition value
  pub fn add_shader_def(
    &mut self,
    name: impl Into<String>,
    value: naga_oil::compose::ShaderDefValue,
  ) -> &mut Self {
    if self.shader_defs.is_none() {
      self.shader_defs = Some(Vec::new());
    }
    self
      .shader_defs
      .as_mut()
      .unwrap()
      .push((name.into(), value));
    self
  }

  /// Add multiple shader definitions from a Vec
  pub fn add_shader_defs(
    &mut self,
    defs: Vec<(String, naga_oil::compose::ShaderDefValue)>,
  ) -> &mut Self {
    match self.shader_defs.as_mut() {
      Some(existing) => existing.extend(defs),
      None => self.shader_defs = Some(defs),
    }
    self
  }

  fn merge_struct_type_overrides(&mut self) {
    let struct_mappings = self
      .override_struct
      .iter()
      .flatten()
      .map(
        |OverrideStruct {
           from,
           to,
           alignment,
         }| {
          let wgsl_type = WgslType::Struct {
            fully_qualified_name: from.clone(),
          };
          // For struct overrides, we don't know the exact size/alignment, so use placeholders
          // These will be calculated later when the struct is actually used
          (wgsl_type, WgslTypeInfo::new(to.clone(), *alignment))
        },
      )
      .collect::<FastIndexMap<_, _>>();

    self.type_map(struct_mappings);
  }

  pub fn extra_binding_generator(
    &mut self,
    config: impl GetBindingsGeneratorConfig,
  ) -> &mut Self {
    let generator = Some(config.get_generator_config());
    self.extra_binding_generator = Some(generator);
    self
  }
}