struct WaterUniform {
render_vp: mat4x4<f32>,
inv_render_vp: mat4x4<f32>,
cur_vp: mat4x4<f32>,
prev_vp: mat4x4<f32>,
model: mat4x4<f32>,
extents: vec4<f32>,
shallow_color: vec4<f32>,
deep_color: vec4<f32>,
foam_color: vec4<f32>,
wave: vec4<f32>,
wave_dir_time: vec4<f32>,
params: vec4<f32>,
sun_dir: vec4<f32>,
sun_color: vec4<f32>,
camera_pos: vec4<f32>,
resolution: vec4<f32>,
fog_color: vec4<f32>,
fog_params: vec4<f32>,
pad: array<vec4<f32>, 15>,
}
@group(0) @binding(0) var<uniform> water: WaterUniform;
@group(0) @binding(1) var refraction_tex: texture_2d<f32>;
@group(0) @binding(2) var depth_tex: texture_depth_2d;
@group(0) @binding(3) var linear_samp: sampler;
@group(0) @binding(4) var point_samp: sampler;
@group(0) @binding(5) var prefiltered_env: texture_cube<f32>;
@group(0) @binding(6) var ibl_samp: sampler;
struct VertexInput {
@location(0) plane: vec2<f32>,
}
struct VertexOutput {
@builtin(position) clip: vec4<f32>,
@location(0) flat_world: vec3<f32>,
}
struct FragmentOutput {
@location(0) color: vec4<f32>,
@location(1) velocity: vec2<f32>,
}
fn safe_normalize(v: vec3<f32>, fallback: vec3<f32>) -> vec3<f32> {
let len = length(v);
if len < 1e-5 {
return fallback;
}
return v / len;
}
struct WaveSample {
offset: vec3<f32>,
normal: vec3<f32>,
detail: f32,
}
const WAVE_COUNT: i32 = 6;
const GRAVITY: f32 = 9.81;
const TAU: f32 = 6.2831853;
fn hash1(n: f32) -> f32 {
return fract(sin(n * 12.9898) * 43758.5453);
}
fn hash2(p: vec2<f32>) -> f32 {
return fract(sin(dot(p, vec2<f32>(127.1, 311.7))) * 43758.5453);
}
fn value_noise(p: vec2<f32>) -> f32 {
let i = floor(p);
let f = fract(p);
let u = f * f * (3.0 - 2.0 * f);
return mix(
mix(hash2(i + vec2<f32>(0.0, 0.0)), hash2(i + vec2<f32>(1.0, 0.0)), u.x),
mix(hash2(i + vec2<f32>(0.0, 1.0)), hash2(i + vec2<f32>(1.0, 1.0)), u.x),
u.y,
);
}
fn fbm(p: vec2<f32>) -> f32 {
var value = 0.0;
var amplitude = 0.5;
var sample_point = p;
for (var i = 0; i < 4; i = i + 1) {
value += amplitude * value_noise(sample_point);
sample_point = sample_point * 2.0;
amplitude = amplitude * 0.5;
}
return value;
}
fn gerstner_waves(world_xz: vec2<f32>, t: f32) -> WaveSample {
let warp_freq = 1.0 / max(water.wave.z * 4.0, 1.0);
let warp_uv = world_xz * warp_freq + vec2<f32>(0.0, t * 0.03);
let warp = vec2<f32>(
fbm(warp_uv) - 0.5,
fbm(warp_uv + vec2<f32>(31.4, 17.2)) - 0.5,
);
let base = world_xz + warp * (max(water.wave.z, 0.4) * 0.8);
let base_angle = atan2(water.wave_dir_time.y, water.wave_dir_time.x);
let base_wavelength = max(water.wave.z, 0.4);
let base_amplitude = water.wave.x;
let steepness = clamp(water.wave.y, 0.0, 1.0);
let speed = water.wave.w;
var offset = vec3<f32>(0.0, 0.0, 0.0);
var slope = vec3<f32>(0.0, 0.0, 0.0);
var detail = 0.0;
for (var i = 0; i < WAVE_COUNT; i = i + 1) {
let fi = f32(i);
let r_dir = hash1(fi * 1.7 + 0.3);
let r_len = hash1(fi * 3.1 + 1.7);
let r_phase = hash1(fi * 2.3 + 4.1);
let angle = base_angle + (r_dir * 2.0 - 1.0) * (0.35 + fi * 0.28);
let dir = vec2<f32>(cos(angle), sin(angle));
let jitter = 0.7 + r_len * 0.6;
let wavelength = base_wavelength * pow(0.64, fi) * jitter;
let amplitude = base_amplitude * pow(0.52, fi) * jitter;
let w = TAU / wavelength;
let q = steepness / (w * amplitude * f32(WAVE_COUNT) + 1e-4);
let phase = w * dot(dir, base) + t * sqrt(GRAVITY * w) * speed * 0.35 + r_phase * TAU;
let c = cos(phase);
let s = sin(phase);
let wa = w * amplitude;
offset.x += q * amplitude * dir.x * c;
offset.z += q * amplitude * dir.y * c;
offset.y += amplitude * s;
slope.x += dir.x * wa * c;
slope.z += dir.y * wa * c;
slope.y += q * wa * s;
let curvature = w * wa;
detail += curvature * curvature;
}
var sample: WaveSample;
sample.offset = offset;
sample.normal = safe_normalize(
vec3<f32>(-slope.x, 1.0 - slope.y, -slope.z),
vec3<f32>(0.0, 1.0, 0.0),
);
sample.detail = detail;
return sample;
}
@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
var out: VertexOutput;
let extents = water.extents.xy;
let local_xz = in.plane * extents;
let time = water.wave_dir_time.z;
let flat_world = (water.model * vec4<f32>(local_xz.x, 0.0, local_xz.y, 1.0)).xyz;
let wave = gerstner_waves(flat_world.xz, time);
let world = flat_world + wave.offset;
out.flat_world = flat_world;
out.clip = water.render_vp * vec4<f32>(world, 1.0);
return out;
}
fn reconstruct_world(uv: vec2<f32>, depth: f32) -> vec3<f32> {
let ndc = vec4<f32>(uv.x * 2.0 - 1.0, 1.0 - uv.y * 2.0, depth, 1.0);
let world = water.inv_render_vp * ndc;
let w = sign(world.w) * max(abs(world.w), 1e-5);
return world.xyz / w;
}
@fragment
fn fs_main(in: VertexOutput) -> FragmentOutput {
var out: FragmentOutput;
let t = water.wave_dir_time.z;
let wave = gerstner_waves(in.flat_world.xz, t);
let world_pos = in.flat_world + wave.offset;
let plane_y = water.model[3].y;
let cam_dist = length(world_pos - water.camera_pos.xyz);
let pixel_world = cam_dist * 2.0 / max(water.resolution.y, 1.0);
let normal = wave.normal;
let view = safe_normalize(water.camera_pos.xyz - world_pos, vec3<f32>(0.0, 1.0, 0.0));
let base_roughness = clamp(water.params.x, 0.02, 1.0);
let roughness_aa = clamp(base_roughness + pixel_world * 0.6 + pixel_world * sqrt(wave.detail) * 0.26, 0.02, 1.0);
let screen_uv = in.clip.xy / water.resolution.xy;
let base_depth = textureSampleLevel(depth_tex, point_samp, screen_uv, 0);
var base_bottom = 40.0;
if base_depth > 0.0 {
base_bottom = clamp(plane_y - reconstruct_world(screen_uv, base_depth).y, 0.0, 40.0);
}
let distort_fade = clamp(base_bottom * 0.5, 0.0, 1.0);
let distortion = normal.xz * water.params.w * distort_fade;
var refract_uv = clamp(screen_uv + distortion, vec2<f32>(0.001), vec2<f32>(0.999));
var floor_depth = textureSampleLevel(depth_tex, point_samp, refract_uv, 0);
if floor_depth > 0.0 {
let probe = reconstruct_world(refract_uv, floor_depth);
if probe.y > plane_y {
refract_uv = screen_uv;
floor_depth = base_depth;
}
}
var bottom_depth = 40.0;
var refracted_scene = water.deep_color.rgb;
if floor_depth > 0.0 {
let floor_world = reconstruct_world(refract_uv, floor_depth);
bottom_depth = clamp(plane_y - floor_world.y, 0.0, 40.0);
refracted_scene = textureSampleLevel(refraction_tex, linear_samp, refract_uv, 0.0).rgb;
}
let depth_fade = max(water.shallow_color.w, 0.1);
let absorb = 1.0 - exp(-bottom_depth / depth_fade);
let water_tint = mix(water.shallow_color.rgb, water.deep_color.rgb, absorb);
let refracted = mix(refracted_scene, water_tint, absorb);
let n_dot_v = max(dot(normal, view), 0.0);
let fresnel = min(0.02 + 0.98 * pow(1.0 - n_dot_v, water.params.y), 0.6);
let reflect_dir = reflect(-view, normal);
let max_reflection_lod = 4.0;
let env = clamp(
textureSampleLevel(prefiltered_env, ibl_samp, reflect_dir, roughness_aa * max_reflection_lod).rgb,
vec3<f32>(0.0),
vec3<f32>(65000.0),
);
let reflection = env * water.params.z;
var color = mix(refracted, reflection, clamp(fresnel * water.params.z, 0.0, 1.0));
let sun = safe_normalize(water.sun_dir.xyz, vec3<f32>(0.0, 1.0, 0.0));
let half_vec = safe_normalize(view + sun, normal);
let shininess = mix(16.0, 180.0, 1.0 - roughness_aa);
let ndh = max(dot(normal, half_vec), 0.0);
let specular = pow(ndh, shininess) * water.extents.z;
color = color + water.sun_color.rgb * specular;
let crest = clamp((world_pos.y - plane_y) / max(water.wave.x * 2.5, 0.01), 0.0, 1.0);
let crest_foam = smoothstep(0.78, 1.0, crest);
let edge_foam = 1.0 - smoothstep(0.0, max(water.deep_color.w, 0.001), bottom_depth);
let foam = water.foam_color.w * max(crest_foam, edge_foam);
color = mix(color, water.foam_color.rgb, foam);
if water.fog_color.w > 0.5 {
let fog_t = clamp(
(cam_dist - water.fog_params.x) / max(water.fog_params.y - water.fog_params.x, 0.001),
0.0,
1.0,
);
color = mix(color, water.fog_color.rgb, fog_t);
}
color = clamp(color, vec3<f32>(0.0), vec3<f32>(64.0));
out.color = vec4<f32>(color, 1.0);
let prev_wave = gerstner_waves(in.flat_world.xz, water.wave_dir_time.w);
let prev_world_pos = in.flat_world + prev_wave.offset;
let cur_clip = water.cur_vp * vec4<f32>(world_pos, 1.0);
let prev_clip = water.prev_vp * vec4<f32>(prev_world_pos, 1.0);
let cur_w = sign(cur_clip.w) * max(abs(cur_clip.w), 1e-5);
let prev_w = sign(prev_clip.w) * max(abs(prev_clip.w), 1e-5);
let cur_ndc = cur_clip.xy / cur_w;
let prev_ndc = prev_clip.xy / prev_w;
let cur_uv = vec2<f32>(cur_ndc.x * 0.5 + 0.5, 0.5 - cur_ndc.y * 0.5);
let prev_uv = vec2<f32>(prev_ndc.x * 0.5 + 0.5, 0.5 - prev_ndc.y * 0.5);
var velocity = prev_uv - cur_uv;
velocity = select(velocity, vec2<f32>(0.0), velocity != velocity);
out.velocity = clamp(velocity, vec2<f32>(-2.0), vec2<f32>(2.0));
return out;
}