Skip to main content

cvkg_render_gpu/passes/
shadow.rs

1//! Shadow pass types and KvasirNode — renders depth-only shadow map from
2//! light's perspective using 3D mesh data.
3
4use crate::kvasir::nodes::PassId;
5use crate::kvasir::{ExecutionContext, KvasirNode, ResourceId};
6use glam::Mat4;
7use wgpu::Buffer;
8
9/// Directional light for shadow rendering.
10#[derive(Debug, Clone, Copy)]
11pub struct DirectionalLight {
12    pub direction: glam::Vec3,
13    pub color: glam::Vec3,
14    pub intensity: f32,
15}
16
17impl Default for DirectionalLight {
18    fn default() -> Self {
19        Self {
20            direction: glam::Vec3::new(0.0, -1.0, 0.0),
21            color: glam::Vec3::ONE,
22            intensity: 1.0,
23        }
24    }
25}
26
27/// GPU resources for a single 3D mesh instance ready for rendering.
28#[derive(Debug, Clone)]
29pub struct GpuMesh3d {
30    /// Vertex buffer (position, normal, UV, etc.).
31    pub vertex_buffer: Buffer,
32    /// Index buffer.
33    pub index_buffer: Buffer,
34    /// Number of indices to draw.
35    pub index_count: u32,
36    /// Per-instance model matrix.
37    pub transform: Mat4,
38    /// View depth for transparent sorting (world-space distance from camera).
39    /// Used by TransparentNode for back-to-front rendering.
40    pub view_depth: f32,
41    /// Index of this instance in the 3D instance buffer.
42    pub instance_index: u32,
43}
44
45/// Shadow pass node — renders depth-only shadow map from light's perspective.
46pub struct ShadowNode {
47    pub light: DirectionalLight,
48    pub shadow_map: ResourceId,
49    /// GPU-ready mesh instances to render into the shadow map.
50    pub mesh_instances: Vec<GpuMesh3d>,
51    /// Cascade splits for CSM.
52    pub cascade_splits: [f32; 4],
53    /// Camera's view projection matrix.
54    pub camera_view_proj: Mat4,
55}
56
57impl KvasirNode for ShadowNode {
58    fn label(&self) -> &'static str {
59        "ShadowPass"
60    }
61
62    fn inputs(&self) -> &[ResourceId] {
63        &[]
64    }
65
66    fn outputs(&self) -> &[ResourceId] {
67        std::slice::from_ref(&self.shadow_map)
68    }
69
70    fn pass_id(&self) -> PassId {
71        PassId::Shadow
72    }
73
74    fn execute(&self, ctx: &mut ExecutionContext) {
75        let light_dir = self.light.direction.normalize();
76
77        // 1. Compute 4 cascades VP matrices
78        let inv_cam_vp = self.camera_view_proj.inverse();
79        let ndc_ranges = [
80            (0.0f32, 0.08f32),
81            (0.08f32, 0.22f32),
82            (0.22f32, 0.55f32),
83            (0.55f32, 1.0f32),
84        ];
85
86        let mut cascade_vps = [glam::Mat4::IDENTITY; 4];
87        for i in 0..4 {
88            let (near_ndc, far_ndc) = ndc_ranges[i];
89            let ndc_corners = [
90                glam::Vec3::new(-1.0, -1.0, near_ndc),
91                glam::Vec3::new(1.0, -1.0, near_ndc),
92                glam::Vec3::new(-1.0, 1.0, near_ndc),
93                glam::Vec3::new(1.0, 1.0, near_ndc),
94                glam::Vec3::new(-1.0, -1.0, far_ndc),
95                glam::Vec3::new(1.0, -1.0, far_ndc),
96                glam::Vec3::new(-1.0, 1.0, far_ndc),
97                glam::Vec3::new(1.0, 1.0, far_ndc),
98            ];
99
100            let mut world_corners = [glam::Vec3::ZERO; 8];
101            let mut center = glam::Vec3::ZERO;
102            for j in 0..8 {
103                let p = inv_cam_vp.project_point3(ndc_corners[j]);
104                world_corners[j] = p;
105                center += p;
106            }
107            center /= 8.0;
108
109            let mut radius = 0.0f32;
110            for corner in &world_corners {
111                radius = radius.max(corner.distance(center));
112            }
113
114            // Snap radius to prevent shimmering
115            radius = (radius * 16.0).round() / 16.0;
116
117            let light_pos = center - light_dir * radius * 2.0;
118            let light_view = glam::Mat4::look_at_lh(light_pos, center, glam::Vec3::Y);
119            let light_proj =
120                glam::Mat4::orthographic_lh(-radius, radius, -radius, radius, 0.0, radius * 4.0);
121
122            cascade_vps[i] = light_proj * light_view;
123        }
124
125        // 2. Update CSM buffer with the new cascade splits/VPs
126        let csm = cvkg_core::render_tier::CsmUniforms {
127            cascade_vps,
128            cascade_splits: [
129                self.cascade_splits[0],
130                self.cascade_splits[1],
131                self.cascade_splits[2],
132                self.cascade_splits[3],
133            ],
134            _pad: [0.0; 4],
135        };
136        ctx.queue
137            .write_buffer(&ctx.renderer.csm_buffer, 0, bytemuck::bytes_of(&csm));
138
139        let shadow_texture = match &ctx.renderer.shadow_map_texture {
140            Some(t) => t,
141            None => {
142                tracing::error!("ShadowNode: renderer missing shadow_map_texture");
143                return;
144            }
145        };
146
147        // 3. Render each cascade into its array layer
148        for (i, vp) in cascade_vps.iter().enumerate() {
149            // Write cascade_vps[i] into scene_buffer's light_vp field (offset 320)
150            ctx.queue
151                .write_buffer(&ctx.renderer.scene_buffer, 320, bytemuck::bytes_of(vp));
152
153            let layer_view = shadow_texture.create_view(&wgpu::TextureViewDescriptor {
154                label: Some(&format!("Surtr CSM Shadow Pass Layer {}", i)),
155                dimension: Some(wgpu::TextureViewDimension::D2),
156                base_array_layer: i as u32,
157                array_layer_count: Some(1),
158                ..wgpu::TextureViewDescriptor::default()
159            });
160
161            // Create a depth-only render pass.
162            let mut pass = ctx.encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
163                label: Some(&format!("Shadow Pass Cascade {}", i)),
164                color_attachments: &[], // No color output.
165                depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment {
166                    view: &layer_view,
167                    depth_ops: Some(wgpu::Operations {
168                        load: wgpu::LoadOp::Clear(1.0),
169                        store: wgpu::StoreOp::Store,
170                    }),
171                    stencil_ops: None,
172                }),
173                timestamp_writes: None,
174                occlusion_query_set: None,
175                multiview_mask: None,
176            });
177
178            // Bind the shadow pipeline and scene uniforms.
179            pass.set_pipeline(&ctx.renderer.shadow_pipeline);
180            pass.set_bind_group(1, &ctx.renderer.berserker_bind_group, &[]);
181
182            // For each mesh, set vertex/index buffers and draw depth only.
183            for mesh in self.mesh_instances.iter() {
184                pass.set_vertex_buffer(0, mesh.vertex_buffer.slice(..));
185                pass.set_index_buffer(mesh.index_buffer.slice(..), wgpu::IndexFormat::Uint32);
186                pass.draw_indexed(0..mesh.index_count, 0, 0..1);
187            }
188        }
189    }
190}