1use std::{
2 borrow::Cow,
3 collections::{btree_map, BTreeMap},
4 sync::Arc,
5};
6
7use aho_corasick::AhoCorasick;
8use ambient_std::{asset_cache::*, CowStr};
9use anyhow::Context;
10use itertools::Itertools;
11use wgpu::{BindGroupLayout, BindGroupLayoutEntry, ComputePipelineDescriptor, DepthBiasState, TextureFormat};
12
13use super::gpu::{Gpu, GpuKey, DEFAULT_SAMPLE_COUNT};
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum WgslValue {
17 String(CowStr),
18 Raw(CowStr),
19 Float(f32),
20 Int32(u32),
21 Int64(u64),
22}
23
24impl WgslValue {
25 pub fn as_integer(&self) -> Option<u32> {
26 match self {
27 WgslValue::Int32(v) => Some(*v),
28 _ => None,
29 }
30 }
31
32 fn to_wgsl(&self) -> String {
33 match self {
34 WgslValue::String(v) => format!("{v:?}"),
35 WgslValue::Raw(v) => v.to_string(),
36 WgslValue::Float(v) => v.to_string(),
37 WgslValue::Int32(v) => v.to_string(),
38 WgslValue::Int64(v) => v.to_string(),
39 }
40 }
41}
42
43impl From<&'static str> for WgslValue {
44 fn from(v: &'static str) -> Self {
45 Self::String(v.into())
46 }
47}
48impl From<String> for WgslValue {
49 fn from(v: String) -> Self {
50 Self::String(v.into())
51 }
52}
53
54impl From<f32> for WgslValue {
55 fn from(v: f32) -> Self {
56 Self::Float(v)
57 }
58}
59
60impl From<u32> for WgslValue {
61 fn from(v: u32) -> Self {
62 Self::Int32(v)
63 }
64}
65
66impl From<u64> for WgslValue {
67 fn from(v: u64) -> Self {
68 Self::Int64(v)
69 }
70}
71
72#[derive(Debug, Clone, PartialEq)]
73pub struct ShaderIdent {
74 name: CowStr,
75 value: WgslValue,
76}
77
78impl ShaderIdent {
79 pub fn raw(name: impl Into<CowStr>, value: impl Into<CowStr>) -> Self {
81 Self { name: name.into(), value: WgslValue::Raw(value.into()) }
82 }
83
84 pub fn constant(name: impl Into<CowStr>, value: impl Into<WgslValue>) -> Self {
86 Self { name: name.into(), value: value.into() }
87 }
88}
89
90type BindingEntry = (CowStr, BindGroupLayoutEntry);
91
92#[derive(Debug, Default)]
103pub struct ShaderModule {
104 pub name: CowStr,
106 pub source: CowStr,
108
109 pub dependencies: Vec<Arc<ShaderModule>>,
111
112 pub idents: Vec<ShaderIdent>,
114 bindings: Vec<BindingEntry>,
115}
116
117impl ShaderModule {
118 pub fn new(name: impl Into<CowStr>, source: impl Into<CowStr>) -> Self {
119 Self {
120 name: name.into(),
121 source: source.into(),
122 idents: Default::default(),
123 bindings: Default::default(),
124 dependencies: Default::default(),
125 }
126 }
127
128 pub fn with_ident(mut self, ident: ShaderIdent) -> Self {
129 self.idents.push(ident);
130 self
131 }
132
133 pub fn with_binding(mut self, group: impl Into<CowStr>, entry: BindGroupLayoutEntry) -> Self {
134 self.bindings.push((group.into(), entry));
135 self
136 }
137
138 pub fn with_bindings(mut self, bindings: impl IntoIterator<Item = (CowStr, BindGroupLayoutEntry)>) -> Self {
139 self.bindings.extend(bindings.into_iter());
140 self
141 }
142
143 pub fn with_binding_desc(mut self, desc: BindGroupDesc<'static>) -> Self {
144 let group = desc.label.clone();
145 self.bindings.extend(desc.entries.iter().map(|&entry| (group.clone(), entry)));
146 self
147 }
148
149 pub fn with_dependency(mut self, module: Arc<ShaderModule>) -> Self {
150 self.dependencies.push(module);
151 self
152 }
153
154 pub fn with_dependencies(mut self, modules: impl IntoIterator<Item = Arc<ShaderModule>>) -> Self {
155 self.dependencies.extend(modules);
156 self
157 }
158
159 fn sanitized_label(&self) -> String {
160 self.name.replace(|v: char| !v.is_ascii_alphanumeric() && !"_-.".contains(v), "?")
161 }
162}
163
164#[derive(Clone, PartialEq, Eq, Debug)]
165pub struct BindGroupDesc<'a> {
166 pub entries: Vec<wgpu::BindGroupLayoutEntry>,
167 pub label: Cow<'a, str>,
169}
170
171impl<'a> SyncAssetKey<Arc<wgpu::BindGroupLayout>> for BindGroupDesc<'a> {
172 fn load(&self, assets: AssetCache) -> Arc<wgpu::BindGroupLayout> {
173 let gpu = GpuKey.get(&assets);
174
175 let layout =
176 gpu.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: Some(&*self.label), entries: &self.entries });
177
178 Arc::new(layout)
179 }
180}
181
182fn resolve_module_graph<'a>(roots: impl IntoIterator<Item = &'a ShaderModule>) -> Vec<&'a ShaderModule> {
188 enum VisitedState {
189 Pending,
190 Visited,
191 }
192
193 let mut visited = BTreeMap::new();
194
195 fn visit<'a>(
196 visited: &mut BTreeMap<&'a str, VisitedState>,
197 result: &mut Vec<&'a ShaderModule>,
198 module: &'a ShaderModule,
199 backtrace: &[&str],
200 ) {
201 match visited.entry(&module.name) {
202 btree_map::Entry::Vacant(slot) => {
203 slot.insert(VisitedState::Pending);
204 }
205 btree_map::Entry::Occupied(slot) => match slot.get() {
206 VisitedState::Pending => panic!("Circular dependency for module: {:?} in {:?}", module.name, backtrace),
207 VisitedState::Visited => return,
208 },
209 }
210
211 let backtrace = backtrace.iter().copied().chain([&*module.name]).collect_vec();
212
213 for module in &module.dependencies {
215 visit(visited, result, module, &backtrace)
216 }
217
218 visited.insert(&module.name, VisitedState::Visited);
219
220 result.push(module);
221 }
222
223 let mut result = Vec::new();
224 for root in roots {
225 visit(&mut visited, &mut result, root, &[]);
226 }
227
228 result
229}
230
231pub struct Shader {
233 module: wgpu::ShaderModule,
234 bind_group_layouts: Vec<Arc<wgpu::BindGroupLayout>>,
236 label: CowStr,
237}
238
239impl std::ops::Deref for Shader {
240 type Target = wgpu::ShaderModule;
241
242 fn deref(&self) -> &Self::Target {
243 &self.module
244 }
245}
246
247impl Shader {
248 pub fn new(
249 assets: &AssetCache,
250 label: impl Into<CowStr>,
251 bind_group_names: &[&str],
252 module: &ShaderModule,
253 ) -> anyhow::Result<Arc<Self>> {
254 let label = label.into();
255 let gpu = GpuKey.get(assets);
256
257 let _span = tracing::info_span!("Shader::from_modules", ?label).entered();
258
259 let modules = resolve_module_graph([module]);
261
262 let bind_group_index: BTreeMap<_, _> = bind_group_names.iter().enumerate().map(|(a, &b)| (b, a)).collect();
264 let mut bind_groups =
265 bind_group_names.iter().map(|group| BindGroupDesc { label: Cow::Borrowed(*group), entries: Default::default() }).collect_vec();
266
267 for module in &modules {
268 for (group, binding) in &module.bindings {
269 let index =
270 *bind_group_index.get(&**group).with_context(|| format!("Failed to resolve bind group: {group} in {}", module.name))?;
271
272 let desc = &mut bind_groups[index];
273 desc.entries.push(*binding);
274 }
275 }
276
277 let bind_group_layouts = bind_groups.iter().map(|desc| desc.get(assets)).collect_vec();
279 if bind_group_layouts.len() > 4 {
280 anyhow::bail!(
281 "Maximum bind group layout count exceeded. Expected a maximum of 4, found {}: {bind_group_names:?}",
282 bind_group_layouts.len()
283 );
284 }
285
286 let (patterns, replace_with): (Vec<_>, Vec<_>) = modules
288 .iter()
289 .flat_map(|v| v.idents.iter().map(|ShaderIdent { name, value }| (format!("{name}"), value.to_wgsl())))
290 .chain(bind_group_index.iter().map(|(name, &index)| (name.to_string(), (index as u32).to_string())))
291 .unzip();
292
293 tracing::debug!(
294 "Preprocessing shader using {}",
295 patterns.iter().zip_eq(&replace_with).map(|(a, b)| { format!("{a} => {b}") }).format("\n")
296 );
297
298 let source = {
300 let source = modules
301 .iter()
302 .map(|module| {
303 let div = "--------------------------------";
304 let label = module.sanitized_label();
305 let source = &module.source;
306 format!("// {div}\n// @module: {label}\n// {div}\n{source}")
307 })
308 .join("\n\n");
309
310 AhoCorasick::new(patterns).replace_all(&source, &replace_with)
311 };
312
313 #[cfg(all(not(target_os = "unknown"), debug_assertions))]
314 {
315 let path = format!("tmp/{label}.wgsl");
316 std::fs::create_dir_all("tmp/").unwrap();
317 std::fs::write(path, source.as_bytes()).unwrap();
318 }
319
320 let module = gpu
321 .device
322 .create_shader_module(wgpu::ShaderModuleDescriptor { label: Some(&label), source: wgpu::ShaderSource::Wgsl(source.into()) });
323
324 Ok(Arc::new(Self { module, bind_group_layouts, label }))
325 }
326
327 #[inline]
328 pub fn layouts(&self) -> &[Arc<BindGroupLayout>] {
329 &self.bind_group_layouts
330 }
331
332 #[inline]
334 pub fn module(&self) -> &wgpu::ShaderModule {
335 &self.module
336 }
337
338 pub fn to_pipeline(self: &Arc<Self>, gpu: &Gpu, info: GraphicsPipelineInfo) -> GraphicsPipeline {
339 let layout = gpu.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
340 label: Some(&self.label),
341 bind_group_layouts: &self.layouts().iter().map(|v| &**v).collect_vec(),
342 push_constant_ranges: &[],
343 });
344
345 let pipeline = gpu.device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
346 label: Some(&self.label),
347 layout: Some(&layout),
348 vertex: wgpu::VertexState { module: self.module(), entry_point: info.vs_main, buffers: &[] },
349 primitive: wgpu::PrimitiveState {
350 front_face: info.front_face,
351 cull_mode: info.cull_mode,
352 topology: info.topology,
353 ..Default::default()
354 },
355 fragment: Some(wgpu::FragmentState { module: self.module(), entry_point: info.fs_main, targets: info.targets }),
356 depth_stencil: info.depth,
357 multisample: wgpu::MultisampleState { count: DEFAULT_SAMPLE_COUNT, mask: !0, alpha_to_coverage_enabled: false },
358 multiview: None,
359 });
360
361 GraphicsPipeline { pipeline, shader: self.clone() }
362 }
363
364 pub fn to_compute_pipeline(self: &Arc<Self>, gpu: &Gpu, entry_point: &str) -> ComputePipeline {
365 let layout = gpu.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
366 label: Some(&self.label),
367 bind_group_layouts: &self.layouts().iter().map(|v| &**v).collect_vec(),
368 push_constant_ranges: &[],
369 });
370
371 let pipeline = gpu.device.create_compute_pipeline(&ComputePipelineDescriptor {
372 label: Some(&self.label),
373 layout: Some(&layout),
374 module: self.module(),
375 entry_point,
376 });
377
378 ComputePipeline { pipeline, shader: self.clone() }
379 }
380}
381
382#[derive(Debug, Clone, PartialEq)]
383pub struct GraphicsPipelineInfo<'a> {
384 pub vs_main: &'a str,
385 pub fs_main: &'a str,
386 pub depth: Option<wgpu::DepthStencilState>,
387 pub targets: &'a [Option<wgpu::ColorTargetState>],
388 pub front_face: wgpu::FrontFace,
389 pub cull_mode: Option<wgpu::Face>,
390 pub topology: wgpu::PrimitiveTopology,
391}
392
393impl<'a> Default for GraphicsPipelineInfo<'a> {
394 fn default() -> Self {
395 Self {
396 vs_main: "vs_main",
397 fs_main: "fs_main",
398 depth: None,
399 targets: &[],
400 front_face: wgpu::FrontFace::Cw,
401 cull_mode: None,
402 topology: wgpu::PrimitiveTopology::TriangleList,
403 }
404 }
405}
406
407pub type GraphicsPipeline = Pipeline<wgpu::RenderPipeline>;
408pub type ComputePipeline = Pipeline<wgpu::ComputePipeline>;
409
410pub struct Pipeline<P> {
411 pipeline: P,
412 shader: Arc<Shader>,
413}
414
415impl<P> Pipeline<P> {
416 pub fn pipeline(&self) -> &P {
418 &self.pipeline
419 }
420
421 #[must_use]
423 pub fn shader(&self) -> &Shader {
424 self.shader.as_ref()
425 }
426}
427
428impl<P> std::ops::Deref for Pipeline<P> {
429 type Target = Shader;
430
431 fn deref(&self) -> &Self::Target {
432 &self.shader
433 }
434}
435
436#[cfg(not(target_os = "unknown"))]
437pub const DEPTH_FORMAT: TextureFormat = TextureFormat::Depth32Float;
438#[cfg(target_os = "unknown")]
439pub const DEPTH_FORMAT: TextureFormat = TextureFormat::Depth24PlusStencil8;
446
447impl<'a> GraphicsPipelineInfo<'a> {
448 pub fn with_depth(self) -> GraphicsPipelineInfo<'a> {
449 Self {
450 depth: Some(wgpu::DepthStencilState {
451 format: DEPTH_FORMAT,
452 depth_write_enabled: true,
453 depth_compare: wgpu::CompareFunction::Greater,
455 stencil: wgpu::StencilState::default(),
456 bias: wgpu::DepthBiasState::default(),
457 }),
458 ..self
459 }
460 }
461
462 pub fn with_depth_bias(mut self, state: DepthBiasState) -> GraphicsPipelineInfo<'a> {
463 self.depth.as_mut().expect("Attempt to set depth bias without a depth buffer").bias = state;
464 self
465 }
466}