use alloc::{borrow::Cow, sync::Arc};
use bevy_asset::{Asset, AssetEvent, AssetId, Handle};
use bevy_color::{ColorToComponents, Gray, LinearRgba};
use bevy_ecs::{
component::Component,
lifecycle::HookContext,
message::MessageReader,
system::{Res, ResMut},
template::FromTemplate,
world::DeferredWorld,
};
use bevy_image::Image;
use bevy_math::curve::{FunctionCurve, Interval, SampleAutoCurve};
use bevy_math::{ops, Curve, FloatPow, Vec3};
use bevy_platform::collections::HashSet;
use bevy_reflect::TypePath;
use bevy_transform::components::GlobalTransform;
use core::f32::{self, consts::PI};
use smallvec::SmallVec;
use wgpu_types::TextureFormat;
#[derive(Clone, Component, FromTemplate)]
#[require(GlobalTransform)]
#[component(on_add = set_default_transform)]
pub struct Atmosphere {
pub inner_radius: f32,
pub outer_radius: f32,
pub ground_albedo: Vec3,
pub medium: Handle<ScatteringMedium>,
}
fn set_default_transform(mut world: DeferredWorld<'_>, HookContext { entity, .. }: HookContext) {
let Some(inner_radius) = world.get::<Atmosphere>(entity).map(|a| a.inner_radius) else {
unreachable!("on_add hooks guarantee the component is present");
};
if let Some(mut transform) = world.get_mut::<GlobalTransform>(entity)
&& *transform == GlobalTransform::default()
{
*transform = GlobalTransform::from_translation(-Vec3::Y * inner_radius);
}
}
impl Atmosphere {
pub fn earth(medium: Handle<ScatteringMedium>) -> Self {
const EARTH_INNER_RADIUS: f32 = 6_360_000.0;
const EARTH_OUTER_RADIUS: f32 = 6_460_000.0;
const EARTH_ALBEDO: Vec3 = Vec3::splat(0.3);
Self {
inner_radius: EARTH_INNER_RADIUS,
outer_radius: EARTH_OUTER_RADIUS,
ground_albedo: EARTH_ALBEDO,
medium,
}
}
pub fn mars(medium: Handle<ScatteringMedium>) -> Self {
const MARS_INNER_RADIUS: f32 = 3_389_500.0;
const MARS_OUTER_RADIUS: f32 = 3_509_500.0;
const MARS_ALBEDO: Vec3 = Vec3::splat(0.1);
Self {
inner_radius: MARS_INNER_RADIUS,
outer_radius: MARS_OUTER_RADIUS,
ground_albedo: MARS_ALBEDO,
medium,
}
}
}
#[derive(TypePath, Asset, Clone)]
pub struct ScatteringMedium {
pub label: Option<Cow<'static, str>>,
pub falloff_resolution: u32,
pub phase_resolution: u32,
pub terms: SmallVec<[ScatteringTerm; 1]>,
}
impl Default for ScatteringMedium {
fn default() -> Self {
ScatteringMedium::earth(256, 256)
}
}
impl ScatteringMedium {
pub fn new(
falloff_resolution: u32,
phase_resolution: u32,
terms: impl IntoIterator<Item = ScatteringTerm>,
) -> Self {
Self {
label: None,
falloff_resolution,
phase_resolution,
terms: terms.into_iter().collect(),
}
}
pub fn with_label(self, label: impl Into<Cow<'static, str>>) -> Self {
Self {
label: Some(label.into()),
..self
}
}
pub fn with_density_multiplier(mut self, multiplier: f32) -> Self {
self.terms.iter_mut().for_each(|term| {
term.absorption *= multiplier;
term.scattering *= multiplier;
});
self
}
pub fn earth(falloff_resolution: u32, phase_resolution: u32) -> Self {
Self::new(
falloff_resolution,
phase_resolution,
[
ScatteringTerm {
absorption: Vec3::ZERO,
scattering: Vec3::new(5.802e-6, 13.558e-6, 33.100e-6),
falloff: Falloff::Exponential { scale: 8.0 / 60.0 },
phase: PhaseFunction::Rayleigh,
},
ScatteringTerm {
absorption: Vec3::splat(3.996e-6),
scattering: Vec3::splat(0.444e-6),
falloff: Falloff::Exponential { scale: 1.2 / 60.0 },
phase: PhaseFunction::Mie { asymmetry: 0.8 },
},
ScatteringTerm {
absorption: Vec3::new(0.650e-6, 1.881e-6, 0.085e-6),
scattering: Vec3::ZERO,
falloff: Falloff::Tent {
center: 0.75,
width: 0.3,
},
phase: PhaseFunction::Isotropic,
},
],
)
.with_label("earth_atmosphere")
}
pub fn mars(falloff_resolution: u32, phase_resolution: u32, dust_phase: Handle<Image>) -> Self {
const MARS_ATMOSPHERE_HEIGHT: f32 = 120_000.0;
const RAYLEIGH_SCALE_HEIGHT: f32 = 8_000.0;
let dust_falloff = Falloff::from_curve(FunctionCurve::new(Interval::UNIT, |p| {
let h = (1.0 - p) * MARS_ATMOSPHERE_HEIGHT;
0.75 * ops::exp(1.0 - ops::exp(h / 4_000.0))
+ 0.25 * ops::exp(1.0 - ops::exp(h / 20_000.0))
}));
Self::new(
falloff_resolution,
phase_resolution,
[
ScatteringTerm {
absorption: Vec3::ZERO,
scattering: Vec3::new(9.91e-8, 2.32e-7, 5.65e-7),
falloff: Falloff::Exponential {
scale: RAYLEIGH_SCALE_HEIGHT / MARS_ATMOSPHERE_HEIGHT,
},
phase: PhaseFunction::Rayleigh,
},
ScatteringTerm {
absorption: Vec3::new(1.26e-6, 5.25e-6, 9.33e-6), scattering: Vec3::new(30.67e-6, 25.39e-6, 20.93e-6), falloff: dust_falloff,
phase: PhaseFunction::from_chromatic_texture(dust_phase),
},
],
)
.with_label("mars_atmosphere")
}
}
#[derive(Default, Clone)]
pub struct ScatteringTerm {
pub absorption: Vec3,
pub scattering: Vec3,
pub falloff: Falloff,
pub phase: PhaseFunction,
}
#[derive(Default, Clone)]
pub enum Falloff {
#[default]
Linear,
Exponential {
scale: f32,
},
Tent {
center: f32,
width: f32,
},
Curve(Arc<dyn Curve<f32> + Send + Sync>),
}
impl Falloff {
pub fn from_curve(curve: impl Curve<f32> + Send + Sync + 'static) -> Self {
Self::Curve(Arc::new(curve))
}
pub fn sample(&self, p: f32) -> f32 {
match self {
Falloff::Linear => p,
Falloff::Exponential { scale } => {
if *scale == 0.0 {
p
} else {
let s = -1.0 / scale;
let exp_p_s = ops::exp((1.0 - p) * s);
let exp_s = ops::exp(s);
(exp_p_s - exp_s) / (1.0 - exp_s)
}
}
Falloff::Tent { center, width } => (1.0 - (p - center).abs() / (0.5 * width)).max(0.0),
Falloff::Curve(curve) => curve.sample(p).unwrap_or(0.0),
}
}
}
#[derive(Clone)]
pub enum PhaseFunction {
Isotropic,
Rayleigh,
Mie {
asymmetry: f32,
},
Curve(Arc<dyn Curve<f32> + Send + Sync>),
ChromaticCurve(Arc<dyn Curve<LinearRgba> + Send + Sync>),
ChromaticTexture(Handle<Image>),
}
impl PhaseFunction {
pub fn from_curve(curve: impl Curve<f32> + Send + Sync + 'static) -> Self {
Self::Curve(Arc::new(curve))
}
pub fn from_chromatic_curve(curve: impl Curve<LinearRgba> + Send + Sync + 'static) -> Self {
Self::ChromaticCurve(Arc::new(curve))
}
pub fn from_chromatic_texture(image: Handle<Image>) -> Self {
Self::ChromaticTexture(image)
}
pub fn sample(&self, neg_l_dot_v: f32) -> Option<LinearRgba> {
const FRAC_4_PI: f32 = 0.25 / PI;
const FRAC_3_16_PI: f32 = 0.1875 / PI;
match self {
PhaseFunction::Isotropic => Some(LinearRgba::gray(FRAC_4_PI)),
PhaseFunction::Rayleigh => Some(LinearRgba::gray(
FRAC_3_16_PI * (1.0 + neg_l_dot_v * neg_l_dot_v),
)),
PhaseFunction::Mie { asymmetry } => {
let denom = 1.0 + asymmetry.squared() - 2.0 * asymmetry * neg_l_dot_v;
Some(LinearRgba::from_vec3(Vec3::splat(
FRAC_4_PI * (1.0 - asymmetry.squared()) / (denom * denom.sqrt()),
)))
}
PhaseFunction::Curve(curve) => curve
.sample(neg_l_dot_v)
.map(LinearRgba::gray)
.or(Some(LinearRgba::gray(0.0))),
PhaseFunction::ChromaticCurve(curve) => {
curve.sample(neg_l_dot_v).or(Some(LinearRgba::gray(0.0)))
}
PhaseFunction::ChromaticTexture(_) => None,
}
}
}
impl Default for PhaseFunction {
fn default() -> Self {
Self::Mie { asymmetry: 0.8 }
}
}
pub fn extract_chromatic_phase_textures(
mut reader: MessageReader<AssetEvent<Image>>,
images: Res<bevy_asset::Assets<Image>>,
mut scattering_media: ResMut<bevy_asset::Assets<ScatteringMedium>>,
) {
let extract_ids: HashSet<AssetId<Image>> = scattering_media
.iter()
.flat_map(|(_, m)| m.terms.iter())
.filter_map(|t| {
let PhaseFunction::ChromaticTexture(h) = &t.phase else {
return None;
};
Some(h.id())
})
.collect();
for event in reader.read() {
let AssetEvent::LoadedWithDependencies { id } = event else {
continue;
};
if !extract_ids.contains(id) {
continue;
}
let Some(image) = images.get(*id) else {
continue;
};
if image.texture_descriptor.format != TextureFormat::Rgba32Float {
continue;
}
let width = image.texture_descriptor.size.width;
if width == 0 {
continue;
}
let Some(samples): Option<Vec<LinearRgba>> = (0..width)
.map(|x| image.get_color_at_1d(x).ok().map(|c| c.to_linear()))
.collect()
else {
continue;
};
let Ok(curve) = SampleAutoCurve::new(
Interval::new(-1.0, 1.0).expect("[-1, 1] valid for cos θ"),
samples,
) else {
continue;
};
let new_phase = PhaseFunction::from_chromatic_curve(curve);
for (_id, medium) in scattering_media.iter_mut() {
for term in medium.terms.iter_mut() {
if let PhaseFunction::ChromaticTexture(handle) = &term.phase
&& handle.id() == *id
{
term.phase = new_phase.clone();
}
}
}
}
}