Skip to main content

proof_engine/lighting/
volumetric.rs

1//! Volumetric lighting effects for Proof Engine.
2//!
3//! Provides volumetric light shafts (god rays), volumetric fog with ray marching,
4//! tiled light culling for deferred rendering, and 3D frustum-based clustered light
5//! assignment for clustered forward rendering.
6
7use super::lights::{Vec3, Color, Mat4, LightId, Light};
8use std::collections::HashMap;
9use std::f32::consts::PI;
10
11// ── Volumetric Light Shafts ─────────────────────────────────────────────────
12
13/// God rays via radial blur from a light's screen position.
14#[derive(Debug, Clone)]
15pub struct VolumetricLightShafts {
16    /// Number of samples along each ray.
17    pub sample_count: u32,
18    /// Density of the light shaft effect (0..1).
19    pub density: f32,
20    /// Weight for combining scattered light.
21    pub weight: f32,
22    /// Decay factor per sample along the ray.
23    pub decay: f32,
24    /// Exposure multiplier.
25    pub exposure: f32,
26    /// Light screen-space position (0..1 in each axis). Updated per frame.
27    pub light_screen_pos: (f32, f32),
28    /// Light color for the shafts.
29    pub light_color: Color,
30    /// Whether the effect is enabled.
31    pub enabled: bool,
32}
33
34impl Default for VolumetricLightShafts {
35    fn default() -> Self {
36        Self {
37            sample_count: 64,
38            density: 1.0,
39            weight: 0.01,
40            decay: 0.97,
41            exposure: 1.0,
42            light_screen_pos: (0.5, 0.5),
43            light_color: Color::WHITE,
44            enabled: true,
45        }
46    }
47}
48
49impl VolumetricLightShafts {
50    pub fn new(sample_count: u32) -> Self {
51        Self {
52            sample_count,
53            ..Default::default()
54        }
55    }
56
57    /// Update the light's screen position from a world position and view-projection matrix.
58    pub fn update_light_position(&mut self, light_world_pos: Vec3, view_projection: &Mat4) {
59        let clip = view_projection.transform_point(light_world_pos);
60        self.light_screen_pos = (clip.x * 0.5 + 0.5, clip.y * 0.5 + 0.5);
61    }
62
63    /// Check if the light is on screen.
64    pub fn is_light_visible(&self) -> bool {
65        let (sx, sy) = self.light_screen_pos;
66        sx >= -0.2 && sx <= 1.2 && sy >= -0.2 && sy <= 1.2
67    }
68
69    /// Apply the radial blur effect to a frame buffer.
70    /// The input is a buffer of colors (width x height) and an occlusion buffer (same size).
71    /// Returns the light shaft contribution buffer.
72    pub fn compute(
73        &self,
74        width: u32,
75        height: u32,
76        occlusion_buffer: &[f32],
77    ) -> Vec<Color> {
78        let w = width as usize;
79        let h = height as usize;
80        let mut result = vec![Color::BLACK; w * h];
81
82        if !self.enabled || !self.is_light_visible() {
83            return result;
84        }
85
86        let (lx, ly) = self.light_screen_pos;
87
88        for y in 0..h {
89            for x in 0..w {
90                let pixel_x = x as f32 / w as f32;
91                let pixel_y = y as f32 / h as f32;
92
93                // Direction from pixel to light in screen space
94                let dx = lx - pixel_x;
95                let dy = ly - pixel_y;
96
97                let delta_x = dx * self.density / self.sample_count as f32;
98                let delta_y = dy * self.density / self.sample_count as f32;
99
100                let mut sample_x = pixel_x;
101                let mut sample_y = pixel_y;
102                let mut illumination_decay = 1.0f32;
103                let mut accumulated = Color::BLACK;
104
105                for _ in 0..self.sample_count {
106                    sample_x += delta_x;
107                    sample_y += delta_y;
108
109                    let sx = (sample_x * w as f32) as usize;
110                    let sy = (sample_y * h as f32) as usize;
111
112                    if sx < w && sy < h {
113                        let occlusion = occlusion_buffer[sy * w + sx];
114                        // Only accumulate light where not occluded
115                        let sample_value = (1.0 - occlusion) * illumination_decay * self.weight;
116                        accumulated = Color::new(
117                            accumulated.r + self.light_color.r * sample_value,
118                            accumulated.g + self.light_color.g * sample_value,
119                            accumulated.b + self.light_color.b * sample_value,
120                        );
121                    }
122
123                    illumination_decay *= self.decay;
124                }
125
126                result[y * w + x] = Color::new(
127                    accumulated.r * self.exposure,
128                    accumulated.g * self.exposure,
129                    accumulated.b * self.exposure,
130                );
131            }
132        }
133
134        result
135    }
136
137    /// Apply a cheap half-resolution version for performance.
138    pub fn compute_half_res(
139        &self,
140        width: u32,
141        height: u32,
142        occlusion_buffer: &[f32],
143    ) -> Vec<Color> {
144        let half_w = width / 2;
145        let half_h = height / 2;
146
147        // Downsample occlusion buffer
148        let half_size = (half_w as usize) * (half_h as usize);
149        let mut half_occlusion = vec![0.0f32; half_size];
150        for y in 0..half_h as usize {
151            for x in 0..half_w as usize {
152                let sx = x * 2;
153                let sy = y * 2;
154                let w_full = width as usize;
155                if sx + 1 < width as usize && sy + 1 < height as usize {
156                    let avg = (occlusion_buffer[sy * w_full + sx]
157                        + occlusion_buffer[sy * w_full + sx + 1]
158                        + occlusion_buffer[(sy + 1) * w_full + sx]
159                        + occlusion_buffer[(sy + 1) * w_full + sx + 1])
160                        * 0.25;
161                    half_occlusion[y * half_w as usize + x] = avg;
162                }
163            }
164        }
165
166        self.compute(half_w, half_h, &half_occlusion)
167    }
168}
169
170// ── Volumetric Fog ──────────────────────────────────────────────────────────
171
172/// Describes a fog density field.
173#[derive(Debug, Clone)]
174pub enum FogDensityField {
175    /// Uniform density everywhere.
176    Uniform(f32),
177    /// Height-based exponential fog.
178    HeightExponential {
179        base_density: f32,
180        falloff: f32,
181        base_height: f32,
182    },
183    /// Spherical fog volume.
184    Sphere {
185        center: Vec3,
186        radius: f32,
187        density: f32,
188    },
189    /// Box-shaped fog volume.
190    Box {
191        min: Vec3,
192        max: Vec3,
193        density: f32,
194    },
195    /// Layered: multiple fog sources combined.
196    Layered(Vec<FogDensityField>),
197}
198
199impl Default for FogDensityField {
200    fn default() -> Self {
201        Self::HeightExponential {
202            base_density: 0.02,
203            falloff: 0.5,
204            base_height: 0.0,
205        }
206    }
207}
208
209impl FogDensityField {
210    /// Sample density at a world position.
211    pub fn sample(&self, pos: Vec3) -> f32 {
212        match self {
213            Self::Uniform(d) => *d,
214            Self::HeightExponential { base_density, falloff, base_height } => {
215                let height_diff = pos.y - base_height;
216                base_density * (-falloff * height_diff.max(0.0)).exp()
217            }
218            Self::Sphere { center, radius, density } => {
219                let dist = center.distance(pos);
220                if dist < *radius {
221                    let t = dist / radius;
222                    density * (1.0 - t * t).max(0.0)
223                } else {
224                    0.0
225                }
226            }
227            Self::Box { min, max, density } => {
228                if pos.x >= min.x && pos.x <= max.x
229                    && pos.y >= min.y && pos.y <= max.y
230                    && pos.z >= min.z && pos.z <= max.z
231                {
232                    *density
233                } else {
234                    0.0
235                }
236            }
237            Self::Layered(layers) => {
238                layers.iter().map(|l| l.sample(pos)).sum()
239            }
240        }
241    }
242}
243
244/// Volumetric fog with ray marching, scattering, absorption, and the
245/// Henyey-Greenstein phase function.
246#[derive(Debug, Clone)]
247pub struct VolumetricFog {
248    /// Fog density field.
249    pub density_field: FogDensityField,
250    /// Scattering coefficient (how much light is scattered per unit distance).
251    pub scattering: f32,
252    /// Absorption coefficient (how much light is absorbed per unit distance).
253    pub absorption: f32,
254    /// Fog color.
255    pub fog_color: Color,
256    /// Henyey-Greenstein asymmetry parameter (-1 = back scatter, 0 = isotropic, 1 = forward).
257    pub hg_asymmetry: f32,
258    /// Number of ray marching steps.
259    pub step_count: u32,
260    /// Maximum ray marching distance.
261    pub max_distance: f32,
262    /// Whether fog is enabled.
263    pub enabled: bool,
264    /// Ambient fog contribution (light scattered from the environment).
265    pub ambient_contribution: f32,
266    /// Temporal reprojection jitter offset (for temporal anti-aliasing of fog).
267    pub jitter_offset: f32,
268}
269
270impl Default for VolumetricFog {
271    fn default() -> Self {
272        Self {
273            density_field: FogDensityField::default(),
274            scattering: 0.05,
275            absorption: 0.01,
276            fog_color: Color::new(0.7, 0.75, 0.85),
277            hg_asymmetry: 0.3,
278            step_count: 64,
279            max_distance: 100.0,
280            enabled: true,
281            ambient_contribution: 0.15,
282            jitter_offset: 0.0,
283        }
284    }
285}
286
287impl VolumetricFog {
288    pub fn new(density_field: FogDensityField) -> Self {
289        Self {
290            density_field,
291            ..Default::default()
292        }
293    }
294
295    /// Henyey-Greenstein phase function.
296    /// Evaluates the probability of light scattering at angle `cos_theta`.
297    pub fn henyey_greenstein(cos_theta: f32, g: f32) -> f32 {
298        let g2 = g * g;
299        let denom = 1.0 + g2 - 2.0 * g * cos_theta;
300        if denom <= 0.0 {
301            return 1.0 / (4.0 * PI);
302        }
303        (1.0 - g2) / (4.0 * PI * denom.powf(1.5))
304    }
305
306    /// Compute the extinction coefficient at a point.
307    fn extinction_at(&self, pos: Vec3) -> f32 {
308        let density = self.density_field.sample(pos);
309        (self.scattering + self.absorption) * density
310    }
311
312    /// Ray march through the fog volume from a camera ray.
313    /// Returns (accumulated fog color, transmittance).
314    pub fn ray_march(
315        &self,
316        ray_origin: Vec3,
317        ray_dir: Vec3,
318        max_dist: f32,
319        light_dir: Vec3,
320        light_color: Color,
321        light_intensity: f32,
322    ) -> (Color, f32) {
323        if !self.enabled {
324            return (Color::BLACK, 1.0);
325        }
326
327        let effective_max = max_dist.min(self.max_distance);
328        let step_size = effective_max / self.step_count as f32;
329        let dir = ray_dir.normalize();
330
331        let mut accumulated_color = Color::BLACK;
332        let mut transmittance = 1.0f32;
333
334        let cos_theta = dir.dot((-light_dir).normalize());
335        let phase = Self::henyey_greenstein(cos_theta, self.hg_asymmetry);
336
337        for i in 0..self.step_count {
338            let t = (i as f32 + 0.5 + self.jitter_offset) * step_size;
339            let sample_pos = ray_origin + dir * t;
340
341            let density = self.density_field.sample(sample_pos);
342            if density <= 0.0 {
343                continue;
344            }
345
346            let extinction = (self.scattering + self.absorption) * density;
347            let sample_transmittance = (-extinction * step_size).exp();
348
349            // In-scattered light from the main light source
350            let in_scattered = Color::new(
351                light_color.r * light_intensity * self.scattering * density * phase,
352                light_color.g * light_intensity * self.scattering * density * phase,
353                light_color.b * light_intensity * self.scattering * density * phase,
354            );
355
356            // Ambient scattering
357            let ambient = Color::new(
358                self.fog_color.r * self.ambient_contribution * self.scattering * density,
359                self.fog_color.g * self.ambient_contribution * self.scattering * density,
360                self.fog_color.b * self.ambient_contribution * self.scattering * density,
361            );
362
363            // Integrate: add in-scattered light weighted by current transmittance
364            let luminance_step = Color::new(
365                (in_scattered.r + ambient.r) * transmittance * step_size,
366                (in_scattered.g + ambient.g) * transmittance * step_size,
367                (in_scattered.b + ambient.b) * transmittance * step_size,
368            );
369
370            accumulated_color = Color::new(
371                accumulated_color.r + luminance_step.r,
372                accumulated_color.g + luminance_step.g,
373                accumulated_color.b + luminance_step.b,
374            );
375
376            transmittance *= sample_transmittance;
377
378            // Early out if nearly fully opaque
379            if transmittance < 0.001 {
380                break;
381            }
382        }
383
384        (accumulated_color, transmittance)
385    }
386
387    /// Ray march with multiple light contributions.
388    pub fn ray_march_multi_light(
389        &self,
390        ray_origin: Vec3,
391        ray_dir: Vec3,
392        max_dist: f32,
393        lights: &[(Vec3, Color, f32)], // (direction, color, intensity)
394    ) -> (Color, f32) {
395        if !self.enabled {
396            return (Color::BLACK, 1.0);
397        }
398
399        let effective_max = max_dist.min(self.max_distance);
400        let step_size = effective_max / self.step_count as f32;
401        let dir = ray_dir.normalize();
402
403        let mut accumulated_color = Color::BLACK;
404        let mut transmittance = 1.0f32;
405
406        // Precompute phase values for each light
407        let phase_values: Vec<f32> = lights.iter().map(|(light_dir, _, _)| {
408            let cos_theta = dir.dot((-*light_dir).normalize());
409            Self::henyey_greenstein(cos_theta, self.hg_asymmetry)
410        }).collect();
411
412        for i in 0..self.step_count {
413            let t = (i as f32 + 0.5 + self.jitter_offset) * step_size;
414            let sample_pos = ray_origin + dir * t;
415
416            let density = self.density_field.sample(sample_pos);
417            if density <= 0.0 {
418                continue;
419            }
420
421            let extinction = (self.scattering + self.absorption) * density;
422            let sample_transmittance = (-extinction * step_size).exp();
423
424            let mut in_scattered = Color::BLACK;
425            for (j, (_, light_color, intensity)) in lights.iter().enumerate() {
426                let phase = phase_values[j];
427                in_scattered = Color::new(
428                    in_scattered.r + light_color.r * intensity * self.scattering * density * phase,
429                    in_scattered.g + light_color.g * intensity * self.scattering * density * phase,
430                    in_scattered.b + light_color.b * intensity * self.scattering * density * phase,
431                );
432            }
433
434            let ambient = Color::new(
435                self.fog_color.r * self.ambient_contribution * self.scattering * density,
436                self.fog_color.g * self.ambient_contribution * self.scattering * density,
437                self.fog_color.b * self.ambient_contribution * self.scattering * density,
438            );
439
440            let luminance_step = Color::new(
441                (in_scattered.r + ambient.r) * transmittance * step_size,
442                (in_scattered.g + ambient.g) * transmittance * step_size,
443                (in_scattered.b + ambient.b) * transmittance * step_size,
444            );
445
446            accumulated_color = Color::new(
447                accumulated_color.r + luminance_step.r,
448                accumulated_color.g + luminance_step.g,
449                accumulated_color.b + luminance_step.b,
450            );
451
452            transmittance *= sample_transmittance;
453
454            if transmittance < 0.001 {
455                break;
456            }
457        }
458
459        (accumulated_color, transmittance)
460    }
461
462    /// Apply fog to a final pixel color given scene depth.
463    pub fn apply_to_pixel(
464        &self,
465        scene_color: Color,
466        fog_color: Color,
467        transmittance: f32,
468    ) -> Color {
469        Color::new(
470            scene_color.r * transmittance + fog_color.r,
471            scene_color.g * transmittance + fog_color.g,
472            scene_color.b * transmittance + fog_color.b,
473        )
474    }
475
476    /// Compute the optical depth along a ray (integral of extinction).
477    pub fn optical_depth(&self, origin: Vec3, direction: Vec3, distance: f32) -> f32 {
478        let steps = (self.step_count / 2).max(4);
479        let step_size = distance / steps as f32;
480        let dir = direction.normalize();
481        let mut depth = 0.0f32;
482
483        for i in 0..steps {
484            let t = (i as f32 + 0.5) * step_size;
485            let pos = origin + dir * t;
486            depth += self.extinction_at(pos) * step_size;
487        }
488
489        depth
490    }
491
492    /// Compute transmittance along a ray.
493    pub fn transmittance(&self, origin: Vec3, direction: Vec3, distance: f32) -> f32 {
494        (-self.optical_depth(origin, direction, distance)).exp()
495    }
496}
497
498// ── Tiled Light Culling ─────────────────────────────────────────────────────
499
500/// Screen-space tile for tiled deferred rendering.
501#[derive(Debug, Clone)]
502pub struct ScreenTile {
503    /// Tile position in tiles (not pixels).
504    pub tile_x: u32,
505    pub tile_y: u32,
506    /// Lights that affect this tile.
507    pub light_ids: Vec<LightId>,
508    /// Min and max depth in this tile (for tighter culling).
509    pub min_depth: f32,
510    pub max_depth: f32,
511}
512
513impl ScreenTile {
514    pub fn new(tile_x: u32, tile_y: u32) -> Self {
515        Self {
516            tile_x,
517            tile_y,
518            light_ids: Vec::new(),
519            min_depth: 1.0,
520            max_depth: 0.0,
521        }
522    }
523
524    /// Update the depth range from a depth buffer region.
525    pub fn update_depth_range(&mut self, depths: &[f32]) {
526        self.min_depth = 1.0;
527        self.max_depth = 0.0;
528        for &d in depths {
529            if d < 1.0 {
530                self.min_depth = self.min_depth.min(d);
531                self.max_depth = self.max_depth.max(d);
532            }
533        }
534    }
535
536    /// Get the light count for this tile.
537    pub fn light_count(&self) -> usize {
538        self.light_ids.len()
539    }
540}
541
542/// Divides the screen into tiles and assigns lights for deferred rendering.
543#[derive(Debug, Clone)]
544pub struct TiledLightCulling {
545    /// Tile size in pixels.
546    pub tile_size: u32,
547    /// Screen width.
548    pub screen_width: u32,
549    /// Screen height.
550    pub screen_height: u32,
551    /// Number of tiles in X.
552    pub tiles_x: u32,
553    /// Number of tiles in Y.
554    pub tiles_y: u32,
555    /// All tiles.
556    pub tiles: Vec<ScreenTile>,
557    /// View-projection matrix for the current frame.
558    pub view_projection: Mat4,
559    /// Inverse projection for reconstructing view-space positions.
560    pub inv_projection: Mat4,
561    /// Near plane.
562    pub near: f32,
563    /// Far plane.
564    pub far: f32,
565}
566
567impl TiledLightCulling {
568    pub fn new(screen_width: u32, screen_height: u32, tile_size: u32) -> Self {
569        let tiles_x = (screen_width + tile_size - 1) / tile_size;
570        let tiles_y = (screen_height + tile_size - 1) / tile_size;
571        let mut tiles = Vec::with_capacity((tiles_x * tiles_y) as usize);
572        for y in 0..tiles_y {
573            for x in 0..tiles_x {
574                tiles.push(ScreenTile::new(x, y));
575            }
576        }
577        Self {
578            tile_size,
579            screen_width,
580            screen_height,
581            tiles_x,
582            tiles_y,
583            tiles,
584            view_projection: Mat4::IDENTITY,
585            inv_projection: Mat4::IDENTITY,
586            near: 0.1,
587            far: 1000.0,
588        }
589    }
590
591    /// Resize the tiling when the screen resolution changes.
592    pub fn resize(&mut self, width: u32, height: u32) {
593        self.screen_width = width;
594        self.screen_height = height;
595        self.tiles_x = (width + self.tile_size - 1) / self.tile_size;
596        self.tiles_y = (height + self.tile_size - 1) / self.tile_size;
597        self.tiles.clear();
598        for y in 0..self.tiles_y {
599            for x in 0..self.tiles_x {
600                self.tiles.push(ScreenTile::new(x, y));
601            }
602        }
603    }
604
605    /// Update depth ranges for all tiles from a full-screen depth buffer.
606    pub fn update_depth_ranges(&mut self, depth_buffer: &[f32]) {
607        let w = self.screen_width as usize;
608
609        for tile in &mut self.tiles {
610            let tx = tile.tile_x as usize;
611            let ty = tile.tile_y as usize;
612            let ts = self.tile_size as usize;
613
614            let x_start = tx * ts;
615            let y_start = ty * ts;
616            let x_end = (x_start + ts).min(self.screen_width as usize);
617            let y_end = (y_start + ts).min(self.screen_height as usize);
618
619            tile.min_depth = 1.0;
620            tile.max_depth = 0.0;
621
622            for y in y_start..y_end {
623                for x in x_start..x_end {
624                    let d = depth_buffer[y * w + x];
625                    if d < 1.0 {
626                        tile.min_depth = tile.min_depth.min(d);
627                        tile.max_depth = tile.max_depth.max(d);
628                    }
629                }
630            }
631        }
632    }
633
634    /// Cull lights against all tiles.
635    pub fn cull_lights(&mut self, lights: &[(LightId, &Light)]) {
636        // Clear previous assignments
637        for tile in &mut self.tiles {
638            tile.light_ids.clear();
639        }
640
641        for &(id, light) in lights {
642            if !light.is_enabled() {
643                continue;
644            }
645
646            match light.position() {
647                None => {
648                    // Directional lights affect all tiles
649                    for tile in &mut self.tiles {
650                        tile.light_ids.push(id);
651                    }
652                }
653                Some(pos) => {
654                    let radius = light.radius();
655
656                    // Project light sphere to screen-space AABB
657                    let screen_bounds = self.project_sphere_to_screen(pos, radius);
658                    if let Some((sx_min, sy_min, sx_max, sy_max)) = screen_bounds {
659                        let tile_x_min = (sx_min / self.tile_size as f32).floor().max(0.0) as u32;
660                        let tile_y_min = (sy_min / self.tile_size as f32).floor().max(0.0) as u32;
661                        let tile_x_max = ((sx_max / self.tile_size as f32).ceil() as u32).min(self.tiles_x);
662                        let tile_y_max = ((sy_max / self.tile_size as f32).ceil() as u32).min(self.tiles_y);
663
664                        for ty in tile_y_min..tile_y_max {
665                            for tx in tile_x_min..tile_x_max {
666                                let idx = (ty * self.tiles_x + tx) as usize;
667                                if idx < self.tiles.len() {
668                                    self.tiles[idx].light_ids.push(id);
669                                }
670                            }
671                        }
672                    }
673                }
674            }
675        }
676    }
677
678    /// Project a sphere onto the screen, returning (min_x, min_y, max_x, max_y) in pixels.
679    fn project_sphere_to_screen(&self, center: Vec3, radius: f32) -> Option<(f32, f32, f32, f32)> {
680        let clip_center = self.view_projection.transform_point(center);
681
682        // Check if the sphere is behind the camera
683        if clip_center.z < -1.0 - radius {
684            return None;
685        }
686
687        // Conservative screen-space bounds
688        let ndc_x = clip_center.x;
689        let ndc_y = clip_center.y;
690        let dist = center.length().max(0.1);
691        let angular_radius = (radius / dist).min(1.0);
692
693        let sx = (ndc_x * 0.5 + 0.5) * self.screen_width as f32;
694        let sy = (ndc_y * 0.5 + 0.5) * self.screen_height as f32;
695        let screen_radius = angular_radius * self.screen_width as f32;
696
697        Some((
698            (sx - screen_radius).max(0.0),
699            (sy - screen_radius).max(0.0),
700            (sx + screen_radius).min(self.screen_width as f32),
701            (sy + screen_radius).min(self.screen_height as f32),
702        ))
703    }
704
705    /// Get the tile at a pixel coordinate.
706    pub fn tile_at_pixel(&self, x: u32, y: u32) -> Option<&ScreenTile> {
707        let tx = x / self.tile_size;
708        let ty = y / self.tile_size;
709        if tx < self.tiles_x && ty < self.tiles_y {
710            Some(&self.tiles[(ty * self.tiles_x + tx) as usize])
711        } else {
712            None
713        }
714    }
715
716    /// Get statistics.
717    pub fn stats(&self) -> TiledCullingStats {
718        let mut total_assignments = 0usize;
719        let mut max_per_tile = 0usize;
720        let mut tiles_with_lights = 0u32;
721
722        for tile in &self.tiles {
723            let count = tile.light_ids.len();
724            total_assignments += count;
725            max_per_tile = max_per_tile.max(count);
726            if count > 0 {
727                tiles_with_lights += 1;
728            }
729        }
730
731        TiledCullingStats {
732            total_tiles: self.tiles.len() as u32,
733            tiles_with_lights,
734            total_light_tile_pairs: total_assignments as u32,
735            max_lights_per_tile: max_per_tile as u32,
736            avg_lights_per_active_tile: if tiles_with_lights > 0 {
737                total_assignments as f32 / tiles_with_lights as f32
738            } else {
739                0.0
740            },
741        }
742    }
743}
744
745/// Statistics for tiled light culling.
746#[derive(Debug, Clone)]
747pub struct TiledCullingStats {
748    pub total_tiles: u32,
749    pub tiles_with_lights: u32,
750    pub total_light_tile_pairs: u32,
751    pub max_lights_per_tile: u32,
752    pub avg_lights_per_active_tile: f32,
753}
754
755// ── Light Cluster ───────────────────────────────────────────────────────────
756
757/// A single 3D cluster in the frustum.
758#[derive(Debug, Clone)]
759pub struct LightCluster {
760    /// Lights assigned to this cluster.
761    pub light_ids: Vec<LightId>,
762    /// Cluster AABB in view space.
763    pub min_bounds: Vec3,
764    pub max_bounds: Vec3,
765}
766
767impl LightCluster {
768    pub fn new(min_bounds: Vec3, max_bounds: Vec3) -> Self {
769        Self {
770            light_ids: Vec::new(),
771            min_bounds,
772            max_bounds,
773        }
774    }
775
776    /// Check if a sphere (in view space) intersects this cluster's AABB.
777    pub fn intersects_sphere(&self, center: Vec3, radius: f32) -> bool {
778        let mut dist_sq = 0.0f32;
779
780        let check = |c: f32, min: f32, max: f32| -> f32 {
781            if c < min {
782                let d = min - c;
783                d * d
784            } else if c > max {
785                let d = c - max;
786                d * d
787            } else {
788                0.0
789            }
790        };
791
792        dist_sq += check(center.x, self.min_bounds.x, self.max_bounds.x);
793        dist_sq += check(center.y, self.min_bounds.y, self.max_bounds.y);
794        dist_sq += check(center.z, self.min_bounds.z, self.max_bounds.z);
795
796        dist_sq <= radius * radius
797    }
798}
799
800/// 3D frustum-based clustered light assignment for clustered forward rendering.
801#[derive(Debug, Clone)]
802pub struct ClusteredLightAssignment {
803    /// Number of clusters in X (screen width).
804    pub clusters_x: u32,
805    /// Number of clusters in Y (screen height).
806    pub clusters_y: u32,
807    /// Number of clusters in Z (depth slices).
808    pub clusters_z: u32,
809    /// All clusters stored in a flat array.
810    pub clusters: Vec<LightCluster>,
811    /// Camera near plane.
812    pub near: f32,
813    /// Camera far plane.
814    pub far: f32,
815    /// Field of view (vertical, in radians).
816    pub fov_y: f32,
817    /// Aspect ratio.
818    pub aspect: f32,
819    /// View matrix for the current frame.
820    pub view_matrix: Mat4,
821    /// Logarithmic depth slice distribution.
822    pub log_depth: bool,
823}
824
825impl ClusteredLightAssignment {
826    pub fn new(
827        clusters_x: u32,
828        clusters_y: u32,
829        clusters_z: u32,
830        near: f32,
831        far: f32,
832        fov_y: f32,
833        aspect: f32,
834    ) -> Self {
835        let total = (clusters_x as usize) * (clusters_y as usize) * (clusters_z as usize);
836        let mut assignment = Self {
837            clusters_x,
838            clusters_y,
839            clusters_z,
840            clusters: Vec::with_capacity(total),
841            near,
842            far,
843            fov_y,
844            aspect,
845            view_matrix: Mat4::IDENTITY,
846            log_depth: true,
847        };
848        assignment.build_clusters();
849        assignment
850    }
851
852    /// Compute the depth of a Z slice boundary.
853    fn slice_depth(&self, slice: u32) -> f32 {
854        let t = slice as f32 / self.clusters_z as f32;
855        if self.log_depth {
856            // Logarithmic distribution: more slices near the camera
857            self.near * (self.far / self.near).powf(t)
858        } else {
859            self.near + (self.far - self.near) * t
860        }
861    }
862
863    /// Determine which Z slice a view-space depth falls into.
864    pub fn depth_to_slice(&self, depth: f32) -> u32 {
865        if depth <= self.near {
866            return 0;
867        }
868        if depth >= self.far {
869            return self.clusters_z.saturating_sub(1);
870        }
871
872        let slice = if self.log_depth {
873            let log_near = self.near.ln();
874            let log_far = self.far.ln();
875            let log_depth = depth.ln();
876            ((log_depth - log_near) / (log_far - log_near) * self.clusters_z as f32) as u32
877        } else {
878            (((depth - self.near) / (self.far - self.near)) * self.clusters_z as f32) as u32
879        };
880
881        slice.min(self.clusters_z - 1)
882    }
883
884    /// Build cluster AABBs in view space.
885    fn build_clusters(&mut self) {
886        self.clusters.clear();
887
888        let tan_half_fov = (self.fov_y * 0.5).tan();
889
890        for z in 0..self.clusters_z {
891            let z_near = self.slice_depth(z);
892            let z_far = self.slice_depth(z + 1);
893
894            for y in 0..self.clusters_y {
895                for x in 0..self.clusters_x {
896                    // Compute the tile's NDC extents
897                    let tile_x_ndc = (x as f32 / self.clusters_x as f32) * 2.0 - 1.0;
898                    let tile_x_ndc_end = ((x + 1) as f32 / self.clusters_x as f32) * 2.0 - 1.0;
899                    let tile_y_ndc = (y as f32 / self.clusters_y as f32) * 2.0 - 1.0;
900                    let tile_y_ndc_end = ((y + 1) as f32 / self.clusters_y as f32) * 2.0 - 1.0;
901
902                    // Convert to view space at the near and far depths
903                    let x_min_near = tile_x_ndc * tan_half_fov * self.aspect * z_near;
904                    let x_max_near = tile_x_ndc_end * tan_half_fov * self.aspect * z_near;
905                    let y_min_near = tile_y_ndc * tan_half_fov * z_near;
906                    let y_max_near = tile_y_ndc_end * tan_half_fov * z_near;
907
908                    let x_min_far = tile_x_ndc * tan_half_fov * self.aspect * z_far;
909                    let x_max_far = tile_x_ndc_end * tan_half_fov * self.aspect * z_far;
910                    let y_min_far = tile_y_ndc * tan_half_fov * z_far;
911                    let y_max_far = tile_y_ndc_end * tan_half_fov * z_far;
912
913                    let min_bounds = Vec3::new(
914                        x_min_near.min(x_min_far),
915                        y_min_near.min(y_min_far),
916                        -z_far, // View space Z is negative
917                    );
918                    let max_bounds = Vec3::new(
919                        x_max_near.max(x_max_far),
920                        y_max_near.max(y_max_far),
921                        -z_near,
922                    );
923
924                    self.clusters.push(LightCluster::new(min_bounds, max_bounds));
925                }
926            }
927        }
928    }
929
930    /// Get cluster index from 3D coordinates.
931    fn cluster_index(&self, x: u32, y: u32, z: u32) -> usize {
932        (z as usize) * (self.clusters_x as usize * self.clusters_y as usize)
933            + (y as usize) * (self.clusters_x as usize)
934            + (x as usize)
935    }
936
937    /// Get cluster index from a screen pixel and depth.
938    pub fn cluster_at(&self, pixel_x: u32, pixel_y: u32, depth: f32, screen_w: u32, screen_h: u32) -> usize {
939        let cx = (pixel_x as f32 / screen_w as f32 * self.clusters_x as f32) as u32;
940        let cy = (pixel_y as f32 / screen_h as f32 * self.clusters_y as f32) as u32;
941        let cz = self.depth_to_slice(depth);
942
943        let cx = cx.min(self.clusters_x - 1);
944        let cy = cy.min(self.clusters_y - 1);
945
946        self.cluster_index(cx, cy, cz)
947    }
948
949    /// Assign lights to clusters.
950    pub fn assign_lights(
951        &mut self,
952        lights: &[(LightId, Vec3, f32)], // (id, view-space position, radius)
953    ) {
954        // Clear existing assignments
955        for cluster in &mut self.clusters {
956            cluster.light_ids.clear();
957        }
958
959        for &(id, view_pos, radius) in lights {
960            // Find the depth range this light covers
961            let light_z_near = (-view_pos.z - radius).max(self.near);
962            let light_z_far = (-view_pos.z + radius).min(self.far);
963
964            if light_z_far < self.near || light_z_near > self.far {
965                continue; // Light is outside the frustum depth range
966            }
967
968            let z_start = self.depth_to_slice(light_z_near);
969            let z_end = self.depth_to_slice(light_z_far);
970
971            for z in z_start..=z_end.min(self.clusters_z - 1) {
972                for y in 0..self.clusters_y {
973                    for x in 0..self.clusters_x {
974                        let idx = self.cluster_index(x, y, z);
975                        if idx < self.clusters.len() && self.clusters[idx].intersects_sphere(view_pos, radius) {
976                            self.clusters[idx].light_ids.push(id);
977                        }
978                    }
979                }
980            }
981        }
982    }
983
984    /// Assign lights from world-space positions, transforming to view space first.
985    pub fn assign_lights_world(
986        &mut self,
987        lights: &[(LightId, &Light)],
988    ) {
989        let mut view_lights = Vec::new();
990
991        for &(id, light) in lights {
992            if !light.is_enabled() {
993                continue;
994            }
995            if let Some(pos) = light.position() {
996                let view_pos = self.view_matrix.transform_point(pos);
997                let radius = light.radius();
998                view_lights.push((id, view_pos, radius));
999            }
1000        }
1001
1002        self.assign_lights(&view_lights);
1003
1004        // Directional lights go into every cluster
1005        for &(id, light) in lights {
1006            if let Light::Directional(_) = light {
1007                if light.is_enabled() {
1008                    for cluster in &mut self.clusters {
1009                        cluster.light_ids.push(id);
1010                    }
1011                }
1012            }
1013        }
1014    }
1015
1016    /// Get the lights for a specific cluster.
1017    pub fn lights_for_cluster(&self, index: usize) -> &[LightId] {
1018        if index < self.clusters.len() {
1019            &self.clusters[index].light_ids
1020        } else {
1021            &[]
1022        }
1023    }
1024
1025    /// Get the lights at a screen pixel and depth.
1026    pub fn lights_at_pixel(&self, pixel_x: u32, pixel_y: u32, depth: f32, screen_w: u32, screen_h: u32) -> &[LightId] {
1027        let idx = self.cluster_at(pixel_x, pixel_y, depth, screen_w, screen_h);
1028        self.lights_for_cluster(idx)
1029    }
1030
1031    /// Total number of clusters.
1032    pub fn total_clusters(&self) -> usize {
1033        self.clusters.len()
1034    }
1035
1036    /// Get statistics.
1037    pub fn stats(&self) -> ClusteredStats {
1038        let mut total_assignments = 0usize;
1039        let mut max_per_cluster = 0usize;
1040        let mut active_clusters = 0u32;
1041        let mut empty_clusters = 0u32;
1042
1043        for cluster in &self.clusters {
1044            let count = cluster.light_ids.len();
1045            total_assignments += count;
1046            max_per_cluster = max_per_cluster.max(count);
1047            if count > 0 {
1048                active_clusters += 1;
1049            } else {
1050                empty_clusters += 1;
1051            }
1052        }
1053
1054        ClusteredStats {
1055            total_clusters: self.clusters.len() as u32,
1056            active_clusters,
1057            empty_clusters,
1058            total_light_cluster_pairs: total_assignments as u32,
1059            max_lights_per_cluster: max_per_cluster as u32,
1060            avg_lights_per_active_cluster: if active_clusters > 0 {
1061                total_assignments as f32 / active_clusters as f32
1062            } else {
1063                0.0
1064            },
1065        }
1066    }
1067
1068    /// Rebuild clusters (call when camera params change).
1069    pub fn rebuild(&mut self) {
1070        self.build_clusters();
1071    }
1072}
1073
1074/// Statistics for clustered light assignment.
1075#[derive(Debug, Clone)]
1076pub struct ClusteredStats {
1077    pub total_clusters: u32,
1078    pub active_clusters: u32,
1079    pub empty_clusters: u32,
1080    pub total_light_cluster_pairs: u32,
1081    pub max_lights_per_cluster: u32,
1082    pub avg_lights_per_active_cluster: f32,
1083}
1084
1085// ── Volumetric System ───────────────────────────────────────────────────────
1086
1087/// Orchestrates all volumetric effects.
1088#[derive(Debug)]
1089pub struct VolumetricSystem {
1090    pub light_shafts: VolumetricLightShafts,
1091    pub fog: VolumetricFog,
1092    pub tiled_culling: Option<TiledLightCulling>,
1093    pub clustered_assignment: Option<ClusteredLightAssignment>,
1094    /// Whether volumetric light shafts are enabled.
1095    pub shafts_enabled: bool,
1096    /// Whether volumetric fog is enabled.
1097    pub fog_enabled: bool,
1098    /// Whether tiled culling is active.
1099    pub tiled_culling_enabled: bool,
1100    /// Whether clustered assignment is active.
1101    pub clustered_enabled: bool,
1102}
1103
1104impl VolumetricSystem {
1105    pub fn new() -> Self {
1106        Self {
1107            light_shafts: VolumetricLightShafts::default(),
1108            fog: VolumetricFog::default(),
1109            tiled_culling: None,
1110            clustered_assignment: None,
1111            shafts_enabled: true,
1112            fog_enabled: true,
1113            tiled_culling_enabled: false,
1114            clustered_enabled: false,
1115        }
1116    }
1117
1118    /// Initialize tiled light culling for the given screen resolution.
1119    pub fn init_tiled_culling(&mut self, width: u32, height: u32, tile_size: u32) {
1120        self.tiled_culling = Some(TiledLightCulling::new(width, height, tile_size));
1121        self.tiled_culling_enabled = true;
1122    }
1123
1124    /// Initialize clustered forward rendering.
1125    pub fn init_clustered(
1126        &mut self,
1127        clusters_x: u32,
1128        clusters_y: u32,
1129        clusters_z: u32,
1130        near: f32,
1131        far: f32,
1132        fov_y: f32,
1133        aspect: f32,
1134    ) {
1135        self.clustered_assignment = Some(ClusteredLightAssignment::new(
1136            clusters_x, clusters_y, clusters_z,
1137            near, far, fov_y, aspect,
1138        ));
1139        self.clustered_enabled = true;
1140    }
1141
1142    /// Update tiled light culling with current lights and depth buffer.
1143    pub fn update_tiled(
1144        &mut self,
1145        lights: &[(LightId, &Light)],
1146        depth_buffer: &[f32],
1147    ) {
1148        if !self.tiled_culling_enabled {
1149            return;
1150        }
1151        if let Some(ref mut tiled) = self.tiled_culling {
1152            tiled.update_depth_ranges(depth_buffer);
1153            tiled.cull_lights(lights);
1154        }
1155    }
1156
1157    /// Update clustered light assignment with current lights.
1158    pub fn update_clustered(&mut self, lights: &[(LightId, &Light)]) {
1159        if !self.clustered_enabled {
1160            return;
1161        }
1162        if let Some(ref mut clustered) = self.clustered_assignment {
1163            clustered.assign_lights_world(lights);
1164        }
1165    }
1166
1167    /// Update the light shaft screen position.
1168    pub fn update_light_shaft_position(&mut self, light_world_pos: Vec3, view_projection: &Mat4) {
1169        if self.shafts_enabled {
1170            self.light_shafts.update_light_position(light_world_pos, view_projection);
1171        }
1172    }
1173
1174    /// Get tiled culling stats.
1175    pub fn tiled_stats(&self) -> Option<TiledCullingStats> {
1176        self.tiled_culling.as_ref().map(|t| t.stats())
1177    }
1178
1179    /// Get clustered stats.
1180    pub fn clustered_stats(&self) -> Option<ClusteredStats> {
1181        self.clustered_assignment.as_ref().map(|c| c.stats())
1182    }
1183}
1184
1185impl Default for VolumetricSystem {
1186    fn default() -> Self {
1187        Self::new()
1188    }
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193    use super::*;
1194
1195    #[test]
1196    fn test_henyey_greenstein() {
1197        // Forward scattering should be stronger when g > 0
1198        let forward = VolumetricFog::henyey_greenstein(1.0, 0.5);
1199        let backward = VolumetricFog::henyey_greenstein(-1.0, 0.5);
1200        assert!(forward > backward);
1201
1202        // Isotropic when g = 0
1203        let iso_fwd = VolumetricFog::henyey_greenstein(1.0, 0.0);
1204        let iso_bwd = VolumetricFog::henyey_greenstein(-1.0, 0.0);
1205        assert!((iso_fwd - iso_bwd).abs() < 0.01);
1206    }
1207
1208    #[test]
1209    fn test_fog_density_height() {
1210        let field = FogDensityField::HeightExponential {
1211            base_density: 1.0,
1212            falloff: 1.0,
1213            base_height: 0.0,
1214        };
1215        let low = field.sample(Vec3::new(0.0, 0.0, 0.0));
1216        let high = field.sample(Vec3::new(0.0, 10.0, 0.0));
1217        assert!(low > high); // Fog should be denser at lower heights
1218    }
1219
1220    #[test]
1221    fn test_fog_density_sphere() {
1222        let field = FogDensityField::Sphere {
1223            center: Vec3::ZERO,
1224            radius: 5.0,
1225            density: 1.0,
1226        };
1227        let center = field.sample(Vec3::ZERO);
1228        let edge = field.sample(Vec3::new(5.0, 0.0, 0.0));
1229        let outside = field.sample(Vec3::new(10.0, 0.0, 0.0));
1230        assert!(center > edge);
1231        assert!(outside < 1e-5);
1232    }
1233
1234    #[test]
1235    fn test_volumetric_fog_ray_march() {
1236        let fog = VolumetricFog::new(FogDensityField::Uniform(0.1));
1237        let (color, transmittance) = fog.ray_march(
1238            Vec3::ZERO,
1239            Vec3::FORWARD,
1240            50.0,
1241            Vec3::new(0.0, -1.0, 0.0),
1242            Color::WHITE,
1243            1.0,
1244        );
1245        assert!(transmittance < 1.0); // Some light should be absorbed
1246        assert!(color.r > 0.0); // Some light should be scattered
1247    }
1248
1249    #[test]
1250    fn test_tiled_culling_creation() {
1251        let tiled = TiledLightCulling::new(1920, 1080, 16);
1252        assert_eq!(tiled.tiles_x, 120);
1253        assert_eq!(tiled.tiles_y, (1080 + 15) / 16);
1254        assert_eq!(tiled.tiles.len(), (tiled.tiles_x * tiled.tiles_y) as usize);
1255    }
1256
1257    #[test]
1258    fn test_clustered_depth_slicing() {
1259        let clustered = ClusteredLightAssignment::new(
1260            16, 8, 24, 0.1, 1000.0, 1.0, 1.777,
1261        );
1262        // Near depth should map to slice 0
1263        assert_eq!(clustered.depth_to_slice(0.1), 0);
1264        // Far depth should map to the last slice
1265        assert_eq!(clustered.depth_to_slice(1000.0), 23);
1266        // Mid depth should be somewhere in between
1267        let mid = clustered.depth_to_slice(10.0);
1268        assert!(mid > 0 && mid < 23);
1269    }
1270
1271    #[test]
1272    fn test_clustered_light_assignment() {
1273        let mut clustered = ClusteredLightAssignment::new(
1274            4, 4, 4, 0.1, 100.0, 1.0, 1.0,
1275        );
1276
1277        // A light at the center of the frustum should hit some clusters
1278        let lights = vec![
1279            (LightId(1), Vec3::new(0.0, 0.0, -10.0), 5.0),
1280        ];
1281        clustered.assign_lights(&lights);
1282
1283        let stats = clustered.stats();
1284        assert!(stats.active_clusters > 0);
1285        assert!(stats.total_light_cluster_pairs > 0);
1286    }
1287
1288    #[test]
1289    fn test_light_shafts_visibility() {
1290        let mut shafts = VolumetricLightShafts::default();
1291        shafts.light_screen_pos = (0.5, 0.5);
1292        assert!(shafts.is_light_visible());
1293
1294        shafts.light_screen_pos = (2.0, 2.0);
1295        assert!(!shafts.is_light_visible());
1296    }
1297
1298    #[test]
1299    fn test_fog_transmittance() {
1300        let fog = VolumetricFog::new(FogDensityField::Uniform(0.1));
1301        let t1 = fog.transmittance(Vec3::ZERO, Vec3::FORWARD, 10.0);
1302        let t2 = fog.transmittance(Vec3::ZERO, Vec3::FORWARD, 50.0);
1303        // Longer distance = less transmittance
1304        assert!(t1 > t2);
1305        // Both should be between 0 and 1
1306        assert!(t1 > 0.0 && t1 < 1.0);
1307        assert!(t2 > 0.0 && t2 < 1.0);
1308    }
1309
1310    #[test]
1311    fn test_cluster_sphere_intersection() {
1312        let cluster = LightCluster::new(
1313            Vec3::new(-1.0, -1.0, -1.0),
1314            Vec3::new(1.0, 1.0, 1.0),
1315        );
1316        assert!(cluster.intersects_sphere(Vec3::ZERO, 0.5));
1317        assert!(cluster.intersects_sphere(Vec3::new(2.0, 0.0, 0.0), 1.5));
1318        assert!(!cluster.intersects_sphere(Vec3::new(5.0, 5.0, 5.0), 1.0));
1319    }
1320}