struct SimParams {
anchor_transform: mat4x4<f32>,
wind: vec4<f32>,
gust: vec4<f32>,
gravity: vec4<f32>,
integration: vec4<f32>,
constraint: vec4<f32>,
counts: vec4<u32>,
}
struct WritebackParams {
inverse_anchor: mat4x4<f32>,
counts: vec4<u32>,
}
const VERTEX_STRIDE_FLOATS: u32 = 18u;
@group(0) @binding(0) var<storage, read_write> positions: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read_write> previous_positions: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read> source_positions: array<vec4<f32>>;
@group(0) @binding(3) var<storage, read_write> normals: array<vec4<f32>>;
@group(0) @binding(4) var<storage, read> rest_positions: array<vec4<f32>>;
@group(0) @binding(5) var<uniform> params: SimParams;
@group(0) @binding(6) var<storage, read_write> tangents: array<vec4<f32>>;
@group(0) @binding(7) var<storage, read_write> target_vertices: array<f32>;
@group(0) @binding(8) var<uniform> writeback: WritebackParams;
const TAU: f32 = 6.2831853;
fn wind_velocity(position: vec3<f32>) -> vec3<f32> {
let direction = params.wind.xyz;
let strength = params.wind.w;
let time = params.gust.w;
let gust_phase = time * params.gust.y * TAU + position.x * 0.45 + position.y * 0.3;
let gust = params.gust.x * (0.5 + 0.5 * sin(gust_phase));
let swirl_phase = time * 1.7 + position.y * 1.3 + position.z * 2.1;
let turbulence = params.gust.z * vec3<f32>(
sin(swirl_phase * 1.3 + position.x),
0.35 * sin(swirl_phase * 0.7 + position.y * 2.0),
cos(swirl_phase + position.x * 1.4),
);
return direction * (strength + gust) + turbulence;
}
@compute @workgroup_size(256)
fn integrate(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
if index >= params.counts.z {
return;
}
let rest = rest_positions[index];
if rest.w == 0.0 {
let anchored = params.anchor_transform * vec4<f32>(rest.xyz, 1.0);
let pinned = vec4<f32>(anchored.xyz, 0.0);
previous_positions[index] = pinned;
positions[index] = pinned;
return;
}
let current = positions[index];
let delta_time = params.integration.x;
let damping = params.integration.y;
let velocity = (current.xyz - previous_positions[index].xyz) * damping;
var normal = normals[index].xyz;
let normal_length = length(normal);
if normal_length > 0.001 {
normal = normal / normal_length;
} else {
normal = vec3<f32>(0.0, 0.0, 1.0);
}
let wind_acceleration = normal * dot(normal, wind_velocity(current.xyz));
let acceleration = params.gravity.xyz + wind_acceleration;
var next = current.xyz + velocity + acceleration * delta_time * delta_time;
next.y = max(next.y, params.integration.z);
previous_positions[index] = vec4<f32>(current.xyz, current.w);
positions[index] = vec4<f32>(next, current.w);
}
@compute @workgroup_size(256)
fn solve(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
if index >= params.counts.z {
return;
}
let current = source_positions[index];
if current.w == 0.0 {
positions[index] = current;
return;
}
let columns = i32(params.counts.x);
let rows = i32(params.counts.y);
let column = i32(index) % columns;
let row = i32(index) / columns;
let spacing = params.constraint.xy;
var offsets = array<vec2<i32>, 12>(
vec2<i32>(-1, 0), vec2<i32>(1, 0), vec2<i32>(0, -1), vec2<i32>(0, 1),
vec2<i32>(-1, -1), vec2<i32>(1, -1), vec2<i32>(-1, 1), vec2<i32>(1, 1),
vec2<i32>(-2, 0), vec2<i32>(2, 0), vec2<i32>(0, -2), vec2<i32>(0, 2),
);
var weights = array<f32, 12>(
1.0, 1.0, 1.0, 1.0,
0.7, 0.7, 0.7, 0.7,
0.35, 0.35, 0.35, 0.35,
);
var correction = vec3<f32>(0.0, 0.0, 0.0);
var total_weight = 0.0;
for (var neighbor = 0; neighbor < 12; neighbor += 1) {
let offset = offsets[neighbor];
let neighbor_column = column + offset.x;
let neighbor_row = row + offset.y;
if neighbor_column < 0 || neighbor_column >= columns || neighbor_row < 0 || neighbor_row >= rows {
continue;
}
let neighbor_position = source_positions[neighbor_row * columns + neighbor_column];
let combined_mass = current.w + neighbor_position.w;
if combined_mass == 0.0 {
continue;
}
let rest_length = length(vec2<f32>(f32(offset.x) * spacing.x, f32(offset.y) * spacing.y));
let difference = neighbor_position.xyz - current.xyz;
let distance = length(difference);
if distance < 0.000001 {
continue;
}
let stretch = (distance - rest_length) / distance;
let weight = weights[neighbor];
correction += difference * stretch * (current.w / combined_mass) * weight;
total_weight += weight;
}
var next = current.xyz;
if total_weight > 0.0 {
next += correction * (params.constraint.z / total_weight) * 2.0;
}
next.y = max(next.y, params.integration.z);
positions[index] = vec4<f32>(next, current.w);
}
@compute @workgroup_size(256)
fn update_normals(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
if index >= params.counts.z {
return;
}
let columns = i32(params.counts.x);
let rows = i32(params.counts.y);
let column = i32(index) % columns;
let row = i32(index) / columns;
let left = positions[row * columns + max(column - 1, 0)].xyz;
let right = positions[row * columns + min(column + 1, columns - 1)].xyz;
let up = positions[max(row - 1, 0) * columns + column].xyz;
let down = positions[min(row + 1, rows - 1) * columns + column].xyz;
let tangent_horizontal = right - left;
let tangent_vertical = down - up;
var normal = cross(tangent_vertical, tangent_horizontal);
let normal_length = length(normal);
if normal_length > 0.000001 {
normal = normal / normal_length;
} else {
normal = vec3<f32>(0.0, 0.0, 1.0);
}
normals[index] = vec4<f32>(normal, 0.0);
var tangent = tangent_horizontal;
let tangent_length = length(tangent);
if tangent_length > 0.000001 {
tangent = tangent / tangent_length;
} else {
tangent = vec3<f32>(1.0, 0.0, 0.0);
}
tangents[index] = vec4<f32>(tangent, 1.0);
}
@compute @workgroup_size(256)
fn write_vertices(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
if index >= writeback.counts.x {
return;
}
let local_position = writeback.inverse_anchor * vec4<f32>(positions[index].xyz, 1.0);
let local_normal = normalize((writeback.inverse_anchor * vec4<f32>(normals[index].xyz, 0.0)).xyz);
let world_tangent = tangents[index];
let local_tangent = normalize((writeback.inverse_anchor * vec4<f32>(world_tangent.xyz, 0.0)).xyz);
let base = (writeback.counts.y + index) * VERTEX_STRIDE_FLOATS;
target_vertices[base + 0u] = local_position.x;
target_vertices[base + 1u] = local_position.y;
target_vertices[base + 2u] = local_position.z;
target_vertices[base + 3u] = local_normal.x;
target_vertices[base + 4u] = local_normal.y;
target_vertices[base + 5u] = local_normal.z;
target_vertices[base + 10u] = local_tangent.x;
target_vertices[base + 11u] = local_tangent.y;
target_vertices[base + 12u] = local_tangent.z;
target_vertices[base + 13u] = world_tangent.w;
}