ambient_gpu/
shader_module.rs

1use std::{
2    borrow::Cow,
3    collections::{btree_map, BTreeMap},
4    sync::Arc,
5};
6
7use aho_corasick::AhoCorasick;
8use ambient_std::{asset_cache::*, CowStr};
9use anyhow::Context;
10use itertools::Itertools;
11use wgpu::{BindGroupLayout, BindGroupLayoutEntry, ComputePipelineDescriptor, DepthBiasState, TextureFormat};
12
13use super::gpu::{Gpu, GpuKey, DEFAULT_SAMPLE_COUNT};
14
15#[derive(Debug, Clone, PartialEq)]
16pub enum WgslValue {
17    String(CowStr),
18    Raw(CowStr),
19    Float(f32),
20    Int32(u32),
21    Int64(u64),
22}
23
24impl WgslValue {
25    pub fn as_integer(&self) -> Option<u32> {
26        match self {
27            WgslValue::Int32(v) => Some(*v),
28            _ => None,
29        }
30    }
31
32    fn to_wgsl(&self) -> String {
33        match self {
34            WgslValue::String(v) => format!("{v:?}"),
35            WgslValue::Raw(v) => v.to_string(),
36            WgslValue::Float(v) => v.to_string(),
37            WgslValue::Int32(v) => v.to_string(),
38            WgslValue::Int64(v) => v.to_string(),
39        }
40    }
41}
42
43impl From<&'static str> for WgslValue {
44    fn from(v: &'static str) -> Self {
45        Self::String(v.into())
46    }
47}
48impl From<String> for WgslValue {
49    fn from(v: String) -> Self {
50        Self::String(v.into())
51    }
52}
53
54impl From<f32> for WgslValue {
55    fn from(v: f32) -> Self {
56        Self::Float(v)
57    }
58}
59
60impl From<u32> for WgslValue {
61    fn from(v: u32) -> Self {
62        Self::Int32(v)
63    }
64}
65
66impl From<u64> for WgslValue {
67    fn from(v: u64) -> Self {
68        Self::Int64(v)
69    }
70}
71
72#[derive(Debug, Clone, PartialEq)]
73pub struct ShaderIdent {
74    name: CowStr,
75    value: WgslValue,
76}
77
78impl ShaderIdent {
79    /// Shortcut for unescaped text replacement
80    pub fn raw(name: impl Into<CowStr>, value: impl Into<CowStr>) -> Self {
81        Self { name: name.into(), value: WgslValue::Raw(value.into()) }
82    }
83
84    /// Replaces any occurence of `name` with the wgsl representation of `value`
85    pub fn constant(name: impl Into<CowStr>, value: impl Into<WgslValue>) -> Self {
86        Self { name: name.into(), value: value.into() }
87    }
88}
89
90type BindingEntry = (CowStr, BindGroupLayoutEntry);
91
92/// Defines a part of a shader, with preprocessing.
93///
94/// Each shadermodule contains:
95/// - Source code
96/// - Dependencies
97/// - Identifier, used for preprocessing and replacing, such as constants
98/// - A list of binding entries for generating the complete pipeline layout when the shader is assembled.
99///     The bindings *do not* describe complete binding groups, as they may be spread out over several shader modules.
100///
101///     As such, it is not possible to get the bind group layout from a single shader module. Prefer to split out and reuse the entries in a separate function
102#[derive(Debug, Default)]
103pub struct ShaderModule {
104    /// The unique name of the shadermodule.
105    pub name: CowStr,
106    /// The wgsl source for the module, *without* dependencies
107    pub source: CowStr,
108
109    /// Dependencies for the module
110    pub dependencies: Vec<Arc<ShaderModule>>,
111
112    // Use the label to preprocess constants
113    pub idents: Vec<ShaderIdent>,
114    bindings: Vec<BindingEntry>,
115}
116
117impl ShaderModule {
118    pub fn new(name: impl Into<CowStr>, source: impl Into<CowStr>) -> Self {
119        Self {
120            name: name.into(),
121            source: source.into(),
122            idents: Default::default(),
123            bindings: Default::default(),
124            dependencies: Default::default(),
125        }
126    }
127
128    pub fn with_ident(mut self, ident: ShaderIdent) -> Self {
129        self.idents.push(ident);
130        self
131    }
132
133    pub fn with_binding(mut self, group: impl Into<CowStr>, entry: BindGroupLayoutEntry) -> Self {
134        self.bindings.push((group.into(), entry));
135        self
136    }
137
138    pub fn with_bindings(mut self, bindings: impl IntoIterator<Item = (CowStr, BindGroupLayoutEntry)>) -> Self {
139        self.bindings.extend(bindings.into_iter());
140        self
141    }
142
143    pub fn with_binding_desc(mut self, desc: BindGroupDesc<'static>) -> Self {
144        let group = desc.label.clone();
145        self.bindings.extend(desc.entries.iter().map(|&entry| (group.clone(), entry)));
146        self
147    }
148
149    pub fn with_dependency(mut self, module: Arc<ShaderModule>) -> Self {
150        self.dependencies.push(module);
151        self
152    }
153
154    pub fn with_dependencies(mut self, modules: impl IntoIterator<Item = Arc<ShaderModule>>) -> Self {
155        self.dependencies.extend(modules);
156        self
157    }
158
159    fn sanitized_label(&self) -> String {
160        self.name.replace(|v: char| !v.is_ascii_alphanumeric() && !"_-.".contains(v), "?")
161    }
162}
163
164#[derive(Clone, PartialEq, Eq, Debug)]
165pub struct BindGroupDesc<'a> {
166    pub entries: Vec<wgpu::BindGroupLayoutEntry>,
167    // Name for group preprocessor
168    pub label: Cow<'a, str>,
169}
170
171impl<'a> SyncAssetKey<Arc<wgpu::BindGroupLayout>> for BindGroupDesc<'a> {
172    fn load(&self, assets: AssetCache) -> Arc<wgpu::BindGroupLayout> {
173        let gpu = GpuKey.get(&assets);
174
175        let layout =
176            gpu.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { label: Some(&*self.label), entries: &self.entries });
177
178        Arc::new(layout)
179    }
180}
181
182/// Returns all shader modules in the dependency graph in topological order
183///
184/// # Panics
185///
186/// If the dependency graph contains a cycle
187fn resolve_module_graph<'a>(roots: impl IntoIterator<Item = &'a ShaderModule>) -> Vec<&'a ShaderModule> {
188    enum VisitedState {
189        Pending,
190        Visited,
191    }
192
193    let mut visited = BTreeMap::new();
194
195    fn visit<'a>(
196        visited: &mut BTreeMap<&'a str, VisitedState>,
197        result: &mut Vec<&'a ShaderModule>,
198        module: &'a ShaderModule,
199        backtrace: &[&str],
200    ) {
201        match visited.entry(&module.name) {
202            btree_map::Entry::Vacant(slot) => {
203                slot.insert(VisitedState::Pending);
204            }
205            btree_map::Entry::Occupied(slot) => match slot.get() {
206                VisitedState::Pending => panic!("Circular dependency for module: {:?} in {:?}", module.name, backtrace),
207                VisitedState::Visited => return,
208            },
209        }
210
211        let backtrace = backtrace.iter().copied().chain([&*module.name]).collect_vec();
212
213        // Ensure dependencies are satisfied first
214        for module in &module.dependencies {
215            visit(visited, result, module, &backtrace)
216        }
217
218        visited.insert(&module.name, VisitedState::Visited);
219
220        result.push(module);
221    }
222
223    let mut result = Vec::new();
224    for root in roots {
225        visit(&mut visited, &mut result, root, &[]);
226    }
227
228    result
229}
230
231/// Represents a shader and its layout
232pub struct Shader {
233    module: wgpu::ShaderModule,
234    // Ordered sets
235    bind_group_layouts: Vec<Arc<wgpu::BindGroupLayout>>,
236    label: CowStr,
237}
238
239impl std::ops::Deref for Shader {
240    type Target = wgpu::ShaderModule;
241
242    fn deref(&self) -> &Self::Target {
243        &self.module
244    }
245}
246
247impl Shader {
248    pub fn new(
249        assets: &AssetCache,
250        label: impl Into<CowStr>,
251        bind_group_names: &[&str],
252        module: &ShaderModule,
253    ) -> anyhow::Result<Arc<Self>> {
254        let label = label.into();
255        let gpu = GpuKey.get(assets);
256
257        let _span = tracing::info_span!("Shader::from_modules", ?label).entered();
258
259        // The complete dependency graph, in the correct order
260        let modules = resolve_module_graph([module]);
261
262        // Resolve all bind groups, resolving the names to an index
263        let bind_group_index: BTreeMap<_, _> = bind_group_names.iter().enumerate().map(|(a, &b)| (b, a)).collect();
264        let mut bind_groups =
265            bind_group_names.iter().map(|group| BindGroupDesc { label: Cow::Borrowed(*group), entries: Default::default() }).collect_vec();
266
267        for module in &modules {
268            for (group, binding) in &module.bindings {
269                let index =
270                    *bind_group_index.get(&**group).with_context(|| format!("Failed to resolve bind group: {group} in {}", module.name))?;
271
272                let desc = &mut bind_groups[index];
273                desc.entries.push(*binding);
274            }
275        }
276
277        // Now for the fun part: constructing the binding group layout descriptors
278        let bind_group_layouts = bind_groups.iter().map(|desc| desc.get(assets)).collect_vec();
279        if bind_group_layouts.len() > 4 {
280            anyhow::bail!(
281                "Maximum bind group layout count exceeded. Expected a maximum of 4, found {}: {bind_group_names:?}",
282                bind_group_layouts.len()
283            );
284        }
285
286        // Efficiently replace all identifiers
287        let (patterns, replace_with): (Vec<_>, Vec<_>) = modules
288            .iter()
289            .flat_map(|v| v.idents.iter().map(|ShaderIdent { name, value }| (format!("{name}"), value.to_wgsl())))
290            .chain(bind_group_index.iter().map(|(name, &index)| (name.to_string(), (index as u32).to_string())))
291            .unzip();
292
293        tracing::debug!(
294            "Preprocessing shader using {}",
295            patterns.iter().zip_eq(&replace_with).map(|(a, b)| { format!("{a} => {b}") }).format("\n")
296        );
297
298        // Collect the raw source code
299        let source = {
300            let source = modules
301                .iter()
302                .map(|module| {
303                    let div = "--------------------------------";
304                    let label = module.sanitized_label();
305                    let source = &module.source;
306                    format!("// {div}\n// @module: {label}\n// {div}\n{source}")
307                })
308                .join("\n\n");
309
310            AhoCorasick::new(patterns).replace_all(&source, &replace_with)
311        };
312
313        #[cfg(all(not(target_os = "unknown"), debug_assertions))]
314        {
315            let path = format!("tmp/{label}.wgsl");
316            std::fs::create_dir_all("tmp/").unwrap();
317            std::fs::write(path, source.as_bytes()).unwrap();
318        }
319
320        let module = gpu
321            .device
322            .create_shader_module(wgpu::ShaderModuleDescriptor { label: Some(&label), source: wgpu::ShaderSource::Wgsl(source.into()) });
323
324        Ok(Arc::new(Self { module, bind_group_layouts, label }))
325    }
326
327    #[inline]
328    pub fn layouts(&self) -> &[Arc<BindGroupLayout>] {
329        &self.bind_group_layouts
330    }
331
332    /// The wgpu shader module
333    #[inline]
334    pub fn module(&self) -> &wgpu::ShaderModule {
335        &self.module
336    }
337
338    pub fn to_pipeline(self: &Arc<Self>, gpu: &Gpu, info: GraphicsPipelineInfo) -> GraphicsPipeline {
339        let layout = gpu.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
340            label: Some(&self.label),
341            bind_group_layouts: &self.layouts().iter().map(|v| &**v).collect_vec(),
342            push_constant_ranges: &[],
343        });
344
345        let pipeline = gpu.device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
346            label: Some(&self.label),
347            layout: Some(&layout),
348            vertex: wgpu::VertexState { module: self.module(), entry_point: info.vs_main, buffers: &[] },
349            primitive: wgpu::PrimitiveState {
350                front_face: info.front_face,
351                cull_mode: info.cull_mode,
352                topology: info.topology,
353                ..Default::default()
354            },
355            fragment: Some(wgpu::FragmentState { module: self.module(), entry_point: info.fs_main, targets: info.targets }),
356            depth_stencil: info.depth,
357            multisample: wgpu::MultisampleState { count: DEFAULT_SAMPLE_COUNT, mask: !0, alpha_to_coverage_enabled: false },
358            multiview: None,
359        });
360
361        GraphicsPipeline { pipeline, shader: self.clone() }
362    }
363
364    pub fn to_compute_pipeline(self: &Arc<Self>, gpu: &Gpu, entry_point: &str) -> ComputePipeline {
365        let layout = gpu.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
366            label: Some(&self.label),
367            bind_group_layouts: &self.layouts().iter().map(|v| &**v).collect_vec(),
368            push_constant_ranges: &[],
369        });
370
371        let pipeline = gpu.device.create_compute_pipeline(&ComputePipelineDescriptor {
372            label: Some(&self.label),
373            layout: Some(&layout),
374            module: self.module(),
375            entry_point,
376        });
377
378        ComputePipeline { pipeline, shader: self.clone() }
379    }
380}
381
382#[derive(Debug, Clone, PartialEq)]
383pub struct GraphicsPipelineInfo<'a> {
384    pub vs_main: &'a str,
385    pub fs_main: &'a str,
386    pub depth: Option<wgpu::DepthStencilState>,
387    pub targets: &'a [Option<wgpu::ColorTargetState>],
388    pub front_face: wgpu::FrontFace,
389    pub cull_mode: Option<wgpu::Face>,
390    pub topology: wgpu::PrimitiveTopology,
391}
392
393impl<'a> Default for GraphicsPipelineInfo<'a> {
394    fn default() -> Self {
395        Self {
396            vs_main: "vs_main",
397            fs_main: "fs_main",
398            depth: None,
399            targets: &[],
400            front_face: wgpu::FrontFace::Cw,
401            cull_mode: None,
402            topology: wgpu::PrimitiveTopology::TriangleList,
403        }
404    }
405}
406
407pub type GraphicsPipeline = Pipeline<wgpu::RenderPipeline>;
408pub type ComputePipeline = Pipeline<wgpu::ComputePipeline>;
409
410pub struct Pipeline<P> {
411    pipeline: P,
412    shader: Arc<Shader>,
413}
414
415impl<P> Pipeline<P> {
416    /// Get a reference to the graphics pipeline's pipeline.
417    pub fn pipeline(&self) -> &P {
418        &self.pipeline
419    }
420
421    /// Get a reference to the pipeline's shader.
422    #[must_use]
423    pub fn shader(&self) -> &Shader {
424        self.shader.as_ref()
425    }
426}
427
428impl<P> std::ops::Deref for Pipeline<P> {
429    type Target = Shader;
430
431    fn deref(&self) -> &Self::Target {
432        &self.shader
433    }
434}
435
436#[cfg(not(target_os = "unknown"))]
437pub const DEPTH_FORMAT: TextureFormat = TextureFormat::Depth32Float;
438#[cfg(target_os = "unknown")]
439// HACK: float depth are broken on wgpu:
440// stencilLoadOp is (LoadOp::Load) and stencilStoreOp is (StoreOp::Store) when stencilReadOnly (0) or the attachment ([TextureView "Renderer.shadow_target_views" of Texture "Renderer.shadow_texture"]) has no stencil aspect.
441// - While validating depthStencilAttachment.
442// - While encoding [CommandEncoder].BeginRenderPass([RenderPassDescriptor "Shadow cascade 0"]).
443
444// Adding a stencil part crashes the gpu
445pub const DEPTH_FORMAT: TextureFormat = TextureFormat::Depth24PlusStencil8;
446
447impl<'a> GraphicsPipelineInfo<'a> {
448    pub fn with_depth(self) -> GraphicsPipelineInfo<'a> {
449        Self {
450            depth: Some(wgpu::DepthStencilState {
451                format: DEPTH_FORMAT,
452                depth_write_enabled: true,
453                // This is Greater because we're using reverse-z NDC
454                depth_compare: wgpu::CompareFunction::Greater,
455                stencil: wgpu::StencilState::default(),
456                bias: wgpu::DepthBiasState::default(),
457            }),
458            ..self
459        }
460    }
461
462    pub fn with_depth_bias(mut self, state: DepthBiasState) -> GraphicsPipelineInfo<'a> {
463        self.depth.as_mut().expect("Attempt to set depth bias without a depth buffer").bias = state;
464        self
465    }
466}