use crate::math_util::{IDENTITY_MAT4, cross, look_at, mul_mat4, normalize3, perspective_90};
use crate::mesh_pipeline::{DepthBuffer, Mesh};
use crate::vertex::Vertex3D;
pub const DEFAULT_SHADOW_MAP_SIZE: u32 = 2048;
pub struct ShadowMap {
pub depth_texture: wgpu::Texture,
pub depth_view: wgpu::TextureView,
pub sampler: wgpu::Sampler,
pub size: u32,
}
impl ShadowMap {
#[must_use]
pub fn new(device: &wgpu::Device, size: u32) -> Self {
let depth_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("shadow_map"),
size: wgpu::Extent3d {
width: size,
height: size,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: DepthBuffer::FORMAT,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let depth_view = depth_texture.create_view(&wgpu::TextureViewDescriptor::default());
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("shadow_sampler"),
address_mode_u: wgpu::AddressMode::ClampToEdge,
address_mode_v: wgpu::AddressMode::ClampToEdge,
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
compare: Some(wgpu::CompareFunction::LessEqual),
..Default::default()
});
Self {
depth_texture,
depth_view,
sampler,
size,
}
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct ShadowUniforms {
pub light_view_proj: [f32; 16],
pub model: [f32; 16],
}
impl Default for ShadowUniforms {
fn default() -> Self {
Self {
light_view_proj: IDENTITY_MAT4,
model: IDENTITY_MAT4,
}
}
}
#[must_use]
#[inline]
pub fn directional_light_matrix(
direction: [f32; 3],
extent: f32,
near: f32,
far: f32,
) -> [f32; 16] {
if extent <= 0.0 {
tracing::warn!(
extent,
"directional_light_matrix: extent <= 0 — returning identity matrix"
);
return IDENTITY_MAT4;
}
if (far - near).abs() < 1e-10 {
tracing::warn!(
near,
far,
"directional_light_matrix: far ≈ near — returning identity matrix"
);
return IDENTITY_MAT4;
}
let d = normalize3(direction);
let up = if d[1].abs() > 0.99 {
[1.0, 0.0, 0.0]
} else {
[0.0, 1.0, 0.0]
};
let right = normalize3(cross(up, d));
let actual_up = cross(d, right);
let view = [
right[0],
actual_up[0],
d[0],
0.0,
right[1],
actual_up[1],
d[1],
0.0,
right[2],
actual_up[2],
d[2],
0.0,
0.0,
0.0,
0.0,
1.0,
];
let l = -extent;
let r = extent;
let b = -extent;
let t = extent;
let proj = [
2.0 / (r - l),
0.0,
0.0,
0.0,
0.0,
2.0 / (t - b),
0.0,
0.0,
0.0,
0.0,
1.0 / (far - near),
0.0,
-(r + l) / (r - l),
-(t + b) / (t - b),
-near / (far - near),
1.0,
];
mul_mat4(proj, view)
}
pub struct ShadowPipeline {
render_pipeline: wgpu::RenderPipeline,
uniform_buffer: wgpu::Buffer,
uniform_bind_group: wgpu::BindGroup,
}
impl ShadowPipeline {
pub fn new(device: &wgpu::Device) -> Self {
tracing::debug!("creating shadow pipeline");
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("shadow_shader"),
source: wgpu::ShaderSource::Wgsl(include_str!("shadow.wgsl").into()),
});
let bind_group_layout = mabda::BindGroupLayoutBuilder::new()
.uniform_buffer(wgpu::ShaderStages::VERTEX)
.build(device, "shadow_uniform_layout");
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("shadow_pipeline_layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let render_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("shadow_pipeline"),
layout: Some(&pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
buffers: &[Vertex3D::layout()],
compilation_options: wgpu::PipelineCompilationOptions::default(),
},
fragment: None, primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleList,
cull_mode: Some(wgpu::Face::Front), ..Default::default()
},
depth_stencil: Some(wgpu::DepthStencilState {
format: DepthBuffer::FORMAT,
depth_write_enabled: Some(true),
depth_compare: Some(wgpu::CompareFunction::LessEqual),
stencil: wgpu::StencilState::default(),
bias: wgpu::DepthBiasState {
constant: 2,
slope_scale: 2.0,
clamp: 0.0,
},
}),
multisample: wgpu::MultisampleState::default(),
multiview_mask: None,
cache: None,
});
let defaults = ShadowUniforms::default();
let uniform_buffer = mabda::create_uniform_buffer(
device,
bytemuck::bytes_of(&defaults),
"shadow_uniform_buffer",
);
let uniform_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("shadow_uniform_bind_group"),
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: uniform_buffer.as_entire_binding(),
}],
});
Self {
render_pipeline,
uniform_buffer,
uniform_bind_group,
}
}
pub fn update_uniforms(&self, queue: &wgpu::Queue, uniforms: &ShadowUniforms) {
queue.write_buffer(&self.uniform_buffer, 0, bytemuck::bytes_of(uniforms));
}
pub fn render(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
shadow_map: &ShadowMap,
meshes: &[&Mesh],
) {
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("shadow_encoder"),
});
{
let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
label: Some("shadow_pass"),
color_attachments: &[],
depth_stencil_attachment: Some(wgpu::RenderPassDepthStencilAttachment {
view: &shadow_map.depth_view,
depth_ops: Some(wgpu::Operations {
load: wgpu::LoadOp::Clear(1.0),
store: wgpu::StoreOp::Store,
}),
stencil_ops: None,
}),
..Default::default()
});
render_pass.set_pipeline(&self.render_pipeline);
render_pass.set_bind_group(0, &self.uniform_bind_group, &[]);
for mesh in meshes {
render_pass.set_vertex_buffer(0, mesh.vertex_buffer.slice(..));
render_pass
.set_index_buffer(mesh.index_buffer.slice(..), wgpu::IndexFormat::Uint32);
render_pass.draw_indexed(0..mesh.index_count, 0, 0..1);
}
}
queue.submit(std::iter::once(encoder.finish()));
}
}
pub const MAX_CASCADES: usize = 4;
pub struct CascadedShadowMap {
pub cascades: Vec<ShadowMap>,
pub split_distances: Vec<f32>,
pub view_proj_matrices: Vec<[f32; 16]>,
}
impl CascadedShadowMap {
#[must_use]
pub fn new(device: &wgpu::Device, cascade_count: u32, resolution: u32) -> Self {
let count = (cascade_count as usize).clamp(1, MAX_CASCADES);
let cascades = (0..count)
.map(|_| ShadowMap::new(device, resolution))
.collect();
Self {
cascades,
split_distances: vec![0.0; count + 1],
view_proj_matrices: vec![IDENTITY_MAT4; count],
}
}
pub fn compute_splits(&mut self, near: f32, far: f32, lambda: f32) {
let near = if near <= 0.0 { 0.001_f32 } else { near };
let count = self.cascades.len();
self.split_distances[0] = near;
for i in 1..count {
let ratio = i as f32 / count as f32;
let log_split = near * (far / near).powf(ratio);
let uniform_split = near + (far - near) * ratio;
self.split_distances[i] = lambda * log_split + (1.0 - lambda) * uniform_split;
}
self.split_distances[count] = far;
}
pub fn set_cascade_matrix(&mut self, index: usize, matrix: [f32; 16]) {
if index < self.view_proj_matrices.len() {
self.view_proj_matrices[index] = matrix;
}
}
#[must_use]
pub fn cascade_count(&self) -> usize {
self.cascades.len()
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
pub struct CascadeUniforms {
pub splits: [f32; 4],
pub matrices: [[f32; 16]; MAX_CASCADES],
}
impl Default for CascadeUniforms {
fn default() -> Self {
Self {
splits: [10.0, 30.0, 100.0, 500.0],
matrices: [IDENTITY_MAT4; MAX_CASCADES],
}
}
}
pub struct ShadowAtlas {
pub depth_texture: wgpu::Texture,
pub depth_view: wgpu::TextureView,
pub sampler: wgpu::Sampler,
pub size: u32,
pub tile_size: u32,
pub columns: u32,
}
impl ShadowAtlas {
pub fn new(device: &wgpu::Device, size: u32, tile_size: u32) -> Self {
let columns = size / tile_size.max(1);
let depth_texture = device.create_texture(&wgpu::TextureDescriptor {
label: Some("shadow_atlas"),
size: wgpu::Extent3d {
width: size,
height: size,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: wgpu::TextureDimension::D2,
format: DepthBuffer::FORMAT,
usage: wgpu::TextureUsages::RENDER_ATTACHMENT | wgpu::TextureUsages::TEXTURE_BINDING,
view_formats: &[],
});
let depth_view = depth_texture.create_view(&wgpu::TextureViewDescriptor::default());
let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
label: Some("shadow_atlas_sampler"),
address_mode_u: wgpu::AddressMode::ClampToEdge,
address_mode_v: wgpu::AddressMode::ClampToEdge,
mag_filter: wgpu::FilterMode::Linear,
min_filter: wgpu::FilterMode::Linear,
compare: Some(wgpu::CompareFunction::LessEqual),
..Default::default()
});
Self {
depth_texture,
depth_view,
sampler,
size,
tile_size,
columns,
}
}
#[must_use]
pub fn tile_viewport(&self, index: u32) -> (u32, u32, u32, u32) {
let col = index % self.columns;
let row = index / self.columns;
(
col * self.tile_size,
row * self.tile_size,
self.tile_size,
self.tile_size,
)
}
#[must_use]
pub fn tile_uv(&self, index: u32) -> [f32; 4] {
let col = index % self.columns;
let row = index / self.columns;
let scale = self.tile_size as f32 / self.size as f32;
[col as f32 * scale, row as f32 * scale, scale, scale]
}
#[must_use]
pub fn max_lights(&self) -> u32 {
self.columns * self.columns
}
}
pub struct PointShadowMap {
pub face_matrices: [[f32; 16]; 6],
}
impl PointShadowMap {
#[must_use]
pub fn new(position: [f32; 3], near: f32, far: f32) -> Self {
let proj = perspective_90(near, far);
let faces: [([f32; 3], [f32; 3]); 6] = [
([1.0, 0.0, 0.0], [0.0, -1.0, 0.0]), ([-1.0, 0.0, 0.0], [0.0, -1.0, 0.0]), ([0.0, 1.0, 0.0], [0.0, 0.0, 1.0]), ([0.0, -1.0, 0.0], [0.0, 0.0, -1.0]), ([0.0, 0.0, 1.0], [0.0, -1.0, 0.0]), ([0.0, 0.0, -1.0], [0.0, -1.0, 0.0]), ];
let mut face_matrices = [IDENTITY_MAT4; 6];
for (i, (dir, up)) in faces.iter().enumerate() {
let view = look_at(position, *dir, *up);
face_matrices[i] = mul_mat4(proj, view);
}
Self { face_matrices }
}
}
#[must_use]
pub fn compute_practical_splits(near: f32, far: f32, count: usize, lambda: f32) -> Vec<f32> {
let near = if near <= 0.0 { 0.001_f32 } else { near };
let mut splits = Vec::with_capacity(count + 1);
splits.push(near);
for i in 1..count {
let ratio = i as f32 / count as f32;
let log_split = near * (far / near).powf(ratio);
let uniform_split = near + (far - near) * ratio;
splits.push(lambda * log_split + (1.0 - lambda) * uniform_split);
}
splits.push(far);
splits
}
pub struct ShadowAtlasConfig {
pub size: u32,
pub tile_size: u32,
}
#[must_use]
pub fn tile_viewport(config: &ShadowAtlasConfig, index: u32) -> (u32, u32, u32, u32) {
let columns = config.size / config.tile_size.max(1);
let col = index % columns;
let row = index / columns;
(
col * config.tile_size,
row * config.tile_size,
config.tile_size,
config.tile_size,
)
}
#[must_use]
pub fn tile_uv(config: &ShadowAtlasConfig, index: u32) -> [f32; 4] {
let columns = config.size / config.tile_size.max(1);
let col = index % columns;
let row = index / columns;
let scale = config.tile_size as f32 / config.size as f32;
[col as f32 * scale, row as f32 * scale, scale, scale]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn shadow_uniforms_size() {
assert_eq!(std::mem::size_of::<ShadowUniforms>(), 128);
}
#[test]
fn shadow_uniforms_default() {
let u = ShadowUniforms::default();
assert_eq!(u.light_view_proj[0], 1.0);
assert_eq!(u.light_view_proj[15], 1.0);
}
#[test]
fn shadow_uniforms_bytemuck() {
let u = ShadowUniforms::default();
let bytes = bytemuck::bytes_of(&u);
assert_eq!(bytes.len(), 128);
}
#[test]
fn directional_light_matrix_produces_valid() {
let m = directional_light_matrix([0.0, -1.0, 0.0], 10.0, 0.1, 50.0);
assert_eq!(m.len(), 16);
assert!(m != IDENTITY_MAT4);
}
#[test]
fn directional_light_matrix_different_directions() {
let m1 = directional_light_matrix([0.0, -1.0, 0.0], 10.0, 0.1, 50.0);
let m2 = directional_light_matrix([1.0, 0.0, 0.0], 10.0, 0.1, 50.0);
assert!(
m1 != m2,
"Different directions should produce different matrices"
);
}
#[test]
fn directional_light_matrix_diagonal_direction() {
let m = directional_light_matrix([0.0, -0.999, -0.01], 20.0, 1.0, 100.0);
for &v in &m {
assert!(!v.is_nan(), "Matrix contains NaN");
}
}
#[test]
fn default_shadow_map_size() {
assert_eq!(DEFAULT_SHADOW_MAP_SIZE, 2048);
}
#[test]
fn cascade_uniforms_size() {
assert_eq!(std::mem::size_of::<CascadeUniforms>(), 272);
}
#[test]
fn cascade_uniforms_default() {
let u = CascadeUniforms::default();
assert_eq!(u.splits[0], 10.0);
assert_eq!(u.splits[3], 500.0);
}
#[test]
fn cascade_splits_practical() {
let splits = compute_practical_splits(0.1, 100.0, 4, 0.5);
assert_eq!(splits.len(), 5); assert_eq!(splits[0], 0.1);
assert_eq!(splits[4], 100.0);
for i in 1..splits.len() {
assert!(splits[i] > splits[i - 1]);
}
}
#[test]
fn shadow_atlas_tile_viewport() {
let atlas = ShadowAtlasConfig {
size: 4096,
tile_size: 1024,
};
let columns = atlas.size / atlas.tile_size;
assert_eq!(columns, 4);
assert_eq!(tile_viewport(&atlas, 0), (0, 0, 1024, 1024));
assert_eq!(tile_viewport(&atlas, 1), (1024, 0, 1024, 1024));
assert_eq!(tile_viewport(&atlas, 4), (0, 1024, 1024, 1024));
}
#[test]
fn shadow_atlas_tile_uv() {
let atlas = ShadowAtlasConfig {
size: 4096,
tile_size: 1024,
};
let uv = tile_uv(&atlas, 0);
assert_eq!(uv, [0.0, 0.0, 0.25, 0.25]);
let uv1 = tile_uv(&atlas, 1);
assert!((uv1[0] - 0.25).abs() < 0.001);
}
#[test]
fn shadow_atlas_max_lights() {
let atlas = ShadowAtlasConfig {
size: 4096,
tile_size: 1024,
};
assert_eq!(
atlas.size / atlas.tile_size * (atlas.size / atlas.tile_size),
16
);
}
#[test]
fn point_shadow_6_faces() {
let psm = PointShadowMap::new([0.0, 5.0, 0.0], 0.1, 25.0);
assert_eq!(psm.face_matrices.len(), 6);
for i in 0..6 {
for j in (i + 1)..6 {
assert!(psm.face_matrices[i] != psm.face_matrices[j]);
}
}
}
#[test]
fn point_shadow_no_nan() {
let psm = PointShadowMap::new([10.0, 3.0, -5.0], 0.1, 50.0);
for face in &psm.face_matrices {
for &v in face {
assert!(!v.is_nan(), "Point shadow matrix contains NaN");
}
}
}
#[test]
fn perspective_90_valid() {
let p = perspective_90(0.1, 100.0);
assert_eq!(p[0], 1.0); assert_eq!(p[5], 1.0);
assert!(!p[10].is_nan());
}
#[test]
fn cascade_splits_zero_near() {
let splits = compute_practical_splits(0.0, 100.0, 4, 0.5);
assert_eq!(splits.len(), 5);
for &s in &splits {
assert!(!s.is_nan(), "split contains NaN");
assert!(!s.is_infinite(), "split contains Inf");
}
assert!((splits[0] - 0.001).abs() < f32::EPSILON);
assert_eq!(splits[4], 100.0);
for i in 1..splits.len() {
assert!(splits[i] > splits[i - 1]);
}
}
}