use glam::{Mat4, Vec3, Vec4};
use crate::bounds::Aabb;
pub const MAX_CASCADES: usize = 4;
#[derive(Clone, Debug)]
pub struct Cascade {
pub view: Mat4,
pub projection: Mat4,
pub view_projection: Mat4,
pub world_per_texel: f32,
}
pub fn fit_cascade(
camera_inv_view_projection: Mat4,
direction: Vec3,
near_normalized: f32,
far_normalized: f32,
resolution: u32,
casters_world_aabbs: &[Aabb],
) -> Cascade {
let z_near = near_normalized.clamp(0.0, 1.0);
let z_far = far_normalized.clamp(0.0, 1.0);
let ndc_corners = [
Vec4::new(-1.0, -1.0, z_near, 1.0),
Vec4::new(1.0, -1.0, z_near, 1.0),
Vec4::new(-1.0, 1.0, z_near, 1.0),
Vec4::new(1.0, 1.0, z_near, 1.0),
Vec4::new(-1.0, -1.0, z_far, 1.0),
Vec4::new(1.0, -1.0, z_far, 1.0),
Vec4::new(-1.0, 1.0, z_far, 1.0),
Vec4::new(1.0, 1.0, z_far, 1.0),
];
let mut world_corners = [Vec3::ZERO; 8];
let mut frustum_center = Vec3::ZERO;
for (i, c) in ndc_corners.iter().enumerate() {
let world = camera_inv_view_projection * *c;
let w = if world.w.abs() < 1e-8 { 1.0 } else { world.w };
world_corners[i] = Vec3::new(world.x / w, world.y / w, world.z / w);
frustum_center += world_corners[i];
}
frustum_center *= 1.0 / 8.0;
let mut sphere_radius = 0.0_f32;
for c in &world_corners {
sphere_radius = sphere_radius.max((*c - frustum_center).length());
}
let diameter = sphere_radius * 2.0;
let dir = if direction.length_squared() < 1e-8 {
Vec3::new(0.0, -1.0, 0.0)
} else {
direction.normalize()
};
let up = if dir.x.abs() < 0.9 { Vec3::X } else { Vec3::Z };
let view = Mat4::look_at_rh(frustum_center - dir, frustum_center, up);
let center_ls = view.transform_point3(frustum_center);
let texel_size = diameter / resolution as f32;
let min_x = ((center_ls.x - sphere_radius) / texel_size).floor() * texel_size;
let min_y = ((center_ls.y - sphere_radius) / texel_size).floor() * texel_size;
let mut min = Vec3::new(min_x, min_y, center_ls.z - sphere_radius);
let mut max = Vec3::new(
min_x + diameter,
min_y + diameter,
center_ls.z + sphere_radius,
);
let mut have_caster = false;
let mut cmin_z = f32::INFINITY;
let mut cmax_z = f32::NEG_INFINITY;
if !casters_world_aabbs.is_empty() {
let clip_min_w = frustum_center - Vec3::splat(sphere_radius);
let clip_max_w = frustum_center + Vec3::splat(sphere_radius);
for aabb in casters_world_aabbs {
let clipped_min = Vec3::new(
aabb.min.x.max(clip_min_w.x),
aabb.min.y.max(clip_min_w.y),
aabb.min.z.max(clip_min_w.z),
);
let clipped_max = Vec3::new(
aabb.max.x.min(clip_max_w.x),
aabb.max.y.min(clip_max_w.y),
aabb.max.z.min(clip_max_w.z),
);
if clipped_min.x > clipped_max.x
|| clipped_min.y > clipped_max.y
|| clipped_min.z > clipped_max.z
{
continue;
}
let corners = [
Vec3::new(clipped_min.x, clipped_min.y, clipped_min.z),
Vec3::new(clipped_max.x, clipped_min.y, clipped_min.z),
Vec3::new(clipped_min.x, clipped_max.y, clipped_min.z),
Vec3::new(clipped_max.x, clipped_max.y, clipped_min.z),
Vec3::new(clipped_min.x, clipped_min.y, clipped_max.z),
Vec3::new(clipped_max.x, clipped_min.y, clipped_max.z),
Vec3::new(clipped_min.x, clipped_max.y, clipped_max.z),
Vec3::new(clipped_max.x, clipped_max.y, clipped_max.z),
];
for c in &corners {
let ls = view.transform_point3(*c);
cmin_z = cmin_z.min(ls.z);
cmax_z = cmax_z.max(ls.z);
}
have_caster = true;
}
}
if have_caster {
max.z = max.z.max(cmax_z);
min.z = min.z.min(cmin_z);
}
const Z_PULL_BACK_MIN: f32 = 50.0;
let visible_near = -max.z;
let visible_far = -min.z;
let z_pull_back = (visible_far - visible_near).max(Z_PULL_BACK_MIN);
let near = visible_near - z_pull_back;
let far = visible_far;
let projection = Mat4::orthographic_rh(min.x, max.x, min.y, max.y, near, far);
let view_projection = projection * view;
let avg_extent = ((max.x - min.x) + (max.y - min.y)) * 0.5;
let world_per_texel = avg_extent / resolution as f32;
Cascade {
view,
projection,
view_projection,
world_per_texel,
}
}
pub fn pssm_splits(near: f32, far: f32, lambda: f32, cascade_count: u32) -> Vec<f32> {
let n = cascade_count.max(1).min(MAX_CASCADES as u32);
let ratio = if near > 0.0 { far / near } else { 1.0 };
let mut splits = Vec::with_capacity(n as usize);
for i in 1..=n {
let p = i as f32 / n as f32;
let log_split = if near > 0.0 {
near * ratio.powf(p)
} else {
far * p
};
let uniform_split = near + (far - near) * p;
let split = lambda * log_split + (1.0 - lambda) * uniform_split;
splits.push(split);
}
splits
}
pub fn cascade_resolution(base: u32, _cascade_index: u32, min_res: u32) -> u32 {
base.max(min_res)
}
pub fn fit_cascades(
camera_view_projection: Mat4,
camera_view: Mat4,
direction: Vec3,
world_near: f32,
world_far: f32,
cascade_count: u32,
lambda: f32,
base_resolution: u32,
min_resolution: u32,
casters_world_aabbs: &[Aabb],
) -> Vec<(Cascade, u32, f32)> {
let inv_view_proj = camera_view_projection.inverse();
let splits = pssm_splits(world_near, world_far, lambda, cascade_count);
let proj = camera_view_projection * camera_view.inverse();
let split_to_ndc = |z: f32| {
let view_p = Vec4::new(0.0, 0.0, -z, 1.0);
let clip = proj * view_p;
if clip.w.abs() < 1e-8 {
return 1.0;
}
(clip.z / clip.w).clamp(0.0, 1.0)
};
const BLEND_OVERLAP: f32 = 0.55;
let mut cascades = Vec::with_capacity(splits.len());
let mut prev_split_world = world_near;
let mut prev_span = 0.0_f32;
for (i, split_world) in splits.iter().enumerate() {
let span = (*split_world - prev_split_world).max(0.0);
let near_world = if i == 0 {
prev_split_world
} else {
(prev_split_world - BLEND_OVERLAP * prev_span).max(world_near)
};
let ndc_near = split_to_ndc(near_world);
let ndc_far = split_to_ndc(*split_world);
let res = cascade_resolution(base_resolution, i as u32, min_resolution);
let cascade = fit_cascade(
inv_view_proj,
direction,
ndc_near,
ndc_far,
res,
casters_world_aabbs,
);
cascades.push((cascade, res, *split_world));
prev_split_world = *split_world;
prev_span = span;
}
cascades
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pssm_splits_count_is_clamped() {
assert_eq!(pssm_splits(0.1, 100.0, 0.5, 3).len(), 3);
assert_eq!(pssm_splits(0.1, 100.0, 0.5, 1).len(), 1);
assert_eq!(pssm_splits(0.1, 100.0, 0.5, 0).len(), 1);
assert_eq!(
pssm_splits(0.1, 100.0, 0.5, 99).len(),
MAX_CASCADES,
"cascade count clamps to MAX_CASCADES"
);
}
#[test]
fn pssm_splits_monotonic_increasing_and_last_is_far() {
for lambda in [0.0, 0.25, 0.5, 0.75, 1.0] {
let s = pssm_splits(0.5, 200.0, lambda, 4);
for w in s.windows(2) {
assert!(
w[1] > w[0],
"splits must strictly increase (lambda={lambda}): {s:?}"
);
}
let last = *s.last().unwrap();
assert!(
(last - 200.0).abs() < 1e-2,
"last split must equal far (lambda={lambda}): got {last}"
);
}
}
#[test]
fn pssm_splits_within_near_far() {
let near = 0.3;
let far = 150.0;
let s = pssm_splits(near, far, 0.5, 4);
for v in &s {
assert!(
*v > near - 1e-3 && *v <= far + 1e-2,
"split {v} out of [{near}, {far}]"
);
}
}
#[test]
fn pssm_splits_lambda0_is_uniform() {
let (near, far, n) = (1.0_f32, 101.0_f32, 4u32);
let s = pssm_splits(near, far, 0.0, n);
for (i, v) in s.iter().enumerate() {
let p = (i + 1) as f32 / n as f32;
let expected = near + (far - near) * p;
assert!(
(v - expected).abs() < 1e-3,
"lambda=0 must be uniform: split[{i}]={v} expected {expected}"
);
}
}
#[test]
fn pssm_splits_lambda1_is_logarithmic() {
let (near, far, n) = (1.0_f32, 256.0_f32, 4u32);
let s = pssm_splits(near, far, 1.0, n);
let ratio = far / near;
for (i, v) in s.iter().enumerate() {
let p = (i + 1) as f32 / n as f32;
let expected = near * ratio.powf(p);
assert!(
(v - expected).abs() < 1e-2,
"lambda=1 must be logarithmic: split[{i}]={v} expected {expected}"
);
}
}
#[test]
fn pssm_splits_near_zero_is_finite() {
let s = pssm_splits(0.0, 100.0, 0.5, 4);
assert_eq!(s.len(), 4);
for v in &s {
assert!(v.is_finite(), "near=0 fallback produced non-finite {v}");
}
assert!((s.last().unwrap() - 100.0).abs() < 1e-2);
}
#[test]
fn cascade_resolution_floors_at_min() {
assert_eq!(cascade_resolution(2048, 0, 16), 2048);
assert_eq!(cascade_resolution(8, 0, 16), 16, "below min floors to min");
assert_eq!(cascade_resolution(16, 3, 16), 16);
}
#[test]
fn fit_cascades_count_ordering_and_far() {
let near = 0.1_f32;
let far = 100.0_f32;
let count = 4u32;
let proj = Mat4::perspective_rh(60.0_f32.to_radians(), 16.0 / 9.0, near, far);
let view = Mat4::look_at_rh(Vec3::new(0.0, 5.0, 10.0), Vec3::ZERO, Vec3::Y);
let view_proj = proj * view;
let dir = Vec3::new(0.3, -1.0, 0.2).normalize();
let out = fit_cascades(view_proj, view, dir, near, far, count, 0.5, 2048, 16, &[]);
assert_eq!(out.len(), count as usize, "one entry per cascade");
let fars: Vec<f32> = out.iter().map(|(_, _, f)| *f).collect();
for w in fars.windows(2) {
assert!(w[1] > w[0], "cascade split_far must increase: {fars:?}");
}
assert!((fars.last().unwrap() - far).abs() < 1e-1);
for (c, res, _) in &out {
assert!(*res >= 16);
assert!(c.world_per_texel > 0.0 && c.world_per_texel.is_finite());
assert!(c
.view_projection
.to_cols_array()
.iter()
.all(|v| v.is_finite()));
}
}
}