Skip to main content

polyscope_render/
ssao_pass.rs

1//! SSAO (Screen Space Ambient Occlusion) rendering pass.
2
3use glam::Mat4;
4use std::num::NonZeroU64;
5use wgpu::util::DeviceExt;
6
7/// GPU representation of SSAO uniforms.
8#[repr(C)]
9#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
10#[allow(clippy::pub_underscore_fields)]
11pub struct SsaoUniforms {
12    pub proj: [[f32; 4]; 4],
13    pub inv_proj: [[f32; 4]; 4],
14    pub radius: f32,
15    pub bias: f32,
16    pub intensity: f32,
17    pub sample_count: u32,
18    pub screen_width: f32,
19    pub screen_height: f32,
20    pub _padding: [f32; 2],
21}
22
23impl Default for SsaoUniforms {
24    fn default() -> Self {
25        Self {
26            proj: Mat4::IDENTITY.to_cols_array_2d(),
27            inv_proj: Mat4::IDENTITY.to_cols_array_2d(),
28            radius: 0.5,
29            bias: 0.025,
30            intensity: 1.0,
31            sample_count: 16,
32            screen_width: 1280.0,
33            screen_height: 720.0,
34            _padding: [0.0; 2],
35        }
36    }
37}
38
39/// SSAO blur uniforms.
40#[repr(C)]
41#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
42#[allow(clippy::pub_underscore_fields)]
43pub struct SsaoBlurUniforms {
44    pub texel_size: [f32; 2],
45    pub blur_scale: f32,
46    /// Controls edge preservation (higher = sharper edges preserved)
47    pub blur_sharpness: f32,
48}
49
50/// SSAO pass resources.
51pub struct SsaoPass {
52    // Main SSAO pass
53    ssao_pipeline: wgpu::RenderPipeline,
54    ssao_bind_group_layout: wgpu::BindGroupLayout,
55    ssao_uniform_buffer: wgpu::Buffer,
56    // Blur pass
57    blur_pipeline: wgpu::RenderPipeline,
58    blur_bind_group_layout: wgpu::BindGroupLayout,
59    blur_uniform_buffer: wgpu::Buffer,
60    // Intermediate texture for blur
61    ssao_texture: wgpu::Texture,
62    ssao_view: wgpu::TextureView,
63    // Sampler
64    sampler: wgpu::Sampler,
65}
66
67impl SsaoPass {
68    /// Creates a new SSAO pass.
69    #[must_use]
70    pub fn new(device: &wgpu::Device, width: u32, height: u32) -> Self {
71        // Create SSAO shader
72        let ssao_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
73            label: Some("SSAO Shader"),
74            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/ssao.wgsl").into()),
75        });
76
77        // Create blur shader
78        let blur_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
79            label: Some("SSAO Blur Shader"),
80            source: wgpu::ShaderSource::Wgsl(include_str!("shaders/ssao_blur.wgsl").into()),
81        });
82
83        // SSAO bind group layout
84        let ssao_bind_group_layout =
85            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
86                label: Some("SSAO Bind Group Layout"),
87                entries: &[
88                    // Depth texture
89                    wgpu::BindGroupLayoutEntry {
90                        binding: 0,
91                        visibility: wgpu::ShaderStages::FRAGMENT,
92                        ty: wgpu::BindingType::Texture {
93                            sample_type: wgpu::TextureSampleType::Depth,
94                            view_dimension: wgpu::TextureViewDimension::D2,
95                            multisampled: false,
96                        },
97                        count: None,
98                    },
99                    // Normal texture
100                    wgpu::BindGroupLayoutEntry {
101                        binding: 1,
102                        visibility: wgpu::ShaderStages::FRAGMENT,
103                        ty: wgpu::BindingType::Texture {
104                            sample_type: wgpu::TextureSampleType::Float { filterable: true },
105                            view_dimension: wgpu::TextureViewDimension::D2,
106                            multisampled: false,
107                        },
108                        count: None,
109                    },
110                    // Noise texture
111                    wgpu::BindGroupLayoutEntry {
112                        binding: 2,
113                        visibility: wgpu::ShaderStages::FRAGMENT,
114                        ty: wgpu::BindingType::Texture {
115                            sample_type: wgpu::TextureSampleType::Float { filterable: true },
116                            view_dimension: wgpu::TextureViewDimension::D2,
117                            multisampled: false,
118                        },
119                        count: None,
120                    },
121                    // Sampler
122                    wgpu::BindGroupLayoutEntry {
123                        binding: 3,
124                        visibility: wgpu::ShaderStages::FRAGMENT,
125                        ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
126                        count: None,
127                    },
128                    // Uniforms
129                    wgpu::BindGroupLayoutEntry {
130                        binding: 4,
131                        visibility: wgpu::ShaderStages::FRAGMENT,
132                        ty: wgpu::BindingType::Buffer {
133                            ty: wgpu::BufferBindingType::Uniform,
134                            has_dynamic_offset: false,
135                            min_binding_size: NonZeroU64::new(160),
136                        },
137                        count: None,
138                    },
139                ],
140            });
141
142        // SSAO pipeline
143        let ssao_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
144            label: Some("SSAO Pipeline Layout"),
145            bind_group_layouts: &[&ssao_bind_group_layout],
146            push_constant_ranges: &[],
147        });
148
149        let ssao_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
150            label: Some("SSAO Pipeline"),
151            layout: Some(&ssao_pipeline_layout),
152            vertex: wgpu::VertexState {
153                module: &ssao_shader,
154                entry_point: Some("vs_main"),
155                buffers: &[],
156                compilation_options: wgpu::PipelineCompilationOptions::default(),
157            },
158            fragment: Some(wgpu::FragmentState {
159                module: &ssao_shader,
160                entry_point: Some("fs_main"),
161                targets: &[Some(wgpu::ColorTargetState {
162                    format: wgpu::TextureFormat::R8Unorm, // Single channel for occlusion
163                    blend: None,
164                    write_mask: wgpu::ColorWrites::ALL,
165                })],
166                compilation_options: wgpu::PipelineCompilationOptions::default(),
167            }),
168            primitive: wgpu::PrimitiveState {
169                topology: wgpu::PrimitiveTopology::TriangleList,
170                ..Default::default()
171            },
172            depth_stencil: None,
173            multisample: wgpu::MultisampleState::default(),
174            multiview: None,
175            cache: None,
176        });
177
178        // Blur bind group layout - now includes depth texture for edge-aware blurring
179        let blur_bind_group_layout =
180            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
181                label: Some("SSAO Blur Bind Group Layout"),
182                entries: &[
183                    // SSAO texture (binding 0)
184                    wgpu::BindGroupLayoutEntry {
185                        binding: 0,
186                        visibility: wgpu::ShaderStages::FRAGMENT,
187                        ty: wgpu::BindingType::Texture {
188                            sample_type: wgpu::TextureSampleType::Float { filterable: true },
189                            view_dimension: wgpu::TextureViewDimension::D2,
190                            multisampled: false,
191                        },
192                        count: None,
193                    },
194                    // Depth texture (binding 1) - for edge-aware bilateral blur
195                    wgpu::BindGroupLayoutEntry {
196                        binding: 1,
197                        visibility: wgpu::ShaderStages::FRAGMENT,
198                        ty: wgpu::BindingType::Texture {
199                            sample_type: wgpu::TextureSampleType::Depth,
200                            view_dimension: wgpu::TextureViewDimension::D2,
201                            multisampled: false,
202                        },
203                        count: None,
204                    },
205                    // Sampler (binding 2)
206                    wgpu::BindGroupLayoutEntry {
207                        binding: 2,
208                        visibility: wgpu::ShaderStages::FRAGMENT,
209                        ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
210                        count: None,
211                    },
212                    // Uniforms (binding 3)
213                    wgpu::BindGroupLayoutEntry {
214                        binding: 3,
215                        visibility: wgpu::ShaderStages::FRAGMENT,
216                        ty: wgpu::BindingType::Buffer {
217                            ty: wgpu::BufferBindingType::Uniform,
218                            has_dynamic_offset: false,
219                            min_binding_size: NonZeroU64::new(16),
220                        },
221                        count: None,
222                    },
223                ],
224            });
225
226        // Blur pipeline
227        let blur_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
228            label: Some("SSAO Blur Pipeline Layout"),
229            bind_group_layouts: &[&blur_bind_group_layout],
230            push_constant_ranges: &[],
231        });
232
233        let blur_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
234            label: Some("SSAO Blur Pipeline"),
235            layout: Some(&blur_pipeline_layout),
236            vertex: wgpu::VertexState {
237                module: &blur_shader,
238                entry_point: Some("vs_main"),
239                buffers: &[],
240                compilation_options: wgpu::PipelineCompilationOptions::default(),
241            },
242            fragment: Some(wgpu::FragmentState {
243                module: &blur_shader,
244                entry_point: Some("fs_main"),
245                targets: &[Some(wgpu::ColorTargetState {
246                    format: wgpu::TextureFormat::R8Unorm,
247                    blend: None,
248                    write_mask: wgpu::ColorWrites::ALL,
249                })],
250                compilation_options: wgpu::PipelineCompilationOptions::default(),
251            }),
252            primitive: wgpu::PrimitiveState {
253                topology: wgpu::PrimitiveTopology::TriangleList,
254                ..Default::default()
255            },
256            depth_stencil: None,
257            multisample: wgpu::MultisampleState::default(),
258            multiview: None,
259            cache: None,
260        });
261
262        // Create uniform buffers
263        let ssao_uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
264            label: Some("SSAO Uniform Buffer"),
265            contents: bytemuck::cast_slice(&[SsaoUniforms::default()]),
266            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
267        });
268
269        let blur_uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
270            label: Some("SSAO Blur Uniform Buffer"),
271            contents: bytemuck::cast_slice(&[SsaoBlurUniforms {
272                texel_size: [1.0 / width as f32, 1.0 / height as f32],
273                blur_scale: 1.0,
274                blur_sharpness: 50.0, // Edge preservation strength
275            }]),
276            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
277        });
278
279        // Create intermediate SSAO texture
280        let ssao_texture = device.create_texture(&wgpu::TextureDescriptor {
281            label: Some("SSAO Texture"),
282            size: wgpu::Extent3d {
283                width,
284                height,
285                depth_or_array_layers: 1,
286            },
287            mip_level_count: 1,
288            sample_count: 1,
289            dimension: wgpu::TextureDimension::D2,
290            format: wgpu::TextureFormat::R8Unorm,
291            usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
292            view_formats: &[],
293        });
294
295        let ssao_view = ssao_texture.create_view(&wgpu::TextureViewDescriptor::default());
296
297        // Create sampler
298        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
299            label: Some("SSAO Sampler"),
300            mag_filter: wgpu::FilterMode::Linear,
301            min_filter: wgpu::FilterMode::Linear,
302            address_mode_u: wgpu::AddressMode::ClampToEdge,
303            address_mode_v: wgpu::AddressMode::ClampToEdge,
304            ..Default::default()
305        });
306
307        Self {
308            ssao_pipeline,
309            ssao_bind_group_layout,
310            ssao_uniform_buffer,
311            blur_pipeline,
312            blur_bind_group_layout,
313            blur_uniform_buffer,
314            ssao_texture,
315            ssao_view,
316            sampler,
317        }
318    }
319
320    /// Resizes the SSAO textures.
321    pub fn resize(&mut self, device: &wgpu::Device, queue: &wgpu::Queue, width: u32, height: u32) {
322        // Recreate SSAO texture
323        self.ssao_texture = device.create_texture(&wgpu::TextureDescriptor {
324            label: Some("SSAO Texture"),
325            size: wgpu::Extent3d {
326                width,
327                height,
328                depth_or_array_layers: 1,
329            },
330            mip_level_count: 1,
331            sample_count: 1,
332            dimension: wgpu::TextureDimension::D2,
333            format: wgpu::TextureFormat::R8Unorm,
334            usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
335            view_formats: &[],
336        });
337
338        self.ssao_view = self
339            .ssao_texture
340            .create_view(&wgpu::TextureViewDescriptor::default());
341
342        // Update blur uniforms
343        queue.write_buffer(
344            &self.blur_uniform_buffer,
345            0,
346            bytemuck::cast_slice(&[SsaoBlurUniforms {
347                texel_size: [1.0 / width as f32, 1.0 / height as f32],
348                blur_scale: 1.0,
349                blur_sharpness: 50.0,
350            }]),
351        );
352    }
353
354    /// Updates SSAO uniforms.
355    pub fn update_uniforms(
356        &self,
357        queue: &wgpu::Queue,
358        proj: Mat4,
359        inv_proj: Mat4,
360        radius: f32,
361        bias: f32,
362        intensity: f32,
363        sample_count: u32,
364        width: f32,
365        height: f32,
366    ) {
367        let uniforms = SsaoUniforms {
368            proj: proj.to_cols_array_2d(),
369            inv_proj: inv_proj.to_cols_array_2d(),
370            radius,
371            bias,
372            intensity,
373            sample_count,
374            screen_width: width,
375            screen_height: height,
376            _padding: [0.0; 2],
377        };
378        queue.write_buffer(
379            &self.ssao_uniform_buffer,
380            0,
381            bytemuck::cast_slice(&[uniforms]),
382        );
383    }
384
385    /// Creates a bind group for the SSAO pass.
386    #[must_use]
387    pub fn create_ssao_bind_group(
388        &self,
389        device: &wgpu::Device,
390        depth_view: &wgpu::TextureView,
391        normal_view: &wgpu::TextureView,
392        noise_view: &wgpu::TextureView,
393    ) -> wgpu::BindGroup {
394        device.create_bind_group(&wgpu::BindGroupDescriptor {
395            label: Some("SSAO Bind Group"),
396            layout: &self.ssao_bind_group_layout,
397            entries: &[
398                wgpu::BindGroupEntry {
399                    binding: 0,
400                    resource: wgpu::BindingResource::TextureView(depth_view),
401                },
402                wgpu::BindGroupEntry {
403                    binding: 1,
404                    resource: wgpu::BindingResource::TextureView(normal_view),
405                },
406                wgpu::BindGroupEntry {
407                    binding: 2,
408                    resource: wgpu::BindingResource::TextureView(noise_view),
409                },
410                wgpu::BindGroupEntry {
411                    binding: 3,
412                    resource: wgpu::BindingResource::Sampler(&self.sampler),
413                },
414                wgpu::BindGroupEntry {
415                    binding: 4,
416                    resource: self.ssao_uniform_buffer.as_entire_binding(),
417                },
418            ],
419        })
420    }
421
422    /// Creates a bind group for the blur pass.
423    /// The `depth_view` is used for edge-aware bilateral blurring.
424    #[must_use]
425    pub fn create_blur_bind_group(
426        &self,
427        device: &wgpu::Device,
428        depth_view: &wgpu::TextureView,
429    ) -> wgpu::BindGroup {
430        device.create_bind_group(&wgpu::BindGroupDescriptor {
431            label: Some("SSAO Blur Bind Group"),
432            layout: &self.blur_bind_group_layout,
433            entries: &[
434                wgpu::BindGroupEntry {
435                    binding: 0,
436                    resource: wgpu::BindingResource::TextureView(&self.ssao_view),
437                },
438                wgpu::BindGroupEntry {
439                    binding: 1,
440                    resource: wgpu::BindingResource::TextureView(depth_view),
441                },
442                wgpu::BindGroupEntry {
443                    binding: 2,
444                    resource: wgpu::BindingResource::Sampler(&self.sampler),
445                },
446                wgpu::BindGroupEntry {
447                    binding: 3,
448                    resource: self.blur_uniform_buffer.as_entire_binding(),
449                },
450            ],
451        })
452    }
453
454    /// Renders the SSAO pass.
455    pub fn render_ssao(&self, encoder: &mut wgpu::CommandEncoder, bind_group: &wgpu::BindGroup) {
456        let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
457            label: Some("SSAO Pass"),
458            color_attachments: &[Some(wgpu::RenderPassColorAttachment {
459                view: &self.ssao_view,
460                resolve_target: None,
461                ops: wgpu::Operations {
462                    load: wgpu::LoadOp::Clear(wgpu::Color::WHITE),
463                    store: wgpu::StoreOp::Store,
464                },
465                depth_slice: None,
466            })],
467            depth_stencil_attachment: None,
468            ..Default::default()
469        });
470
471        render_pass.set_pipeline(&self.ssao_pipeline);
472        render_pass.set_bind_group(0, bind_group, &[]);
473        render_pass.draw(0..3, 0..1);
474    }
475
476    /// Renders the blur pass to the output texture.
477    pub fn render_blur(
478        &self,
479        encoder: &mut wgpu::CommandEncoder,
480        output_view: &wgpu::TextureView,
481        bind_group: &wgpu::BindGroup,
482    ) {
483        let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
484            label: Some("SSAO Blur Pass"),
485            color_attachments: &[Some(wgpu::RenderPassColorAttachment {
486                view: output_view,
487                resolve_target: None,
488                ops: wgpu::Operations {
489                    load: wgpu::LoadOp::Clear(wgpu::Color::WHITE),
490                    store: wgpu::StoreOp::Store,
491                },
492                depth_slice: None,
493            })],
494            depth_stencil_attachment: None,
495            ..Default::default()
496        });
497
498        render_pass.set_pipeline(&self.blur_pipeline);
499        render_pass.set_bind_group(0, bind_group, &[]);
500        render_pass.draw(0..3, 0..1);
501    }
502
503    /// Returns the blurred SSAO texture view.
504    #[must_use]
505    pub fn ssao_view(&self) -> &wgpu::TextureView {
506        &self.ssao_view
507    }
508}