use crate::ecs::cloth::components::Cloth;
use crate::ecs::world::{CLOTH, GLOBAL_TRANSFORM, RENDER_MESH, World};
use crate::render::wgpu::rendergraph::{PassExecutionContext, PassNode, SubGraphRunCommand};
use freecs::Entity;
use nalgebra_glm::{Mat4, Vec3};
use std::collections::HashMap;
use wgpu::util::DeviceExt;
const WORKGROUP_SIZE: u32 = 256;
const MAX_DELTA_TIME: f32 = 1.0 / 30.0;
const DISABLED_GROUND_HEIGHT: f32 = -1.0e30;
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct SimParams {
anchor_transform: [[f32; 4]; 4],
wind: [f32; 4],
gust: [f32; 4],
gravity: [f32; 4],
integration: [f32; 4],
constraint: [f32; 4],
counts: [u32; 4],
}
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct WritebackParams {
inverse_anchor: [[f32; 4]; 4],
counts: [u32; 4],
}
pub struct ClothWriteTarget {
pub buffer: wgpu::Buffer,
pub buffer_generation: u64,
pub vertex_offset: u32,
}
struct WritebackState {
buffer_generation: u64,
vertex_offset: u32,
params_buffer: wgpu::Buffer,
bind_group: wgpu::BindGroup,
}
struct ClothState {
config: Cloth,
columns: u32,
rows: u32,
substeps: u32,
solver_iterations: u32,
workgroups: u32,
positions: wgpu::Buffer,
normals: wgpu::Buffer,
tangents: wgpu::Buffer,
sim_params_buffer: wgpu::Buffer,
compute_group_primary: wgpu::BindGroup,
compute_group_secondary: wgpu::BindGroup,
writebacks: Vec<WritebackState>,
}
pub struct ClothPass {
compute_layout: wgpu::BindGroupLayout,
writeback_layout: wgpu::BindGroupLayout,
integrate_pipeline: wgpu::ComputePipeline,
solve_pipeline: wgpu::ComputePipeline,
normals_pipeline: wgpu::ComputePipeline,
writeback_pipeline: wgpu::ComputePipeline,
max_storage_binding_size: u64,
states: HashMap<Entity, ClothState>,
write_targets: HashMap<String, Vec<ClothWriteTarget>>,
time: f32,
}
fn build_rest_positions(cloth: &Cloth, columns: u32, rows: u32) -> Vec<[f32; 4]> {
use crate::ecs::cloth::components::ClothPinning;
let spacing_x = cloth.width / (columns - 1) as f32;
let spacing_y = cloth.height / (rows - 1) as f32;
let mut rest = Vec::with_capacity((columns * rows) as usize);
for row in 0..rows {
for column in 0..columns {
let pinned = match cloth.pinning {
ClothPinning::TopRow => row == 0,
ClothPinning::TopCorners => row == 0 && (column == 0 || column == columns - 1),
ClothPinning::None => false,
};
let inverse_mass = if pinned { 0.0 } else { 1.0 };
rest.push([
-cloth.width / 2.0 + column as f32 * spacing_x,
-(row as f32 * spacing_y),
0.0,
inverse_mass,
]);
}
}
rest
}
impl ClothPass {
pub fn new(device: &wgpu::Device) -> Self {
let storage_read_write = wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
};
let storage_read = wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
};
let uniform = wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
};
let compute_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Cloth Compute Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read_write,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read_write,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read_write,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 5,
visibility: wgpu::ShaderStages::COMPUTE,
ty: uniform,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 6,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read_write,
count: None,
},
],
});
let writeback_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Cloth Writeback Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read_write,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read_write,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 6,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read_write,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 7,
visibility: wgpu::ShaderStages::COMPUTE,
ty: storage_read_write,
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 8,
visibility: wgpu::ShaderStages::COMPUTE,
ty: uniform,
count: None,
},
],
});
let simulate_shader = crate::render::wgpu::shader_compose::compile_wgsl(
device,
"Cloth Simulate Shader",
include_str!("../../shaders/cloth_simulate.wgsl"),
);
let compute_pipeline_layout =
device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Cloth Compute Pipeline Layout"),
bind_group_layouts: &[Some(&compute_layout)],
immediate_size: 0,
});
let writeback_pipeline_layout =
device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Cloth Writeback Pipeline Layout"),
bind_group_layouts: &[Some(&writeback_layout)],
immediate_size: 0,
});
let integrate_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Cloth Integrate Pipeline"),
layout: Some(&compute_pipeline_layout),
module: &simulate_shader,
entry_point: Some("integrate"),
compilation_options: Default::default(),
cache: None,
});
let solve_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Cloth Solve Pipeline"),
layout: Some(&compute_pipeline_layout),
module: &simulate_shader,
entry_point: Some("solve"),
compilation_options: Default::default(),
cache: None,
});
let normals_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Cloth Normals Pipeline"),
layout: Some(&compute_pipeline_layout),
module: &simulate_shader,
entry_point: Some("update_normals"),
compilation_options: Default::default(),
cache: None,
});
let writeback_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Cloth Writeback Pipeline"),
layout: Some(&writeback_pipeline_layout),
module: &simulate_shader,
entry_point: Some("write_vertices"),
compilation_options: Default::default(),
cache: None,
});
Self {
compute_layout,
writeback_layout,
integrate_pipeline,
solve_pipeline,
normals_pipeline,
writeback_pipeline,
max_storage_binding_size: device.limits().max_storage_buffer_binding_size,
states: HashMap::new(),
write_targets: HashMap::new(),
time: 0.0,
}
}
pub fn set_write_targets(&mut self, targets: HashMap<String, Vec<ClothWriteTarget>>) {
self.write_targets = targets;
}
fn build_state(&self, device: &wgpu::Device, config: &Cloth, anchor: &Mat4) -> ClothState {
let columns = config.columns.max(2);
let rows = config.rows.max(2);
let particle_count = columns * rows;
let rest_positions = build_rest_positions(config, columns, rows);
let initial_positions: Vec<[f32; 4]> = rest_positions
.iter()
.map(|rest| {
let world = anchor * nalgebra_glm::Vec4::new(rest[0], rest[1], rest[2], 1.0);
[world.x, world.y, world.z, rest[3]]
})
.collect();
let initial_normals = vec![[0.0_f32, 0.0, 1.0, 0.0]; particle_count as usize];
let initial_tangents = vec![[1.0_f32, 0.0, 0.0, 1.0]; particle_count as usize];
let positions = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Cloth Positions"),
contents: bytemuck::cast_slice(&initial_positions),
usage: wgpu::BufferUsages::STORAGE,
});
let scratch_positions = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Cloth Scratch Positions"),
contents: bytemuck::cast_slice(&initial_positions),
usage: wgpu::BufferUsages::STORAGE,
});
let previous_positions = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Cloth Previous Positions"),
contents: bytemuck::cast_slice(&initial_positions),
usage: wgpu::BufferUsages::STORAGE,
});
let normals = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Cloth Normals"),
contents: bytemuck::cast_slice(&initial_normals),
usage: wgpu::BufferUsages::STORAGE,
});
let tangents = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Cloth Tangents"),
contents: bytemuck::cast_slice(&initial_tangents),
usage: wgpu::BufferUsages::STORAGE,
});
let rest_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Cloth Rest Positions"),
contents: bytemuck::cast_slice(&rest_positions),
usage: wgpu::BufferUsages::STORAGE,
});
let sim_params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Cloth Sim Params"),
size: std::mem::size_of::<SimParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
fn sim_bind_group_entries<'buffers>(
destination: &'buffers wgpu::Buffer,
source: &'buffers wgpu::Buffer,
previous_positions: &'buffers wgpu::Buffer,
normals: &'buffers wgpu::Buffer,
rest_buffer: &'buffers wgpu::Buffer,
sim_params_buffer: &'buffers wgpu::Buffer,
tangents: &'buffers wgpu::Buffer,
) -> [wgpu::BindGroupEntry<'buffers>; 7] {
[
wgpu::BindGroupEntry {
binding: 0,
resource: destination.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: previous_positions.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: source.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: normals.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: rest_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: sim_params_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: tangents.as_entire_binding(),
},
]
}
let compute_group_primary = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Cloth Compute Bind Group Primary"),
layout: &self.compute_layout,
entries: &sim_bind_group_entries(
&positions,
&scratch_positions,
&previous_positions,
&normals,
&rest_buffer,
&sim_params_buffer,
&tangents,
),
});
let compute_group_secondary = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Cloth Compute Bind Group Secondary"),
layout: &self.compute_layout,
entries: &sim_bind_group_entries(
&scratch_positions,
&positions,
&previous_positions,
&normals,
&rest_buffer,
&sim_params_buffer,
&tangents,
),
});
ClothState {
config: config.clone(),
columns,
rows,
substeps: config.substeps.max(1),
solver_iterations: config.solver_iterations.max(1).next_multiple_of(2),
workgroups: particle_count.div_ceil(WORKGROUP_SIZE),
positions,
normals,
tangents,
sim_params_buffer,
compute_group_primary,
compute_group_secondary,
writebacks: Vec::new(),
}
}
fn sync_writebacks(&mut self, device: &wgpu::Device, entity: Entity, mesh_name: &str) {
let Some(targets) = self.write_targets.get(mesh_name) else {
if let Some(state) = self.states.get_mut(&entity) {
state.writebacks.clear();
}
return;
};
let Some(state) = self.states.get(&entity) else {
return;
};
let up_to_date = state.writebacks.len() == targets.len()
&& state
.writebacks
.iter()
.zip(targets.iter())
.all(|(writeback, target)| {
writeback.buffer_generation == target.buffer_generation
&& writeback.vertex_offset == target.vertex_offset
});
if up_to_date {
return;
}
let mut writebacks = Vec::with_capacity(targets.len());
for target in targets {
if target.buffer.size() > self.max_storage_binding_size {
tracing::warn!(
"cloth '{}' skipped a vertex write target: buffer size {} exceeds the \
device's maximum storage binding size {}",
mesh_name,
target.buffer.size(),
self.max_storage_binding_size
);
continue;
}
let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Cloth Writeback Params"),
size: std::mem::size_of::<WritebackParams>() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Cloth Writeback Bind Group"),
layout: &self.writeback_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: state.positions.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: state.normals.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 6,
resource: state.tangents.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 7,
resource: target.buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 8,
resource: params_buffer.as_entire_binding(),
},
],
});
writebacks.push(WritebackState {
buffer_generation: target.buffer_generation,
vertex_offset: target.vertex_offset,
params_buffer,
bind_group,
});
}
if let Some(state) = self.states.get_mut(&entity) {
state.writebacks = writebacks;
}
}
}
impl PassNode<World> for ClothPass {
fn name(&self) -> &str {
"cloth_pass"
}
fn reads(&self) -> Vec<&str> {
vec![]
}
fn writes(&self) -> Vec<&str> {
vec![]
}
fn reads_writes(&self) -> Vec<&str> {
vec!["color", "depth"]
}
fn prepare(&mut self, device: &wgpu::Device, queue: &wgpu::Queue, world: &World) {
let delta_time = world
.resources
.window
.timing
.delta_time
.clamp(0.0, MAX_DELTA_TIME);
self.time += delta_time;
let wind = world.resources.wind;
let wind_length = nalgebra_glm::length(&wind.direction);
let wind_direction = if wind_length > 0.0001 {
wind.direction / wind_length
} else {
Vec3::new(0.0, 0.0, 1.0)
};
let mut gathered: Vec<(Entity, Cloth, Mat4, String)> = Vec::new();
world
.core
.query()
.with(CLOTH | GLOBAL_TRANSFORM | RENDER_MESH)
.iter(|entity, table, index| {
gathered.push((
entity,
table.cloth[index].clone(),
table.global_transform[index].0,
table.render_mesh[index].name.clone(),
));
});
self.states
.retain(|entity, _| gathered.iter().any(|(seen, _, _, _)| seen == entity));
for (entity, config, anchor, mesh_name) in &gathered {
let needs_rebuild = self
.states
.get(entity)
.is_none_or(|state| state.config != *config);
if needs_rebuild {
let state = self.build_state(device, config, anchor);
self.states.insert(*entity, state);
}
self.sync_writebacks(device, *entity, mesh_name);
}
for (entity, config, anchor, _) in &gathered {
let Some(state) = self.states.get(entity) else {
continue;
};
let spacing_x = config.width / (state.columns - 1) as f32;
let spacing_y = config.height / (state.rows - 1) as f32;
let ground_height = config.ground_height.unwrap_or(DISABLED_GROUND_HEIGHT);
let sim_params = SimParams {
anchor_transform: (*anchor).into(),
wind: [
wind_direction.x,
wind_direction.y,
wind_direction.z,
wind.strength * config.wind_response,
],
gust: [
wind.gust_strength * config.wind_response,
wind.gust_frequency,
wind.turbulence * config.wind_response,
self.time,
],
gravity: [config.gravity.x, config.gravity.y, config.gravity.z, 0.0],
integration: [
delta_time / state.substeps as f32,
config.damping,
ground_height,
0.0,
],
constraint: [spacing_x, spacing_y, config.stiffness, 0.0],
counts: [state.columns, state.rows, state.columns * state.rows, 0],
};
queue.write_buffer(
&state.sim_params_buffer,
0,
bytemuck::cast_slice(&[sim_params]),
);
let inverse_anchor = nalgebra_glm::inverse(anchor);
for writeback in &state.writebacks {
let writeback_params = WritebackParams {
inverse_anchor: inverse_anchor.into(),
counts: [state.columns * state.rows, writeback.vertex_offset, 0, 0],
};
queue.write_buffer(
&writeback.params_buffer,
0,
bytemuck::cast_slice(&[writeback_params]),
);
}
}
}
fn execute<'r, 'e>(
&mut self,
context: PassExecutionContext<'r, 'e, World>,
) -> crate::render::wgpu::rendergraph::Result<Vec<SubGraphRunCommand<'r>>> {
if self.states.is_empty() {
return Ok(context.into_sub_graph_commands());
}
{
let mut compute_pass =
context
.encoder
.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Cloth Simulation Pass"),
timestamp_writes: None,
});
for state in self.states.values() {
for _ in 0..state.substeps {
compute_pass.set_pipeline(&self.integrate_pipeline);
compute_pass.set_bind_group(0, &state.compute_group_primary, &[]);
compute_pass.dispatch_workgroups(state.workgroups, 1, 1);
compute_pass.set_pipeline(&self.solve_pipeline);
for iteration in 0..state.solver_iterations {
if iteration % 2 == 0 {
compute_pass.set_bind_group(0, &state.compute_group_secondary, &[]);
} else {
compute_pass.set_bind_group(0, &state.compute_group_primary, &[]);
}
compute_pass.dispatch_workgroups(state.workgroups, 1, 1);
}
}
compute_pass.set_pipeline(&self.normals_pipeline);
compute_pass.set_bind_group(0, &state.compute_group_primary, &[]);
compute_pass.dispatch_workgroups(state.workgroups, 1, 1);
compute_pass.set_pipeline(&self.writeback_pipeline);
for writeback in &state.writebacks {
compute_pass.set_bind_group(0, &writeback.bind_group, &[]);
compute_pass.dispatch_workgroups(state.workgroups, 1, 1);
}
}
}
Ok(context.into_sub_graph_commands())
}
}